ug4
Loading...
Searching...
No Matches
lua_user_data_impl.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2012-2015: G-CSC, Goethe University Frankfurt
3 * Author: Andreas Vogel
4 *
5 * This file is part of UG4.
6 *
7 * UG4 is free software: you can redistribute it and/or modify it under the
8 * terms of the GNU Lesser General Public License version 3 (as published by the
9 * Free Software Foundation) with the following additional attribution
10 * requirements (according to LGPL/GPL v3 §7):
11 *
12 * (1) The following notice must be displayed in the Appropriate Legal Notices
13 * of covered and combined works: "Based on UG4 (www.ug4.org/license)".
14 *
15 * (2) The following notice must be displayed at a prominent place in the
16 * terminal output of covered works: "Based on UG4 (www.ug4.org/license)".
17 *
18 * (3) The following bibliography is recommended for citation and must be
19 * preserved in all covered files:
20 * "Reiter, S., Vogel, A., Heppner, I., Rupp, M., and Wittum, G. A massively
21 * parallel geometric multigrid solver on hierarchically distributed grids.
22 * Computing and visualization in science 16, 4 (2013), 151-164"
23 * "Vogel, A., Reiter, S., Rupp, M., Nägel, A., and Wittum, G. UG4 -- a novel
24 * flexible software system for simulating pde based models on high performance
25 * computers. Computing and visualization in science 16, 4 (2013), 165-179"
26 *
27 * This program is distributed in the hope that it will be useful,
28 * but WITHOUT ANY WARRANTY; without even the implied warranty of
29 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
30 * GNU Lesser General Public License for more details.
31 */
32
33#ifndef __H__UG_BRIDGE__BRIDGES__USER_DATA__USER_DATA_IMPL_
34#define __H__UG_BRIDGE__BRIDGES__USER_DATA__USER_DATA_IMPL_
35
36#ifdef UG_FOR_LUA
37#include "lua_user_data.h"
38#endif
41
42#include "info_commands.h"
44
45#if 0
46#define PROFILE_CALLBACK() PROFILE_FUNC_GROUP("luacallback")
47#define PROFILE_CALLBACK_BEGIN(name) PROFILE_BEGIN_GROUP(name, "luacallback")
48#define PROFILE_CALLBACK_END() PROFILE_END()
49#else
50#define PROFILE_CALLBACK()
51#define PROFILE_CALLBACK_BEGIN(name)
52#define PROFILE_CALLBACK_END()
53#endif
54namespace ug{
55
56#ifdef USE_LUA2C
57 extern bool useLuaCompiler;
58#endif
59
60
61
63// LuaUserData
65
66template <typename TData, int dim, typename TRet>
68{
69 std::stringstream ss;
70 ss << "function name(";
71 if(dim >= 1) ss << "x";
72 if(dim >= 2) ss << ", y";
73 if(dim >= 3) ss << ", z";
74 ss << ", t, si)\n ... \n return ";
76 ss << lua_traits<TRet>::signature() << ", ";
77 ss << lua_traits<TData>::signature();
78 ss << "\nend";
79 return ss.str();
80}
81
82
83template <typename TData, int dim, typename TRet>
85{
86 std::stringstream ss;
87 ss << "Lua";
88 if(lua_traits<TRet>::size > 0) ss << "Cond";
89 ss << "User" << lua_traits<TData>::name() << dim << "d";
90 return ss.str();
91}
92
93template <typename TData, int dim, typename TRet>
95 : m_callbackName(luaCallback), m_bFromFactory(false)
96{
97// get lua state
99
100// obtain a reference
101 lua_getglobal(m_L, m_callbackName.c_str());
102
103// make sure that the reference is valid
104 if(lua_isnil(m_L, -1)){
105 UG_THROW(name() << ": Specified lua callback "
106 "does not exist: " << m_callbackName);
107 }
108
109// store reference to lua function
110 m_callbackRef = luaL_ref(m_L, LUA_REGISTRYINDEX);
111
112// make a test run
114
115 #ifdef USE_LUA2C
116 if(useLuaCompiler) m_luaComp.create(luaCallback);
117 #endif
118}
119
120template <typename TData, int dim, typename TRet>
122 : m_callbackName("__anonymous__lua__function__"), m_bFromFactory(false)
123{
124// get lua state
126
127// store reference to lua function
128 m_callbackRef = handle.ref;
129
130// make a test run
132
133 #ifdef USE_LUA2C
134// UG_THROW("LuaFunctionHandle usage currently not supported with LUA2C.");
135 if(useLuaCompiler) m_luaComp.create(m_callbackName.c_str(), &handle);
136 #endif
137}
138
139
140template <typename TData, int dim, typename TRet>
142check_callback_returns(lua_State* L, int callbackRef, const char* callName, const bool bThrow)
143{
145// get current stack level
146 const int level = lua_gettop(L);
147
148// dummy values to invoke the callback once
149 MathVector<dim> x; x = 0.0;
150 number time = 0.0;
151 int si = 0;
152
153// push the callback function on the stack
154 lua_rawgeti(L, LUA_REGISTRYINDEX, callbackRef);
155
156// push space coordinates on stack
157 lua_traits<MathVector<dim> >::push(L, x);
158
159// push time on stack
161
162// push subset on stack
164
165// compute total args size
166 const int argSize = lua_traits<MathVector<dim> >::size
169
170// compute total return size
172
173// call lua function
174 if(lua_pcall(L, argSize, LUA_MULTRET, 0) != 0)
175 UG_THROW(name() << ": Error while "
176 "testing callback '" << callName << "',"
177 " lua message: "<< lua_tostring(L, -1));
178
179 // get number of results
180 const int numResults = lua_gettop(L) - level;
181
182// success flag
183 bool bRet = true;
184
185// if number of results is wrong return error
186 if(numResults != retSize){
187 if(bThrow){
188 UG_THROW(name() << ": Number of return values incorrect "
189 "for callback\n"<<callName<< " (" << bridge::GetLUAScriptFunctionDefined(callName) << ")"
190 "\nRequired: "<<retSize<<", passed: "<<numResults
191 <<". Use signature as follows:\n"
192 << signature());
193 }
194 else{
195 bRet = false;
196 }
197 }
198
199// check return value
201 if(bThrow){
202 UG_THROW(name() << ": Data values type incorrect "
203 "for callback\n"<<callName<< " (" << bridge::GetLUAScriptFunctionDefined(callName) << ")"
204 "\nUse signature as follows:\n"
205 << signature());
206 }
207 else{
208 bRet = false;
209 }
210 }
211
212// read return flag (may be void)
213 if(!lua_traits<TRet>::check(L, -retSize)){
214 if(bThrow){
215 UG_THROW("LuaUserData: Return values type incorrect "
216 "for callback\n"<<callName<< " (" << bridge::GetLUAScriptFunctionDefined(callName) << ")"
217 "\nUse signature as follows:\n"
218 << signature());
219 }
220 else{
221 bRet = false;
222 }
223 }
224
225// pop values
226 lua_pop(L, numResults);
227
228// return match
229 return bRet;
230}
231
232template <typename TData, int dim, typename TRet>
234check_callback_returns(LuaFunctionHandle handle, const bool bThrow)
235{
237// get lua state
239
240// forward call
241 bool bRet = check_callback_returns(L, handle.ref, "__lua_function_handle__", bThrow);
242
243// return match
244 return bRet;
245}
246
247template <typename TData, int dim, typename TRet>
249check_callback_returns(const char* callName, const bool bThrow)
250{
252// get lua state
254
255// obtain a reference
256 lua_getglobal(L, callName);
257
258// check if reference is valid
259 if(lua_isnil(L, -1)) {
260 if(bThrow) {
261 UG_THROW(name() << ": Cannot find specified lua callback "
262 " with name: "<<callName);
263 }
264 else {
265 return false;
266 }
267 }
268
269// get reference
270 int callbackRef = luaL_ref(L, LUA_REGISTRYINDEX);
271
272// forward call
273 bool bRet = check_callback_returns(L, callbackRef, callName, bThrow);
274
275// free reference to callback
276 luaL_unref(L, LUA_REGISTRYINDEX, callbackRef);
277
278// return match
279 return bRet;
280}
281
282template <typename TData, int dim, typename TRet>
284evaluate(TData& D, const MathVector<dim>& x, number time, int si) const
285{
287 #ifdef USE_LUA2C
288 if(useLuaCompiler && m_luaComp.is_valid())
289 {
290 double d[dim+2];
291 for(int i=0; i<dim; i++)
292 d[i] = x[i];
293 d[dim] = time;
294 d[dim+1] = si;
295 double ret[lua_traits<TData>::size+1];
296 m_luaComp.call(ret, d);
297 //TData D2;
298 TRet *t=NULL;
299 lua_traits<TData>::read(D, ret, t);
300 return lua_traits<TRet>::do_return(ret[0]);
301 }
302 else
303 #endif
304 {
305 // push the callback function on the stack
306 lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_callbackRef);
307
308 // push space coordinates on stack
309 lua_traits<MathVector<dim> >::push(m_L, x);
310
311 // push time on stack
312 lua_traits<number>::push(m_L, time);
313
314 // push subset index on stack
315 lua_traits<int>::push(m_L, si);
316
317 // compute total args size
318 const int argSize = lua_traits<MathVector<dim> >::size
321
322 // compute total return size
324
325 // call lua function
326 if(lua_pcall(m_L, argSize, retSize, 0) != 0)
327 UG_THROW(name() << "::operator(...): Error while "
328 "running callback '" << m_callbackName << "',"
329 " lua message: "<< lua_tostring(m_L, -1)<<".\n"
330 "Use signature as follows:\n"
331 << signature());
332
333 bool res = false;
334 try{
335 // read return value
337
338 // read return flag (may be void)
339 lua_traits<TRet>::read(m_L, res, -retSize);
340 }
341 UG_CATCH_THROW(name() << "::operator(...): Error while running "
342 "callback '" << m_callbackName << "'.\n"
343 "Use signature as follows:\n"
344 << signature());
345
346 // pop values
347 lua_pop(m_L, retSize);
348
349 // forward flag
350 return lua_traits<TRet>::do_return(res);
351 }
352}
353
354template <typename TData, int dim, typename TRet>
356{
357// free reference to callback
358 luaL_unref(m_L, LUA_REGISTRYINDEX, m_callbackRef);
359
360 if(m_bFromFactory)
362}
363
365// LuaUserDataFactory
367
368template <typename TData, int dim, typename TRet>
371{
373 typedef std::map<std::string, std::pair<LuaUserData<TData,dim,TRet>*, int*> > Map;
374 typedef typename Map::iterator iterator;
375
376// check for element
377 iterator iter = m_mData.find(name);
378
379// if name does not exist, create new one
380 if(iter == m_mData.end())
381 {
384
385 // the LuaUserData must remember to unregister itself at destruction
386 sp->set_created_from_factory(true);
387
388 // NOTE AND WARNING: This is very hacky and dangerous. We only do this
389 // since we exactly know what we are doing and everything is save and
390 // only in protected or private area. However, if you once want to change
391 // this code, please be aware, that we store here plain pointers and
392 // associated reference counters of a SmartPtr. This should not be done
393 // in general and this kind of coding is not recommended at all. Please
394 // use different approaches whenever possible.
395 std::pair<LuaUserData<TData,dim,TRet>*, int*>& data = m_mData[name];
396 data.first = sp.get();
397 data.second = sp.refcount_ptr();
398
399 return sp;
400 }
401// else return present data
402 {
403 // NOTE AND WARNING: This is very hacky and dangerous. We only do this
404 // since we exactly know what we are doing and everything is save and
405 // only in protected or private area. However, if you once want to change
406 // this code, please be aware, that we store here plain pointers and
407 // associated reference counters of a SmartPtr. This should not be done
408 // in general and this kind of coding is not recommended at all. Please
409 // use different approaches whenever possible.
410 std::pair<LuaUserData<TData,dim,TRet>*, int*>& data = iter->second;
411 return SmartPtr<LuaUserData<TData,dim,TRet> >(data.first, data.second);
413}
414
415template <typename TData, int dim, typename TRet>
416void
418{
419 typedef std::map<std::string, std::pair<LuaUserData<TData,dim,TRet>*, int*> > Map;
420 typedef typename Map::iterator iterator;
421
422// check for element
423 iterator iter = m_mData.find(name);
424
425// if name does not exist, create new one
426 if(iter == m_mData.end())
427 UG_THROW("LuaUserDataFactory: trying to remove non-registered"
428 " data with name: "<<name);
429
430 m_mData.erase(iter);
431}
432
433
434// instantiation of static member
435template <typename TData, int dim, typename TRet>
436std::map<std::string, std::pair<LuaUserData<TData,dim,TRet>*, int*> >
437LuaUserDataFactory<TData,dim,TRet>::m_mData = std::map<std::string, std::pair<LuaUserData<TData,dim,TRet>*, int*> >();
438
440// LuaUserFunction
442
443template <typename TData, int dim, typename TDataIn>
445LuaUserFunction(const char* luaCallback, size_t numArgs)
446 : m_numArgs(numArgs), m_bPosTimeNeed(false)
447{
449 m_cbValueRef = LUA_NOREF;
450 m_cbDerivRef.clear();
451 m_cbDerivName.clear();
452 set_lua_value_callback(luaCallback, numArgs);
453 #ifdef USE_LUA2C
454 if(useLuaCompiler) m_luaComp.create(luaCallback);
455 #endif
456}
457
458template <typename TData, int dim, typename TDataIn>
460LuaUserFunction(const char* luaCallback, size_t numArgs, bool bPosTimeNeed)
461 : m_numArgs(numArgs), m_bPosTimeNeed(bPosTimeNeed)
462{
464 m_cbValueRef = LUA_NOREF;
465 m_cbDerivRef.clear();
466 m_cbDerivName.clear();
467 set_lua_value_callback(luaCallback, numArgs);
468 #ifdef USE_LUA2C
469 m_luaComp_Deriv.clear();
470 #endif
471}
472
473
474template <typename TData, int dim, typename TDataIn>
476LuaUserFunction(LuaFunctionHandle handle, size_t numArgs)
477 : m_numArgs(numArgs), m_bPosTimeNeed(false)
478{
480 m_cbValueRef = LUA_NOREF;
481 m_cbDerivRef.clear();
482 m_cbDerivName.clear();
483 set_lua_value_callback(handle, numArgs);
484 #ifdef USE_LUA2C
485 if(useLuaCompiler){
486 UG_LOG("WARNING (in LuaUserFunction): LUA2C compiler "
487 "can't be executed for FunctionHandle.\n");
488 }
489 #endif
490}
491
492template <typename TData, int dim, typename TDataIn>
494LuaUserFunction(LuaFunctionHandle handle, size_t numArgs, bool bPosTimeNeed)
495 : m_numArgs(numArgs), m_bPosTimeNeed(bPosTimeNeed)
496{
498 m_cbValueRef = LUA_NOREF;
499 m_cbDerivRef.clear();
500 m_cbDerivName.clear();
501 set_lua_value_callback(handle, numArgs);
502 #ifdef USE_LUA2C
503 m_luaComp_Deriv.clear();
504 #endif
505}
506
507
508
509template <typename TData, int dim, typename TDataIn>
511{
512// free reference to callback
513 free_callback_ref();
514
515// free references to derivate callbacks
516 for(size_t i = 0; i < m_numArgs; ++i){
517 free_deriv_callback_ref(i);
518 }
519}
520
521template <typename TData, int dim, typename TDataIn>
523{
524 if(m_cbValueRef != LUA_NOREF){
525 luaL_unref(m_L, LUA_REGISTRYINDEX, m_cbValueRef);
526 m_cbValueRef = LUA_NOREF;
527 }
528}
529
530template <typename TData, int dim, typename TDataIn>
532{
533 if(m_cbDerivRef[arg] != LUA_NOREF){
534 luaL_unref(m_L, LUA_REGISTRYINDEX, m_cbDerivRef[arg]);
535 m_cbDerivRef[arg] = LUA_NOREF;
536 }
537}
538
539
540template <typename TData, int dim, typename TDataIn>
541void LuaUserFunction<TData,dim,TDataIn>::set_lua_value_callback(const char* luaCallback, size_t numArgs)
542{
543// store name (string) of callback
544 m_cbValueName = luaCallback;
545
546// obtain a reference
547 lua_getglobal(m_L, m_cbValueName.c_str());
548
549// make sure that the reference is valid
550 if(lua_isnil(m_L, -1)){
551 UG_THROW("LuaUserFunction::set_lua_value_callback(...):"
552 "Specified callback does not exist: " << m_cbValueName);
553 }
554
555// if a callback was already set, we have to free the old one
556 free_callback_ref();
557
558// store reference to lua function
559 m_cbValueRef = luaL_ref(m_L, LUA_REGISTRYINDEX);
560
561// remember number of arguments to be used
562 m_numArgs = numArgs;
563 m_cbDerivName.resize(numArgs);
564 m_cbDerivRef.resize(numArgs, LUA_NOREF);
565
566// set num inputs for linker
567 set_num_input(numArgs);
568
569 #ifdef USE_LUA2C
570 m_luaComp_Deriv.resize(numArgs);
571 #endif
572}
573
574template <typename TData, int dim, typename TDataIn>
576set_lua_value_callback(LuaFunctionHandle handle, size_t numArgs)
577{
578// store name (string) of callback
579 m_cbValueName = "__anonymous__lua__function__";
580
581// if a callback was already set, we have to free the old one
582 free_callback_ref();
583
584// store reference to lua function
585 m_cbValueRef = handle.ref;
586
587// remember number of arguments to be used
588 m_numArgs = numArgs;
589 m_cbDerivName.resize(numArgs);
590 m_cbDerivRef.resize(numArgs, LUA_NOREF);
591
592// set num inputs for linker
593 set_num_input(numArgs);
594
595 #ifdef USE_LUA2C
596 m_luaComp_Deriv.resize(numArgs);
597 #endif
598}
599
600template <typename TData, int dim, typename TDataIn>
601void LuaUserFunction<TData,dim,TDataIn>::set_deriv(size_t arg, const char* luaCallback)
602{
603// check number of arg
604 if(arg >= m_numArgs)
605 UG_THROW("LuaUserFunction::set_lua_deriv_callback: Trying "
606 "to set a derivative for argument " << arg <<", that "
607 "does not exist. Number of arguments is "<<m_numArgs);
608
609// store name (string) of callback
610 m_cbDerivName[arg] = luaCallback;
611
612// free old reference
613 free_deriv_callback_ref(arg);
614
615// obtain a reference
616 lua_getglobal(m_L, m_cbDerivName[arg].c_str());
617
618// make sure that the reference is valid
619 if(lua_isnil(m_L, -1)){
620 UG_THROW("LuaUserFunction::set_lua_deriv_callback(...):"
621 "Specified callback does not exist: " << m_cbDerivName[arg]);
622 }
623
624// store reference to lua function
625 m_cbDerivRef[arg] = luaL_ref(m_L, LUA_REGISTRYINDEX);
626
627 #ifdef USE_LUA2C
628 if(useLuaCompiler) m_luaComp_Deriv[arg].create(luaCallback);
629 #endif
630
631}
632
633template <typename TData, int dim, typename TDataIn>
635{
636// check number of arg
637 if(arg >= m_numArgs)
638 UG_THROW("LuaUserFunction::set_lua_deriv_callback: Trying "
639 "to set a derivative for argument " << arg <<", that "
640 "does not exist. Number of arguments is "<<m_numArgs);
641
642// store name (string) of callback
643 m_cbDerivName[arg] = std::string("__anonymous__lua__function__");
644
645// free old reference
646 free_deriv_callback_ref(arg);
647
648// store reference to lua function
649 m_cbDerivRef[arg] = handle.ref;
650
651 #ifdef USE_LUA2C
652 // if(useLuaCompiler) m_luaComp_Deriv[arg].create(luaCallback);
653 #endif
654
655}
656
657
658
659
660template <typename TData, int dim, typename TDataIn>
661void LuaUserFunction<TData,dim,TDataIn>::operator() (TData& out, int numArgs, ...) const
662{
664 #ifdef USE_LUA2C
665 if(useLuaCompiler && m_luaComp.is_valid())
666 {
667 double d[20];
668 // get list of arguments
669 va_list ap2;
670 va_start(ap2, numArgs);
671
672 // read all arguments and push them to the lua stack
673 for(int i = 0; i < numArgs; ++i)
674 d[i] = va_arg(ap2, double);
675 va_end(ap2);
676
677 double ret[lua_traits<TData>::size+1];
678
679 UG_ASSERT(m_luaComp.num_in() == numArgs && m_luaComp.num_out() == lua_traits<TData>::size,
680 m_luaComp.name() << ", " << m_luaComp.num_in() << " != " << numArgs << " or " << m_luaComp.num_out() << " != " << lua_traits<TData>::size);
681 m_luaComp.call(ret, d);
682 //TData D2;
683 void *t=NULL;
684 //TData out2;
685 lua_traits<TData>::read(out, ret, t);
686 return;
687 }
688 else
689 #endif
690 {
691 UG_ASSERT(numArgs == (int)m_numArgs, "Number of arguments mismatched.");
692
693 // push the callback function on the stack
694 lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_cbValueRef);
695
696 // get list of arguments
697 va_list ap;
698 va_start(ap, numArgs);
699
700 // read all arguments and push them to the lua stack
701 for(int i = 0; i < numArgs; ++i)
702 {
703 // cast data
704 TDataIn val = va_arg(ap, TDataIn);
705
706 // push data to lua stack
708 }
709
710 // end read in of parameters
711 va_end(ap);
712
713 // compute total args size
714 size_t argSize = lua_traits<TDataIn>::size * numArgs;
715
716 // compute total return size
717 size_t retSize = lua_traits<TData>::size;
718
719 // call lua function
720 if(lua_pcall(m_L, argSize, retSize, 0) != 0)
721 UG_THROW("LuaUserFunction::operator(...): Error while "
722 "running callback '" << m_cbValueName << "',"
723 " lua message: "<< lua_tostring(m_L, -1));
724
725 try{
726 // read return value
727 lua_traits<TData>::read(m_L, out);
728 UG_COND_THROW(IsFiniteAndNotTooBig(out)==false, out);
729 }
730 UG_CATCH_THROW("LuaUserFunction::operator(...): Error while running "
731 "callback '" << m_cbValueName << "'");
732
733 // pop values
734 lua_pop(m_L, retSize);
735 }
736}
737
738
739template <typename TData, int dim, typename TDataIn>
740void LuaUserFunction<TData,dim,TDataIn>::eval_value(TData& out, const std::vector<TDataIn>& dataIn,
741 const MathVector<dim>& x, number time, int si) const
742{
744 #ifdef USE_LUA2C
745 if(useLuaCompiler && m_luaComp.is_valid())
746 {
747 double d[20];
748
749 // read all arguments and push them to the lua stack
750 for(size_t i = 0; i < dataIn.size(); ++i)
751 d[i] = dataIn[i];
752 if(m_bPosTimeNeed){
753 for(int i=0; i<dim; i++)
754 d[i+m_numArgs] = x[i];
755 d[dim+m_numArgs]=time;
756 d[dim+m_numArgs+1]=si;
757 UG_ASSERT(dim+m_numArgs+1 < 20, m_luaComp.name());
758 }
759
760 double ret[lua_traits<TData>::size];
761 m_luaComp.call(ret, d);
762 //TData D2;
763 void *t=NULL;
764 //TData out2;
765 UG_ASSERT(m_luaComp.num_out() == lua_traits<TData>::size, m_luaComp.name() << ", " << m_luaComp.num_out() << " != " << lua_traits<TData>::size);
766 lua_traits<TData>::read(out, ret, t);
767 return;
768 }
769 else
770 #endif
771 {
772 UG_ASSERT(dataIn.size() == m_numArgs, "Number of arguments mismatched.");
773
774 // push the callback function on the stack
775 lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_cbValueRef);
776
777 // read all arguments and push them to the lua stack
778 for(size_t i = 0; i < dataIn.size(); ++i)
779 {
780 // push data to lua stack
781 lua_traits<TDataIn>::push(m_L, dataIn[i]);
782 }
783
784 // if needed, read additional coordinate, time and subset index arguments and push them to the lua stack
785 if(m_bPosTimeNeed){
786 lua_traits<MathVector<dim> >::push(m_L, x);
787 lua_traits<number>::push(m_L, time);
788 lua_traits<int>::push(m_L, si);
789 }
790
791 // compute total args size
792 size_t argSize = lua_traits<TDataIn>::size * dataIn.size();
793 if(m_bPosTimeNeed){
794 argSize += lua_traits<MathVector<dim> >::size
797 }
798
799 // compute total return size
800 size_t retSize = lua_traits<TData>::size;
801
802 // call lua function
803 if(lua_pcall(m_L, argSize, retSize, 0) != 0)
804 UG_THROW("LuaUserFunction::eval_value(...): Error while "
805 "running callback '" << m_cbValueName << "',"
806 " lua message: "<< lua_tostring(m_L, -1));
807
808 try{
809 // read return value
810 lua_traits<TData>::read(m_L, out);
811 UG_COND_THROW(IsFiniteAndNotTooBig(out)==false, out);
812 }
813 UG_CATCH_THROW("LuaUserFunction::eval_value(...): Error while "
814 "running callback '" << m_cbValueName << "'");
815
816 // pop values
817 lua_pop(m_L, retSize);
818 }
819}
820
821
822template <typename TData, int dim, typename TDataIn>
823void LuaUserFunction<TData,dim,TDataIn>::eval_deriv(TData& out, const std::vector<TDataIn>& dataIn,
824 const MathVector<dim>& x, number time, int si, size_t arg) const
825{
827 #ifdef USE_LUA2C
828 if(useLuaCompiler && m_luaComp_Deriv[arg].is_valid()
829 && dim+m_numArgs+1 < 20 && m_luaComp_Deriv[arg].num_out() == lua_traits<TData>::size)
830 {
831 const bridge::LUACompiler &luaComp = m_luaComp_Deriv[arg];
832 double d[25];
833 UG_ASSERT(dim+m_numArgs+1 < 20, luaComp.name());
834 for(size_t i=0; i<m_numArgs; i++)
835 d[i] = dataIn[i];
836 if(m_bPosTimeNeed){
837 for(int i=0; i<dim; i++)
838 d[i+m_numArgs] = x[i];
839 d[dim+m_numArgs]=time;
840 d[dim+m_numArgs+1]=si;
841 UG_ASSERT(dim+m_numArgs+1 < 20, luaComp.name());
842 }
844 luaComp.name() << " has wrong number of outputs: is " << luaComp.num_out() << ", needs " << lua_traits<TData>::size);
845 double ret[lua_traits<TData>::size+1];
846 luaComp.call(ret, d);
847 //TData D2;
848 void *t=NULL;
849 //TData out2;
850 lua_traits<TData>::read(out, ret, t);
851 return;
852 }
853 else
854 #endif
855 {
856 UG_ASSERT(dataIn.size() == m_numArgs, "Number of arguments mismatched.");
857 UG_ASSERT(arg < m_numArgs, "Argument does not exist.");
858
859 // push the callback function on the stack
860 lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_cbDerivRef[arg]);
861
862 // read all arguments and push them to the lua stack
863 for(size_t i = 0; i < dataIn.size(); ++i)
864 {
865 // push data to lua stack
866 lua_traits<TDataIn>::push(m_L, dataIn[i]);
867 }
868
869 // if needed, read additional coordinate, time and subset index arguments and push them to the lua stack
870 if(m_bPosTimeNeed){
871 lua_traits<MathVector<dim> >::push(m_L, x);
872 lua_traits<number>::push(m_L, time);
873 lua_traits<int>::push(m_L, si);
874 }
875
876 // compute total args size
877 size_t argSize = lua_traits<TDataIn>::size * dataIn.size();
878 if(m_bPosTimeNeed){
879 argSize += lua_traits<MathVector<dim> >::size
882 }
883
884 // compute total return size
885 size_t retSize = lua_traits<TData>::size;
886
887 // call lua function
888 if(lua_pcall(m_L, argSize, retSize, 0) != 0)
889 UG_THROW("LuaUserFunction::eval_deriv: Error while "
890 "running callback '" << m_cbDerivName[arg] << "',"
891 " lua message: "<< lua_tostring(m_L, -1) );
892
893 try{
894 // read return value
895 lua_traits<TData>::read(m_L, out);
896 UG_COND_THROW(IsFiniteAndNotTooBig(out)==false, out);
897 }
898 UG_CATCH_THROW("LuaUserFunction::eval_deriv(...): Error while "
899 "running callback '" << m_cbDerivName[arg] << "'");
900
901 // pop values
902 lua_pop(m_L, retSize);
903 }
904}
905
906
907template <typename TData, int dim, typename TDataIn>
909evaluate (TData& value,
910 const MathVector<dim>& globIP,
911 number time, int si) const
912{
914// vector of data for all inputs
915 std::vector<TDataIn> vDataIn(this->num_input());
916
917// gather all input data for this ip
918 for(size_t c = 0; c < vDataIn.size(); ++c)
919 (*m_vpUserData[c])(vDataIn[c], globIP, time, si);
920
921// evaluate data at ip
922 eval_value(value, vDataIn, globIP, time, si);
923
924 UG_COND_THROW(IsFiniteAndNotTooBig(value)==false, value);
925}
926
927template <typename TData, int dim, typename TDataIn>
928template <int refDim>
930evaluate(TData vValue[],
931 const MathVector<dim> vGlobIP[],
932 number time, int si,
933 GridObject* elem,
934 const MathVector<dim> vCornerCoords[],
935 const MathVector<refDim> vLocIP[],
936 const size_t nip,
937 LocalVector* u,
938 const MathMatrix<refDim, dim>* vJT) const
939{
941// vector of data for all inputs
942 std::vector<TDataIn> vDataIn(this->num_input());
943
944// gather all input data for this ip
945 for(size_t ip = 0; ip < nip; ++ip)
946 {
947 for(size_t c = 0; c < vDataIn.size(); ++c)
948 (*m_vpUserData[c])(vDataIn[c], vGlobIP[ip], time, si, elem, vCornerCoords, vLocIP[ip], u);
949
950 // evaluate data at ip
951 eval_value(vValue[ip], vDataIn, vGlobIP[ip], time, si);
952 UG_COND_THROW(IsFiniteAndNotTooBig(vValue[ip])==false, vValue[ip]);
953 }
954}
955
956template <typename TData, int dim, typename TDataIn>
957template <int refDim>
959eval_and_deriv(TData vValue[],
960 const MathVector<dim> vGlobIP[],
961 number time, int si,
962 GridObject* elem,
963 const MathVector<dim> vCornerCoords[],
964 const MathVector<refDim> vLocIP[],
965 const size_t nip,
966 LocalVector* u,
967 bool bDeriv,
968 int s,
969 std::vector<std::vector<TData> > vvvDeriv[],
970 const MathMatrix<refDim, dim>* vJT)
971{
973// vector of data for all inputs
974 std::vector<TDataIn> vDataIn(this->num_input());
975
976 for(size_t ip = 0; ip < nip; ++ip)
977 {
978 // gather all input data for this ip
979 for(size_t c = 0; c < vDataIn.size(); ++c)
980 vDataIn[c] = m_vpUserData[c]->value(this->series_id(c,s), ip);
981
982 // evaluate data at ip
983 eval_value(vValue[ip], vDataIn, vGlobIP[ip], time, si);
984 }
985
986// check if derivative is required
987 if(!bDeriv || this->zero_derivative()) return;
988
989// clear all derivative values
990 this->set_zero(vvvDeriv, nip);
991
992// loop all inputs
993 for(size_t c = 0; c < vDataIn.size(); ++c)
994 {
995 // check if we have the derivative w.r.t. this input, and the input has derivative
996 if(m_cbDerivRef[c] == LUA_NOREF || m_vpUserData[c]->zero_derivative()) continue;
997
998 // loop ips
999 for(size_t ip = 0; ip < nip; ++ip)
1000 {
1001 // gather all input data for this ip
1002 for(size_t i = 0; i < vDataIn.size(); ++i)
1003 vDataIn[i] = m_vpUserData[i]->value(this->series_id(c,s), ip); //< series_id(c,s) or series_id(i,s)
1004
1005 // data of derivative w.r.t. one component at ip-values
1006 TData derivVal;
1007
1008 // evaluate data at ip
1009 eval_deriv(derivVal, vDataIn, vGlobIP[ip], time, si, c);
1010
1011 // loop functions
1012 for(size_t fct = 0; fct < this->input_num_fct(c); ++fct)
1013 {
1014 // get common fct id for this function
1015 const size_t commonFct = this->input_common_fct(c, fct);
1016
1017 // loop dofs
1018 for(size_t dof = 0; dof < this->num_sh(fct); ++dof)
1019 {
1021 mult_add(vvvDeriv[ip][commonFct][dof],
1022 derivVal,
1023 m_vpDependData[c]->deriv(this->series_id(c,s), ip, fct, dof));
1024 UG_COND_THROW(IsFiniteAndNotTooBig(vvvDeriv[ip][commonFct][dof])==false, vvvDeriv[ip][commonFct][dof]);
1025 }
1026 }
1027 }
1028 }
1029}
1030
1036template <typename TData, int dim, typename TDataIn>
1038{
1039// resize arrays
1040 m_vpUserData.resize(num);
1041 m_vpDependData.resize(num);
1042
1043// forward size to base class
1044 base_type::set_num_input(num);
1045}
1046
1047template <typename TData, int dim, typename TDataIn>
1050{
1051 UG_ASSERT(i < m_vpUserData.size(), "Input not needed");
1052 UG_ASSERT(i < m_vpDependData.size(), "Input not needed");
1053
1054// check input number
1055 if(i >= this->num_input())
1056 UG_THROW("LuaUserFunction::set_input: Only " << this->num_input()
1057 << " inputs can be set. Use 'set_num_input' to increase"
1058 " the number of needed inputs.");
1059
1060// remember userdata
1061 m_vpUserData[i] = data;
1062
1063// cast to dependent data
1064 m_vpDependData[i] = data.template cast_dynamic<DependentUserData<TDataIn, dim> >();
1065
1066// forward to base class
1067 base_type::set_input(i, data, data);
1068}
1069
1070template <typename TData, int dim, typename TDataIn>
1072{
1073 set_input(i, CreateConstUserData<dim>(val, TDataIn()));
1074}
1075
1076
1078// LuaFunction
1080
1081template <typename TData, typename TDataIn>
1083{
1085 m_cbValueRef = LUA_NOREF;
1086}
1087
1088template <typename TData, typename TDataIn>
1089void LuaFunction<TData,TDataIn>::set_lua_callback(const char* luaCallback, size_t numArgs)
1090{
1091// store name (string) of callback
1092 m_cbValueName = luaCallback;
1093
1094// obtain a reference
1095 lua_getglobal(m_L, m_cbValueName.c_str());
1096
1097// make sure that the reference is valid
1098 if(lua_isnil(m_L, -1)){
1099 UG_THROW("LuaFunction::set_lua_callback(...):"
1100 "Specified lua callback does not exist: " << m_cbValueName);
1101 }
1102
1103// store reference to lua function
1104 m_cbValueRef = luaL_ref(m_L, LUA_REGISTRYINDEX);
1105
1106// remember number of arguments to be used
1107 m_numArgs = numArgs;
1108}
1109
1110template <typename TData, typename TDataIn>
1111void LuaFunction<TData,TDataIn>::operator() (TData& out, int numArgs, ...)
1112{
1113 PROFILE_CALLBACK_BEGIN(operatorBracket);
1114 UG_ASSERT(numArgs == (int)m_numArgs, "Number of arguments mismatched.");
1115
1116// push the callback function on the stack
1117 lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_cbValueRef);
1118
1119// get list of arguments
1120 va_list ap;
1121 va_start(ap, numArgs);
1122
1123// read all arguments and push them to the lua stack
1124 for(int i = 0; i < numArgs; ++i)
1125 {
1126 // cast data
1127 TDataIn val = va_arg(ap, TDataIn);
1128
1129 // push data to lua stack
1130 lua_traits<TDataIn>::push(m_L, val);
1131 }
1132
1133// end read in of parameters
1134 va_end(ap);
1135
1136// compute total args size
1137 size_t argSize = lua_traits<TDataIn>::size * numArgs;
1138
1139// compute total return size
1140 size_t retSize = lua_traits<TData>::size;
1141
1142// call lua function
1143 if(lua_pcall(m_L, argSize, retSize, 0) != 0)
1144 UG_THROW("LuaFunction::operator(...): Error while "
1145 "running callback '" << m_cbValueName << "',"
1146 " lua message: "<< lua_tostring(m_L, -1));
1147
1148 try{
1149 // read return value
1150 lua_traits<TData>::read(m_L, out);
1151 UG_COND_THROW(IsFiniteAndNotTooBig(out)==false, out);
1152 }
1153 UG_CATCH_THROW("LuaFunction::operator(...): Error while running "
1154 "callback '" << m_cbValueName << "'");
1155
1156// pop values
1157 lua_pop(m_L, retSize);
1158
1160}
1161
1162
1163
1164} // end namespace ug
1165
1166#endif /* LUA_USER_DATA_IMPL_H_ */
parameterString s
location name
Definition checkpoint_util.lua:128
Definition smart_pointer.h:108
T * get()
returns encapsulated pointer
Definition smart_pointer.h:197
int * refcount_ptr() const
WARNING: this method is DANGEROUS!
Definition smart_pointer.h:263
Type based UserData.
Definition user_data.h:501
The base class for all geometric objects, such as vertices, edges, faces, volumes,...
Definition grid_base_objects.h:157
Definition local_algebra.h:198
Handle for a lua reference.
Definition lua_function_handle.h:40
int ref
Definition lua_function_handle.h:42
int m_cbValueRef
reference to lua function
Definition lua_user_data.h:422
virtual void operator()(TData &out, int numArgs,...)
evaluates the data
Definition lua_user_data_impl.h:1111
lua_State * m_L
lua state
Definition lua_user_data.h:425
LuaFunction()
constructor
Definition lua_user_data_impl.h:1082
void set_lua_callback(const char *luaCallback, size_t numArgs)
sets the Lua function used to compute the data
Definition lua_user_data_impl.h:1089
Factory providing LuaUserData.
Definition lua_user_data.h:180
static void remove(const std::string &name)
removes the user data
Definition lua_user_data_impl.h:417
static SmartPtr< LuaUserData< TData, dim, TRet > > provide_or_create(const std::string &name)
returns new Data if not already created, already existing else
Definition lua_user_data_impl.h:370
provides data specified in the lua script
Definition lua_user_data.h:96
static bool check_callback_returns(const char *callName, const bool bThrow=false)
returns true if callback has correct return values
Definition lua_user_data_impl.h:249
lua_State * m_L
lua state
Definition lua_user_data.h:157
int m_callbackRef
reference to lua function
Definition lua_user_data.h:147
static std::string signature()
returns string of required callback signature
Definition lua_user_data_impl.h:67
static std::string name()
returns name of UserData
Definition lua_user_data_impl.h:84
std::string m_callbackName
callback name as string
Definition lua_user_data.h:144
LuaUserData(const char *luaCallback)
Constructor.
Definition lua_user_data_impl.h:94
TRet evaluate(TData &D, const MathVector< dim > &x, number time, int si) const
evaluates the data at a given point and time
Definition lua_user_data_impl.h:284
virtual ~LuaUserData()
}
Definition lua_user_data_impl.h:355
LuaUserFunction(const char *luaCallback, size_t numArgs)
constructor
Definition lua_user_data_impl.h:445
void eval_value(TData &out, const std::vector< TDataIn > &dataIn, const MathVector< dim > &x, number time, int si) const
evaluates the data at a given point and time
Definition lua_user_data_impl.h:740
virtual ~LuaUserFunction()
destructor frees the reference
Definition lua_user_data_impl.h:510
void free_deriv_callback_ref(size_t arg)
frees callback-references for derivate callbacks
Definition lua_user_data_impl.h:531
void evaluate(TData &value, const MathVector< dim > &globIP, number time, int si) const
Definition lua_user_data_impl.h:909
void set_input(size_t i, SmartPtr< CplUserData< TDataIn, dim > > data)
set input value for paramter i
Definition lua_user_data_impl.h:1049
std::vector< int > m_cbDerivRef
Definition lua_user_data.h:346
void eval_deriv(TData &out, const std::vector< TDataIn > &dataIn, const MathVector< dim > &x, number time, int si, size_t arg) const
evaluates the data at a given point and time
Definition lua_user_data_impl.h:823
void eval_and_deriv(TData vValue[], const MathVector< dim > vGlobIP[], number time, int si, GridObject *elem, const MathVector< dim > vCornerCoords[], const MathVector< refDim > vLocIP[], const size_t nip, LocalVector *u, bool bDeriv, int s, std::vector< std::vector< TData > > vvvDeriv[], const MathMatrix< refDim, dim > *vJT=NULL)
Definition lua_user_data_impl.h:959
lua_State * m_L
lua state
Definition lua_user_data.h:349
void free_callback_ref()
frees the callback-reference, if a callback was set.
Definition lua_user_data_impl.h:522
int m_cbValueRef
reference to lua function
Definition lua_user_data.h:345
void set_deriv(size_t arg, const char *luaCallback)
sets the Lua function used to compute the derivative
Definition lua_user_data_impl.h:601
std::vector< std::string > m_cbDerivName
Definition lua_user_data.h:342
void set_lua_value_callback(const char *luaCallback, size_t numArgs)
sets the Lua function used to compute the data
Definition lua_user_data_impl.h:541
void set_num_input(size_t num)
set number of needed inputs
Definition lua_user_data_impl.h:1037
virtual void operator()(TData &out, int numArgs,...) const
evaluates the data
Definition lua_user_data_impl.h:661
A class for fixed size, dense matrices.
Definition math_matrix.h:63
a mathematical Vector with N entries.
Definition math_vector.h:97
Definition lua_compiler.h:50
const std::string & name() const
Definition lua_compiler.h:86
int num_out() const
Definition lua_compiler.h:81
bool call(double *ret, const double *in) const
Definition lua_compiler.cpp:263
#define UG_ASSERT(expr, msg)
Definition assert.h:70
#define UG_CATCH_THROW(msg)
Definition error.h:64
#define UG_THROW(msg)
Definition error.h:57
#define UG_LOG(msg)
Definition log.h:367
#define UG_COND_THROW(cond, msg)
UG_COND_THROW(cond, msg) : performs a UG_THROW(msg) if cond == true.
Definition error.h:61
double number
Definition types.h:124
struct lua_State lua_State
Definition lua_table_handle.h:40
#define PROFILE_CALLBACK()
Definition lua_user_data_impl.h:50
#define PROFILE_CALLBACK_END()
Definition lua_user_data_impl.h:52
#define PROFILE_CALLBACK_BEGIN(name)
Definition lua_user_data_impl.h:51
string GetLUAScriptFunctionDefined(const char *functionName)
returns file and line of defined script function
Definition info_commands.cpp:368
lua_State * GetDefaultLuaState()
returns the default lua state
Definition lua_util.cpp:242
the ug namespace
bool useLuaCompiler
Definition info_commands.cpp:93
bool IsFiniteAndNotTooBig(double d)
Definition number_util.h:39
SmartPtr< T, FreePolicy > make_sp(T *inst)
returns a SmartPtr for the passed raw pointer
Definition smart_pointer.h:836
static void mult_add(TData &out, const TData &in1, const TDataIn &s)
computes out += s * in1 (with appropriate '*')
Lua Traits to push/pop on lua stack.
Definition lua_traits.h:79