Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 02ae456

Browse files
authored
Improve environment variable handling in unittests (#18424)
This PR makes it easy to create unittests that require specific settings of environment variables, while avoiding the pitfalls (discussed in comments section). This PR can be considered a recasting and expansion of the great vision of @larroy in creating the EnvManager class in #13140. In its base form, the facility is a drop-in replacement for EnvManager, and is called 'environment': with environment('MXNET_MY_NEW_FEATURE', '1'): <test with feature enabled> with environment('MXNET_MY_NEW_FEATURE', '0'): <test with feature disabled> Like EnvManager, this facility takes care of the save/restore of the previous environment variable state, including when exceptions are raised. In addition though, this PR introduces the features: A similarly-named unittest decorator: @with_environment(key, value) The ability to pass in multiple env vars as a dict (as is needed for some tests) in both forms, so for example: with environment({'MXNET_FEATURE_A': '1', 'MXNET_FEATURE_B': '1'}): <test with both features enabled> Works on Windows! This PR includes a wrapping of the backend's setenv() and getenv() functions, and uses this direct access to the backend environment to keep it in sync with the python environment. This works around the problem that the C Runtime on Windows gets a snapshot of the Python environment at startup that is immutable from Python. with environment() has a simple implementation using the @contextmanager decorator Tests are included that validate the facility works with all combinations of before_val/set_val, namely unset/unset, unset/set, set/unset, set/set. There were 5 unittests previously using EnvManager, and this PR shifts those uses to with environment():, while converting over 20 other ad-hoc uses of os.environ[] within the unittests. This PR also enables those unittests that were bypassed on Windows (due to the inability to set environment variables) to run on all platforms. Further Comments Environment variables are a two-edged sword- they enable useful operating modes for testing, debugging or niche applications, but like all features they must be tested. The correct approach for testing with a particular env var setting is: def set_env_var(key, value): if value is None: os.environ.pop(key, None) else: os.environ[key] = value old_env_var_value = os.environ.get(env_var_name) try: set_env_var(env_var_name, test_env_var_value) <perform test> finally: set_env_var(env_var_name, old_env_var_value ) The above code makes no assumption about whether the before-test and within-test state of the env var is set or unset, and restores the prior environment even if the test raises an exception. This represents a lot of boiler-plate code that could be potentially mishandled. The with environment() context makes it simple to handle all this properly. If an entire unittest wants a forced env var setting, then using the @with_environment() decorator avoids the code indent of the with environment() approach if used otherwise within the test.
1 parent 18af71e commit 02ae456

22 files changed

+491
-380
lines changed

include/mxnet/c_api_test.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,22 @@ MXNET_DLL int MXRemoveSubgraphPropertyOpNames(const char* prop_name);
7575
MXNET_DLL int MXRemoveSubgraphPropertyOpNamesV2(const char* prop_name);
7676

7777

78+
/*!
79+
* \brief Get the value of an environment variable as seen by the backend.
80+
* \param name The name of the environment variable
81+
* \param value The returned value of the environment variable
82+
*/
83+
MXNET_DLL int MXGetEnv(const char* name,
84+
const char** value);
85+
86+
/*!
87+
* \brief Set the value of an environment variable from the backend.
88+
* \param name The name of the environment variable
89+
* \param value The desired value to set the environment variable `name`
90+
*/
91+
MXNET_DLL int MXSetEnv(const char* name,
92+
const char* value);
93+
7894
#ifdef __cplusplus
7995
}
8096
#endif // __cplusplus

python/mxnet/test_utils.py

Lines changed: 76 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import numbers
2525
import sys
2626
import os
27+
import platform
2728
import errno
2829
import logging
2930
import bz2
@@ -48,7 +49,7 @@
4849
from .ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
4950
from .symbol import Symbol
5051
from .symbol.numpy import _Symbol as np_symbol
51-
from .util import use_np, use_np_default_dtype # pylint: disable=unused-import
52+
from .util import use_np, use_np_default_dtype, getenv, setenv # pylint: disable=unused-import
5253
from .runtime import Features
5354
from .numpy_extension import get_cuda_compute_capability
5455

@@ -1920,27 +1921,6 @@ def get_bz2_data(data_dir, data_name, url, data_origin_name):
19201921
bz_file.close()
19211922
os.remove(data_origin_name)
19221923

1923-
def set_env_var(key, val, default_val=""):
1924-
"""Set environment variable
1925-
1926-
Parameters
1927-
----------
1928-
1929-
key : str
1930-
Env var to set
1931-
val : str
1932-
New value assigned to the env var
1933-
default_val : str, optional
1934-
Default value returned if the env var doesn't exist
1935-
1936-
Returns
1937-
-------
1938-
str
1939-
The value of env var before it is set to the new value
1940-
"""
1941-
prev_val = os.environ.get(key, default_val)
1942-
os.environ[key] = val
1943-
return prev_val
19441924

19451925
def same_array(array1, array2):
19461926
"""Check whether two NDArrays sharing the same memory block
@@ -1965,9 +1945,11 @@ def same_array(array1, array2):
19651945
array1[:] -= 1
19661946
return same(array1.asnumpy(), array2.asnumpy())
19671947

1948+
19681949
@contextmanager
19691950
def discard_stderr():
1970-
"""Discards error output of a routine if invoked as:
1951+
"""
1952+
Discards error output of a routine if invoked as:
19711953
19721954
with discard_stderr():
19731955
...
@@ -2400,22 +2382,79 @@ def same_symbol_structure(sym1, sym2):
24002382
return True
24012383

24022384

2403-
class EnvManager(object):
2404-
"""Environment variable setter and unsetter via with idiom"""
2405-
def __init__(self, key, val):
2406-
self._key = key
2407-
self._next_val = val
2408-
self._prev_val = None
2385+
@contextmanager
2386+
def environment(*args):
2387+
"""
2388+
Environment variable setter and unsetter via `with` idiom.
24092389
2410-
def __enter__(self):
2411-
self._prev_val = os.environ.get(self._key)
2412-
os.environ[self._key] = self._next_val
2390+
Takes a specification of env var names and desired values and adds those
2391+
settings to the environment in advance of running the body of the `with`
2392+
statement. The original environment state is restored afterwards, even
2393+
if exceptions are raised in the `with` body.
24132394
2414-
def __exit__(self, ptype, value, trace):
2415-
if self._prev_val:
2416-
os.environ[self._key] = self._prev_val
2417-
else:
2418-
del os.environ[self._key]
2395+
Parameters
2396+
----------
2397+
args:
2398+
if 2 args are passed:
2399+
name, desired_value strings of the single env var to update, or
2400+
if 1 arg is passed:
2401+
a dict of name:desired_value for env var's to update
2402+
2403+
"""
2404+
2405+
# On Linux, env var changes made through python's os.environ are seen
2406+
# by the backend. On Windows though, the C runtime gets a snapshot
2407+
# of the environment that cannot be altered by os.environ. Here we
2408+
# check, using a wrapped version of the backend's getenv(), that
2409+
# the desired env var value is seen by the backend, and otherwise use
2410+
# a wrapped setenv() to establish that value in the backend.
2411+
2412+
# Also on Windows, a set env var can never have the value '', since
2413+
# the command 'set FOO= ' is used to unset the variable. Perhaps
2414+
# as a result, the wrapped dmlc::GetEnv() routine returns the same
2415+
# value for unset variables and those set to ''. As a result, we
2416+
# ignore discrepancy.
2417+
def validate_backend_setting(name, value, can_use_setenv=True):
2418+
backend_value = getenv(name)
2419+
if value == backend_value or \
2420+
value == '' and backend_value is None and platform.system() == 'Windows':
2421+
return
2422+
if not can_use_setenv:
2423+
raise RuntimeError('Could not set env var {}={} within C Runtime'.format(name, value))
2424+
setenv(name, value)
2425+
validate_backend_setting(name, value, can_use_setenv=False)
2426+
2427+
# Core routine to alter environment from a dict of env_var_name, env_var_value pairs
2428+
def set_environ(env_var_dict):
2429+
for env_var_name, env_var_value in env_var_dict.items():
2430+
if env_var_value is None:
2431+
os.environ.pop(env_var_name, None)
2432+
else:
2433+
os.environ[env_var_name] = env_var_value
2434+
validate_backend_setting(env_var_name, env_var_value)
2435+
2436+
# Create env_var name:value dict from the two calling methods of this routine
2437+
if len(args) == 1 and isinstance(args[0], dict):
2438+
env_vars = args[0]
2439+
else:
2440+
assert len(args) == 2, 'Expecting one dict arg or two args: env var name and value'
2441+
env_vars = {args[0]: args[1]}
2442+
2443+
# Take a snapshot of the existing environment variable state
2444+
# for those variables to be changed. get() return None for unset keys.
2445+
snapshot = {x: os.environ.get(x) for x in env_vars.keys()}
2446+
2447+
# Alter the environment per the env_vars dict
2448+
set_environ(env_vars)
2449+
2450+
# Now run the wrapped code
2451+
try:
2452+
yield
2453+
finally:
2454+
# the backend engines may still be referencing the changed env var state
2455+
mx.nd.waitall()
2456+
# reinstate original env_var state per the snapshot taken earlier
2457+
set_environ(snapshot)
24192458

24202459

24212460
def collapse_sum_like(a, shape):

python/mxnet/util.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import inspect
2222
import threading
2323

24-
from .base import _LIB, check_call
24+
from .base import _LIB, check_call, c_str, py_str
2525

2626

2727
_np_ufunc_default_kwargs = {
@@ -913,6 +913,7 @@ def get_cuda_compute_capability(ctx):
913913
.format(ret, error_str.value.decode()))
914914
return cc_major.value * 10 + cc_minor.value
915915

916+
916917
def default_array(source_array, ctx=None, dtype=None):
917918
"""Creates an array from any object exposing the default(nd or np) array interface.
918919
@@ -1144,3 +1145,35 @@ def set_np_default_dtype(is_np_default_dtype=True): # pylint: disable=redefined
11441145
prev = ctypes.c_bool()
11451146
check_call(_LIB.MXSetIsNumpyDefaultDtype(ctypes.c_bool(is_np_default_dtype), ctypes.byref(prev)))
11461147
return prev.value
1148+
1149+
1150+
def getenv(name):
1151+
"""Get the setting of an environment variable from the C Runtime.
1152+
1153+
Parameters
1154+
----------
1155+
name : string type
1156+
The environment variable name
1157+
1158+
Returns
1159+
-------
1160+
value : string
1161+
The value of the environment variable, or None if not set
1162+
"""
1163+
ret = ctypes.c_char_p()
1164+
check_call(_LIB.MXGetEnv(c_str(name), ctypes.byref(ret)))
1165+
return None if ret.value is None else py_str(ret.value)
1166+
1167+
1168+
def setenv(name, value):
1169+
"""Set an environment variable in the C Runtime.
1170+
1171+
Parameters
1172+
----------
1173+
name : string type
1174+
The environment variable name
1175+
value : string type
1176+
The desired value to set the environment value to
1177+
"""
1178+
passed_value = None if value is None else c_str(value)
1179+
check_call(_LIB.MXSetEnv(c_str(name), passed_value))

src/c_api/c_api_test.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,25 @@ int MXRemoveSubgraphPropertyOpNamesV2(const char* prop_name) {
106106
}
107107
API_END();
108108
}
109+
110+
int MXGetEnv(const char* name,
111+
const char** value) {
112+
API_BEGIN();
113+
*value = getenv(name);
114+
API_END();
115+
}
116+
117+
int MXSetEnv(const char* name,
118+
const char* value) {
119+
API_BEGIN();
120+
#ifdef _WIN32
121+
auto value_arg = (value == nullptr) ? "" : value;
122+
_putenv_s(name, value_arg);
123+
#else
124+
if (value == nullptr)
125+
unsetenv(name);
126+
else
127+
setenv(name, value, 1);
128+
#endif
129+
API_END();
130+
}

tests/python/gpu/test_device.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
import pytest
2121
import os
2222
import logging
23-
24-
from mxnet.test_utils import EnvManager
23+
from mxnet.test_utils import environment
2524

2625
shapes = [(10), (100), (1000), (10000), (100000), (2,2), (2,3,4,5,6,7,8)]
2726
keys = [1,2,3,4,5,6,7]
@@ -51,16 +50,15 @@ def check_dense_pushpull(kv_type):
5150
for x in range(n_gpus):
5251
assert(np.sum(np.abs((res[x]-n_gpus).asnumpy()))==0)
5352

54-
kvstore_tree_array_bound = 'MXNET_KVSTORE_TREE_ARRAY_BOUND'
55-
kvstore_usetree_values = ['','1']
56-
kvstore_usetree = 'MXNET_KVSTORE_USETREE'
57-
for _ in range(2):
53+
kvstore_tree_array_bound_values = [None, '1']
54+
kvstore_usetree_values = [None, '1']
55+
for y in kvstore_tree_array_bound_values:
5856
for x in kvstore_usetree_values:
59-
with EnvManager(kvstore_usetree, x):
57+
with environment({'MXNET_KVSTORE_USETREE': x,
58+
'MXNET_KVSTORE_TREE_ARRAY_BOUND': y}):
6059
check_dense_pushpull('local')
6160
check_dense_pushpull('device')
62-
os.environ[kvstore_tree_array_bound] = '1'
63-
del os.environ[kvstore_tree_array_bound]
61+
6462

6563
if __name__ == '__main__':
6664
test_device_pushpull()

tests/python/gpu/test_fusion.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,18 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import sys
1819
import os
1920
import random
21+
import itertools
2022
import mxnet as mx
2123
import numpy as np
2224
from mxnet import autograd, gluon
2325
from mxnet.test_utils import *
2426

2527
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
2628
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
27-
from common import with_seed
29+
from common import setup_module, teardown_module, with_seed
2830

2931
def check_fused_symbol(sym, **kwargs):
3032
inputs = sym.list_inputs()
@@ -44,10 +46,10 @@ def check_fused_symbol(sym, **kwargs):
4446
data = {inp : kwargs[inp].astype(dtype) for inp in inputs}
4547
for grad_req in ['write', 'add']:
4648
type_dict = {inp : dtype for inp in inputs}
47-
os.environ["MXNET_USE_FUSION"] = "0"
48-
orig_exec = test_sym._simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes)
49-
os.environ["MXNET_USE_FUSION"] = "1"
50-
fused_exec = test_sym._simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes)
49+
with environment('MXNET_USE_FUSION', '0'):
50+
orig_exec = test_sym._simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes)
51+
with environment('MXNET_USE_FUSION', '1'):
52+
fused_exec = test_sym._simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes)
5153
fwd_orig = orig_exec.forward(is_train=True, **data)
5254
out_grads = [mx.nd.ones_like(arr) for arr in fwd_orig]
5355
orig_exec.backward(out_grads=out_grads)
@@ -231,6 +233,7 @@ def check_other_ops():
231233
arr2 = mx.random.uniform(shape=(2,2,2,3))
232234
check_fused_symbol(mx.sym.broadcast_like(a, b, lhs_axes=[0], rhs_axes=[0]), a=arr1, b=arr2)
233235

236+
234237
def check_leakyrelu_ops():
235238
a = mx.sym.Variable('a')
236239
b = mx.sym.Variable('b')
@@ -331,18 +334,18 @@ def hybrid_forward(self, F, x, y, z):
331334

332335
arrays = {}
333336
for use_fusion in ('0', '1'):
334-
os.environ['MXNET_USE_FUSION'] = use_fusion
335-
arrays[use_fusion] = {}
336-
n = Block()
337-
n.hybridize(static_alloc=static_alloc)
338-
args = [arg.copyto(mx.gpu()) for arg in arg_data]
339-
for arg in args:
340-
arg.attach_grad()
341-
with autograd.record():
342-
r = n(*args)
343-
arrays[use_fusion]['result'] = r
344-
r.backward()
345-
for i, arg in enumerate(args):
346-
arrays[use_fusion][i] = arg.grad
337+
with environment('MXNET_USE_FUSION', use_fusion):
338+
arrays[use_fusion] = {}
339+
n = Block()
340+
n.hybridize(static_alloc=static_alloc)
341+
args = [arg.copyto(mx.gpu()) for arg in arg_data]
342+
for arg in args:
343+
arg.attach_grad()
344+
with autograd.record():
345+
r = n(*args)
346+
arrays[use_fusion]['result'] = r
347+
r.backward()
348+
for i, arg in enumerate(args):
349+
arrays[use_fusion][i] = arg.grad
347350
for key in ['result'] + list(range(len(arg_data))):
348351
assert_allclose(arrays['0'][key].asnumpy(), arrays['1'][key].asnumpy())

0 commit comments

Comments
 (0)