Loading [MathJax]/extensions/tex2jax.js
Plugins
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
solver_util.h
Go to the documentation of this file.
1//
2// Created by julian on 6/4/24.
3//
4
5#ifndef UG4_SOLVER_UTIL_H
6#define UG4_SOLVER_UTIL_H
7
8#ifdef UG_USE_JSON
9
10
11#include <nlohmann/json.hpp>
12
13#include "common/common.h"
14#include "common/log.h"
16#include <memory>
17#include <unordered_map>
18#include <variant>
19#include <stdexcept>
22// include solver components
31namespace ug{
32namespace Util {
33
34 void CondAbort(bool condition, std::string message){
35 UG_ASSERT(!condition, "ERROR in util.solver: " << message);
36 }
37 template<typename TAlgebra>
39
40 typedef typename TAlgebra::vector_type vector_type;
41 // set parameters from descriptor attributes
42 bool verbose = false;
43 if(descriptor.contains("verbose")){
44 verbose = descriptor["verbose"];
45 }
46 std::string name = "standard";
47 if(descriptor.contains("type")){
48 name = descriptor["type"];
49 }
50 number iterations = 100;
51 if(descriptor.contains("iterations")){
52 iterations = descriptor["iterations"];
53 }
54 number reduction = 1e-6;
55 if(descriptor.contains("reduction")){
56 reduction = descriptor["reduction"];
57 }
58 number absolute = 1e-12;
59 if(descriptor.contains("absolute")){
60 absolute = descriptor["absolute"];
61 }
62 bool suppress_unsuccessful = false;
63 if(descriptor.contains("suppress_unsuccessful")){
64 suppress_unsuccessful = descriptor["suppress_unsuccessful"];
65 }
66
67 // create convergence check
68 SmartPtr<StdConvCheck<vector_type>> convCheck = make_sp<StdConvCheck<vector_type>>(
69 new StdConvCheck<vector_type>(iterations,
70 absolute,
71 reduction,
72 verbose,
73 suppress_unsuccessful));
74 return convCheck;
75
76 }
77 // Variant type for storing different solver components
78 template<typename TDomain, typename TAlgebra>
79 using SolverComponent = std::variant<
92 // missing ElementGaussSeidel
93 >;
94
95 template<typename TDomain, typename TAlgebra>
96 class SolverUtil{
98 /*
99 * components are stored as std::variant and populate an unordered
100 * map for access. Calls to various subroutines are made
101 * */
102 public:
103
104 void setComponent(const std::string& key, SolverComponent<TDomain, TAlgebra> component){
105 components[key] = component;
106 }
107
108 SolverComponent<TDomain, TAlgebra> getComponent(const std::string& key) const {
109 auto it = components.find(key);
110 if (it != components.end()){
111 return it->second;
112 }
113 UG_THROW("SolverUtil.getComponent(" << key << "): no component named " << key << " found");
114 }
115
116 bool hasComponent(const std::string& key) const{
117
118 auto it = components.find(key);
119 if (it != components.end()){
120 return true;
121 }
122 return false;
123
124 }
125
126 template<typename T>
127 SmartPtr<T> getComponentAs(const std::string& key) const {
128 return std::get<SmartPtr<T>>(getComponent(key));
129 }
130
131 private:
132 std::unordered_map<std::string, SolverComponent<TDomain, TAlgebra>> components;
133 };
134
135
136
137
138 template<typename TDomain, typename TAlgebra>
139 SmartPtr<SolverUtil<TDomain, TAlgebra>> CreateSolver(nlohmann::json& solverDesc,
140 SolverUtil<TDomain, TAlgebra> solverutil){
141
142 //PrepareSolverUtil<TDomain, TAlgebra>();
143 std::string type = "linear";
144 if(solverDesc.contains("type")){
145 type = solverDesc["type"];
146 }
147 if(type == "newton"){
148 auto newton_solver= make_sp(new NewtonSolver<TAlgebra>());
149 // call createlinearsolver
150
151 // get descriptor for linear solver
152 if(solverDesc.contains("linSolver")){
153 //CreateLinearSolver<TDomain, TAlgebra>(solverDesc, solverutil);
154 }
155
156 // line search
157 if(solverDesc.contains("lineSearch")){
158 //CreateLineSearch(solverDesc["lineSearch"], solverutil);
159 }
160 newton_solver->set_convergence_check(CreateConvCheck<TAlgebra>(solverDesc["convCheck"]), solverutil);
161 }
162
163 }
164
165
166
167
168
169 template<typename TDomain, typename TAlgebra>
170 SmartPtr<LinearSolver<typename TAlgebra::vector_type>> CreateLinearSolver(nlohmann::json& desc,
171 SolverUtil<TDomain, TAlgebra> solverutil){
172 typedef typename TAlgebra::vector_type TVector;
173 typedef LinearSolver<TVector> TLinSolv;
174
175 // TODO: is preset
176
177 bool create_precond = false;
178 bool create_conv_check = false;
179
180 // if no descriptor given, create default linear solver
181 if(!desc.contains("LinearSolver")){
182
183 SmartPtr<ILU<TAlgebra>> ilu = make_sp<ILU<TAlgebra>>(new ILU<TAlgebra>());
184 SmartPtr<StdConvCheck<TVector>> convCheck = make_sp<StdConvCheck<TVector>>(
185 new StdConvCheck<TVector>(100, 1e-9, 1e-12));
186 SmartPtr<TLinSolv> default_linear_solver = make_sp(new TLinSolv());
187 default_linear_solver->set_convergence_check(convCheck);
188 default_linear_solver->set_preconditioner(ilu);
189 return default_linear_solver;
190
191 }
192
193
194
195 }
196
197 template<typename TDomain, typename TAlgebra>
198 SmartPtr <IPreconditioner<TAlgebra>>
199 CreatePreconditioner(nlohmann::json &desc, SolverUtil<TDomain, TAlgebra> &solverutil) {
200 typedef typename TAlgebra::vector_type TVector;
201 typedef IPreconditioner<TAlgebra> TPrecond;
202
203 // TODO: Implement preconditioner creation behavior based on 'desc' and 'solverutil'
204 // TODO: Check if preset
205
206 nlohmann::json json_default_preconds = json_predefined_defaults::solvers.at("preconditioner");
207
208 SmartPtr <TPrecond> preconditioner;
209
211
212 if(solverutil.hasComponent("approxSpace")){
213 //approxSpace = solverutil.getComponent("approxSpace");
214 }
215
216 std::string type = desc["type"];
217 if(type == "ilu"){
218 // create ilu
219 typedef ILU<TAlgebra> TILU;
220 SmartPtr<TILU> ILU = make_sp(new TILU());
221
222 // configure ilu
223 number beta = json_default_preconds["ilu"]["beta"];
224 if(desc.contains("beta")){
225 beta = desc["beta"];
226 }
227 ILU->set_beta(beta);
228
229 number damping = json_default_preconds["ilu"]["damping"];
230 if(desc.contains("damping")){
231 damping = desc["damping"];
232 }
233 ILU->set_damp(damping);
234
235 bool sort = json_default_preconds["ilu"]["sort"];
236 if(desc.contains("sort")){
237 sort = desc["sort"];
238 }
239 ILU->set_sort(sort);
240
241 number sortEps = json_default_preconds["ilu"]["sortEps"];
242 if(desc.contains("sortEps")){
243 sortEps = desc["sortEps"];
244 }
245 ILU->set_sort_eps(sortEps);
246
247 number inversionEps = json_default_preconds["ilu"]["inversionEps"];
248 if(desc.contains("inversionEps")){
249 inversionEps = desc["inversionEps"];
250 }
251 ILU->set_inversion_eps(inversionEps);
252
253 bool consistentInterfaces = json_default_preconds["ilu"]["consistentInterfaces"];
254 if(desc.contains("consistentInterfaces")){
255 consistentInterfaces = desc["consistentInterfaces"];
256 }
257 ILU->enable_consistent_interfaces(consistentInterfaces);
258
259 bool overlap = json_default_preconds["ilu"]["overlap"];
260 if(desc.contains("overlap")){
261 overlap = desc["overlap"];
262 }
263 ILU->enable_overlap(overlap);
264 // set ordering (no default value)
265 // TODO: ordering = CreateOrdering
266 // TODO: precond.set_ordering_algorithm(ordering)
267
268 // Cast ILU to IPreconditioner for return
269 preconditioner = ILU.template cast_static<TPrecond>();
270
271 }
272 else if(type == "ilut"){
273 // create ilut
274 typedef ILUTPreconditioner<TAlgebra> TILUT;
275
276 number threshold = json_default_preconds["ilut"]["threshold"];
277 if(desc["ilut"].contains("threshold")){
278 threshold = desc["ilut"]["threshold"];
279 }
280 SmartPtr<TILUT> ILUT = make_sp(new TILUT(threshold));
281
282 // TODO: ordering = CreateOrdering
283 // TODO: precond.set_ordering_algorithm(ordering)
284
285
286 }
287 else if(type == "jac"){
288 //TODO:Duy createjac
289 }
290 else if(type == "gs"){
291
292 UG_LOG("creating gauss seidel\n")
293 typedef GaussSeidel<TAlgebra> TGS;
294 SmartPtr<TGS> GS = make_sp(new TGS());
295 UG_LOG("consistentInterfaces default\n")
296 bool consistentInterfaces = json_default_preconds["gs"]["consistentInterfaces"];
297 UG_LOG("consistentInterfaces desc\n")
298 if(desc["gs"].contains("consistentInterfaces")){
299 consistentInterfaces = desc["gs"]["consistentInterfaces"];
300 }
301 UG_LOG("enable consistentInterfaces\n")
302 GS->enable_consistent_interfaces(consistentInterfaces);
303
304
305 bool overlap = json_default_preconds["gs"]["overlap"];
306 if(desc["gs"].contains("overlap")){
307 overlap = desc["gs"]["overlap"];
308 }
309 UG_LOG("enable overlap\n")
310 GS->enable_overlap(overlap);
311 preconditioner = GS.template cast_static<TPrecond>();
312 }
313 else if(type == "sgs"){
314 //TODO:Tim create sgs
315 }
316 else if(type == "egs"){
317 //TODO:Lukas
318 }
319 else if(type == "cgs"){
320 //TODO:Myrto
321 }
322 else if(type == "ssc"){
323 //TODO:Duy create ssc
324 }
325 else if(type == "gmg"){
326 //TODO: Julian
327 }
328 else if(type == "schur"){
329 //TODO:Duy create schur
330 }
331
332// return Preconditioner
333 return preconditioner;
334 }
335
336 template<typename TAlgebra>
338 // typedef for convenience
339 typedef StandardLineSearch<typename TAlgebra::vector_type> line_search_type;
341
342 // load defaults
343 nlohmann::json json_default_lineSearch = json_predefined_defaults::solvers["lineSearch"];
344
345 // handle type of line search
346 // default
347 std::string type = "standard";
348 // input type
349 if(desc.contains("type")){
350 type = desc["type"];
351 }
352
353 if(type == "standard"){
354 // handle parameters of standard line search
355 int maxSteps = json_default_lineSearch[type]["maxSteps"];
356 if(desc.contains("maxSteps")){
357 maxSteps = desc["maxSteps"];
358 }
359 number lambdaStart = json_default_lineSearch[type]["lambdaStart"];
360 if(desc.contains("lambdaStart")){
361 lambdaStart = desc["lambdaStart"];
362 }
363 number lambdaReduce = json_default_lineSearch[type]["lambdaReduce"];
364 if(desc.contains("lambdaReduce")){
365 lambdaReduce = desc["lambdaReduce"];
366 }
367 bool acceptBest = json_default_lineSearch[type]["acceptBest"];
368 if(desc.contains("acceptBest")){
369 acceptBest = desc["acceptBest"];
370 }
371 bool checkAll = json_default_lineSearch[type]["checkAll"];
372 if(desc.contains("checkAll")){
373 checkAll = desc["checkAll"];
374 }
375 // create line search with chosen parameters
376 ls = make_sp(new line_search_type(maxSteps,
377 lambdaStart,
378 lambdaReduce,
379 acceptBest,
380 checkAll));
381 // set conditional parameters
382 if(desc.contains("verbose")){
383 bool verbose = desc["verbose"];
384 ls->set_verbose(verbose);
385 }
386 if(desc.contains("suffDesc")){
387 number suffDesc = desc["suffDesc"];
388 ls->set_suff_descent_factor(suffDesc);
389 }
390 if(desc.contains("maxDefect")){
391 number maxDefect = desc["maxDefect"];
392 ls->set_maximum_defect(maxDefect);
393 }
394
395 }
396 // force exit if line search is invalid
397 CondAbort(ls.invalid(), "Invalid line-search specified: " + type);
398
399 return ls;
400 }
401
402 template<typename TDomain, typename TAlgebra>
403 void PrepateSolverUtil(nlohmann::json& desc, nlohmann::json& solverutil){
404
405 typedef SolverUtil<TDomain, TAlgebra> TSolverUtil;
406 // Create SolverUtil container class
407 SmartPtr<TSolverUtil> solv_util = make_sp(new TSolverUtil());
408
409 if(solverutil.contains("ApproxSpace")){
410 solv_util->setComponent("ApproxSpace",solverutil["ApproxSpace"]);
411 }
412 }
413
414/*
415* Helper class to provide c++ util functions in lua.
416* We can instaciate this object in lua via
417* local functionProvider = SolverUtilFunctionProvider()
418* and automatically get the correct templated class,
419* e.g SolverUtilFunctionProvider2dCPU1
420* this means functionprovider automatically chooses
421* the correct templated util functions!
422*/
423template<typename TDomain, typename TAlgebra>
424class SolverUtilFunctionProvider{
425public:
426 typedef typename TAlgebra::vector_type vector_type;
427 typedef typename TAlgebra::matrix_type matrix_type;
428 const static int dim = TDomain::dim;
429
430 SolverUtilFunctionProvider(){};
431
432 SmartPtr <IPreconditioner<TAlgebra>> GetCreatePreconditioner(nlohmann::json &desc, SolverUtil<TDomain, TAlgebra> &solverutil){
433 return CreatePreconditioner<TDomain, TAlgebra>(desc, solverutil);
434 }
435
436 SmartPtr<StandardLineSearch<vector_type>> GetCreateLineSearch(nlohmann::json &desc){
437 return CreateLineSearch<TAlgebra>(desc);
438 }
439
440
441};
442} //namespace util
443} //namespace ug
444
445#endif // UG_USE_JSON
446#endif //UG4_SOLVER_UTIL_H
function util test CreateConvCheck(convCheckDesc, solverutil)
function util test CreatePreconditioner(desc, solverutil)
function util test CreateLineSearch(desc)
function util test CreateSolver(descriptor, solverutil)
bool invalid() const
Variant::Type type()
#define UG_ASSERT(expr, msg)
#define UG_THROW(msg)
#define UG_LOG(msg)
double number
CPUAlgebra::matrix_type matrix_type
Sparse matrix type as defined in lib_algebra/cpu_algebra/sparsematrix.h.
Definition demo_plugin.cpp:43
CPUAlgebra::vector_type vector_type
Vector type as defined in lib_algebra/cpu_algebra/vector.h.
Definition demo_plugin.cpp:45
bool contains(std::string str, std::string search)
SmartPtr< T, FreePolicy > make_sp(T *inst)
function ProblemDisc new(problemDesc, dom)