1515except ImportError : # pragma: no cover
1616 _NUMEXPR_INSTALLED = False
1717
18+ _TEST_MODE = None
19+ _TEST_RESULT = None
1820_USE_NUMEXPR = _NUMEXPR_INSTALLED
1921_evaluate = None
2022_where = None
@@ -55,9 +57,10 @@ def set_numexpr_threads(n=None):
5557
5658def _evaluate_standard (op , op_str , a , b , raise_on_error = True , ** eval_kwargs ):
5759 """ standard evaluation """
60+ if _TEST_MODE :
61+ _store_test_result (False )
5862 return op (a , b )
5963
60-
6164def _can_use_numexpr (op , op_str , a , b , dtype_check ):
6265 """ return a boolean if we WILL be using numexpr """
6366 if op_str is not None :
@@ -88,11 +91,8 @@ def _evaluate_numexpr(op, op_str, a, b, raise_on_error=False, **eval_kwargs):
8891
8992 if _can_use_numexpr (op , op_str , a , b , 'evaluate' ):
9093 try :
91- a_value , b_value = a , b
92- if hasattr (a_value , 'values' ):
93- a_value = a_value .values
94- if hasattr (b_value , 'values' ):
95- b_value = b_value .values
94+ a_value = getattr (a , "values" , a )
95+ b_value = getattr (b , "values" , b )
9696 result = ne .evaluate ('a_value %s b_value' % op_str ,
9797 local_dict = {'a_value' : a_value ,
9898 'b_value' : b_value },
@@ -104,6 +104,9 @@ def _evaluate_numexpr(op, op_str, a, b, raise_on_error=False, **eval_kwargs):
104104 if raise_on_error :
105105 raise
106106
107+ if _TEST_MODE :
108+ _store_test_result (result is not None )
109+
107110 if result is None :
108111 result = _evaluate_standard (op , op_str , a , b , raise_on_error )
109112
@@ -119,13 +122,9 @@ def _where_numexpr(cond, a, b, raise_on_error=False):
119122 if _can_use_numexpr (None , 'where' , a , b , 'where' ):
120123
121124 try :
122- cond_value , a_value , b_value = cond , a , b
123- if hasattr (cond_value , 'values' ):
124- cond_value = cond_value .values
125- if hasattr (a_value , 'values' ):
126- a_value = a_value .values
127- if hasattr (b_value , 'values' ):
128- b_value = b_value .values
125+ cond_value = getattr (cond , 'values' , cond )
126+ a_value = getattr (a , 'values' , a )
127+ b_value = getattr (b , 'values' , b )
129128 result = ne .evaluate ('where(cond_value, a_value, b_value)' ,
130129 local_dict = {'cond_value' : cond_value ,
131130 'a_value' : a_value ,
@@ -189,3 +188,28 @@ def where(cond, a, b, raise_on_error=False, use_numexpr=True):
189188 if use_numexpr :
190189 return _where (cond , a , b , raise_on_error = raise_on_error )
191190 return _where_standard (cond , a , b , raise_on_error = raise_on_error )
191+
192+
193+ def set_test_mode (v = True ):
194+ """
195+ Keeps track of whether numexpr was used. Stores an additional ``True`` for
196+ every successful use of evaluate with numexpr since the last
197+ ``get_test_result``
198+ """
199+ global _TEST_MODE , _TEST_RESULT
200+ _TEST_MODE = v
201+ _TEST_RESULT = []
202+
203+
204+ def _store_test_result (used_numexpr ):
205+ global _TEST_RESULT
206+ if used_numexpr :
207+ _TEST_RESULT .append (used_numexpr )
208+
209+
210+ def get_test_result ():
211+ """get test result and reset test_results"""
212+ global _TEST_RESULT
213+ res = _TEST_RESULT
214+ _TEST_RESULT = []
215+ return res
0 commit comments