diff --git a/python/mxnet/attribute.py b/python/mxnet/attribute.py index 15d38f81f2e3..17044ddaef06 100644 --- a/python/mxnet/attribute.py +++ b/python/mxnet/attribute.py @@ -18,10 +18,12 @@ # coding: utf-8 """Attribute scoping support for symbolic API.""" from __future__ import absolute_import +import threading +import warnings -from .base import string_types +from .base import string_types, classproperty, with_metaclass, _MXClassPropertyMetaClass -class AttrScope(object): +class AttrScope(with_metaclass(_MXClassPropertyMetaClass, object)): """Attribute manager for scoping. User can also inherit this object to change naming behavior. @@ -31,7 +33,7 @@ class AttrScope(object): kwargs The attributes to set for all symbol creations in the scope. """ - current = None + _current = threading.local() def __init__(self, **kwargs): self._old_scope = None @@ -64,15 +66,35 @@ def get(self, attr): def __enter__(self): # pylint: disable=protected-access - self._old_scope = AttrScope.current - attr = AttrScope.current._attr.copy() + if not hasattr(AttrScope._current, "value"): + AttrScope._current.value = AttrScope() + self._old_scope = AttrScope._current.value + attr = AttrScope._current.value._attr.copy() attr.update(self._attr) self._attr = attr - AttrScope.current = self + AttrScope._current.value = self return self def __exit__(self, ptype, value, trace): assert self._old_scope - AttrScope.current = self._old_scope + AttrScope._current.value = self._old_scope -AttrScope.current = AttrScope() + #pylint: disable=no-self-argument + @classproperty + def current(cls): + warnings.warn("AttrScope.current has been deprecated. " + "It is advised to use the `with` statement with AttrScope.", + DeprecationWarning) + if not hasattr(AttrScope._current, "value"): + cls._current.value = AttrScope() + return cls._current.value + + @current.setter + def current(cls, val): + warnings.warn("AttrScope.current has been deprecated. " + "It is advised to use the `with` statement with AttrScope.", + DeprecationWarning) + cls._current.value = val + #pylint: enable=no-self-argument + +AttrScope._current.value = AttrScope() diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 9790e090e387..0fb73b3c7dda 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -16,7 +16,7 @@ # under the License. # coding: utf-8 -# pylint: disable=invalid-name, no-member, trailing-comma-tuple +# pylint: disable=invalid-name, no-member, trailing-comma-tuple, bad-mcs-classmethod-argument """ctypes library of mxnet and helper functions.""" from __future__ import absolute_import @@ -98,6 +98,67 @@ class MXCallbackList(ctypes.Structure): ('contexts', ctypes.POINTER(ctypes.c_void_p)) ] +# Please see: https://stackoverflow.com/questions/5189699/how-to-make-a-class-property +class _MXClassPropertyDescriptor(object): + def __init__(self, fget, fset=None): + self.fget = fget + self.fset = fset + + def __get__(self, obj, clas=None): + if clas is None: + clas = type(obj) + return self.fget.__get__(obj, clas)() + + def __set__(self, obj, value): + if not self.fset: + raise MXNetError("cannot use the setter: %s to set attribute" % obj.__name__) + if inspect.isclass(obj): + type_ = obj + obj = None + else: + type_ = type(obj) + return self.fset.__get__(obj, type_)(value) + + def setter(self, func): + if not isinstance(func, (classmethod, staticmethod)): + func = classmethod(func) + self.fset = func + return self + +class _MXClassPropertyMetaClass(type): + def __setattr__(cls, key, value): + if key in cls.__dict__: + obj = cls.__dict__.get(key) + if obj and isinstance(obj, _MXClassPropertyDescriptor): + return obj.__set__(cls, value) + + return super(_MXClassPropertyMetaClass, cls).__setattr__(key, value) + +# with_metaclass function obtained from: https://github.com/benjaminp/six/blob/master/six.py +#pylint: disable=unused-argument +def with_metaclass(meta, *bases): + """Create a base class with a metaclass.""" + # This requires a bit of explanation: the basic idea is to make a dummy + # metaclass for one level of class instantiation that replaces itself with + # the actual metaclass. + class metaclass(type): + + def __new__(cls, name, this_bases, d): + return meta(name, bases, d) + + @classmethod + def __prepare__(cls, name, this_bases): + return meta.__prepare__(name, bases) + return type.__new__(metaclass, 'temporary_class', (), {}) +#pylint: enable=unused-argument + +def classproperty(func): + if not isinstance(func, (classmethod, staticmethod)): + func = classmethod(func) + + return _MXClassPropertyDescriptor(func) + + def _load_lib(): """Load library by searching possible path.""" @@ -227,6 +288,7 @@ def c_str_array(strings): arr[:] = [s.encode('utf-8') for s in strings] return arr + def c_array(ctype, values): """Create ctypes array from a Python array. diff --git a/python/mxnet/context.py b/python/mxnet/context.py index eb47614e3335..5861890f40c1 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -18,8 +18,11 @@ # coding: utf-8 """Context management API of mxnet.""" from __future__ import absolute_import +import threading +import warnings +from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass -class Context(object): +class Context(with_metaclass(_MXClassPropertyMetaClass, object)): """Constructs a context. MXNet can run operations on CPU and different GPUs. @@ -61,7 +64,7 @@ class Context(object): gpu(1) """ # static class variable - default_ctx = None + _default_ctx = threading.local() devtype2str = {1: 'cpu', 2: 'gpu', 3: 'cpu_pinned', 5: 'cpu_shared'} devstr2type = {'cpu': 1, 'gpu': 2, 'cpu_pinned': 3, 'cpu_shared': 5} def __init__(self, device_type, device_id=0): @@ -109,15 +112,37 @@ def __repr__(self): return self.__str__() def __enter__(self): - self._old_ctx = Context.default_ctx - Context.default_ctx = self + if not hasattr(Context._default_ctx, "value"): + Context._default_ctx.value = Context('cpu', 0) + self._old_ctx = Context._default_ctx.value + Context._default_ctx.value = self return self def __exit__(self, ptype, value, trace): - Context.default_ctx = self._old_ctx + Context._default_ctx.value = self._old_ctx + + #pylint: disable=no-self-argument + @classproperty + def default_ctx(cls): + warnings.warn("Context.default_ctx has been deprecated. " + "Please use Context.current_context() instead. " + "Please use test_utils.set_default_context to set a default context", + DeprecationWarning) + if not hasattr(Context._default_ctx, "value"): + cls._default_ctx.value = Context('cpu', 0) + return cls._default_ctx.value + + @default_ctx.setter + def default_ctx(cls, val): + warnings.warn("Context.default_ctx has been deprecated. " + "Please use Context.current_context() instead. " + "Please use test_utils.set_default_context to set a default context", + DeprecationWarning) + cls._default_ctx.value = val + #pylint: enable=no-self-argument # initialize the default context in Context -Context.default_ctx = Context('cpu', 0) +Context._default_ctx.value = Context('cpu', 0) def cpu(device_id=0): @@ -234,4 +259,6 @@ def current_context(): ------- default_ctx : Context """ - return Context.default_ctx + if not hasattr(Context._default_ctx, "value"): + Context._default_ctx.value = Context('cpu', 0) + return Context._default_ctx.value diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index abc474850f24..7e4127250a09 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -20,6 +20,7 @@ """Base container class for all neural network models.""" __all__ = ['Block', 'HybridBlock', 'SymbolBlock'] +import threading import copy import warnings import re @@ -35,7 +36,7 @@ class _BlockScope(object): """Scope for collecting child `Block` s.""" - _current = None + _current = threading.local() def __init__(self, block): self._block = block @@ -46,10 +47,10 @@ def __init__(self, block): @staticmethod def create(prefix, params, hint): """Creates prefix and params for new `Block`.""" - current = _BlockScope._current + current = getattr(_BlockScope._current, "value", None) if current is None: if prefix is None: - prefix = _name.NameManager.current.get(None, hint) + '_' + prefix = _name.NameManager._current.value.get(None, hint) + '_' if params is None: params = ParameterDict(prefix) else: @@ -70,8 +71,8 @@ def create(prefix, params, hint): def __enter__(self): if self._block._empty_prefix: return self - self._old_scope = _BlockScope._current - _BlockScope._current = self + self._old_scope = getattr(_BlockScope._current, "value", None) + _BlockScope._current.value = self self._name_scope = _name.Prefix(self._block.prefix) self._name_scope.__enter__() return self @@ -81,7 +82,7 @@ def __exit__(self, ptype, value, trace): return self._name_scope.__exit__(ptype, value, trace) self._name_scope = None - _BlockScope._current = self._old_scope + _BlockScope._current.value = self._old_scope def _flatten(args, inout_str): diff --git a/python/mxnet/name.py b/python/mxnet/name.py index 966d38280ef7..4149d1db2731 100644 --- a/python/mxnet/name.py +++ b/python/mxnet/name.py @@ -18,13 +18,16 @@ # coding: utf-8 """Automatic naming support for symbolic API.""" from __future__ import absolute_import +import threading +import warnings +from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass -class NameManager(object): +class NameManager(with_metaclass(_MXClassPropertyMetaClass, object)): """NameManager to do automatic naming. Developers can also inherit from this class to change naming behavior. """ - current = None + _current = threading.local() def __init__(self): self._counter = {} @@ -62,14 +65,30 @@ def get(self, name, hint): return name def __enter__(self): - self._old_manager = NameManager.current - NameManager.current = self + if not hasattr(NameManager._current, "value"): + NameManager._current.value = NameManager() + self._old_manager = NameManager._current.value + NameManager._current.value = self return self def __exit__(self, ptype, value, trace): assert self._old_manager - NameManager.current = self._old_manager - + NameManager._current.value = self._old_manager + + #pylint: disable=no-self-argument + @classproperty + def current(cls): + warnings.warn("NameManager.current has been deprecated. " + "It is advised to use the `with` statement with NameManager.", + DeprecationWarning) + if not hasattr(NameManager._current, "value"): + cls._current.value = NameManager() + return cls._current.value + + @current.setter + def current(cls, val): + cls._current.value = val + #pylint: enable=no-self-argument class Prefix(NameManager): """A name manager that attaches a prefix to all names. @@ -92,4 +111,4 @@ def get(self, name, hint): return self._prefix + name # initialize the default name manager -NameManager.current = NameManager() +NameManager._current.value = NameManager() diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 7bfb3c79b35f..007b3c82def2 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -37,7 +37,7 @@ from ..base import c_array, c_array_buf, c_handle_array, mx_real_t from ..base import mx_uint, NDArrayHandle, check_call from ..base import ctypes2buffer -from ..context import Context +from ..context import Context, current_context from . import _internal from . import op from ._internal import NDArrayBase @@ -2261,7 +2261,7 @@ def ones(shape, ctx=None, dtype=None, **kwargs): The shape of the empty array. ctx : Context, optional An optional device context. - Defaults to the current default context (``mxnet.Context.default_ctx``). + Defaults to the current default context (``mxnet.context.current_context()``). dtype : str or numpy.dtype, optional An optional value type (default is `float32`). out : NDArray, optional @@ -2283,7 +2283,7 @@ def ones(shape, ctx=None, dtype=None, **kwargs): """ # pylint: disable= unused-argument if ctx is None: - ctx = Context.default_ctx + ctx = current_context() dtype = mx_real_t if dtype is None else dtype # pylint: disable= no-member, protected-access return _internal._ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs) @@ -2439,7 +2439,7 @@ def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t): array([2, 2, 2, 4, 4, 4], dtype=int32) """ if ctx is None: - ctx = Context.default_ctx + ctx = current_context() return _internal._arange(start=start, stop=stop, step=step, repeat=repeat, dtype=dtype, ctx=str(ctx)) # pylint: enable= no-member, protected-access, too-many-arguments @@ -3666,7 +3666,7 @@ def zeros(shape, ctx=None, dtype=None, **kwargs): """ # pylint: disable= unused-argument if ctx is None: - ctx = Context.default_ctx + ctx = current_context() dtype = mx_real_t if dtype is None else dtype # pylint: disable= no-member, protected-access return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs) @@ -3705,7 +3705,7 @@ def eye(N, M=0, k=0, ctx=None, dtype=None, **kwargs): """ # pylint: disable= unused-argument if ctx is None: - ctx = Context.default_ctx + ctx = current_context() dtype = mx_real_t if dtype is None else dtype # pylint: disable= no-member, protected-access return _internal._eye(N=N, M=M, k=k, ctx=ctx, dtype=dtype, **kwargs) @@ -3733,7 +3733,7 @@ def empty(shape, ctx=None, dtype=None): if isinstance(shape, int): shape = (shape, ) if ctx is None: - ctx = Context.default_ctx + ctx = current_context() if dtype is None: dtype = mx_real_t return NDArray(handle=_new_alloc_handle(shape, ctx, False, dtype)) diff --git a/python/mxnet/ndarray/sparse.py b/python/mxnet/ndarray/sparse.py index c7355c2e46d0..9c02b8e2cf27 100644 --- a/python/mxnet/ndarray/sparse.py +++ b/python/mxnet/ndarray/sparse.py @@ -42,7 +42,7 @@ from ..base import _LIB, numeric_types from ..base import c_array_buf, mx_real_t, integer_types from ..base import mx_uint, NDArrayHandle, check_call -from ..context import Context +from ..context import Context, current_context from . import _internal from . import op try: @@ -977,7 +977,7 @@ def _csr_matrix_from_definition(data, indices, indptr, shape=None, ctx=None, # pylint: disable= no-member, protected-access storage_type = 'csr' # context - ctx = Context.default_ctx if ctx is None else ctx + ctx = current_context() if ctx is None else ctx # types dtype = _prepare_default_dtype(data, dtype) indptr_type = _STORAGE_AUX_TYPES[storage_type][0] if indptr_type is None else indptr_type @@ -1140,7 +1140,7 @@ def _row_sparse_ndarray_from_definition(data, indices, shape=None, ctx=None, """Create a `RowSparseNDArray` based on data and indices""" storage_type = 'row_sparse' # context - ctx = Context.default_ctx if ctx is None else ctx + ctx = current_context() if ctx is None else ctx # types dtype = _prepare_default_dtype(data, dtype) indices_type = _STORAGE_AUX_TYPES[storage_type][0] if indices_type is None else indices_type @@ -1529,7 +1529,7 @@ def zeros(stype, shape, ctx=None, dtype=None, **kwargs): if stype == 'default': return _zeros_ndarray(shape, ctx=ctx, dtype=dtype, **kwargs) if ctx is None: - ctx = Context.default_ctx + ctx = current_context() dtype = mx_real_t if dtype is None else dtype if stype == 'row_sparse' or stype == 'csr': aux_types = _STORAGE_AUX_TYPES[stype] @@ -1562,7 +1562,7 @@ def empty(stype, shape, ctx=None, dtype=None): if isinstance(shape, int): shape = (shape, ) if ctx is None: - ctx = Context.default_ctx + ctx = current_context() if dtype is None: dtype = mx_real_t assert(stype is not None) @@ -1603,7 +1603,7 @@ def array(source_array, ctx=None, dtype=None): >>> mx.nd.sparse.array(mx.nd.sparse.zeros('row_sparse', (3, 2))) """ - ctx = Context.default_ctx if ctx is None else ctx + ctx = current_context() if ctx is None else ctx if isinstance(source_array, NDArray): assert(source_array.stype != 'default'), \ "Please use `tostype` to create RowSparseNDArray or CSRNDArray from an NDArray" diff --git a/python/mxnet/symbol/register.py b/python/mxnet/symbol/register.py index 6f9e868e2321..3e81dcf3a6c9 100644 --- a/python/mxnet/symbol/register.py +++ b/python/mxnet/symbol/register.py @@ -113,9 +113,9 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name)) dtype_name, dtype_name, dtype_name)) code.append(""" attr = kwargs.pop('attr', None) - kwargs.update(AttrScope.current.get(attr)) + kwargs.update(AttrScope._current.value.get(attr)) name = kwargs.pop('name', None) - name = NameManager.current.get(name, '%s') + name = NameManager._current.value.get(name, '%s') _ = kwargs.pop('out', None) keys = [] vals = [] @@ -141,7 +141,7 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name)) def %s(%s):"""%(func_name, ', '.join(signature))) if not signature_only: code.append(""" - kwargs.update(AttrScope.current.get(attr)) + kwargs.update(AttrScope._current.value.get(attr)) sym_kwargs = dict() _keys = [] _vals = [] @@ -172,7 +172,7 @@ def %s(%s):"""%(func_name, ', '.join(signature))) _vals.append(np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name)) code.append(""" - name = NameManager.current.get(name, '%s') + name = NameManager._current.value.get(name, '%s') return _symbol_creator(%d, None, sym_kwargs, _keys, _vals, name)"""%( func_name.lower(), handle.value)) diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 1ab7cf87bf50..49023db2fe0c 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -37,7 +37,7 @@ from ..base import mx_uint, py_str, string_types from ..base import NDArrayHandle, ExecutorHandle, SymbolHandle from ..base import check_call, MXNetError, NotImplementedForSymbol -from ..context import Context +from ..context import Context, current_context from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP, _GRAD_REQ_MAP from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID from ..ndarray import _ndarray_cls @@ -1767,7 +1767,7 @@ def eval(self, ctx=None, **kwargs): the result will be a list with one element. """ if ctx is None: - ctx = Context.default_ctx + ctx = current_context() return self.bind(ctx, kwargs).forward() def reshape(self, *args, **kwargs): @@ -2448,7 +2448,7 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, handle = SymbolHandle() check_call(_LIB.MXSymbolCreateVariable(c_str(name), ctypes.byref(handle))) ret = Symbol(handle) - attr = AttrScope.current.get(attr) + attr = AttrScope._current.value.get(attr) attr = {} if attr is None else attr if shape is not None: attr['__shape__'] = str(shape) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index aa388c14ea1e..bcdcc9c64080 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -44,7 +44,7 @@ # in rare cases requests may be not installed pass import mxnet as mx -from .context import Context +from .context import Context, current_context from .ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID from .ndarray import array from .symbol import Symbol @@ -54,12 +54,12 @@ def default_context(): """Get default context for regression test.""" # _TODO: get context from environment variable to support # testing with GPUs - return Context.default_ctx + return current_context() def set_default_context(ctx): """Set default context.""" - Context.default_ctx = ctx + Context._default_ctx.value = ctx def default_dtype(): diff --git a/tests/nightly/test_tlocal_racecondition.py b/tests/nightly/test_tlocal_racecondition.py new file mode 100644 index 000000000000..d43c45937c05 --- /dev/null +++ b/tests/nightly/test_tlocal_racecondition.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx +from mxnet import gluon +from mxnet import image +from mxnet import nd +import numpy as np +import logging + +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + +root_url = ('https://apache-mxnet.s3-accelerate.amazonaws.com/' + 'gluon/dataset/pikachu/') +data_dir = './data/pikachu/' +dataset = {'train.rec': 'e6bcb6ffba1ac04ff8a9b1115e650af56ee969c8', + 'train.idx': 'dcf7318b2602c06428b9988470c731621716c393', + 'val.rec': 'd6c33f799b4d058e82f2cb5bd9a976f69d72d520'} +for k, v in dataset.items(): + gluon.utils.download(root_url+k, data_dir+k, sha1_hash=v) + +T = 1 +devs = [mx.gpu(i) for i in range(4)] +data_shape = 224 * T +batch_size = 20 * len(devs) +rgb_mean = np.array([1,2,3]) + +class_names = ['pikachu'] +num_class = len(class_names) + +def get_iterators(data_shape, batch_size): + train_iter = image.ImageDetIter( + batch_size=batch_size, + data_shape=(3, data_shape, data_shape), + path_imgrec=data_dir+'train.rec', + path_imgidx=data_dir+'train.idx', + shuffle=True, + mean=True, + rand_crop=1, + min_object_covered=0.95, + max_attempts=200) + val_iter = image.ImageDetIter( + batch_size=batch_size, + data_shape=(3, data_shape, data_shape), + path_imgrec=data_dir+'val.rec', + shuffle=False, + mean=True) + return train_iter, val_iter, class_names, num_class + +train_data, test_data, class_names, num_class = get_iterators( + data_shape, batch_size) + + +class MyCustom(mx.operator.CustomOp): + def __init__(self): + super(MyCustom, self).__init__() + def forward(self, is_train, req, in_data, out_data, aux): + self.assign(out_data[0], req[0], 0) + def backward(self, req, out_grad, in_data, out_data, in_grad, aux): + self.assign(in_grad[0], req[0], 0) + self.assign(in_grad[1], req[1], 0) + +@mx.operator.register("MyCustom") +class MyCustomProp(mx.operator.CustomOpProp): + def __init__(self): + super(MyCustomProp, self).__init__(need_top_grad = False) + def list_arguments(self): + return ["data", "label"] + def list_outputs(self): + return ["loss"] + def infer_shape(self, in_shape): + return [in_shape[0], in_shape[1]], [(1, )], [] + def infer_type(self, in_type): + dtype = in_type[0] + return [dtype, dtype], [dtype], [] + def create_operator(self, ctx, shapes, dtypes): + return MyCustom() + +class MyMetric(mx.metric.EvalMetric): + def __init__(self): + super(MyMetric, self).__init__("MyMetric") + self.name = ['empty'] + def update(self, labels, preds): + pass + def get(self): + return self.name, [0] + +if __name__ == '__main__': + x = mx.sym.Variable("data") + label = mx.sym.Variable("label") + x = mx.sym.FullyConnected(data = x, num_hidden = 100) + label = mx.sym.Reshape(data = label, shape = (0, -1)) + sym = mx.sym.Custom(data = x, label = label, op_type = "MyCustom") + model = mx.module.Module(context = devs, symbol = sym, data_names = ('data',), label_names = ('label',)) + model.fit(train_data = train_data, begin_epoch = 0, num_epoch = 20, allow_missing = True, batch_end_callback = mx.callback.Speedometer(batch_size, 5), eval_metric = MyMetric()) diff --git a/tests/python/unittest/test_contrib_operator.py b/tests/python/unittest/test_contrib_operator.py index 800426c035b2..5618e11a0400 100644 --- a/tests/python/unittest/test_contrib_operator.py +++ b/tests/python/unittest/test_contrib_operator.py @@ -42,7 +42,7 @@ def test_box_nms_backward(data, grad, expected, thresh=0.5, topk=-1, coord=2, sc op = mx.contrib.sym.box_nms(in_var, overlap_thresh=thresh, topk=topk, coord_start=coord, score_index=score, id_index=cid, force_suppress=force, in_format=in_format, out_format=out_format) - exe = op.bind(ctx=mx.context.Context.default_ctx, args=[arr_data], args_grad=[arr_grad]) + exe = op.bind(ctx=default_context(), args=[arr_data], args_grad=[arr_grad]) exe.forward(is_train=True) exe.backward(mx.nd.array(grad)) assert_almost_equal(arr_grad.asnumpy(), expected) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 1db836b0918c..b7c5e49cda0e 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3714,7 +3714,7 @@ def test_tile_backward(): reps2 = 2 reps = (reps1, reps2) test = mx.sym.tile(data, reps=reps) - exe = test.bind(ctx=mx.context.Context.default_ctx, args=[arr_data], args_grad=[arr_grad]) + exe = test.bind(ctx=default_context(), args=[arr_data], args_grad=[arr_grad]) npout_grad = np.random.randint(0, 10, n1 * n2 * reps1 * reps2).reshape(n1 * reps1, n2 * reps2) out_grad = mx.nd.array(npout_grad) exe.backward(out_grad) @@ -4421,7 +4421,7 @@ def test_psroipooling(): output_dim=num_classes, name='test_op') rtol, atol = 1e-2, 1e-3 # By now we only have gpu implementation - if mx.Context.default_ctx.device_type == 'gpu': + if default_context().device_type == 'gpu': check_numeric_gradient(op, [im_data, rois_data], rtol=rtol, atol=atol, grad_nodes=grad_nodes, ctx=mx.gpu(0)) @@ -4459,7 +4459,7 @@ def test_deformable_convolution(): else: rtol, atol = 0.05, 1e-3 # By now we only have gpu implementation - if mx.Context.default_ctx.device_type == 'gpu': + if default_context().device_type == 'gpu': check_numeric_gradient(op, [im_data, offset_data, weight, bias], rtol=rtol, atol=atol, grad_nodes=grad_nodes, ctx=mx.gpu(0)) @@ -4495,7 +4495,7 @@ def test_deformable_psroipooling(): else: rtol, atol = 1e-2, 1e-3 # By now we only have gpu implementation - if mx.Context.default_ctx.device_type == 'gpu': + if default_context().device_type == 'gpu': check_numeric_gradient(op, [im_data, rois_data, offset_data], rtol=rtol, atol=atol, grad_nodes=grad_nodes, ctx=mx.gpu(0)) diff --git a/tests/python/unittest/test_thread_local.py b/tests/python/unittest/test_thread_local.py new file mode 100644 index 000000000000..a571a25ab2a6 --- /dev/null +++ b/tests/python/unittest/test_thread_local.py @@ -0,0 +1,139 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import threading +import mxnet as mx +from mxnet import context, attribute, name +from mxnet.gluon import block +from mxnet.context import Context +from mxnet.attribute import AttrScope +from mxnet.name import NameManager +from mxnet.test_utils import set_default_context + +def test_context(): + ctx_list = [] + ctx_list.append(Context.default_ctx) + def f(): + set_default_context(mx.gpu(11)) + ctx_list.append(Context.default_ctx) + thread = threading.Thread(target=f) + thread.start() + thread.join() + assert Context.devtype2str[ctx_list[0].device_typeid] == "cpu" + assert ctx_list[0].device_id == 0 + assert Context.devtype2str[ctx_list[1].device_typeid] == "gpu" + assert ctx_list[1].device_id == 11 + + event = threading.Event() + status = [False] + def g(): + with mx.cpu(10): + event.wait() + if Context.default_ctx.device_id == 10: + status[0] = True + thread = threading.Thread(target=g) + thread.start() + Context.default_ctx = Context("cpu", 11) + event.set() + thread.join() + event.clear() + assert status[0], "Spawned thread didn't set the correct context" + +def test_attrscope(): + attrscope_list = [] + AttrScope.current = AttrScope(y="hi", z="hey") + attrscope_list.append(AttrScope.current) + def f(): + AttrScope.current = AttrScope(x="hello") + attrscope_list.append(AttrScope.current) + thread = threading.Thread(target=f) + thread.start() + thread.join() + assert len(attrscope_list[0]._attr) == 2 + assert attrscope_list[1]._attr["x"] == "hello" + + event = threading.Event() + status = [False] + def g(): + with mx.AttrScope(x="hello"): + event.wait() + if "hello" in AttrScope.current._attr.values(): + status[0] = True + thread = threading.Thread(target=g) + thread.start() + AttrScope.current = AttrScope(x="hi") + event.set() + thread.join() + AttrScope.current = AttrScope() + event.clear() + assert status[0], "Spawned thread didn't set the correct attr key values" + +def test_name(): + name_list = [] + NameManager.current = NameManager() + NameManager.current.get(None, "main_thread") + name_list.append(NameManager.current) + def f(): + NameManager.current = NameManager() + NameManager.current.get(None, "spawned_thread") + name_list.append(NameManager.current) + thread = threading.Thread(target=f) + thread.start() + thread.join() + assert "main_thread" in name_list[0]._counter, "cannot find the string `main thread` in name_list[0]._counter" + assert "spawned_thread" in name_list[1]._counter, "cannot find the string `spawned thread` in name_list[1]._counter" + + event = threading.Event() + status = [False] + def g(): + with NameManager(): + if "main_thread" not in NameManager.current._counter: + status[0] = True + thread = threading.Thread(target=g) + thread.start() + NameManager.current = NameManager() + NameManager.current.get(None, "main_thread") + event.set() + thread.join() + event.clear() + assert status[0], "Spawned thread isn't using thread local NameManager" + +def test_blockscope(): + class dummy_block(object): + def __init__(self, prefix): + self.prefix = prefix + self._empty_prefix = False + blockscope_list = [] + status = [False] + event = threading.Event() + def f(): + with block._BlockScope(dummy_block("spawned_")): + x= NameManager.current.get(None, "hello") + event.wait() + if x == "spawned_hello0": + status[0] = True + thread = threading.Thread(target=f) + thread.start() + block._BlockScope.create("main_thread", None, "hi") + event.set() + thread.join() + event.clear() + assert status[0], "Spawned thread isn't using the correct blockscope namemanager" + +if __name__ == '__main__': + import nose + nose.runmodule()