ug4
lua_user_data_impl.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2012-2015: G-CSC, Goethe University Frankfurt
3  * Author: Andreas Vogel
4  *
5  * This file is part of UG4.
6  *
7  * UG4 is free software: you can redistribute it and/or modify it under the
8  * terms of the GNU Lesser General Public License version 3 (as published by the
9  * Free Software Foundation) with the following additional attribution
10  * requirements (according to LGPL/GPL v3 §7):
11  *
12  * (1) The following notice must be displayed in the Appropriate Legal Notices
13  * of covered and combined works: "Based on UG4 (www.ug4.org/license)".
14  *
15  * (2) The following notice must be displayed at a prominent place in the
16  * terminal output of covered works: "Based on UG4 (www.ug4.org/license)".
17  *
18  * (3) The following bibliography is recommended for citation and must be
19  * preserved in all covered files:
20  * "Reiter, S., Vogel, A., Heppner, I., Rupp, M., and Wittum, G. A massively
21  * parallel geometric multigrid solver on hierarchically distributed grids.
22  * Computing and visualization in science 16, 4 (2013), 151-164"
23  * "Vogel, A., Reiter, S., Rupp, M., Nägel, A., and Wittum, G. UG4 -- a novel
24  * flexible software system for simulating pde based models on high performance
25  * computers. Computing and visualization in science 16, 4 (2013), 165-179"
26  *
27  * This program is distributed in the hope that it will be useful,
28  * but WITHOUT ANY WARRANTY; without even the implied warranty of
29  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
30  * GNU Lesser General Public License for more details.
31  */
32 
33 #ifndef __H__UG_BRIDGE__BRIDGES__USER_DATA__USER_DATA_IMPL_
34 #define __H__UG_BRIDGE__BRIDGES__USER_DATA__USER_DATA_IMPL_
35 
36 #include "lua_user_data.h"
39 
40 #include "info_commands.h"
42 
43 #if 0
44 #define PROFILE_CALLBACK() PROFILE_FUNC_GROUP("luacallback")
45 #define PROFILE_CALLBACK_BEGIN(name) PROFILE_BEGIN_GROUP(name, "luacallback")
46 #define PROFILE_CALLBACK_END() PROFILE_END()
47 #else
48 #define PROFILE_CALLBACK()
49 #define PROFILE_CALLBACK_BEGIN(name)
50 #define PROFILE_CALLBACK_END()
51 #endif
52 namespace ug{
53 
54 #ifdef USE_LUA2C
55  extern bool useLuaCompiler;
56 #endif
57 
58 
59 
61 // LuaUserData
63 
64 template <typename TData, int dim, typename TRet>
66 {
67  std::stringstream ss;
68  ss << "function name(";
69  if(dim >= 1) ss << "x";
70  if(dim >= 2) ss << ", y";
71  if(dim >= 3) ss << ", z";
72  ss << ", t, si)\n ... \n return ";
73  if(lua_traits<TRet>::size != 0)
74  ss << lua_traits<TRet>::signature() << ", ";
75  ss << lua_traits<TData>::signature();
76  ss << "\nend";
77  return ss.str();
78 }
79 
80 
81 template <typename TData, int dim, typename TRet>
83 {
84  std::stringstream ss;
85  ss << "Lua";
86  if(lua_traits<TRet>::size > 0) ss << "Cond";
87  ss << "User" << lua_traits<TData>::name() << dim << "d";
88  return ss.str();
89 }
90 
91 template <typename TData, int dim, typename TRet>
93  : m_callbackName(luaCallback), m_bFromFactory(false)
94 {
95 // get lua state
97 
98 // obtain a reference
99  lua_getglobal(m_L, m_callbackName.c_str());
100 
101 // make sure that the reference is valid
102  if(lua_isnil(m_L, -1)){
103  UG_THROW(name() << ": Specified lua callback "
104  "does not exist: " << m_callbackName);
105  }
106 
107 // store reference to lua function
108  m_callbackRef = luaL_ref(m_L, LUA_REGISTRYINDEX);
109 
110 // make a test run
112 
113  #ifdef USE_LUA2C
114  if(useLuaCompiler) m_luaComp.create(luaCallback);
115  #endif
116 }
117 
118 template <typename TData, int dim, typename TRet>
120  : m_callbackName("__anonymous__lua__function__"), m_bFromFactory(false)
121 {
122 // get lua state
124 
125 // store reference to lua function
126  m_callbackRef = handle.ref;
127 
128 // make a test run
130 
131  #ifdef USE_LUA2C
132 // UG_THROW("LuaFunctionHandle usage currently not supported with LUA2C.");
133  if(useLuaCompiler) m_luaComp.create(m_callbackName.c_str(), &handle);
134  #endif
135 }
136 
137 
138 template <typename TData, int dim, typename TRet>
140 check_callback_returns(lua_State* L, int callbackRef, const char* callName, const bool bThrow)
141 {
143 // get current stack level
144  const int level = lua_gettop(L);
145 
146 // dummy values to invoke the callback once
147  MathVector<dim> x; x = 0.0;
148  number time = 0.0;
149  int si = 0;
150 
151 // push the callback function on the stack
152  lua_rawgeti(L, LUA_REGISTRYINDEX, callbackRef);
153 
154 // push space coordinates on stack
155  lua_traits<MathVector<dim> >::push(L, x);
156 
157 // push time on stack
158  lua_traits<number>::push(L, time);
159 
160 // push subset on stack
161  lua_traits<int>::push(L, si);
162 
163 // compute total args size
164  const int argSize = lua_traits<MathVector<dim> >::size
167 
168 // compute total return size
169  const int retSize = lua_traits<TData>::size + lua_traits<TRet>::size;
170 
171 // call lua function
172  if(lua_pcall(L, argSize, LUA_MULTRET, 0) != 0)
173  UG_THROW(name() << ": Error while "
174  "testing callback '" << callName << "',"
175  " lua message: "<< lua_tostring(L, -1));
176 
177  // get number of results
178  const int numResults = lua_gettop(L) - level;
179 
180 // success flag
181  bool bRet = true;
182 
183 // if number of results is wrong return error
184  if(numResults != retSize){
185  if(bThrow){
186  UG_THROW(name() << ": Number of return values incorrect "
187  "for callback\n"<<callName<< " (" << bridge::GetLUAScriptFunctionDefined(callName) << ")"
188  "\nRequired: "<<retSize<<", passed: "<<numResults
189  <<". Use signature as follows:\n"
190  << signature());
191  }
192  else{
193  bRet = false;
194  }
195  }
196 
197 // check return value
198  if(!lua_traits<TData>::check(L)){
199  if(bThrow){
200  UG_THROW(name() << ": Data values type incorrect "
201  "for callback\n"<<callName<< " (" << bridge::GetLUAScriptFunctionDefined(callName) << ")"
202  "\nUse signature as follows:\n"
203  << signature());
204  }
205  else{
206  bRet = false;
207  }
208  }
209 
210 // read return flag (may be void)
211  if(!lua_traits<TRet>::check(L, -retSize)){
212  if(bThrow){
213  UG_THROW("LuaUserData: Return values type incorrect "
214  "for callback\n"<<callName<< " (" << bridge::GetLUAScriptFunctionDefined(callName) << ")"
215  "\nUse signature as follows:\n"
216  << signature());
217  }
218  else{
219  bRet = false;
220  }
221  }
222 
223 // pop values
224  lua_pop(L, numResults);
225 
226 // return match
227  return bRet;
228 }
229 
230 template <typename TData, int dim, typename TRet>
232 check_callback_returns(LuaFunctionHandle handle, const bool bThrow)
233 {
235 // get lua state
237 
238 // forward call
239  bool bRet = check_callback_returns(L, handle.ref, "__lua_function_handle__", bThrow);
240 
241 // return match
242  return bRet;
243 }
244 
245 template <typename TData, int dim, typename TRet>
247 check_callback_returns(const char* callName, const bool bThrow)
248 {
250 // get lua state
252 
253 // obtain a reference
254  lua_getglobal(L, callName);
255 
256 // check if reference is valid
257  if(lua_isnil(L, -1)) {
258  if(bThrow) {
259  UG_THROW(name() << ": Cannot find specified lua callback "
260  " with name: "<<callName);
261  }
262  else {
263  return false;
264  }
265  }
266 
267 // get reference
268  int callbackRef = luaL_ref(L, LUA_REGISTRYINDEX);
269 
270 // forward call
271  bool bRet = check_callback_returns(L, callbackRef, callName, bThrow);
272 
273 // free reference to callback
274  luaL_unref(L, LUA_REGISTRYINDEX, callbackRef);
275 
276 // return match
277  return bRet;
278 }
279 
280 template <typename TData, int dim, typename TRet>
282 evaluate(TData& D, const MathVector<dim>& x, number time, int si) const
283 {
285  #ifdef USE_LUA2C
286  if(useLuaCompiler && m_luaComp.is_valid())
287  {
288  double d[dim+2];
289  for(int i=0; i<dim; i++)
290  d[i] = x[i];
291  d[dim] = time;
292  d[dim+1] = si;
293  double ret[lua_traits<TData>::size+1];
294  m_luaComp.call(ret, d);
295  //TData D2;
296  TRet *t=NULL;
297  lua_traits<TData>::read(D, ret, t);
298  return lua_traits<TRet>::do_return(ret[0]);
299  }
300  else
301  #endif
302  {
303  // push the callback function on the stack
304  lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_callbackRef);
305 
306  // push space coordinates on stack
307  lua_traits<MathVector<dim> >::push(m_L, x);
308 
309  // push time on stack
310  lua_traits<number>::push(m_L, time);
311 
312  // push subset index on stack
313  lua_traits<int>::push(m_L, si);
314 
315  // compute total args size
316  const int argSize = lua_traits<MathVector<dim> >::size
319 
320  // compute total return size
321  const int retSize = lua_traits<TData>::size + lua_traits<TRet>::size;
322 
323  // call lua function
324  if(lua_pcall(m_L, argSize, retSize, 0) != 0)
325  UG_THROW(name() << "::operator(...): Error while "
326  "running callback '" << m_callbackName << "',"
327  " lua message: "<< lua_tostring(m_L, -1)<<".\n"
328  "Use signature as follows:\n"
329  << signature());
330 
331  bool res = false;
332  try{
333  // read return value
334  lua_traits<TData>::read(m_L, D);
335 
336  // read return flag (may be void)
337  lua_traits<TRet>::read(m_L, res, -retSize);
338  }
339  UG_CATCH_THROW(name() << "::operator(...): Error while running "
340  "callback '" << m_callbackName << "'.\n"
341  "Use signature as follows:\n"
342  << signature());
343 
344  // pop values
345  lua_pop(m_L, retSize);
346 
347  // forward flag
348  return lua_traits<TRet>::do_return(res);
349  }
350 }
351 
352 template <typename TData, int dim, typename TRet>
354 {
355 // free reference to callback
356  luaL_unref(m_L, LUA_REGISTRYINDEX, m_callbackRef);
357 
358  if(m_bFromFactory)
360 }
361 
363 // LuaUserDataFactory
365 
366 template <typename TData, int dim, typename TRet>
369 {
371  typedef std::map<std::string, std::pair<LuaUserData<TData,dim,TRet>*, int*> > Map;
372  typedef typename Map::iterator iterator;
373 
374 // check for element
375  iterator iter = m_mData.find(name);
376 
377 // if name does not exist, create new one
378  if(iter == m_mData.end())
379  {
381  = make_sp(new LuaUserData<TData,dim,TRet>(name.c_str()));
382 
383  // the LuaUserData must remember to unregister itself at destruction
384  sp->set_created_from_factory(true);
385 
386  // NOTE AND WARNING: This is very hacky and dangerous. We only do this
387  // since we exactly know what we are doing and everything is save and
388  // only in protected or private area. However, if you once want to change
389  // this code, please be aware, that we store here plain pointers and
390  // associated reference counters of a SmartPtr. This should not be done
391  // in general and this kind of coding is not recommended at all. Please
392  // use different approaches whenever possible.
393  std::pair<LuaUserData<TData,dim,TRet>*, int*>& data = m_mData[name];
394  data.first = sp.get();
395  data.second = sp.refcount_ptr();
396 
397  return sp;
398  }
399 // else return present data
400  {
401  // NOTE AND WARNING: This is very hacky and dangerous. We only do this
402  // since we exactly know what we are doing and everything is save and
403  // only in protected or private area. However, if you once want to change
404  // this code, please be aware, that we store here plain pointers and
405  // associated reference counters of a SmartPtr. This should not be done
406  // in general and this kind of coding is not recommended at all. Please
407  // use different approaches whenever possible.
408  std::pair<LuaUserData<TData,dim,TRet>*, int*>& data = iter->second;
409  return SmartPtr<LuaUserData<TData,dim,TRet> >(data.first, data.second);
410  }
411 }
412 
413 template <typename TData, int dim, typename TRet>
414 void
416 {
417  typedef std::map<std::string, std::pair<LuaUserData<TData,dim,TRet>*, int*> > Map;
418  typedef typename Map::iterator iterator;
419 
420 // check for element
421  iterator iter = m_mData.find(name);
422 
423 // if name does not exist, create new one
424  if(iter == m_mData.end())
425  UG_THROW("LuaUserDataFactory: trying to remove non-registered"
426  " data with name: "<<name);
427 
428  m_mData.erase(iter);
429 }
430 
431 
432 // instantiation of static member
433 template <typename TData, int dim, typename TRet>
434 std::map<std::string, std::pair<LuaUserData<TData,dim,TRet>*, int*> >
435 LuaUserDataFactory<TData,dim,TRet>::m_mData = std::map<std::string, std::pair<LuaUserData<TData,dim,TRet>*, int*> >();
436 
438 // LuaUserFunction
440 
441 template <typename TData, int dim, typename TDataIn>
443 LuaUserFunction(const char* luaCallback, size_t numArgs)
444  : m_numArgs(numArgs), m_bPosTimeNeed(false)
445 {
447  m_cbValueRef = LUA_NOREF;
448  m_cbDerivRef.clear();
449  m_cbDerivName.clear();
450  set_lua_value_callback(luaCallback, numArgs);
451  #ifdef USE_LUA2C
452  if(useLuaCompiler) m_luaComp.create(luaCallback);
453  #endif
454 }
455 
456 template <typename TData, int dim, typename TDataIn>
458 LuaUserFunction(const char* luaCallback, size_t numArgs, bool bPosTimeNeed)
459  : m_numArgs(numArgs), m_bPosTimeNeed(bPosTimeNeed)
460 {
462  m_cbValueRef = LUA_NOREF;
463  m_cbDerivRef.clear();
464  m_cbDerivName.clear();
465  set_lua_value_callback(luaCallback, numArgs);
466  #ifdef USE_LUA2C
467  m_luaComp_Deriv.clear();
468  #endif
469 }
470 
471 
472 template <typename TData, int dim, typename TDataIn>
474 LuaUserFunction(LuaFunctionHandle handle, size_t numArgs)
475  : m_numArgs(numArgs), m_bPosTimeNeed(false)
476 {
478  m_cbValueRef = LUA_NOREF;
479  m_cbDerivRef.clear();
480  m_cbDerivName.clear();
481  set_lua_value_callback(handle, numArgs);
482  #ifdef USE_LUA2C
483  if(useLuaCompiler){
484  UG_LOG("WARNING (in LuaUserFunction): LUA2C compiler "
485  "can't be executed for FunctionHandle.\n");
486  }
487  #endif
488 }
489 
490 template <typename TData, int dim, typename TDataIn>
492 LuaUserFunction(LuaFunctionHandle handle, size_t numArgs, bool bPosTimeNeed)
493  : m_numArgs(numArgs), m_bPosTimeNeed(bPosTimeNeed)
494 {
496  m_cbValueRef = LUA_NOREF;
497  m_cbDerivRef.clear();
498  m_cbDerivName.clear();
499  set_lua_value_callback(handle, numArgs);
500  #ifdef USE_LUA2C
501  m_luaComp_Deriv.clear();
502  #endif
503 }
504 
505 
506 
507 template <typename TData, int dim, typename TDataIn>
509 {
510 // free reference to callback
511  free_callback_ref();
512 
513 // free references to derivate callbacks
514  for(size_t i = 0; i < m_numArgs; ++i){
515  free_deriv_callback_ref(i);
516  }
517 }
518 
519 template <typename TData, int dim, typename TDataIn>
521 {
522  if(m_cbValueRef != LUA_NOREF){
523  luaL_unref(m_L, LUA_REGISTRYINDEX, m_cbValueRef);
524  m_cbValueRef = LUA_NOREF;
525  }
526 }
527 
528 template <typename TData, int dim, typename TDataIn>
530 {
531  if(m_cbDerivRef[arg] != LUA_NOREF){
532  luaL_unref(m_L, LUA_REGISTRYINDEX, m_cbDerivRef[arg]);
533  m_cbDerivRef[arg] = LUA_NOREF;
534  }
535 }
536 
537 
538 template <typename TData, int dim, typename TDataIn>
539 void LuaUserFunction<TData,dim,TDataIn>::set_lua_value_callback(const char* luaCallback, size_t numArgs)
540 {
541 // store name (string) of callback
542  m_cbValueName = luaCallback;
543 
544 // obtain a reference
545  lua_getglobal(m_L, m_cbValueName.c_str());
546 
547 // make sure that the reference is valid
548  if(lua_isnil(m_L, -1)){
549  UG_THROW("LuaUserFunction::set_lua_value_callback(...):"
550  "Specified callback does not exist: " << m_cbValueName);
551  }
552 
553 // if a callback was already set, we have to free the old one
554  free_callback_ref();
555 
556 // store reference to lua function
557  m_cbValueRef = luaL_ref(m_L, LUA_REGISTRYINDEX);
558 
559 // remember number of arguments to be used
560  m_numArgs = numArgs;
561  m_cbDerivName.resize(numArgs);
562  m_cbDerivRef.resize(numArgs, LUA_NOREF);
563 
564 // set num inputs for linker
565  set_num_input(numArgs);
566 
567  #ifdef USE_LUA2C
568  m_luaComp_Deriv.resize(numArgs);
569  #endif
570 }
571 
572 template <typename TData, int dim, typename TDataIn>
574 set_lua_value_callback(LuaFunctionHandle handle, size_t numArgs)
575 {
576 // store name (string) of callback
577  m_cbValueName = "__anonymous__lua__function__";
578 
579 // if a callback was already set, we have to free the old one
580  free_callback_ref();
581 
582 // store reference to lua function
583  m_cbValueRef = handle.ref;
584 
585 // remember number of arguments to be used
586  m_numArgs = numArgs;
587  m_cbDerivName.resize(numArgs);
588  m_cbDerivRef.resize(numArgs, LUA_NOREF);
589 
590 // set num inputs for linker
591  set_num_input(numArgs);
592 
593  #ifdef USE_LUA2C
594  m_luaComp_Deriv.resize(numArgs);
595  #endif
596 }
597 
598 template <typename TData, int dim, typename TDataIn>
599 void LuaUserFunction<TData,dim,TDataIn>::set_deriv(size_t arg, const char* luaCallback)
600 {
601 // check number of arg
602  if(arg >= m_numArgs)
603  UG_THROW("LuaUserFunction::set_lua_deriv_callback: Trying "
604  "to set a derivative for argument " << arg <<", that "
605  "does not exist. Number of arguments is "<<m_numArgs);
606 
607 // store name (string) of callback
608  m_cbDerivName[arg] = luaCallback;
609 
610 // free old reference
611  free_deriv_callback_ref(arg);
612 
613 // obtain a reference
614  lua_getglobal(m_L, m_cbDerivName[arg].c_str());
615 
616 // make sure that the reference is valid
617  if(lua_isnil(m_L, -1)){
618  UG_THROW("LuaUserFunction::set_lua_deriv_callback(...):"
619  "Specified callback does not exist: " << m_cbDerivName[arg]);
620  }
621 
622 // store reference to lua function
623  m_cbDerivRef[arg] = luaL_ref(m_L, LUA_REGISTRYINDEX);
624 
625  #ifdef USE_LUA2C
626  if(useLuaCompiler) m_luaComp_Deriv[arg].create(luaCallback);
627  #endif
628 
629 }
630 
631 template <typename TData, int dim, typename TDataIn>
633 {
634 // check number of arg
635  if(arg >= m_numArgs)
636  UG_THROW("LuaUserFunction::set_lua_deriv_callback: Trying "
637  "to set a derivative for argument " << arg <<", that "
638  "does not exist. Number of arguments is "<<m_numArgs);
639 
640 // store name (string) of callback
641  m_cbDerivName[arg] = std::string("__anonymous__lua__function__");
642 
643 // free old reference
644  free_deriv_callback_ref(arg);
645 
646 // store reference to lua function
647  m_cbDerivRef[arg] = handle.ref;
648 
649  #ifdef USE_LUA2C
650  // if(useLuaCompiler) m_luaComp_Deriv[arg].create(luaCallback);
651  #endif
652 
653 }
654 
655 
656 
657 
658 template <typename TData, int dim, typename TDataIn>
659 void LuaUserFunction<TData,dim,TDataIn>::operator() (TData& out, int numArgs, ...) const
660 {
662  #ifdef USE_LUA2C
663  if(useLuaCompiler && m_luaComp.is_valid())
664  {
665  double d[20];
666  // get list of arguments
667  va_list ap2;
668  va_start(ap2, numArgs);
669 
670  // read all arguments and push them to the lua stack
671  for(int i = 0; i < numArgs; ++i)
672  d[i] = va_arg(ap2, double);
673  va_end(ap2);
674 
675  double ret[lua_traits<TData>::size+1];
676 
677  UG_ASSERT(m_luaComp.num_in() == numArgs && m_luaComp.num_out() == lua_traits<TData>::size,
678  m_luaComp.name() << ", " << m_luaComp.num_in() << " != " << numArgs << " or " << m_luaComp.num_out() << " != " << lua_traits<TData>::size);
679  m_luaComp.call(ret, d);
680  //TData D2;
681  void *t=NULL;
682  //TData out2;
683  lua_traits<TData>::read(out, ret, t);
684  return;
685  }
686  else
687  #endif
688  {
689  UG_ASSERT(numArgs == (int)m_numArgs, "Number of arguments mismatched.");
690 
691  // push the callback function on the stack
692  lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_cbValueRef);
693 
694  // get list of arguments
695  va_list ap;
696  va_start(ap, numArgs);
697 
698  // read all arguments and push them to the lua stack
699  for(int i = 0; i < numArgs; ++i)
700  {
701  // cast data
702  TDataIn val = va_arg(ap, TDataIn);
703 
704  // push data to lua stack
705  lua_traits<TDataIn>::push(m_L, val);
706  }
707 
708  // end read in of parameters
709  va_end(ap);
710 
711  // compute total args size
712  size_t argSize = lua_traits<TDataIn>::size * numArgs;
713 
714  // compute total return size
715  size_t retSize = lua_traits<TData>::size;
716 
717  // call lua function
718  if(lua_pcall(m_L, argSize, retSize, 0) != 0)
719  UG_THROW("LuaUserFunction::operator(...): Error while "
720  "running callback '" << m_cbValueName << "',"
721  " lua message: "<< lua_tostring(m_L, -1));
722 
723  try{
724  // read return value
725  lua_traits<TData>::read(m_L, out);
726  UG_COND_THROW(IsFiniteAndNotTooBig(out)==false, out);
727  }
728  UG_CATCH_THROW("LuaUserFunction::operator(...): Error while running "
729  "callback '" << m_cbValueName << "'");
730 
731  // pop values
732  lua_pop(m_L, retSize);
733  }
734 }
735 
736 
737 template <typename TData, int dim, typename TDataIn>
738 void LuaUserFunction<TData,dim,TDataIn>::eval_value(TData& out, const std::vector<TDataIn>& dataIn,
739  const MathVector<dim>& x, number time, int si) const
740 {
742  #ifdef USE_LUA2C
743  if(useLuaCompiler && m_luaComp.is_valid())
744  {
745  double d[20];
746 
747  // read all arguments and push them to the lua stack
748  for(size_t i = 0; i < dataIn.size(); ++i)
749  d[i] = dataIn[i];
750  if(m_bPosTimeNeed){
751  for(int i=0; i<dim; i++)
752  d[i+m_numArgs] = x[i];
753  d[dim+m_numArgs]=time;
754  d[dim+m_numArgs+1]=si;
755  UG_ASSERT(dim+m_numArgs+1 < 20, m_luaComp.name());
756  }
757 
758  double ret[lua_traits<TData>::size];
759  m_luaComp.call(ret, d);
760  //TData D2;
761  void *t=NULL;
762  //TData out2;
763  UG_ASSERT(m_luaComp.num_out() == lua_traits<TData>::size, m_luaComp.name() << ", " << m_luaComp.num_out() << " != " << lua_traits<TData>::size);
764  lua_traits<TData>::read(out, ret, t);
765  return;
766  }
767  else
768  #endif
769  {
770  UG_ASSERT(dataIn.size() == m_numArgs, "Number of arguments mismatched.");
771 
772  // push the callback function on the stack
773  lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_cbValueRef);
774 
775  // read all arguments and push them to the lua stack
776  for(size_t i = 0; i < dataIn.size(); ++i)
777  {
778  // push data to lua stack
779  lua_traits<TDataIn>::push(m_L, dataIn[i]);
780  }
781 
782  // if needed, read additional coordinate, time and subset index arguments and push them to the lua stack
783  if(m_bPosTimeNeed){
784  lua_traits<MathVector<dim> >::push(m_L, x);
785  lua_traits<number>::push(m_L, time);
786  lua_traits<int>::push(m_L, si);
787  }
788 
789  // compute total args size
790  size_t argSize = lua_traits<TDataIn>::size * dataIn.size();
791  if(m_bPosTimeNeed){
792  argSize += lua_traits<MathVector<dim> >::size
795  }
796 
797  // compute total return size
798  size_t retSize = lua_traits<TData>::size;
799 
800  // call lua function
801  if(lua_pcall(m_L, argSize, retSize, 0) != 0)
802  UG_THROW("LuaUserFunction::eval_value(...): Error while "
803  "running callback '" << m_cbValueName << "',"
804  " lua message: "<< lua_tostring(m_L, -1));
805 
806  try{
807  // read return value
808  lua_traits<TData>::read(m_L, out);
809  UG_COND_THROW(IsFiniteAndNotTooBig(out)==false, out);
810  }
811  UG_CATCH_THROW("LuaUserFunction::eval_value(...): Error while "
812  "running callback '" << m_cbValueName << "'");
813 
814  // pop values
815  lua_pop(m_L, retSize);
816  }
817 }
818 
819 
820 template <typename TData, int dim, typename TDataIn>
821 void LuaUserFunction<TData,dim,TDataIn>::eval_deriv(TData& out, const std::vector<TDataIn>& dataIn,
822  const MathVector<dim>& x, number time, int si, size_t arg) const
823 {
825  #ifdef USE_LUA2C
826  if(useLuaCompiler && m_luaComp_Deriv[arg].is_valid()
827  && dim+m_numArgs+1 < 20 && m_luaComp_Deriv[arg].num_out() == lua_traits<TData>::size)
828  {
829  const bridge::LUACompiler &luaComp = m_luaComp_Deriv[arg];
830  double d[25];
831  UG_ASSERT(dim+m_numArgs+1 < 20, luaComp.name());
832  for(size_t i=0; i<m_numArgs; i++)
833  d[i] = dataIn[i];
834  if(m_bPosTimeNeed){
835  for(int i=0; i<dim; i++)
836  d[i+m_numArgs] = x[i];
837  d[dim+m_numArgs]=time;
838  d[dim+m_numArgs+1]=si;
839  UG_ASSERT(dim+m_numArgs+1 < 20, luaComp.name());
840  }
842  luaComp.name() << " has wrong number of outputs: is " << luaComp.num_out() << ", needs " << lua_traits<TData>::size);
843  double ret[lua_traits<TData>::size+1];
844  luaComp.call(ret, d);
845  //TData D2;
846  void *t=NULL;
847  //TData out2;
848  lua_traits<TData>::read(out, ret, t);
849  return;
850  }
851  else
852  #endif
853  {
854  UG_ASSERT(dataIn.size() == m_numArgs, "Number of arguments mismatched.");
855  UG_ASSERT(arg < m_numArgs, "Argument does not exist.");
856 
857  // push the callback function on the stack
858  lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_cbDerivRef[arg]);
859 
860  // read all arguments and push them to the lua stack
861  for(size_t i = 0; i < dataIn.size(); ++i)
862  {
863  // push data to lua stack
864  lua_traits<TDataIn>::push(m_L, dataIn[i]);
865  }
866 
867  // if needed, read additional coordinate, time and subset index arguments and push them to the lua stack
868  if(m_bPosTimeNeed){
869  lua_traits<MathVector<dim> >::push(m_L, x);
870  lua_traits<number>::push(m_L, time);
871  lua_traits<int>::push(m_L, si);
872  }
873 
874  // compute total args size
875  size_t argSize = lua_traits<TDataIn>::size * dataIn.size();
876  if(m_bPosTimeNeed){
877  argSize += lua_traits<MathVector<dim> >::size
880  }
881 
882  // compute total return size
883  size_t retSize = lua_traits<TData>::size;
884 
885  // call lua function
886  if(lua_pcall(m_L, argSize, retSize, 0) != 0)
887  UG_THROW("LuaUserFunction::eval_deriv: Error while "
888  "running callback '" << m_cbDerivName[arg] << "',"
889  " lua message: "<< lua_tostring(m_L, -1) );
890 
891  try{
892  // read return value
893  lua_traits<TData>::read(m_L, out);
894  UG_COND_THROW(IsFiniteAndNotTooBig(out)==false, out);
895  }
896  UG_CATCH_THROW("LuaUserFunction::eval_deriv(...): Error while "
897  "running callback '" << m_cbDerivName[arg] << "'");
898 
899  // pop values
900  lua_pop(m_L, retSize);
901  }
902 }
903 
904 
905 template <typename TData, int dim, typename TDataIn>
907 evaluate (TData& value,
908  const MathVector<dim>& globIP,
909  number time, int si) const
910 {
912 // vector of data for all inputs
913  std::vector<TDataIn> vDataIn(this->num_input());
914 
915 // gather all input data for this ip
916  for(size_t c = 0; c < vDataIn.size(); ++c)
917  (*m_vpUserData[c])(vDataIn[c], globIP, time, si);
918 
919 // evaluate data at ip
920  eval_value(value, vDataIn, globIP, time, si);
921 
922  UG_COND_THROW(IsFiniteAndNotTooBig(value)==false, value);
923 }
924 
925 template <typename TData, int dim, typename TDataIn>
926 template <int refDim>
928 evaluate(TData vValue[],
929  const MathVector<dim> vGlobIP[],
930  number time, int si,
931  GridObject* elem,
932  const MathVector<dim> vCornerCoords[],
933  const MathVector<refDim> vLocIP[],
934  const size_t nip,
935  LocalVector* u,
936  const MathMatrix<refDim, dim>* vJT) const
937 {
939 // vector of data for all inputs
940  std::vector<TDataIn> vDataIn(this->num_input());
941 
942 // gather all input data for this ip
943  for(size_t ip = 0; ip < nip; ++ip)
944  {
945  for(size_t c = 0; c < vDataIn.size(); ++c)
946  (*m_vpUserData[c])(vDataIn[c], vGlobIP[ip], time, si, elem, vCornerCoords, vLocIP[ip], u);
947 
948  // evaluate data at ip
949  eval_value(vValue[ip], vDataIn, vGlobIP[ip], time, si);
950  UG_COND_THROW(IsFiniteAndNotTooBig(vValue[ip])==false, vValue[ip]);
951  }
952 }
953 
954 template <typename TData, int dim, typename TDataIn>
955 template <int refDim>
957 eval_and_deriv(TData vValue[],
958  const MathVector<dim> vGlobIP[],
959  number time, int si,
960  GridObject* elem,
961  const MathVector<dim> vCornerCoords[],
962  const MathVector<refDim> vLocIP[],
963  const size_t nip,
964  LocalVector* u,
965  bool bDeriv,
966  int s,
967  std::vector<std::vector<TData> > vvvDeriv[],
968  const MathMatrix<refDim, dim>* vJT)
969 {
971 // vector of data for all inputs
972  std::vector<TDataIn> vDataIn(this->num_input());
973 
974  for(size_t ip = 0; ip < nip; ++ip)
975  {
976  // gather all input data for this ip
977  for(size_t c = 0; c < vDataIn.size(); ++c)
978  vDataIn[c] = m_vpUserData[c]->value(this->series_id(c,s), ip);
979 
980  // evaluate data at ip
981  eval_value(vValue[ip], vDataIn, vGlobIP[ip], time, si);
982  }
983 
984 // check if derivative is required
985  if(!bDeriv || this->zero_derivative()) return;
986 
987 // clear all derivative values
988  this->set_zero(vvvDeriv, nip);
989 
990 // loop all inputs
991  for(size_t c = 0; c < vDataIn.size(); ++c)
992  {
993  // check if we have the derivative w.r.t. this input, and the input has derivative
994  if(m_cbDerivRef[c] == LUA_NOREF || m_vpUserData[c]->zero_derivative()) continue;
995 
996  // loop ips
997  for(size_t ip = 0; ip < nip; ++ip)
998  {
999  // gather all input data for this ip
1000  for(size_t i = 0; i < vDataIn.size(); ++i)
1001  vDataIn[i] = m_vpUserData[i]->value(this->series_id(c,s), ip); //< series_id(c,s) or series_id(i,s)
1002 
1003  // data of derivative w.r.t. one component at ip-values
1004  TData derivVal;
1005 
1006  // evaluate data at ip
1007  eval_deriv(derivVal, vDataIn, vGlobIP[ip], time, si, c);
1008 
1009  // loop functions
1010  for(size_t fct = 0; fct < this->input_num_fct(c); ++fct)
1011  {
1012  // get common fct id for this function
1013  const size_t commonFct = this->input_common_fct(c, fct);
1014 
1015  // loop dofs
1016  for(size_t dof = 0; dof < this->num_sh(fct); ++dof)
1017  {
1019  mult_add(vvvDeriv[ip][commonFct][dof],
1020  derivVal,
1021  m_vpDependData[c]->deriv(this->series_id(c,s), ip, fct, dof));
1022  UG_COND_THROW(IsFiniteAndNotTooBig(vvvDeriv[ip][commonFct][dof])==false, vvvDeriv[ip][commonFct][dof]);
1023  }
1024  }
1025  }
1026  }
1027 }
1028 
1034 template <typename TData, int dim, typename TDataIn>
1036 {
1037 // resize arrays
1038  m_vpUserData.resize(num);
1039  m_vpDependData.resize(num);
1040 
1041 // forward size to base class
1042  base_type::set_num_input(num);
1043 }
1044 
1045 template <typename TData, int dim, typename TDataIn>
1048 {
1049  UG_ASSERT(i < m_vpUserData.size(), "Input not needed");
1050  UG_ASSERT(i < m_vpDependData.size(), "Input not needed");
1051 
1052 // check input number
1053  if(i >= this->num_input())
1054  UG_THROW("LuaUserFunction::set_input: Only " << this->num_input()
1055  << " inputs can be set. Use 'set_num_input' to increase"
1056  " the number of needed inputs.");
1057 
1058 // remember userdata
1059  m_vpUserData[i] = data;
1060 
1061 // cast to dependent data
1062  m_vpDependData[i] = data.template cast_dynamic<DependentUserData<TDataIn, dim> >();
1063 
1064 // forward to base class
1065  base_type::set_input(i, data, data);
1066 }
1067 
1068 template <typename TData, int dim, typename TDataIn>
1070 {
1071  set_input(i, CreateConstUserData<dim>(val, TDataIn()));
1072 }
1073 
1074 
1076 // LuaFunction
1078 
1079 template <typename TData, typename TDataIn>
1081 {
1083  m_cbValueRef = LUA_NOREF;
1084 }
1085 
1086 template <typename TData, typename TDataIn>
1087 void LuaFunction<TData,TDataIn>::set_lua_callback(const char* luaCallback, size_t numArgs)
1088 {
1089 // store name (string) of callback
1090  m_cbValueName = luaCallback;
1091 
1092 // obtain a reference
1093  lua_getglobal(m_L, m_cbValueName.c_str());
1094 
1095 // make sure that the reference is valid
1096  if(lua_isnil(m_L, -1)){
1097  UG_THROW("LuaFunction::set_lua_callback(...):"
1098  "Specified lua callback does not exist: " << m_cbValueName);
1099  }
1100 
1101 // store reference to lua function
1102  m_cbValueRef = luaL_ref(m_L, LUA_REGISTRYINDEX);
1103 
1104 // remember number of arguments to be used
1105  m_numArgs = numArgs;
1106 }
1107 
1108 template <typename TData, typename TDataIn>
1109 void LuaFunction<TData,TDataIn>::operator() (TData& out, int numArgs, ...)
1110 {
1111  PROFILE_CALLBACK_BEGIN(operatorBracket);
1112  UG_ASSERT(numArgs == (int)m_numArgs, "Number of arguments mismatched.");
1113 
1114 // push the callback function on the stack
1115  lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_cbValueRef);
1116 
1117 // get list of arguments
1118  va_list ap;
1119  va_start(ap, numArgs);
1120 
1121 // read all arguments and push them to the lua stack
1122  for(int i = 0; i < numArgs; ++i)
1123  {
1124  // cast data
1125  TDataIn val = va_arg(ap, TDataIn);
1126 
1127  // push data to lua stack
1128  lua_traits<TDataIn>::push(m_L, val);
1129  }
1130 
1131 // end read in of parameters
1132  va_end(ap);
1133 
1134 // compute total args size
1135  size_t argSize = lua_traits<TDataIn>::size * numArgs;
1136 
1137 // compute total return size
1138  size_t retSize = lua_traits<TData>::size;
1139 
1140 // call lua function
1141  if(lua_pcall(m_L, argSize, retSize, 0) != 0)
1142  UG_THROW("LuaFunction::operator(...): Error while "
1143  "running callback '" << m_cbValueName << "',"
1144  " lua message: "<< lua_tostring(m_L, -1));
1145 
1146  try{
1147  // read return value
1148  lua_traits<TData>::read(m_L, out);
1149  UG_COND_THROW(IsFiniteAndNotTooBig(out)==false, out);
1150  }
1151  UG_CATCH_THROW("LuaFunction::operator(...): Error while running "
1152  "callback '" << m_cbValueName << "'");
1153 
1154 // pop values
1155  lua_pop(m_L, retSize);
1156 
1158 }
1159 
1160 
1161 
1162 } // end namespace ug
1163 
1164 #endif /* LUA_USER_DATA_IMPL_H_ */
parameterString s
location name
Definition: checkpoint_util.lua:128
Definition: smart_pointer.h:108
T * get()
returns encapsulated pointer
Definition: smart_pointer.h:197
int * refcount_ptr() const
WARNING: this method is DANGEROUS!
Definition: smart_pointer.h:263
Type based UserData.
Definition: user_data.h:475
The base class for all geometric objects, such as vertices, edges, faces, volumes,...
Definition: grid_base_objects.h:157
Definition: local_algebra.h:198
Handle for a lua reference.
Definition: lua_function_handle.h:40
int ref
Definition: lua_function_handle.h:42
int m_cbValueRef
reference to lua function
Definition: lua_user_data.h:422
virtual void operator()(TData &out, int numArgs,...)
evaluates the data
Definition: lua_user_data_impl.h:1109
lua_State * m_L
lua state
Definition: lua_user_data.h:425
LuaFunction()
constructor
Definition: lua_user_data_impl.h:1080
void set_lua_callback(const char *luaCallback, size_t numArgs)
sets the Lua function used to compute the data
Definition: lua_user_data_impl.h:1087
Factory providing LuaUserData.
Definition: lua_user_data.h:180
static void remove(const std::string &name)
removes the user data
Definition: lua_user_data_impl.h:415
static SmartPtr< LuaUserData< TData, dim, TRet > > provide_or_create(const std::string &name)
returns new Data if not already created, already existing else
Definition: lua_user_data_impl.h:368
provides data specified in the lua script
Definition: lua_user_data.h:96
static bool check_callback_returns(const char *callName, const bool bThrow=false)
returns true if callback has correct return values
Definition: lua_user_data_impl.h:247
lua_State * m_L
lua state
Definition: lua_user_data.h:157
int m_callbackRef
reference to lua function
Definition: lua_user_data.h:147
static std::string signature()
returns string of required callback signature
Definition: lua_user_data_impl.h:65
static std::string name()
returns name of UserData
Definition: lua_user_data_impl.h:82
std::string m_callbackName
callback name as string
Definition: lua_user_data.h:144
LuaUserData(const char *luaCallback)
Constructor.
Definition: lua_user_data_impl.h:92
TRet evaluate(TData &D, const MathVector< dim > &x, number time, int si) const
evaluates the data at a given point and time
Definition: lua_user_data_impl.h:282
virtual ~LuaUserData()
}
Definition: lua_user_data_impl.h:353
LuaUserFunction(const char *luaCallback, size_t numArgs)
constructor
Definition: lua_user_data_impl.h:443
void eval_value(TData &out, const std::vector< TDataIn > &dataIn, const MathVector< dim > &x, number time, int si) const
evaluates the data at a given point and time
Definition: lua_user_data_impl.h:738
virtual ~LuaUserFunction()
destructor frees the reference
Definition: lua_user_data_impl.h:508
void free_deriv_callback_ref(size_t arg)
frees callback-references for derivate callbacks
Definition: lua_user_data_impl.h:529
void evaluate(TData &value, const MathVector< dim > &globIP, number time, int si) const
Definition: lua_user_data_impl.h:907
void set_input(size_t i, SmartPtr< CplUserData< TDataIn, dim > > data)
set input value for paramter i
Definition: lua_user_data_impl.h:1047
std::vector< int > m_cbDerivRef
Definition: lua_user_data.h:346
void eval_deriv(TData &out, const std::vector< TDataIn > &dataIn, const MathVector< dim > &x, number time, int si, size_t arg) const
evaluates the data at a given point and time
Definition: lua_user_data_impl.h:821
void eval_and_deriv(TData vValue[], const MathVector< dim > vGlobIP[], number time, int si, GridObject *elem, const MathVector< dim > vCornerCoords[], const MathVector< refDim > vLocIP[], const size_t nip, LocalVector *u, bool bDeriv, int s, std::vector< std::vector< TData > > vvvDeriv[], const MathMatrix< refDim, dim > *vJT=NULL)
Definition: lua_user_data_impl.h:957
lua_State * m_L
lua state
Definition: lua_user_data.h:349
void free_callback_ref()
frees the callback-reference, if a callback was set.
Definition: lua_user_data_impl.h:520
int m_cbValueRef
reference to lua function
Definition: lua_user_data.h:345
void set_deriv(size_t arg, const char *luaCallback)
sets the Lua function used to compute the derivative
Definition: lua_user_data_impl.h:599
std::vector< std::string > m_cbDerivName
Definition: lua_user_data.h:342
void set_lua_value_callback(const char *luaCallback, size_t numArgs)
sets the Lua function used to compute the data
Definition: lua_user_data_impl.h:539
void set_num_input(size_t num)
set number of needed inputs
Definition: lua_user_data_impl.h:1035
virtual void operator()(TData &out, int numArgs,...) const
evaluates the data
Definition: lua_user_data_impl.h:659
A class for fixed size, dense matrices.
Definition: math_matrix.h:52
Definition: lua_compiler.h:50
const std::string & name() const
Definition: lua_compiler.h:86
int num_out() const
Definition: lua_compiler.h:81
bool call(double *ret, const double *in) const
Definition: lua_compiler.cpp:263
function util FileDummy read(...) error("io.open_0 does not support read.") end
static const int dim
#define UG_ASSERT(expr, msg)
Definition: assert.h:70
#define UG_CATCH_THROW(msg)
Definition: error.h:64
#define UG_THROW(msg)
Definition: error.h:57
#define UG_LOG(msg)
Definition: log.h:367
#define UG_COND_THROW(cond, msg)
UG_COND_THROW(cond, msg) : performs a UG_THROW(msg) if cond == true.
Definition: error.h:61
double number
Definition: types.h:124
struct lua_State lua_State
Definition: lua_table_handle.h:40
#define PROFILE_CALLBACK()
Definition: lua_user_data_impl.h:48
#define PROFILE_CALLBACK_END()
Definition: lua_user_data_impl.h:50
#define PROFILE_CALLBACK_BEGIN(name)
Definition: lua_user_data_impl.h:49
string GetLUAScriptFunctionDefined(const char *functionName)
returns file and line of defined script function
Definition: info_commands.cpp:368
lua_State * GetDefaultLuaState()
returns the default lua state
Definition: lua_util.cpp:242
the ug namespace
bool useLuaCompiler
Definition: info_commands.cpp:93
bool IsFiniteAndNotTooBig(double d)
Definition: number_util.h:40
SmartPtr< T, FreePolicy > make_sp(T *inst)
returns a SmartPtr for the passed raw pointer
Definition: smart_pointer.h:836
static void mult_add(TData &out, const TData &in1, const TDataIn &s)
computes out += s * in1 (with appropriate '*')
Lua Traits to push/pop on lua stack.
Definition: lua_traits.h:79