Plugins
solver_util.h
Go to the documentation of this file.
1 //
2 // Created by julian on 6/4/24.
3 //
4 
5 #ifndef UG4_SOLVER_UTIL_H
6 #define UG4_SOLVER_UTIL_H
7 
8 #ifdef UG_USE_JSON
9 
10 
11 #include <nlohmann/json.hpp>
12 
13 #include "common/common.h"
14 #include "common/log.h"
16 #include <memory>
17 #include <unordered_map>
18 #include <variant>
19 #include <stdexcept>
22 // include solver components
31 namespace ug{
32 namespace Util {
33 
34  void CondAbort(bool condition, std::string message){
35  UG_ASSERT(!condition, "ERROR in util.solver: " << message);
36  }
37  template<typename TAlgebra>
39 
40  typedef typename TAlgebra::vector_type vector_type;
41  // set parameters from descriptor attributes
42  bool verbose = false;
43  if(descriptor.contains("verbose")){
44  verbose = descriptor["verbose"];
45  }
46  std::string name = "standard";
47  if(descriptor.contains("type")){
48  name = descriptor["type"];
49  }
50  number iterations = 100;
51  if(descriptor.contains("iterations")){
52  iterations = descriptor["iterations"];
53  }
54  number reduction = 1e-6;
55  if(descriptor.contains("reduction")){
56  reduction = descriptor["reduction"];
57  }
58  number absolute = 1e-12;
59  if(descriptor.contains("absolute")){
60  absolute = descriptor["absolute"];
61  }
62  bool suppress_unsuccessful = false;
63  if(descriptor.contains("suppress_unsuccessful")){
64  suppress_unsuccessful = descriptor["suppress_unsuccessful"];
65  }
66 
67  // create convergence check
68  SmartPtr<StdConvCheck<vector_type>> convCheck = make_sp<StdConvCheck<vector_type>>(
69  new StdConvCheck<vector_type>(iterations,
70  absolute,
71  reduction,
72  verbose,
73  suppress_unsuccessful));
74  return convCheck;
75 
76  }
77  // Variant type for storing different solver components
78  template<typename TDomain, typename TAlgebra>
79  using SolverComponent = std::variant<
92  // missing ElementGaussSeidel
93  >;
94 
95  template<typename TDomain, typename TAlgebra>
96  class SolverUtil{
98  /*
99  * components are stored as std::variant and populate an unordered
100  * map for access. Calls to various subroutines are made
101  * */
102  public:
103 
104  void setComponent(const std::string& key, SolverComponent<TDomain, TAlgebra> component){
105  components[key] = component;
106  }
107 
108  SolverComponent<TDomain, TAlgebra> getComponent(const std::string& key) const {
109  auto it = components.find(key);
110  if (it != components.end()){
111  return it->second;
112  }
113  UG_THROW("SolverUtil.getComponent(" << key << "): no component named " << key << " found");
114  }
115 
116  bool hasComponent(const std::string& key) const{
117 
118  auto it = components.find(key);
119  if (it != components.end()){
120  return true;
121  }
122  return false;
123 
124  }
125 
126  template<typename T>
127  SmartPtr<T> getComponentAs(const std::string& key) const {
128  return std::get<SmartPtr<T>>(getComponent(key));
129  }
130 
131  private:
132  std::unordered_map<std::string, SolverComponent<TDomain, TAlgebra>> components;
133  };
134 
135 
136 
137 
138  template<typename TDomain, typename TAlgebra>
139  SmartPtr<SolverUtil<TDomain, TAlgebra>> CreateSolver(nlohmann::json& solverDesc,
140  SolverUtil<TDomain, TAlgebra> solverutil){
141 
142  //PrepareSolverUtil<TDomain, TAlgebra>();
143  std::string type = "linear";
144  if(solverDesc.contains("type")){
145  type = solverDesc["type"];
146  }
147  if(type == "newton"){
148  auto newton_solver= make_sp(new NewtonSolver<TAlgebra>());
149  // call createlinearsolver
150 
151  // get descriptor for linear solver
152  if(solverDesc.contains("linSolver")){
153  //CreateLinearSolver<TDomain, TAlgebra>(solverDesc, solverutil);
154  }
155 
156  // line search
157  if(solverDesc.contains("lineSearch")){
158  //CreateLineSearch(solverDesc["lineSearch"], solverutil);
159  }
160  newton_solver->set_convergence_check(CreateConvCheck<TAlgebra>(solverDesc["convCheck"]), solverutil);
161  }
162 
163  }
164 
165 
166 
167 
168 
169  template<typename TDomain, typename TAlgebra>
170  SmartPtr<LinearSolver<typename TAlgebra::vector_type>> CreateLinearSolver(nlohmann::json& desc,
171  SolverUtil<TDomain, TAlgebra> solverutil){
172  typedef typename TAlgebra::vector_type TVector;
173  typedef LinearSolver<TVector> TLinSolv;
174 
175  // TODO: is preset
176 
177  bool create_precond = false;
178  bool create_conv_check = false;
179 
180  // if no descriptor given, create default linear solver
181  if(!desc.contains("LinearSolver")){
182 
183  SmartPtr<ILU<TAlgebra>> ilu = make_sp<ILU<TAlgebra>>(new ILU<TAlgebra>());
184  SmartPtr<StdConvCheck<TVector>> convCheck = make_sp<StdConvCheck<TVector>>(
185  new StdConvCheck<TVector>(100, 1e-9, 1e-12));
186  SmartPtr<TLinSolv> default_linear_solver = make_sp(new TLinSolv());
187  default_linear_solver->set_convergence_check(convCheck);
188  default_linear_solver->set_preconditioner(ilu);
189  return default_linear_solver;
190 
191  }
192 
193 
194 
195  }
196 
197  template<typename TDomain, typename TAlgebra>
199  CreatePreconditioner(nlohmann::json &desc, SolverUtil<TDomain, TAlgebra> &solverutil) {
200  typedef typename TAlgebra::vector_type TVector;
201  typedef IPreconditioner<TAlgebra> TPrecond;
202 
203  // TODO: Implement preconditioner creation behavior based on 'desc' and 'solverutil'
204  // TODO: Check if preset
205 
206  nlohmann::json json_default_preconds = json_predefined_defaults::solvers.at("preconditioner");
207 
208  SmartPtr <TPrecond> preconditioner;
209 
211 
212  if(solverutil.hasComponent("approxSpace")){
213  //approxSpace = solverutil.getComponent("approxSpace");
214  }
215 
216  std::string type = desc["type"];
217  if(type == "ilu"){
218  // create ilu
219  typedef ILU<TAlgebra> TILU;
220  SmartPtr<TILU> ILU = make_sp(new TILU());
221 
222  // configure ilu
223  number beta = json_default_preconds["ilu"]["beta"];
224  if(desc.contains("beta")){
225  beta = desc["beta"];
226  }
227  ILU->set_beta(beta);
228 
229  number damping = json_default_preconds["ilu"]["damping"];
230  if(desc.contains("damping")){
231  damping = desc["damping"];
232  }
233  ILU->set_damp(damping);
234 
235  bool sort = json_default_preconds["ilu"]["sort"];
236  if(desc.contains("sort")){
237  sort = desc["sort"];
238  }
239  ILU->set_sort(sort);
240 
241  number sortEps = json_default_preconds["ilu"]["sortEps"];
242  if(desc.contains("sortEps")){
243  sortEps = desc["sortEps"];
244  }
245  ILU->set_sort_eps(sortEps);
246 
247  number inversionEps = json_default_preconds["ilu"]["inversionEps"];
248  if(desc.contains("inversionEps")){
249  inversionEps = desc["inversionEps"];
250  }
251  ILU->set_inversion_eps(inversionEps);
252 
253  bool consistentInterfaces = json_default_preconds["ilu"]["consistentInterfaces"];
254  if(desc.contains("consistentInterfaces")){
255  consistentInterfaces = desc["consistentInterfaces"];
256  }
257  ILU->enable_consistent_interfaces(consistentInterfaces);
258 
259  bool overlap = json_default_preconds["ilu"]["overlap"];
260  if(desc.contains("overlap")){
261  overlap = desc["overlap"];
262  }
263  ILU->enable_overlap(overlap);
264  // set ordering (no default value)
265  // TODO: ordering = CreateOrdering
266  // TODO: precond.set_ordering_algorithm(ordering)
267 
268  // Cast ILU to IPreconditioner for return
269  preconditioner = ILU.template cast_static<TPrecond>();
270 
271  }
272  else if(type == "ilut"){
273  // create ilut
274  typedef ILUTPreconditioner<TAlgebra> TILUT;
275 
276  number threshold = json_default_preconds["ilut"]["threshold"];
277  if(desc["ilut"].contains("threshold")){
278  threshold = desc["ilut"]["threshold"];
279  }
280  SmartPtr<TILUT> ILUT = make_sp(new TILUT(threshold));
281 
282  // TODO: ordering = CreateOrdering
283  // TODO: precond.set_ordering_algorithm(ordering)
284 
285 
286  }
287  else if(type == "jac"){
288  //TODO:Duy createjac
289  }
290  else if(type == "gs"){
291 
292  UG_LOG("creating gauss seidel\n")
293  typedef GaussSeidel<TAlgebra> TGS;
294  SmartPtr<TGS> GS = make_sp(new TGS());
295  UG_LOG("consistentInterfaces default\n")
296  bool consistentInterfaces = json_default_preconds["gs"]["consistentInterfaces"];
297  UG_LOG("consistentInterfaces desc\n")
298  if(desc["gs"].contains("consistentInterfaces")){
299  consistentInterfaces = desc["gs"]["consistentInterfaces"];
300  }
301  UG_LOG("enable consistentInterfaces\n")
302  GS->enable_consistent_interfaces(consistentInterfaces);
303 
304 
305  bool overlap = json_default_preconds["gs"]["overlap"];
306  if(desc["gs"].contains("overlap")){
307  overlap = desc["gs"]["overlap"];
308  }
309  UG_LOG("enable overlap\n")
310  GS->enable_overlap(overlap);
311  preconditioner = GS.template cast_static<TPrecond>();
312  }
313  else if(type == "sgs"){
314  //TODO:Tim create sgs
315  }
316  else if(type == "egs"){
317  //TODO:Lukas
318  }
319  else if(type == "cgs"){
320  //TODO:Myrto
321  }
322  else if(type == "ssc"){
323  //TODO:Duy create ssc
324  }
325  else if(type == "gmg"){
326  //TODO: Julian
327  }
328  else if(type == "schur"){
329  //TODO:Duy create schur
330  }
331 
332 // return Preconditioner
333  return preconditioner;
334  }
335 
336  template<typename TAlgebra>
338  // typedef for convenience
339  typedef StandardLineSearch<typename TAlgebra::vector_type> line_search_type;
341 
342  // load defaults
343  nlohmann::json json_default_lineSearch = json_predefined_defaults::solvers["lineSearch"];
344 
345  // handle type of line search
346  // default
347  std::string type = "standard";
348  // input type
349  if(desc.contains("type")){
350  type = desc["type"];
351  }
352 
353  if(type == "standard"){
354  // handle parameters of standard line search
355  int maxSteps = json_default_lineSearch[type]["maxSteps"];
356  if(desc.contains("maxSteps")){
357  maxSteps = desc["maxSteps"];
358  }
359  number lambdaStart = json_default_lineSearch[type]["lambdaStart"];
360  if(desc.contains("lambdaStart")){
361  lambdaStart = desc["lambdaStart"];
362  }
363  number lambdaReduce = json_default_lineSearch[type]["lambdaReduce"];
364  if(desc.contains("lambdaReduce")){
365  lambdaReduce = desc["lambdaReduce"];
366  }
367  bool acceptBest = json_default_lineSearch[type]["acceptBest"];
368  if(desc.contains("acceptBest")){
369  acceptBest = desc["acceptBest"];
370  }
371  bool checkAll = json_default_lineSearch[type]["checkAll"];
372  if(desc.contains("checkAll")){
373  checkAll = desc["checkAll"];
374  }
375  // create line search with chosen parameters
376  ls = make_sp(new line_search_type(maxSteps,
377  lambdaStart,
378  lambdaReduce,
379  acceptBest,
380  checkAll));
381  // set conditional parameters
382  if(desc.contains("verbose")){
383  bool verbose = desc["verbose"];
384  ls->set_verbose(verbose);
385  }
386  if(desc.contains("suffDesc")){
387  number suffDesc = desc["suffDesc"];
388  ls->set_suff_descent_factor(suffDesc);
389  }
390  if(desc.contains("maxDefect")){
391  number maxDefect = desc["maxDefect"];
392  ls->set_maximum_defect(maxDefect);
393  }
394 
395  }
396  // force exit if line search is invalid
397  CondAbort(ls.invalid(), "Invalid line-search specified: " + type);
398 
399  return ls;
400  }
401 
402  template<typename TDomain, typename TAlgebra>
403  void PrepateSolverUtil(nlohmann::json& desc, nlohmann::json& solverutil){
404 
405  typedef SolverUtil<TDomain, TAlgebra> TSolverUtil;
406  // Create SolverUtil container class
407  SmartPtr<TSolverUtil> solv_util = make_sp(new TSolverUtil());
408 
409  if(solverutil.contains("ApproxSpace")){
410  solv_util->setComponent("ApproxSpace",solverutil["ApproxSpace"]);
411  }
412  }
413 
414 /*
415 * Helper class to provide c++ util functions in lua.
416 * We can instaciate this object in lua via
417 * local functionProvider = SolverUtilFunctionProvider()
418 * and automatically get the correct templated class,
419 * e.g SolverUtilFunctionProvider2dCPU1
420 * this means functionprovider automatically chooses
421 * the correct templated util functions!
422 */
423 template<typename TDomain, typename TAlgebra>
424 class SolverUtilFunctionProvider{
425 public:
426  typedef typename TAlgebra::vector_type vector_type;
427  typedef typename TAlgebra::matrix_type matrix_type;
428  const static int dim = TDomain::dim;
429 
430  SolverUtilFunctionProvider(){};
431 
432  SmartPtr <IPreconditioner<TAlgebra>> GetCreatePreconditioner(nlohmann::json &desc, SolverUtil<TDomain, TAlgebra> &solverutil){
433  return CreatePreconditioner<TDomain, TAlgebra>(desc, solverutil);
434  }
435 
436  SmartPtr<StandardLineSearch<vector_type>> GetCreateLineSearch(nlohmann::json &desc){
437  return CreateLineSearch<TAlgebra>(desc);
438  }
439 
440 
441 };
442 } //namespace util
443 } //namespace ug
444 
445 #endif // UG_USE_JSON
446 #endif //UG4_SOLVER_UTIL_H
function util test CreateConvCheck(convCheckDesc, solverutil)
function util test CreatePreconditioner(desc, solverutil)
function util test CreateLineSearch(desc)
function util test CreateSolver(descriptor, solverutil)
bool invalid() const
ParallelMatrix< SparseMatrix< double > > matrix_type
ParallelVector< Vector< double > > vector_type
static const int dim
Variant::Type type()
#define UG_ASSERT(expr, msg)
#define UG_THROW(msg)
#define UG_LOG(msg)
double number
bool contains(std::string str, std::string search)
SmartPtr< T, FreePolicy > make_sp(T *inst)
function ProblemDisc new(problemDesc, dom)