ug4
lua_user_data_impl.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2012-2015: G-CSC, Goethe University Frankfurt
3  * Author: Andreas Vogel
4  *
5  * This file is part of UG4.
6  *
7  * UG4 is free software: you can redistribute it and/or modify it under the
8  * terms of the GNU Lesser General Public License version 3 (as published by the
9  * Free Software Foundation) with the following additional attribution
10  * requirements (according to LGPL/GPL v3 §7):
11  *
12  * (1) The following notice must be displayed in the Appropriate Legal Notices
13  * of covered and combined works: "Based on UG4 (www.ug4.org/license)".
14  *
15  * (2) The following notice must be displayed at a prominent place in the
16  * terminal output of covered works: "Based on UG4 (www.ug4.org/license)".
17  *
18  * (3) The following bibliography is recommended for citation and must be
19  * preserved in all covered files:
20  * "Reiter, S., Vogel, A., Heppner, I., Rupp, M., and Wittum, G. A massively
21  * parallel geometric multigrid solver on hierarchically distributed grids.
22  * Computing and visualization in science 16, 4 (2013), 151-164"
23  * "Vogel, A., Reiter, S., Rupp, M., Nägel, A., and Wittum, G. UG4 -- a novel
24  * flexible software system for simulating pde based models on high performance
25  * computers. Computing and visualization in science 16, 4 (2013), 165-179"
26  *
27  * This program is distributed in the hope that it will be useful,
28  * but WITHOUT ANY WARRANTY; without even the implied warranty of
29  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
30  * GNU Lesser General Public License for more details.
31  */
32 
33 #ifndef __H__UG_BRIDGE__BRIDGES__USER_DATA__USER_DATA_IMPL_
34 #define __H__UG_BRIDGE__BRIDGES__USER_DATA__USER_DATA_IMPL_
35 
36 #ifdef UG_FOR_LUA
37 #include "lua_user_data.h"
38 #endif
41 
42 #include "info_commands.h"
44 
45 #if 0
46 #define PROFILE_CALLBACK() PROFILE_FUNC_GROUP("luacallback")
47 #define PROFILE_CALLBACK_BEGIN(name) PROFILE_BEGIN_GROUP(name, "luacallback")
48 #define PROFILE_CALLBACK_END() PROFILE_END()
49 #else
50 #define PROFILE_CALLBACK()
51 #define PROFILE_CALLBACK_BEGIN(name)
52 #define PROFILE_CALLBACK_END()
53 #endif
54 namespace ug{
55 
56 #ifdef USE_LUA2C
57  extern bool useLuaCompiler;
58 #endif
59 
60 
61 
63 // LuaUserData
65 
66 template <typename TData, int dim, typename TRet>
68 {
69  std::stringstream ss;
70  ss << "function name(";
71  if(dim >= 1) ss << "x";
72  if(dim >= 2) ss << ", y";
73  if(dim >= 3) ss << ", z";
74  ss << ", t, si)\n ... \n return ";
75  if(lua_traits<TRet>::size != 0)
76  ss << lua_traits<TRet>::signature() << ", ";
77  ss << lua_traits<TData>::signature();
78  ss << "\nend";
79  return ss.str();
80 }
81 
82 
83 template <typename TData, int dim, typename TRet>
85 {
86  std::stringstream ss;
87  ss << "Lua";
88  if(lua_traits<TRet>::size > 0) ss << "Cond";
89  ss << "User" << lua_traits<TData>::name() << dim << "d";
90  return ss.str();
91 }
92 
93 template <typename TData, int dim, typename TRet>
95  : m_callbackName(luaCallback), m_bFromFactory(false)
96 {
97 // get lua state
99 
100 // obtain a reference
101  lua_getglobal(m_L, m_callbackName.c_str());
102 
103 // make sure that the reference is valid
104  if(lua_isnil(m_L, -1)){
105  UG_THROW(name() << ": Specified lua callback "
106  "does not exist: " << m_callbackName);
107  }
108 
109 // store reference to lua function
110  m_callbackRef = luaL_ref(m_L, LUA_REGISTRYINDEX);
111 
112 // make a test run
114 
115  #ifdef USE_LUA2C
116  if(useLuaCompiler) m_luaComp.create(luaCallback);
117  #endif
118 }
119 
120 template <typename TData, int dim, typename TRet>
122  : m_callbackName("__anonymous__lua__function__"), m_bFromFactory(false)
123 {
124 // get lua state
126 
127 // store reference to lua function
128  m_callbackRef = handle.ref;
129 
130 // make a test run
132 
133  #ifdef USE_LUA2C
134 // UG_THROW("LuaFunctionHandle usage currently not supported with LUA2C.");
135  if(useLuaCompiler) m_luaComp.create(m_callbackName.c_str(), &handle);
136  #endif
137 }
138 
139 
140 template <typename TData, int dim, typename TRet>
142 check_callback_returns(lua_State* L, int callbackRef, const char* callName, const bool bThrow)
143 {
145 // get current stack level
146  const int level = lua_gettop(L);
147 
148 // dummy values to invoke the callback once
149  MathVector<dim> x; x = 0.0;
150  number time = 0.0;
151  int si = 0;
152 
153 // push the callback function on the stack
154  lua_rawgeti(L, LUA_REGISTRYINDEX, callbackRef);
155 
156 // push space coordinates on stack
157  lua_traits<MathVector<dim> >::push(L, x);
158 
159 // push time on stack
160  lua_traits<number>::push(L, time);
161 
162 // push subset on stack
163  lua_traits<int>::push(L, si);
164 
165 // compute total args size
166  const int argSize = lua_traits<MathVector<dim> >::size
169 
170 // compute total return size
171  const int retSize = lua_traits<TData>::size + lua_traits<TRet>::size;
172 
173 // call lua function
174  if(lua_pcall(L, argSize, LUA_MULTRET, 0) != 0)
175  UG_THROW(name() << ": Error while "
176  "testing callback '" << callName << "',"
177  " lua message: "<< lua_tostring(L, -1));
178 
179  // get number of results
180  const int numResults = lua_gettop(L) - level;
181 
182 // success flag
183  bool bRet = true;
184 
185 // if number of results is wrong return error
186  if(numResults != retSize){
187  if(bThrow){
188  UG_THROW(name() << ": Number of return values incorrect "
189  "for callback\n"<<callName<< " (" << bridge::GetLUAScriptFunctionDefined(callName) << ")"
190  "\nRequired: "<<retSize<<", passed: "<<numResults
191  <<". Use signature as follows:\n"
192  << signature());
193  }
194  else{
195  bRet = false;
196  }
197  }
198 
199 // check return value
200  if(!lua_traits<TData>::check(L)){
201  if(bThrow){
202  UG_THROW(name() << ": Data values type incorrect "
203  "for callback\n"<<callName<< " (" << bridge::GetLUAScriptFunctionDefined(callName) << ")"
204  "\nUse signature as follows:\n"
205  << signature());
206  }
207  else{
208  bRet = false;
209  }
210  }
211 
212 // read return flag (may be void)
213  if(!lua_traits<TRet>::check(L, -retSize)){
214  if(bThrow){
215  UG_THROW("LuaUserData: Return values type incorrect "
216  "for callback\n"<<callName<< " (" << bridge::GetLUAScriptFunctionDefined(callName) << ")"
217  "\nUse signature as follows:\n"
218  << signature());
219  }
220  else{
221  bRet = false;
222  }
223  }
224 
225 // pop values
226  lua_pop(L, numResults);
227 
228 // return match
229  return bRet;
230 }
231 
232 template <typename TData, int dim, typename TRet>
234 check_callback_returns(LuaFunctionHandle handle, const bool bThrow)
235 {
237 // get lua state
239 
240 // forward call
241  bool bRet = check_callback_returns(L, handle.ref, "__lua_function_handle__", bThrow);
242 
243 // return match
244  return bRet;
245 }
246 
247 template <typename TData, int dim, typename TRet>
249 check_callback_returns(const char* callName, const bool bThrow)
250 {
252 // get lua state
254 
255 // obtain a reference
256  lua_getglobal(L, callName);
257 
258 // check if reference is valid
259  if(lua_isnil(L, -1)) {
260  if(bThrow) {
261  UG_THROW(name() << ": Cannot find specified lua callback "
262  " with name: "<<callName);
263  }
264  else {
265  return false;
266  }
267  }
268 
269 // get reference
270  int callbackRef = luaL_ref(L, LUA_REGISTRYINDEX);
271 
272 // forward call
273  bool bRet = check_callback_returns(L, callbackRef, callName, bThrow);
274 
275 // free reference to callback
276  luaL_unref(L, LUA_REGISTRYINDEX, callbackRef);
277 
278 // return match
279  return bRet;
280 }
281 
282 template <typename TData, int dim, typename TRet>
284 evaluate(TData& D, const MathVector<dim>& x, number time, int si) const
285 {
287  #ifdef USE_LUA2C
288  if(useLuaCompiler && m_luaComp.is_valid())
289  {
290  double d[dim+2];
291  for(int i=0; i<dim; i++)
292  d[i] = x[i];
293  d[dim] = time;
294  d[dim+1] = si;
295  double ret[lua_traits<TData>::size+1];
296  m_luaComp.call(ret, d);
297  //TData D2;
298  TRet *t=NULL;
299  lua_traits<TData>::read(D, ret, t);
300  return lua_traits<TRet>::do_return(ret[0]);
301  }
302  else
303  #endif
304  {
305  // push the callback function on the stack
306  lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_callbackRef);
307 
308  // push space coordinates on stack
309  lua_traits<MathVector<dim> >::push(m_L, x);
310 
311  // push time on stack
312  lua_traits<number>::push(m_L, time);
313 
314  // push subset index on stack
315  lua_traits<int>::push(m_L, si);
316 
317  // compute total args size
318  const int argSize = lua_traits<MathVector<dim> >::size
321 
322  // compute total return size
323  const int retSize = lua_traits<TData>::size + lua_traits<TRet>::size;
324 
325  // call lua function
326  if(lua_pcall(m_L, argSize, retSize, 0) != 0)
327  UG_THROW(name() << "::operator(...): Error while "
328  "running callback '" << m_callbackName << "',"
329  " lua message: "<< lua_tostring(m_L, -1)<<".\n"
330  "Use signature as follows:\n"
331  << signature());
332 
333  bool res = false;
334  try{
335  // read return value
336  lua_traits<TData>::read(m_L, D);
337 
338  // read return flag (may be void)
339  lua_traits<TRet>::read(m_L, res, -retSize);
340  }
341  UG_CATCH_THROW(name() << "::operator(...): Error while running "
342  "callback '" << m_callbackName << "'.\n"
343  "Use signature as follows:\n"
344  << signature());
345 
346  // pop values
347  lua_pop(m_L, retSize);
348 
349  // forward flag
350  return lua_traits<TRet>::do_return(res);
351  }
352 }
353 
354 template <typename TData, int dim, typename TRet>
356 {
357 // free reference to callback
358  luaL_unref(m_L, LUA_REGISTRYINDEX, m_callbackRef);
359 
360  if(m_bFromFactory)
362 }
363 
365 // LuaUserDataFactory
367 
368 template <typename TData, int dim, typename TRet>
371 {
373  typedef std::map<std::string, std::pair<LuaUserData<TData,dim,TRet>*, int*> > Map;
374  typedef typename Map::iterator iterator;
375 
376 // check for element
377  iterator iter = m_mData.find(name);
378 
379 // if name does not exist, create new one
380  if(iter == m_mData.end())
381  {
383  = make_sp(new LuaUserData<TData,dim,TRet>(name.c_str()));
384 
385  // the LuaUserData must remember to unregister itself at destruction
386  sp->set_created_from_factory(true);
387 
388  // NOTE AND WARNING: This is very hacky and dangerous. We only do this
389  // since we exactly know what we are doing and everything is save and
390  // only in protected or private area. However, if you once want to change
391  // this code, please be aware, that we store here plain pointers and
392  // associated reference counters of a SmartPtr. This should not be done
393  // in general and this kind of coding is not recommended at all. Please
394  // use different approaches whenever possible.
395  std::pair<LuaUserData<TData,dim,TRet>*, int*>& data = m_mData[name];
396  data.first = sp.get();
397  data.second = sp.refcount_ptr();
398 
399  return sp;
400  }
401 // else return present data
402  {
403  // NOTE AND WARNING: This is very hacky and dangerous. We only do this
404  // since we exactly know what we are doing and everything is save and
405  // only in protected or private area. However, if you once want to change
406  // this code, please be aware, that we store here plain pointers and
407  // associated reference counters of a SmartPtr. This should not be done
408  // in general and this kind of coding is not recommended at all. Please
409  // use different approaches whenever possible.
410  std::pair<LuaUserData<TData,dim,TRet>*, int*>& data = iter->second;
411  return SmartPtr<LuaUserData<TData,dim,TRet> >(data.first, data.second);
412  }
413 }
414 
415 template <typename TData, int dim, typename TRet>
416 void
418 {
419  typedef std::map<std::string, std::pair<LuaUserData<TData,dim,TRet>*, int*> > Map;
420  typedef typename Map::iterator iterator;
421 
422 // check for element
423  iterator iter = m_mData.find(name);
424 
425 // if name does not exist, create new one
426  if(iter == m_mData.end())
427  UG_THROW("LuaUserDataFactory: trying to remove non-registered"
428  " data with name: "<<name);
429 
430  m_mData.erase(iter);
431 }
432 
433 
434 // instantiation of static member
435 template <typename TData, int dim, typename TRet>
436 std::map<std::string, std::pair<LuaUserData<TData,dim,TRet>*, int*> >
437 LuaUserDataFactory<TData,dim,TRet>::m_mData = std::map<std::string, std::pair<LuaUserData<TData,dim,TRet>*, int*> >();
438 
440 // LuaUserFunction
442 
443 template <typename TData, int dim, typename TDataIn>
445 LuaUserFunction(const char* luaCallback, size_t numArgs)
446  : m_numArgs(numArgs), m_bPosTimeNeed(false)
447 {
449  m_cbValueRef = LUA_NOREF;
450  m_cbDerivRef.clear();
451  m_cbDerivName.clear();
452  set_lua_value_callback(luaCallback, numArgs);
453  #ifdef USE_LUA2C
454  if(useLuaCompiler) m_luaComp.create(luaCallback);
455  #endif
456 }
457 
458 template <typename TData, int dim, typename TDataIn>
460 LuaUserFunction(const char* luaCallback, size_t numArgs, bool bPosTimeNeed)
461  : m_numArgs(numArgs), m_bPosTimeNeed(bPosTimeNeed)
462 {
464  m_cbValueRef = LUA_NOREF;
465  m_cbDerivRef.clear();
466  m_cbDerivName.clear();
467  set_lua_value_callback(luaCallback, numArgs);
468  #ifdef USE_LUA2C
469  m_luaComp_Deriv.clear();
470  #endif
471 }
472 
473 
474 template <typename TData, int dim, typename TDataIn>
476 LuaUserFunction(LuaFunctionHandle handle, size_t numArgs)
477  : m_numArgs(numArgs), m_bPosTimeNeed(false)
478 {
480  m_cbValueRef = LUA_NOREF;
481  m_cbDerivRef.clear();
482  m_cbDerivName.clear();
483  set_lua_value_callback(handle, numArgs);
484  #ifdef USE_LUA2C
485  if(useLuaCompiler){
486  UG_LOG("WARNING (in LuaUserFunction): LUA2C compiler "
487  "can't be executed for FunctionHandle.\n");
488  }
489  #endif
490 }
491 
492 template <typename TData, int dim, typename TDataIn>
494 LuaUserFunction(LuaFunctionHandle handle, size_t numArgs, bool bPosTimeNeed)
495  : m_numArgs(numArgs), m_bPosTimeNeed(bPosTimeNeed)
496 {
498  m_cbValueRef = LUA_NOREF;
499  m_cbDerivRef.clear();
500  m_cbDerivName.clear();
501  set_lua_value_callback(handle, numArgs);
502  #ifdef USE_LUA2C
503  m_luaComp_Deriv.clear();
504  #endif
505 }
506 
507 
508 
509 template <typename TData, int dim, typename TDataIn>
511 {
512 // free reference to callback
513  free_callback_ref();
514 
515 // free references to derivate callbacks
516  for(size_t i = 0; i < m_numArgs; ++i){
517  free_deriv_callback_ref(i);
518  }
519 }
520 
521 template <typename TData, int dim, typename TDataIn>
523 {
524  if(m_cbValueRef != LUA_NOREF){
525  luaL_unref(m_L, LUA_REGISTRYINDEX, m_cbValueRef);
526  m_cbValueRef = LUA_NOREF;
527  }
528 }
529 
530 template <typename TData, int dim, typename TDataIn>
532 {
533  if(m_cbDerivRef[arg] != LUA_NOREF){
534  luaL_unref(m_L, LUA_REGISTRYINDEX, m_cbDerivRef[arg]);
535  m_cbDerivRef[arg] = LUA_NOREF;
536  }
537 }
538 
539 
540 template <typename TData, int dim, typename TDataIn>
541 void LuaUserFunction<TData,dim,TDataIn>::set_lua_value_callback(const char* luaCallback, size_t numArgs)
542 {
543 // store name (string) of callback
544  m_cbValueName = luaCallback;
545 
546 // obtain a reference
547  lua_getglobal(m_L, m_cbValueName.c_str());
548 
549 // make sure that the reference is valid
550  if(lua_isnil(m_L, -1)){
551  UG_THROW("LuaUserFunction::set_lua_value_callback(...):"
552  "Specified callback does not exist: " << m_cbValueName);
553  }
554 
555 // if a callback was already set, we have to free the old one
556  free_callback_ref();
557 
558 // store reference to lua function
559  m_cbValueRef = luaL_ref(m_L, LUA_REGISTRYINDEX);
560 
561 // remember number of arguments to be used
562  m_numArgs = numArgs;
563  m_cbDerivName.resize(numArgs);
564  m_cbDerivRef.resize(numArgs, LUA_NOREF);
565 
566 // set num inputs for linker
567  set_num_input(numArgs);
568 
569  #ifdef USE_LUA2C
570  m_luaComp_Deriv.resize(numArgs);
571  #endif
572 }
573 
574 template <typename TData, int dim, typename TDataIn>
576 set_lua_value_callback(LuaFunctionHandle handle, size_t numArgs)
577 {
578 // store name (string) of callback
579  m_cbValueName = "__anonymous__lua__function__";
580 
581 // if a callback was already set, we have to free the old one
582  free_callback_ref();
583 
584 // store reference to lua function
585  m_cbValueRef = handle.ref;
586 
587 // remember number of arguments to be used
588  m_numArgs = numArgs;
589  m_cbDerivName.resize(numArgs);
590  m_cbDerivRef.resize(numArgs, LUA_NOREF);
591 
592 // set num inputs for linker
593  set_num_input(numArgs);
594 
595  #ifdef USE_LUA2C
596  m_luaComp_Deriv.resize(numArgs);
597  #endif
598 }
599 
600 template <typename TData, int dim, typename TDataIn>
601 void LuaUserFunction<TData,dim,TDataIn>::set_deriv(size_t arg, const char* luaCallback)
602 {
603 // check number of arg
604  if(arg >= m_numArgs)
605  UG_THROW("LuaUserFunction::set_lua_deriv_callback: Trying "
606  "to set a derivative for argument " << arg <<", that "
607  "does not exist. Number of arguments is "<<m_numArgs);
608 
609 // store name (string) of callback
610  m_cbDerivName[arg] = luaCallback;
611 
612 // free old reference
613  free_deriv_callback_ref(arg);
614 
615 // obtain a reference
616  lua_getglobal(m_L, m_cbDerivName[arg].c_str());
617 
618 // make sure that the reference is valid
619  if(lua_isnil(m_L, -1)){
620  UG_THROW("LuaUserFunction::set_lua_deriv_callback(...):"
621  "Specified callback does not exist: " << m_cbDerivName[arg]);
622  }
623 
624 // store reference to lua function
625  m_cbDerivRef[arg] = luaL_ref(m_L, LUA_REGISTRYINDEX);
626 
627  #ifdef USE_LUA2C
628  if(useLuaCompiler) m_luaComp_Deriv[arg].create(luaCallback);
629  #endif
630 
631 }
632 
633 template <typename TData, int dim, typename TDataIn>
635 {
636 // check number of arg
637  if(arg >= m_numArgs)
638  UG_THROW("LuaUserFunction::set_lua_deriv_callback: Trying "
639  "to set a derivative for argument " << arg <<", that "
640  "does not exist. Number of arguments is "<<m_numArgs);
641 
642 // store name (string) of callback
643  m_cbDerivName[arg] = std::string("__anonymous__lua__function__");
644 
645 // free old reference
646  free_deriv_callback_ref(arg);
647 
648 // store reference to lua function
649  m_cbDerivRef[arg] = handle.ref;
650 
651  #ifdef USE_LUA2C
652  // if(useLuaCompiler) m_luaComp_Deriv[arg].create(luaCallback);
653  #endif
654 
655 }
656 
657 
658 
659 
660 template <typename TData, int dim, typename TDataIn>
661 void LuaUserFunction<TData,dim,TDataIn>::operator() (TData& out, int numArgs, ...) const
662 {
664  #ifdef USE_LUA2C
665  if(useLuaCompiler && m_luaComp.is_valid())
666  {
667  double d[20];
668  // get list of arguments
669  va_list ap2;
670  va_start(ap2, numArgs);
671 
672  // read all arguments and push them to the lua stack
673  for(int i = 0; i < numArgs; ++i)
674  d[i] = va_arg(ap2, double);
675  va_end(ap2);
676 
677  double ret[lua_traits<TData>::size+1];
678 
679  UG_ASSERT(m_luaComp.num_in() == numArgs && m_luaComp.num_out() == lua_traits<TData>::size,
680  m_luaComp.name() << ", " << m_luaComp.num_in() << " != " << numArgs << " or " << m_luaComp.num_out() << " != " << lua_traits<TData>::size);
681  m_luaComp.call(ret, d);
682  //TData D2;
683  void *t=NULL;
684  //TData out2;
685  lua_traits<TData>::read(out, ret, t);
686  return;
687  }
688  else
689  #endif
690  {
691  UG_ASSERT(numArgs == (int)m_numArgs, "Number of arguments mismatched.");
692 
693  // push the callback function on the stack
694  lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_cbValueRef);
695 
696  // get list of arguments
697  va_list ap;
698  va_start(ap, numArgs);
699 
700  // read all arguments and push them to the lua stack
701  for(int i = 0; i < numArgs; ++i)
702  {
703  // cast data
704  TDataIn val = va_arg(ap, TDataIn);
705 
706  // push data to lua stack
707  lua_traits<TDataIn>::push(m_L, val);
708  }
709 
710  // end read in of parameters
711  va_end(ap);
712 
713  // compute total args size
714  size_t argSize = lua_traits<TDataIn>::size * numArgs;
715 
716  // compute total return size
717  size_t retSize = lua_traits<TData>::size;
718 
719  // call lua function
720  if(lua_pcall(m_L, argSize, retSize, 0) != 0)
721  UG_THROW("LuaUserFunction::operator(...): Error while "
722  "running callback '" << m_cbValueName << "',"
723  " lua message: "<< lua_tostring(m_L, -1));
724 
725  try{
726  // read return value
727  lua_traits<TData>::read(m_L, out);
728  UG_COND_THROW(IsFiniteAndNotTooBig(out)==false, out);
729  }
730  UG_CATCH_THROW("LuaUserFunction::operator(...): Error while running "
731  "callback '" << m_cbValueName << "'");
732 
733  // pop values
734  lua_pop(m_L, retSize);
735  }
736 }
737 
738 
739 template <typename TData, int dim, typename TDataIn>
740 void LuaUserFunction<TData,dim,TDataIn>::eval_value(TData& out, const std::vector<TDataIn>& dataIn,
741  const MathVector<dim>& x, number time, int si) const
742 {
744  #ifdef USE_LUA2C
745  if(useLuaCompiler && m_luaComp.is_valid())
746  {
747  double d[20];
748 
749  // read all arguments and push them to the lua stack
750  for(size_t i = 0; i < dataIn.size(); ++i)
751  d[i] = dataIn[i];
752  if(m_bPosTimeNeed){
753  for(int i=0; i<dim; i++)
754  d[i+m_numArgs] = x[i];
755  d[dim+m_numArgs]=time;
756  d[dim+m_numArgs+1]=si;
757  UG_ASSERT(dim+m_numArgs+1 < 20, m_luaComp.name());
758  }
759 
760  double ret[lua_traits<TData>::size];
761  m_luaComp.call(ret, d);
762  //TData D2;
763  void *t=NULL;
764  //TData out2;
765  UG_ASSERT(m_luaComp.num_out() == lua_traits<TData>::size, m_luaComp.name() << ", " << m_luaComp.num_out() << " != " << lua_traits<TData>::size);
766  lua_traits<TData>::read(out, ret, t);
767  return;
768  }
769  else
770  #endif
771  {
772  UG_ASSERT(dataIn.size() == m_numArgs, "Number of arguments mismatched.");
773 
774  // push the callback function on the stack
775  lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_cbValueRef);
776 
777  // read all arguments and push them to the lua stack
778  for(size_t i = 0; i < dataIn.size(); ++i)
779  {
780  // push data to lua stack
781  lua_traits<TDataIn>::push(m_L, dataIn[i]);
782  }
783 
784  // if needed, read additional coordinate, time and subset index arguments and push them to the lua stack
785  if(m_bPosTimeNeed){
786  lua_traits<MathVector<dim> >::push(m_L, x);
787  lua_traits<number>::push(m_L, time);
788  lua_traits<int>::push(m_L, si);
789  }
790 
791  // compute total args size
792  size_t argSize = lua_traits<TDataIn>::size * dataIn.size();
793  if(m_bPosTimeNeed){
794  argSize += lua_traits<MathVector<dim> >::size
797  }
798 
799  // compute total return size
800  size_t retSize = lua_traits<TData>::size;
801 
802  // call lua function
803  if(lua_pcall(m_L, argSize, retSize, 0) != 0)
804  UG_THROW("LuaUserFunction::eval_value(...): Error while "
805  "running callback '" << m_cbValueName << "',"
806  " lua message: "<< lua_tostring(m_L, -1));
807 
808  try{
809  // read return value
810  lua_traits<TData>::read(m_L, out);
811  UG_COND_THROW(IsFiniteAndNotTooBig(out)==false, out);
812  }
813  UG_CATCH_THROW("LuaUserFunction::eval_value(...): Error while "
814  "running callback '" << m_cbValueName << "'");
815 
816  // pop values
817  lua_pop(m_L, retSize);
818  }
819 }
820 
821 
822 template <typename TData, int dim, typename TDataIn>
823 void LuaUserFunction<TData,dim,TDataIn>::eval_deriv(TData& out, const std::vector<TDataIn>& dataIn,
824  const MathVector<dim>& x, number time, int si, size_t arg) const
825 {
827  #ifdef USE_LUA2C
828  if(useLuaCompiler && m_luaComp_Deriv[arg].is_valid()
829  && dim+m_numArgs+1 < 20 && m_luaComp_Deriv[arg].num_out() == lua_traits<TData>::size)
830  {
831  const bridge::LUACompiler &luaComp = m_luaComp_Deriv[arg];
832  double d[25];
833  UG_ASSERT(dim+m_numArgs+1 < 20, luaComp.name());
834  for(size_t i=0; i<m_numArgs; i++)
835  d[i] = dataIn[i];
836  if(m_bPosTimeNeed){
837  for(int i=0; i<dim; i++)
838  d[i+m_numArgs] = x[i];
839  d[dim+m_numArgs]=time;
840  d[dim+m_numArgs+1]=si;
841  UG_ASSERT(dim+m_numArgs+1 < 20, luaComp.name());
842  }
844  luaComp.name() << " has wrong number of outputs: is " << luaComp.num_out() << ", needs " << lua_traits<TData>::size);
845  double ret[lua_traits<TData>::size+1];
846  luaComp.call(ret, d);
847  //TData D2;
848  void *t=NULL;
849  //TData out2;
850  lua_traits<TData>::read(out, ret, t);
851  return;
852  }
853  else
854  #endif
855  {
856  UG_ASSERT(dataIn.size() == m_numArgs, "Number of arguments mismatched.");
857  UG_ASSERT(arg < m_numArgs, "Argument does not exist.");
858 
859  // push the callback function on the stack
860  lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_cbDerivRef[arg]);
861 
862  // read all arguments and push them to the lua stack
863  for(size_t i = 0; i < dataIn.size(); ++i)
864  {
865  // push data to lua stack
866  lua_traits<TDataIn>::push(m_L, dataIn[i]);
867  }
868 
869  // if needed, read additional coordinate, time and subset index arguments and push them to the lua stack
870  if(m_bPosTimeNeed){
871  lua_traits<MathVector<dim> >::push(m_L, x);
872  lua_traits<number>::push(m_L, time);
873  lua_traits<int>::push(m_L, si);
874  }
875 
876  // compute total args size
877  size_t argSize = lua_traits<TDataIn>::size * dataIn.size();
878  if(m_bPosTimeNeed){
879  argSize += lua_traits<MathVector<dim> >::size
882  }
883 
884  // compute total return size
885  size_t retSize = lua_traits<TData>::size;
886 
887  // call lua function
888  if(lua_pcall(m_L, argSize, retSize, 0) != 0)
889  UG_THROW("LuaUserFunction::eval_deriv: Error while "
890  "running callback '" << m_cbDerivName[arg] << "',"
891  " lua message: "<< lua_tostring(m_L, -1) );
892 
893  try{
894  // read return value
895  lua_traits<TData>::read(m_L, out);
896  UG_COND_THROW(IsFiniteAndNotTooBig(out)==false, out);
897  }
898  UG_CATCH_THROW("LuaUserFunction::eval_deriv(...): Error while "
899  "running callback '" << m_cbDerivName[arg] << "'");
900 
901  // pop values
902  lua_pop(m_L, retSize);
903  }
904 }
905 
906 
907 template <typename TData, int dim, typename TDataIn>
909 evaluate (TData& value,
910  const MathVector<dim>& globIP,
911  number time, int si) const
912 {
914 // vector of data for all inputs
915  std::vector<TDataIn> vDataIn(this->num_input());
916 
917 // gather all input data for this ip
918  for(size_t c = 0; c < vDataIn.size(); ++c)
919  (*m_vpUserData[c])(vDataIn[c], globIP, time, si);
920 
921 // evaluate data at ip
922  eval_value(value, vDataIn, globIP, time, si);
923 
924  UG_COND_THROW(IsFiniteAndNotTooBig(value)==false, value);
925 }
926 
927 template <typename TData, int dim, typename TDataIn>
928 template <int refDim>
930 evaluate(TData vValue[],
931  const MathVector<dim> vGlobIP[],
932  number time, int si,
933  GridObject* elem,
934  const MathVector<dim> vCornerCoords[],
935  const MathVector<refDim> vLocIP[],
936  const size_t nip,
937  LocalVector* u,
938  const MathMatrix<refDim, dim>* vJT) const
939 {
941 // vector of data for all inputs
942  std::vector<TDataIn> vDataIn(this->num_input());
943 
944 // gather all input data for this ip
945  for(size_t ip = 0; ip < nip; ++ip)
946  {
947  for(size_t c = 0; c < vDataIn.size(); ++c)
948  (*m_vpUserData[c])(vDataIn[c], vGlobIP[ip], time, si, elem, vCornerCoords, vLocIP[ip], u);
949 
950  // evaluate data at ip
951  eval_value(vValue[ip], vDataIn, vGlobIP[ip], time, si);
952  UG_COND_THROW(IsFiniteAndNotTooBig(vValue[ip])==false, vValue[ip]);
953  }
954 }
955 
956 template <typename TData, int dim, typename TDataIn>
957 template <int refDim>
959 eval_and_deriv(TData vValue[],
960  const MathVector<dim> vGlobIP[],
961  number time, int si,
962  GridObject* elem,
963  const MathVector<dim> vCornerCoords[],
964  const MathVector<refDim> vLocIP[],
965  const size_t nip,
966  LocalVector* u,
967  bool bDeriv,
968  int s,
969  std::vector<std::vector<TData> > vvvDeriv[],
970  const MathMatrix<refDim, dim>* vJT)
971 {
973 // vector of data for all inputs
974  std::vector<TDataIn> vDataIn(this->num_input());
975 
976  for(size_t ip = 0; ip < nip; ++ip)
977  {
978  // gather all input data for this ip
979  for(size_t c = 0; c < vDataIn.size(); ++c)
980  vDataIn[c] = m_vpUserData[c]->value(this->series_id(c,s), ip);
981 
982  // evaluate data at ip
983  eval_value(vValue[ip], vDataIn, vGlobIP[ip], time, si);
984  }
985 
986 // check if derivative is required
987  if(!bDeriv || this->zero_derivative()) return;
988 
989 // clear all derivative values
990  this->set_zero(vvvDeriv, nip);
991 
992 // loop all inputs
993  for(size_t c = 0; c < vDataIn.size(); ++c)
994  {
995  // check if we have the derivative w.r.t. this input, and the input has derivative
996  if(m_cbDerivRef[c] == LUA_NOREF || m_vpUserData[c]->zero_derivative()) continue;
997 
998  // loop ips
999  for(size_t ip = 0; ip < nip; ++ip)
1000  {
1001  // gather all input data for this ip
1002  for(size_t i = 0; i < vDataIn.size(); ++i)
1003  vDataIn[i] = m_vpUserData[i]->value(this->series_id(c,s), ip); //< series_id(c,s) or series_id(i,s)
1004 
1005  // data of derivative w.r.t. one component at ip-values
1006  TData derivVal;
1007 
1008  // evaluate data at ip
1009  eval_deriv(derivVal, vDataIn, vGlobIP[ip], time, si, c);
1010 
1011  // loop functions
1012  for(size_t fct = 0; fct < this->input_num_fct(c); ++fct)
1013  {
1014  // get common fct id for this function
1015  const size_t commonFct = this->input_common_fct(c, fct);
1016 
1017  // loop dofs
1018  for(size_t dof = 0; dof < this->num_sh(fct); ++dof)
1019  {
1021  mult_add(vvvDeriv[ip][commonFct][dof],
1022  derivVal,
1023  m_vpDependData[c]->deriv(this->series_id(c,s), ip, fct, dof));
1024  UG_COND_THROW(IsFiniteAndNotTooBig(vvvDeriv[ip][commonFct][dof])==false, vvvDeriv[ip][commonFct][dof]);
1025  }
1026  }
1027  }
1028  }
1029 }
1030 
1036 template <typename TData, int dim, typename TDataIn>
1038 {
1039 // resize arrays
1040  m_vpUserData.resize(num);
1041  m_vpDependData.resize(num);
1042 
1043 // forward size to base class
1044  base_type::set_num_input(num);
1045 }
1046 
1047 template <typename TData, int dim, typename TDataIn>
1050 {
1051  UG_ASSERT(i < m_vpUserData.size(), "Input not needed");
1052  UG_ASSERT(i < m_vpDependData.size(), "Input not needed");
1053 
1054 // check input number
1055  if(i >= this->num_input())
1056  UG_THROW("LuaUserFunction::set_input: Only " << this->num_input()
1057  << " inputs can be set. Use 'set_num_input' to increase"
1058  " the number of needed inputs.");
1059 
1060 // remember userdata
1061  m_vpUserData[i] = data;
1062 
1063 // cast to dependent data
1064  m_vpDependData[i] = data.template cast_dynamic<DependentUserData<TDataIn, dim> >();
1065 
1066 // forward to base class
1067  base_type::set_input(i, data, data);
1068 }
1069 
1070 template <typename TData, int dim, typename TDataIn>
1072 {
1073  set_input(i, CreateConstUserData<dim>(val, TDataIn()));
1074 }
1075 
1076 
1078 // LuaFunction
1080 
1081 template <typename TData, typename TDataIn>
1083 {
1085  m_cbValueRef = LUA_NOREF;
1086 }
1087 
1088 template <typename TData, typename TDataIn>
1089 void LuaFunction<TData,TDataIn>::set_lua_callback(const char* luaCallback, size_t numArgs)
1090 {
1091 // store name (string) of callback
1092  m_cbValueName = luaCallback;
1093 
1094 // obtain a reference
1095  lua_getglobal(m_L, m_cbValueName.c_str());
1096 
1097 // make sure that the reference is valid
1098  if(lua_isnil(m_L, -1)){
1099  UG_THROW("LuaFunction::set_lua_callback(...):"
1100  "Specified lua callback does not exist: " << m_cbValueName);
1101  }
1102 
1103 // store reference to lua function
1104  m_cbValueRef = luaL_ref(m_L, LUA_REGISTRYINDEX);
1105 
1106 // remember number of arguments to be used
1107  m_numArgs = numArgs;
1108 }
1109 
1110 template <typename TData, typename TDataIn>
1111 void LuaFunction<TData,TDataIn>::operator() (TData& out, int numArgs, ...)
1112 {
1113  PROFILE_CALLBACK_BEGIN(operatorBracket);
1114  UG_ASSERT(numArgs == (int)m_numArgs, "Number of arguments mismatched.");
1115 
1116 // push the callback function on the stack
1117  lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_cbValueRef);
1118 
1119 // get list of arguments
1120  va_list ap;
1121  va_start(ap, numArgs);
1122 
1123 // read all arguments and push them to the lua stack
1124  for(int i = 0; i < numArgs; ++i)
1125  {
1126  // cast data
1127  TDataIn val = va_arg(ap, TDataIn);
1128 
1129  // push data to lua stack
1130  lua_traits<TDataIn>::push(m_L, val);
1131  }
1132 
1133 // end read in of parameters
1134  va_end(ap);
1135 
1136 // compute total args size
1137  size_t argSize = lua_traits<TDataIn>::size * numArgs;
1138 
1139 // compute total return size
1140  size_t retSize = lua_traits<TData>::size;
1141 
1142 // call lua function
1143  if(lua_pcall(m_L, argSize, retSize, 0) != 0)
1144  UG_THROW("LuaFunction::operator(...): Error while "
1145  "running callback '" << m_cbValueName << "',"
1146  " lua message: "<< lua_tostring(m_L, -1));
1147 
1148  try{
1149  // read return value
1150  lua_traits<TData>::read(m_L, out);
1151  UG_COND_THROW(IsFiniteAndNotTooBig(out)==false, out);
1152  }
1153  UG_CATCH_THROW("LuaFunction::operator(...): Error while running "
1154  "callback '" << m_cbValueName << "'");
1155 
1156 // pop values
1157  lua_pop(m_L, retSize);
1158 
1160 }
1161 
1162 
1163 
1164 } // end namespace ug
1165 
1166 #endif /* LUA_USER_DATA_IMPL_H_ */
parameterString s
location name
Definition: checkpoint_util.lua:128
Definition: smart_pointer.h:108
T * get()
returns encapsulated pointer
Definition: smart_pointer.h:197
int * refcount_ptr() const
WARNING: this method is DANGEROUS!
Definition: smart_pointer.h:263
Type based UserData.
Definition: user_data.h:501
The base class for all geometric objects, such as vertices, edges, faces, volumes,...
Definition: grid_base_objects.h:157
Definition: local_algebra.h:198
Handle for a lua reference.
Definition: lua_function_handle.h:40
int ref
Definition: lua_function_handle.h:42
int m_cbValueRef
reference to lua function
Definition: lua_user_data.h:422
virtual void operator()(TData &out, int numArgs,...)
evaluates the data
Definition: lua_user_data_impl.h:1111
lua_State * m_L
lua state
Definition: lua_user_data.h:425
LuaFunction()
constructor
Definition: lua_user_data_impl.h:1082
void set_lua_callback(const char *luaCallback, size_t numArgs)
sets the Lua function used to compute the data
Definition: lua_user_data_impl.h:1089
Factory providing LuaUserData.
Definition: lua_user_data.h:180
static void remove(const std::string &name)
removes the user data
Definition: lua_user_data_impl.h:417
static SmartPtr< LuaUserData< TData, dim, TRet > > provide_or_create(const std::string &name)
returns new Data if not already created, already existing else
Definition: lua_user_data_impl.h:370
provides data specified in the lua script
Definition: lua_user_data.h:96
static bool check_callback_returns(const char *callName, const bool bThrow=false)
returns true if callback has correct return values
Definition: lua_user_data_impl.h:249
lua_State * m_L
lua state
Definition: lua_user_data.h:157
int m_callbackRef
reference to lua function
Definition: lua_user_data.h:147
static std::string signature()
returns string of required callback signature
Definition: lua_user_data_impl.h:67
static std::string name()
returns name of UserData
Definition: lua_user_data_impl.h:84
std::string m_callbackName
callback name as string
Definition: lua_user_data.h:144
LuaUserData(const char *luaCallback)
Constructor.
Definition: lua_user_data_impl.h:94
TRet evaluate(TData &D, const MathVector< dim > &x, number time, int si) const
evaluates the data at a given point and time
Definition: lua_user_data_impl.h:284
virtual ~LuaUserData()
}
Definition: lua_user_data_impl.h:355
LuaUserFunction(const char *luaCallback, size_t numArgs)
constructor
Definition: lua_user_data_impl.h:445
void eval_value(TData &out, const std::vector< TDataIn > &dataIn, const MathVector< dim > &x, number time, int si) const
evaluates the data at a given point and time
Definition: lua_user_data_impl.h:740
virtual ~LuaUserFunction()
destructor frees the reference
Definition: lua_user_data_impl.h:510
void free_deriv_callback_ref(size_t arg)
frees callback-references for derivate callbacks
Definition: lua_user_data_impl.h:531
void evaluate(TData &value, const MathVector< dim > &globIP, number time, int si) const
Definition: lua_user_data_impl.h:909
void set_input(size_t i, SmartPtr< CplUserData< TDataIn, dim > > data)
set input value for paramter i
Definition: lua_user_data_impl.h:1049
std::vector< int > m_cbDerivRef
Definition: lua_user_data.h:346
void eval_deriv(TData &out, const std::vector< TDataIn > &dataIn, const MathVector< dim > &x, number time, int si, size_t arg) const
evaluates the data at a given point and time
Definition: lua_user_data_impl.h:823
void eval_and_deriv(TData vValue[], const MathVector< dim > vGlobIP[], number time, int si, GridObject *elem, const MathVector< dim > vCornerCoords[], const MathVector< refDim > vLocIP[], const size_t nip, LocalVector *u, bool bDeriv, int s, std::vector< std::vector< TData > > vvvDeriv[], const MathMatrix< refDim, dim > *vJT=NULL)
Definition: lua_user_data_impl.h:959
lua_State * m_L
lua state
Definition: lua_user_data.h:349
void free_callback_ref()
frees the callback-reference, if a callback was set.
Definition: lua_user_data_impl.h:522
int m_cbValueRef
reference to lua function
Definition: lua_user_data.h:345
void set_deriv(size_t arg, const char *luaCallback)
sets the Lua function used to compute the derivative
Definition: lua_user_data_impl.h:601
std::vector< std::string > m_cbDerivName
Definition: lua_user_data.h:342
void set_lua_value_callback(const char *luaCallback, size_t numArgs)
sets the Lua function used to compute the data
Definition: lua_user_data_impl.h:541
void set_num_input(size_t num)
set number of needed inputs
Definition: lua_user_data_impl.h:1037
virtual void operator()(TData &out, int numArgs,...) const
evaluates the data
Definition: lua_user_data_impl.h:661
A class for fixed size, dense matrices.
Definition: math_matrix.h:52
Definition: lua_compiler.h:50
const std::string & name() const
Definition: lua_compiler.h:86
int num_out() const
Definition: lua_compiler.h:81
bool call(double *ret, const double *in) const
Definition: lua_compiler.cpp:263
function util FileDummy read(...) error("io.open_0 does not support read.") end
static const int dim
#define UG_ASSERT(expr, msg)
Definition: assert.h:70
#define UG_CATCH_THROW(msg)
Definition: error.h:64
#define UG_THROW(msg)
Definition: error.h:57
#define UG_LOG(msg)
Definition: log.h:367
#define UG_COND_THROW(cond, msg)
UG_COND_THROW(cond, msg) : performs a UG_THROW(msg) if cond == true.
Definition: error.h:61
double number
Definition: types.h:124
struct lua_State lua_State
Definition: lua_table_handle.h:40
#define PROFILE_CALLBACK()
Definition: lua_user_data_impl.h:50
#define PROFILE_CALLBACK_END()
Definition: lua_user_data_impl.h:52
#define PROFILE_CALLBACK_BEGIN(name)
Definition: lua_user_data_impl.h:51
string GetLUAScriptFunctionDefined(const char *functionName)
returns file and line of defined script function
Definition: info_commands.cpp:368
lua_State * GetDefaultLuaState()
returns the default lua state
Definition: lua_util.cpp:242
the ug namespace
bool useLuaCompiler
Definition: info_commands.cpp:93
bool IsFiniteAndNotTooBig(double d)
Definition: number_util.h:39
SmartPtr< T, FreePolicy > make_sp(T *inst)
returns a SmartPtr for the passed raw pointer
Definition: smart_pointer.h:836
static void mult_add(TData &out, const TData &in1, const TDataIn &s)
computes out += s * in1 (with appropriate '*')
Lua Traits to push/pop on lua stack.
Definition: lua_traits.h:79