From 1a8dda6a39298c9c8e02090025e379840e2a4a2d Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Tue, 8 May 2018 07:45:27 +0000 Subject: [PATCH 01/13] Change to simpler implementation --- python/mxnet/attribute.py | 14 ++++++++------ python/mxnet/context.py | 14 ++++++++------ python/mxnet/gluon/block.py | 14 ++++++++------ python/mxnet/name.py | 12 +++++++----- python/mxnet/ndarray/ndarray.py | 12 ++++++------ python/mxnet/ndarray/sparse.py | 10 +++++----- python/mxnet/symbol/register.py | 8 ++++---- python/mxnet/symbol/symbol.py | 4 ++-- python/mxnet/test_utils.py | 4 ++-- tests/python/unittest/test_contrib_operator.py | 2 +- tests/python/unittest/test_operator.py | 8 ++++---- 11 files changed, 55 insertions(+), 47 deletions(-) diff --git a/python/mxnet/attribute.py b/python/mxnet/attribute.py index 15d38f81f2e3..0b3c92e4e109 100644 --- a/python/mxnet/attribute.py +++ b/python/mxnet/attribute.py @@ -18,6 +18,7 @@ # coding: utf-8 """Attribute scoping support for symbolic API.""" from __future__ import absolute_import +import threading from .base import string_types @@ -31,7 +32,8 @@ class AttrScope(object): kwargs The attributes to set for all symbol creations in the scope. """ - current = None + _current = threading.local() + _current.value = None def __init__(self, **kwargs): self._old_scope = None @@ -64,15 +66,15 @@ def get(self, attr): def __enter__(self): # pylint: disable=protected-access - self._old_scope = AttrScope.current - attr = AttrScope.current._attr.copy() + self._old_scope = getattr(AttrScope._current, "value", AttrScope()) + attr = AttrScope._current.value._attr.copy() attr.update(self._attr) self._attr = attr - AttrScope.current = self + AttrScope._current.value = self return self def __exit__(self, ptype, value, trace): assert self._old_scope - AttrScope.current = self._old_scope + AttrScope._current.value = self._old_scope -AttrScope.current = AttrScope() +AttrScope._current.value = AttrScope() diff --git a/python/mxnet/context.py b/python/mxnet/context.py index eb47614e3335..012de387a11a 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -18,6 +18,7 @@ # coding: utf-8 """Context management API of mxnet.""" from __future__ import absolute_import +import threading class Context(object): """Constructs a context. @@ -61,7 +62,8 @@ class Context(object): gpu(1) """ # static class variable - default_ctx = None + _default_ctx = threading.local() + _default_ctx.value = None devtype2str = {1: 'cpu', 2: 'gpu', 3: 'cpu_pinned', 5: 'cpu_shared'} devstr2type = {'cpu': 1, 'gpu': 2, 'cpu_pinned': 3, 'cpu_shared': 5} def __init__(self, device_type, device_id=0): @@ -109,15 +111,15 @@ def __repr__(self): return self.__str__() def __enter__(self): - self._old_ctx = Context.default_ctx - Context.default_ctx = self + self._old_ctx = getattr(Context._default_ctx, "value", Context('cpu', 0)) + Context._default_ctx.value = self return self def __exit__(self, ptype, value, trace): - Context.default_ctx = self._old_ctx + Context._default_ctx.value = self._old_ctx # initialize the default context in Context -Context.default_ctx = Context('cpu', 0) +Context._default_ctx.value = Context('cpu', 0) def cpu(device_id=0): @@ -234,4 +236,4 @@ def current_context(): ------- default_ctx : Context """ - return Context.default_ctx + return Context._default_ctx.value diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index abc474850f24..8e1a8d5424ef 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -20,6 +20,7 @@ """Base container class for all neural network models.""" __all__ = ['Block', 'HybridBlock', 'SymbolBlock'] +import threading import copy import warnings import re @@ -35,7 +36,8 @@ class _BlockScope(object): """Scope for collecting child `Block` s.""" - _current = None + _current = threading.local() + _current.value = None def __init__(self, block): self._block = block @@ -46,10 +48,10 @@ def __init__(self, block): @staticmethod def create(prefix, params, hint): """Creates prefix and params for new `Block`.""" - current = _BlockScope._current + current = _BlockScope._current.value if current is None: if prefix is None: - prefix = _name.NameManager.current.get(None, hint) + '_' + prefix = _name.NameManager._current.value.get(None, hint) + '_' if params is None: params = ParameterDict(prefix) else: @@ -70,8 +72,8 @@ def create(prefix, params, hint): def __enter__(self): if self._block._empty_prefix: return self - self._old_scope = _BlockScope._current - _BlockScope._current = self + self._old_scope = getattr(_BlockScope._current, "value", None) + _BlockScope._current.value = self self._name_scope = _name.Prefix(self._block.prefix) self._name_scope.__enter__() return self @@ -81,7 +83,7 @@ def __exit__(self, ptype, value, trace): return self._name_scope.__exit__(ptype, value, trace) self._name_scope = None - _BlockScope._current = self._old_scope + _BlockScope._current.value = self._old_scope def _flatten(args, inout_str): diff --git a/python/mxnet/name.py b/python/mxnet/name.py index 966d38280ef7..65c7a85ebf10 100644 --- a/python/mxnet/name.py +++ b/python/mxnet/name.py @@ -18,13 +18,15 @@ # coding: utf-8 """Automatic naming support for symbolic API.""" from __future__ import absolute_import +import threading class NameManager(object): """NameManager to do automatic naming. Developers can also inherit from this class to change naming behavior. """ - current = None + _current = threading.local() + _current.value = None def __init__(self): self._counter = {} @@ -62,13 +64,13 @@ def get(self, name, hint): return name def __enter__(self): - self._old_manager = NameManager.current - NameManager.current = self + self._old_manager = getattr(NameManager._current, "value", NameManager()) + NameManager._current.value = self return self def __exit__(self, ptype, value, trace): assert self._old_manager - NameManager.current = self._old_manager + NameManager._current.value = self._old_manager class Prefix(NameManager): @@ -92,4 +94,4 @@ def get(self, name, hint): return self._prefix + name # initialize the default name manager -NameManager.current = NameManager() +NameManager._current.value = NameManager() diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 2411932af268..a8f46eebe56a 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -2257,7 +2257,7 @@ def ones(shape, ctx=None, dtype=None, **kwargs): The shape of the empty array. ctx : Context, optional An optional device context. - Defaults to the current default context (``mxnet.Context.default_ctx``). + Defaults to the current default context (``mxnet.Context._default_ctx.value``). dtype : str or numpy.dtype, optional An optional value type (default is `float32`). out : NDArray, optional @@ -2279,7 +2279,7 @@ def ones(shape, ctx=None, dtype=None, **kwargs): """ # pylint: disable= unused-argument if ctx is None: - ctx = Context.default_ctx + ctx = Context._default_ctx.value dtype = mx_real_t if dtype is None else dtype # pylint: disable= no-member, protected-access return _internal._ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs) @@ -2435,7 +2435,7 @@ def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t): array([2, 2, 2, 4, 4, 4], dtype=int32) """ if ctx is None: - ctx = Context.default_ctx + ctx = Context._default_ctx.value return _internal._arange(start=start, stop=stop, step=step, repeat=repeat, dtype=dtype, ctx=str(ctx)) # pylint: enable= no-member, protected-access, too-many-arguments @@ -3662,7 +3662,7 @@ def zeros(shape, ctx=None, dtype=None, **kwargs): """ # pylint: disable= unused-argument if ctx is None: - ctx = Context.default_ctx + ctx = Context._default_ctx.value dtype = mx_real_t if dtype is None else dtype # pylint: disable= no-member, protected-access return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs) @@ -3701,7 +3701,7 @@ def eye(N, M=0, k=0, ctx=None, dtype=None, **kwargs): """ # pylint: disable= unused-argument if ctx is None: - ctx = Context.default_ctx + ctx = Context._default_ctx.value dtype = mx_real_t if dtype is None else dtype # pylint: disable= no-member, protected-access return _internal._eye(N=N, M=M, k=k, ctx=ctx, dtype=dtype, **kwargs) @@ -3729,7 +3729,7 @@ def empty(shape, ctx=None, dtype=None): if isinstance(shape, int): shape = (shape, ) if ctx is None: - ctx = Context.default_ctx + ctx = Context._default_ctx.value if dtype is None: dtype = mx_real_t return NDArray(handle=_new_alloc_handle(shape, ctx, False, dtype)) diff --git a/python/mxnet/ndarray/sparse.py b/python/mxnet/ndarray/sparse.py index c7355c2e46d0..524b87f3cd3d 100644 --- a/python/mxnet/ndarray/sparse.py +++ b/python/mxnet/ndarray/sparse.py @@ -977,7 +977,7 @@ def _csr_matrix_from_definition(data, indices, indptr, shape=None, ctx=None, # pylint: disable= no-member, protected-access storage_type = 'csr' # context - ctx = Context.default_ctx if ctx is None else ctx + ctx = Context._default_ctx.value if ctx is None else ctx # types dtype = _prepare_default_dtype(data, dtype) indptr_type = _STORAGE_AUX_TYPES[storage_type][0] if indptr_type is None else indptr_type @@ -1140,7 +1140,7 @@ def _row_sparse_ndarray_from_definition(data, indices, shape=None, ctx=None, """Create a `RowSparseNDArray` based on data and indices""" storage_type = 'row_sparse' # context - ctx = Context.default_ctx if ctx is None else ctx + ctx = Context._default_ctx.value if ctx is None else ctx # types dtype = _prepare_default_dtype(data, dtype) indices_type = _STORAGE_AUX_TYPES[storage_type][0] if indices_type is None else indices_type @@ -1529,7 +1529,7 @@ def zeros(stype, shape, ctx=None, dtype=None, **kwargs): if stype == 'default': return _zeros_ndarray(shape, ctx=ctx, dtype=dtype, **kwargs) if ctx is None: - ctx = Context.default_ctx + ctx = Context._default_ctx.value dtype = mx_real_t if dtype is None else dtype if stype == 'row_sparse' or stype == 'csr': aux_types = _STORAGE_AUX_TYPES[stype] @@ -1562,7 +1562,7 @@ def empty(stype, shape, ctx=None, dtype=None): if isinstance(shape, int): shape = (shape, ) if ctx is None: - ctx = Context.default_ctx + ctx = Context._default_ctx.value if dtype is None: dtype = mx_real_t assert(stype is not None) @@ -1603,7 +1603,7 @@ def array(source_array, ctx=None, dtype=None): >>> mx.nd.sparse.array(mx.nd.sparse.zeros('row_sparse', (3, 2))) """ - ctx = Context.default_ctx if ctx is None else ctx + ctx = Context._default_ctx.value if ctx is None else ctx if isinstance(source_array, NDArray): assert(source_array.stype != 'default'), \ "Please use `tostype` to create RowSparseNDArray or CSRNDArray from an NDArray" diff --git a/python/mxnet/symbol/register.py b/python/mxnet/symbol/register.py index 6f9e868e2321..3e81dcf3a6c9 100644 --- a/python/mxnet/symbol/register.py +++ b/python/mxnet/symbol/register.py @@ -113,9 +113,9 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name)) dtype_name, dtype_name, dtype_name)) code.append(""" attr = kwargs.pop('attr', None) - kwargs.update(AttrScope.current.get(attr)) + kwargs.update(AttrScope._current.value.get(attr)) name = kwargs.pop('name', None) - name = NameManager.current.get(name, '%s') + name = NameManager._current.value.get(name, '%s') _ = kwargs.pop('out', None) keys = [] vals = [] @@ -141,7 +141,7 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name)) def %s(%s):"""%(func_name, ', '.join(signature))) if not signature_only: code.append(""" - kwargs.update(AttrScope.current.get(attr)) + kwargs.update(AttrScope._current.value.get(attr)) sym_kwargs = dict() _keys = [] _vals = [] @@ -172,7 +172,7 @@ def %s(%s):"""%(func_name, ', '.join(signature))) _vals.append(np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name)) code.append(""" - name = NameManager.current.get(name, '%s') + name = NameManager._current.value.get(name, '%s') return _symbol_creator(%d, None, sym_kwargs, _keys, _vals, name)"""%( func_name.lower(), handle.value)) diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 1ab7cf87bf50..757c31f98fa2 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -1767,7 +1767,7 @@ def eval(self, ctx=None, **kwargs): the result will be a list with one element. """ if ctx is None: - ctx = Context.default_ctx + ctx = Context._default_ctx.value return self.bind(ctx, kwargs).forward() def reshape(self, *args, **kwargs): @@ -2448,7 +2448,7 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, handle = SymbolHandle() check_call(_LIB.MXSymbolCreateVariable(c_str(name), ctypes.byref(handle))) ret = Symbol(handle) - attr = AttrScope.current.get(attr) + attr = AttrScope._current.value.get(attr) attr = {} if attr is None else attr if shape is not None: attr['__shape__'] = str(shape) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index aa388c14ea1e..2f0ae97e58c4 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -54,12 +54,12 @@ def default_context(): """Get default context for regression test.""" # _TODO: get context from environment variable to support # testing with GPUs - return Context.default_ctx + return Context._default_ctx.value def set_default_context(ctx): """Set default context.""" - Context.default_ctx = ctx + Context._default_ctx.value = ctx def default_dtype(): diff --git a/tests/python/unittest/test_contrib_operator.py b/tests/python/unittest/test_contrib_operator.py index 800426c035b2..5618e11a0400 100644 --- a/tests/python/unittest/test_contrib_operator.py +++ b/tests/python/unittest/test_contrib_operator.py @@ -42,7 +42,7 @@ def test_box_nms_backward(data, grad, expected, thresh=0.5, topk=-1, coord=2, sc op = mx.contrib.sym.box_nms(in_var, overlap_thresh=thresh, topk=topk, coord_start=coord, score_index=score, id_index=cid, force_suppress=force, in_format=in_format, out_format=out_format) - exe = op.bind(ctx=mx.context.Context.default_ctx, args=[arr_data], args_grad=[arr_grad]) + exe = op.bind(ctx=default_context(), args=[arr_data], args_grad=[arr_grad]) exe.forward(is_train=True) exe.backward(mx.nd.array(grad)) assert_almost_equal(arr_grad.asnumpy(), expected) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 7ee67dd20660..762719d3b050 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3667,7 +3667,7 @@ def test_tile_backward(): reps2 = 2 reps = (reps1, reps2) test = mx.sym.tile(data, reps=reps) - exe = test.bind(ctx=mx.context.Context.default_ctx, args=[arr_data], args_grad=[arr_grad]) + exe = test.bind(ctx=default_context(), args=[arr_data], args_grad=[arr_grad]) npout_grad = np.random.randint(0, 10, n1 * n2 * reps1 * reps2).reshape(n1 * reps1, n2 * reps2) out_grad = mx.nd.array(npout_grad) exe.backward(out_grad) @@ -4364,7 +4364,7 @@ def test_psroipooling(): output_dim=num_classes, name='test_op') rtol, atol = 1e-2, 1e-3 # By now we only have gpu implementation - if mx.Context.default_ctx.device_type == 'gpu': + if default_context().device_type == 'gpu': check_numeric_gradient(op, [im_data, rois_data], rtol=rtol, atol=atol, grad_nodes=grad_nodes, ctx=mx.gpu(0)) @@ -4402,7 +4402,7 @@ def test_deformable_convolution(): else: rtol, atol = 0.05, 1e-3 # By now we only have gpu implementation - if mx.Context.default_ctx.device_type == 'gpu': + if default_context().device_type == 'gpu': check_numeric_gradient(op, [im_data, offset_data, weight, bias], rtol=rtol, atol=atol, grad_nodes=grad_nodes, ctx=mx.gpu(0)) @@ -4438,7 +4438,7 @@ def test_deformable_psroipooling(): else: rtol, atol = 1e-2, 1e-3 # By now we only have gpu implementation - if mx.Context.default_ctx.device_type == 'gpu': + if default_context().device_type == 'gpu': check_numeric_gradient(op, [im_data, rois_data, offset_data], rtol=rtol, atol=atol, grad_nodes=grad_nodes, ctx=mx.gpu(0)) From 9c5c903157e76a03218daa96899621c288f4a583 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Tue, 8 May 2018 22:58:43 +0000 Subject: [PATCH 02/13] Add property --- python/mxnet/attribute.py | 17 ++++++++++++++--- python/mxnet/base.py | 31 +++++++++++++++++++++++++++++++ python/mxnet/context.py | 17 +++++++++++++++-- python/mxnet/gluon/block.py | 3 +-- python/mxnet/name.py | 15 +++++++++++++-- 5 files changed, 74 insertions(+), 9 deletions(-) diff --git a/python/mxnet/attribute.py b/python/mxnet/attribute.py index 0b3c92e4e109..934fc3fbb400 100644 --- a/python/mxnet/attribute.py +++ b/python/mxnet/attribute.py @@ -20,7 +20,7 @@ from __future__ import absolute_import import threading -from .base import string_types +from .base import string_types, classproperty class AttrScope(object): """Attribute manager for scoping. @@ -33,7 +33,6 @@ class AttrScope(object): The attributes to set for all symbol creations in the scope. """ _current = threading.local() - _current.value = None def __init__(self, **kwargs): self._old_scope = None @@ -66,7 +65,9 @@ def get(self, attr): def __enter__(self): # pylint: disable=protected-access - self._old_scope = getattr(AttrScope._current, "value", AttrScope()) + if not hasattr(AttrScope._current, "value"): + AttrScope._current.value = AttrScope() + self._old_scope = AttrScope._current.value attr = AttrScope._current.value._attr.copy() attr.update(self._attr) self._attr = attr @@ -77,4 +78,14 @@ def __exit__(self, ptype, value, trace): assert self._old_scope AttrScope._current.value = self._old_scope + @classproperty + def current(cls): + if not hasattr(AttrScope._current, "value"): + cls._current.value = AttrScope() + return cls._current.value + + @current.setter + def current(cls, val): + cls._current.value = val + AttrScope._current.value = AttrScope() diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 9790e090e387..4c1eb8441e6e 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -227,6 +227,37 @@ def c_str_array(strings): arr[:] = [s.encode('utf-8') for s in strings] return arr +class _MXClassPropertyDescriptor(object): + + def __init__(self, fget, fset=None): + self.fget = fget + self.fset = fset + + def __get__(self, obj, clas=None): + if clas is None: + clas = type(obj) + return self.fget.__get__(obj, clas)() + + def __set__(self, obj, value): + if not self.fset: + raise MXNetError("cannot use the setter: %s to set attribute".format(obj.__name__)) + type_ = type(obj) + return self.fset.__get__(obj, type_)(value) + + def setter(self, func): + if not isinstance(func, (classmethod, staticmethod)): + func = classmethod(func) + self.fset = func + return self + +def classproperty(func): + if not isinstance(func, (classmethod, staticmethod)): + func = classmethod(func) + + return _MXClassPropertyDescriptor(func) + + + def c_array(ctype, values): """Create ctypes array from a Python array. diff --git a/python/mxnet/context.py b/python/mxnet/context.py index 012de387a11a..701c0523f41f 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -19,6 +19,7 @@ """Context management API of mxnet.""" from __future__ import absolute_import import threading +from .base import classproperty class Context(object): """Constructs a context. @@ -63,7 +64,6 @@ class Context(object): """ # static class variable _default_ctx = threading.local() - _default_ctx.value = None devtype2str = {1: 'cpu', 2: 'gpu', 3: 'cpu_pinned', 5: 'cpu_shared'} devstr2type = {'cpu': 1, 'gpu': 2, 'cpu_pinned': 3, 'cpu_shared': 5} def __init__(self, device_type, device_id=0): @@ -111,13 +111,26 @@ def __repr__(self): return self.__str__() def __enter__(self): - self._old_ctx = getattr(Context._default_ctx, "value", Context('cpu', 0)) + import pdb; pdb.set_trace() + if not hasattr(Context._default_ctx, "value"): + Context._default_ctx.value = Context('cpu', 0) + self._old_ctx = Context._default_ctx.value Context._default_ctx.value = self return self def __exit__(self, ptype, value, trace): Context._default_ctx.value = self._old_ctx + @classproperty + def default_ctx(cls): + if not hasattr(Context._default_ctx, "value"): + cls._default_ctx.value = Context('cpu', 0) + return cls._default_ctx.value + + @default_ctx.setter + def default_ctx(cls, val): + cls._default_ctx.value = val + # initialize the default context in Context Context._default_ctx.value = Context('cpu', 0) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 8e1a8d5424ef..7e4127250a09 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -37,7 +37,6 @@ class _BlockScope(object): """Scope for collecting child `Block` s.""" _current = threading.local() - _current.value = None def __init__(self, block): self._block = block @@ -48,7 +47,7 @@ def __init__(self, block): @staticmethod def create(prefix, params, hint): """Creates prefix and params for new `Block`.""" - current = _BlockScope._current.value + current = getattr(_BlockScope._current, "value", None) if current is None: if prefix is None: prefix = _name.NameManager._current.value.get(None, hint) + '_' diff --git a/python/mxnet/name.py b/python/mxnet/name.py index 65c7a85ebf10..c67eb8f063db 100644 --- a/python/mxnet/name.py +++ b/python/mxnet/name.py @@ -19,6 +19,7 @@ """Automatic naming support for symbolic API.""" from __future__ import absolute_import import threading +from .base import classproperty class NameManager(object): """NameManager to do automatic naming. @@ -26,7 +27,6 @@ class NameManager(object): Developers can also inherit from this class to change naming behavior. """ _current = threading.local() - _current.value = None def __init__(self): self._counter = {} @@ -64,7 +64,9 @@ def get(self, name, hint): return name def __enter__(self): - self._old_manager = getattr(NameManager._current, "value", NameManager()) + if not hasattr(NameManager._current, "value"): + NameManager._current.value = NameManager() + self._old_manager = NameManager._current.value NameManager._current.value = self return self @@ -72,6 +74,15 @@ def __exit__(self, ptype, value, trace): assert self._old_manager NameManager._current.value = self._old_manager + @classproperty + def current(cls): + if not hasattr(NameManager._current, "value"): + cls._current.value = NameManager() + return cls._current.value + + @current.setter + def current(cls, val): + cls._current.value = val class Prefix(NameManager): """A name manager that attaches a prefix to all names. From 4623ce56a174ead7445c390cc261bdce9f5f492b Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Wed, 9 May 2018 05:05:29 +0000 Subject: [PATCH 03/13] Remove pdb --- python/mxnet/context.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/mxnet/context.py b/python/mxnet/context.py index 701c0523f41f..8924726aa825 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -111,7 +111,6 @@ def __repr__(self): return self.__str__() def __enter__(self): - import pdb; pdb.set_trace() if not hasattr(Context._default_ctx, "value"): Context._default_ctx.value = Context('cpu', 0) self._old_ctx = Context._default_ctx.value From d8b1b2f5a61e5367c20affcc734d982df058130d Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Wed, 9 May 2018 21:56:41 +0000 Subject: [PATCH 04/13] Add support for setter and getter --- python/mxnet/attribute.py | 4 +- python/mxnet/base.py | 79 ++++++++++++++++++++++++--------------- python/mxnet/context.py | 15 +++++++- python/mxnet/name.py | 4 +- 4 files changed, 66 insertions(+), 36 deletions(-) diff --git a/python/mxnet/attribute.py b/python/mxnet/attribute.py index 934fc3fbb400..2dae078cf5c2 100644 --- a/python/mxnet/attribute.py +++ b/python/mxnet/attribute.py @@ -20,9 +20,9 @@ from __future__ import absolute_import import threading -from .base import string_types, classproperty +from .base import string_types, classproperty, _MXPr -class AttrScope(object): +class AttrScope(_MXPropMetaClassHolder): """Attribute manager for scoping. User can also inherit this object to change naming behavior. diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 4c1eb8441e6e..ec0276583879 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -98,6 +98,55 @@ class MXCallbackList(ctypes.Structure): ('contexts', ctypes.POINTER(ctypes.c_void_p)) ] +class _MXClassPropertyDescriptor(object): + def __init__(self, fget, fset=None): + self.fget = fget + self.fset = fset + + def __get__(self, obj, clas=None): + if clas is None: + clas = type(obj) + return self.fget.__get__(obj, clas)() + + def __set__(self, obj, value): + if not self.fset: + raise MXNetError("cannot use the setter: %s to set attribute".format(obj.__name__)) + if inspect.isclass(obj): + type_ = obj + obj = None + else: + type_ = type(obj) + return self.fset.__get__(obj, type_)(value) + + def setter(self, func): + if not isinstance(func, (classmethod, staticmethod)): + func = classmethod(func) + self.fset = func + return self + +class _MXClassPropertyMetaClass(type): + def __setattr__(self, key, value): + if key in self.__dict__: + obj = self.__dict__.get(key) + if obj and type(obj) is _MXClassPropertyDescriptor: + return obj.__set__(self, value) + + return super(_MXClassPropertyMetaClass, self).__setattr__(key, value) + +if sys.version_info[0] < 3: + class _MXPropMetaClassHolder(object): + __metaclass__ = _MXClassPropertyMetaClass +else: + class _MXPropMetaClassHolder(object, metaclass = _MXClassPropertyMetaClass): + pass + +def classproperty(func): + if not isinstance(func, (classmethod, staticmethod)): + func = classmethod(func) + + return _MXClassPropertyDescriptor(func) + + def _load_lib(): """Load library by searching possible path.""" @@ -227,36 +276,6 @@ def c_str_array(strings): arr[:] = [s.encode('utf-8') for s in strings] return arr -class _MXClassPropertyDescriptor(object): - - def __init__(self, fget, fset=None): - self.fget = fget - self.fset = fset - - def __get__(self, obj, clas=None): - if clas is None: - clas = type(obj) - return self.fget.__get__(obj, clas)() - - def __set__(self, obj, value): - if not self.fset: - raise MXNetError("cannot use the setter: %s to set attribute".format(obj.__name__)) - type_ = type(obj) - return self.fset.__get__(obj, type_)(value) - - def setter(self, func): - if not isinstance(func, (classmethod, staticmethod)): - func = classmethod(func) - self.fset = func - return self - -def classproperty(func): - if not isinstance(func, (classmethod, staticmethod)): - func = classmethod(func) - - return _MXClassPropertyDescriptor(func) - - def c_array(ctype, values): """Create ctypes array from a Python array. diff --git a/python/mxnet/context.py b/python/mxnet/context.py index 8924726aa825..c8317e82ab08 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -19,9 +19,10 @@ """Context management API of mxnet.""" from __future__ import absolute_import import threading -from .base import classproperty +import warnings +from .base import classproperty, _MXClassPropertyMetaClass, _MXPropMetaClassHolder -class Context(object): +class Context(_MXPropMetaClassHolder): """Constructs a context. MXNet can run operations on CPU and different GPUs. @@ -122,12 +123,22 @@ def __exit__(self, ptype, value, trace): @classproperty def default_ctx(cls): + warnings.warn("Context.default_ctx has been deprecated. " + "Please use Context.current_context() instead. " + "Please use test_utils.set_default_context to set a default context", + DeprecationWarning, + stacklevel=3) if not hasattr(Context._default_ctx, "value"): cls._default_ctx.value = Context('cpu', 0) return cls._default_ctx.value @default_ctx.setter def default_ctx(cls, val): + warnings.warn("Context.default_ctx has been deprecated. " + "Please use Context.current_context() instead. " + "Please use test_utils.set_default_context to set a default context", + DeprecationWarning, + stacklevel=3) cls._default_ctx.value = val # initialize the default context in Context diff --git a/python/mxnet/name.py b/python/mxnet/name.py index c67eb8f063db..f3cc360ca37c 100644 --- a/python/mxnet/name.py +++ b/python/mxnet/name.py @@ -19,9 +19,9 @@ """Automatic naming support for symbolic API.""" from __future__ import absolute_import import threading -from .base import classproperty +from .base import classproperty, _MXPropMetaClassHolder -class NameManager(object): +class NameManager(_MXPropMetaClassHolder): """NameManager to do automatic naming. Developers can also inherit from this class to change naming behavior. From 40fc6bb70b358b72ceaf487c893aea5677c45fbe Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Thu, 10 May 2018 00:24:11 +0000 Subject: [PATCH 05/13] fix issues --- python/mxnet/attribute.py | 10 ++++++++-- python/mxnet/base.py | 22 ++++++++++++++++------ python/mxnet/context.py | 10 ++++------ python/mxnet/name.py | 7 +++++-- 4 files changed, 33 insertions(+), 16 deletions(-) diff --git a/python/mxnet/attribute.py b/python/mxnet/attribute.py index 2dae078cf5c2..dde56f4d9636 100644 --- a/python/mxnet/attribute.py +++ b/python/mxnet/attribute.py @@ -20,9 +20,9 @@ from __future__ import absolute_import import threading -from .base import string_types, classproperty, _MXPr +from .base import string_types, classproperty, with_metaclass, _MXClassPropertyMetaClass -class AttrScope(_MXPropMetaClassHolder): +class AttrScope(with_metaclass(_MXClassPropertyMetaClass, object)): """Attribute manager for scoping. User can also inherit this object to change naming behavior. @@ -80,12 +80,18 @@ def __exit__(self, ptype, value, trace): @classproperty def current(cls): + warnings.warn("AttrScope.current has been deprecated. " + "It is advised to use the `with` statement with AttrScope.", + DeprecationWarning) if not hasattr(AttrScope._current, "value"): cls._current.value = AttrScope() return cls._current.value @current.setter def current(cls, val): + warnings.warn("AttrScope.current has been deprecated. " + "It is advised to use the `with` statement with AttrScope.", + DeprecationWarning) cls._current.value = val AttrScope._current.value = AttrScope() diff --git a/python/mxnet/base.py b/python/mxnet/base.py index ec0276583879..7b004df8d3f0 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -98,6 +98,7 @@ class MXCallbackList(ctypes.Structure): ('contexts', ctypes.POINTER(ctypes.c_void_p)) ] +# Please see: https://stackoverflow.com/questions/5189699/how-to-make-a-class-property class _MXClassPropertyDescriptor(object): def __init__(self, fget, fset=None): self.fget = fget @@ -133,12 +134,21 @@ def __setattr__(self, key, value): return super(_MXClassPropertyMetaClass, self).__setattr__(key, value) -if sys.version_info[0] < 3: - class _MXPropMetaClassHolder(object): - __metaclass__ = _MXClassPropertyMetaClass -else: - class _MXPropMetaClassHolder(object, metaclass = _MXClassPropertyMetaClass): - pass +# with_metaclass function obtained from: https://github.com/benjaminp/six/blob/master/six.py +def with_metaclass(meta, *bases): + """Create a base class with a metaclass.""" + # This requires a bit of explanation: the basic idea is to make a dummy + # metaclass for one level of class instantiation that replaces itself with + # the actual metaclass. + class metaclass(type): + + def __new__(cls, name, this_bases, d): + return meta(name, bases, d) + + @classmethod + def __prepare__(cls, name, this_bases): + return meta.__prepare__(name, bases) + return type.__new__(metaclass, 'temporary_class', (), {}) def classproperty(func): if not isinstance(func, (classmethod, staticmethod)): diff --git a/python/mxnet/context.py b/python/mxnet/context.py index c8317e82ab08..6dd39f6b6146 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -20,9 +20,9 @@ from __future__ import absolute_import import threading import warnings -from .base import classproperty, _MXClassPropertyMetaClass, _MXPropMetaClassHolder +from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass -class Context(_MXPropMetaClassHolder): +class Context(with_metaclass(_MXClassPropertyMetaClass, object)): """Constructs a context. MXNet can run operations on CPU and different GPUs. @@ -126,8 +126,7 @@ def default_ctx(cls): warnings.warn("Context.default_ctx has been deprecated. " "Please use Context.current_context() instead. " "Please use test_utils.set_default_context to set a default context", - DeprecationWarning, - stacklevel=3) + DeprecationWarning) if not hasattr(Context._default_ctx, "value"): cls._default_ctx.value = Context('cpu', 0) return cls._default_ctx.value @@ -137,8 +136,7 @@ def default_ctx(cls, val): warnings.warn("Context.default_ctx has been deprecated. " "Please use Context.current_context() instead. " "Please use test_utils.set_default_context to set a default context", - DeprecationWarning, - stacklevel=3) + DeprecationWarning) cls._default_ctx.value = val # initialize the default context in Context diff --git a/python/mxnet/name.py b/python/mxnet/name.py index f3cc360ca37c..4f611f87e1b2 100644 --- a/python/mxnet/name.py +++ b/python/mxnet/name.py @@ -19,9 +19,9 @@ """Automatic naming support for symbolic API.""" from __future__ import absolute_import import threading -from .base import classproperty, _MXPropMetaClassHolder +from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass -class NameManager(_MXPropMetaClassHolder): +class NameManager(with_metaclass(_MXClassPropertyMetaClass, object)): """NameManager to do automatic naming. Developers can also inherit from this class to change naming behavior. @@ -76,6 +76,9 @@ def __exit__(self, ptype, value, trace): @classproperty def current(cls): + warnings.warn("NameManager.current has been deprecated. " + "It is advised to use the `with` statement with NameManager.", + DeprecationWarning) if not hasattr(NameManager._current, "value"): cls._current.value = NameManager() return cls._current.value From 6b79d89099ce358b71f56e1bc1a4d2c8a748491b Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Thu, 10 May 2018 00:41:24 +0000 Subject: [PATCH 06/13] Add warnings --- python/mxnet/attribute.py | 1 + python/mxnet/name.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/mxnet/attribute.py b/python/mxnet/attribute.py index dde56f4d9636..3a829eb2c370 100644 --- a/python/mxnet/attribute.py +++ b/python/mxnet/attribute.py @@ -19,6 +19,7 @@ """Attribute scoping support for symbolic API.""" from __future__ import absolute_import import threading +import warnings from .base import string_types, classproperty, with_metaclass, _MXClassPropertyMetaClass diff --git a/python/mxnet/name.py b/python/mxnet/name.py index 4f611f87e1b2..271f7e75cf26 100644 --- a/python/mxnet/name.py +++ b/python/mxnet/name.py @@ -19,6 +19,7 @@ """Automatic naming support for symbolic API.""" from __future__ import absolute_import import threading +import warnings from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass class NameManager(with_metaclass(_MXClassPropertyMetaClass, object)): From df057e2c40328206cceb2899624a333144d0967a Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Thu, 10 May 2018 01:06:08 +0000 Subject: [PATCH 07/13] Add thread local unittest and tlocal race condition --- tests/nightly/test_tlocal_racecondition.py | 110 ++++++++++++++++ tests/python/unittest/test_thread_local.py | 146 +++++++++++++++++++++ 2 files changed, 256 insertions(+) create mode 100644 tests/nightly/test_tlocal_racecondition.py create mode 100644 tests/python/unittest/test_thread_local.py diff --git a/tests/nightly/test_tlocal_racecondition.py b/tests/nightly/test_tlocal_racecondition.py new file mode 100644 index 000000000000..d43c45937c05 --- /dev/null +++ b/tests/nightly/test_tlocal_racecondition.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx +from mxnet import gluon +from mxnet import image +from mxnet import nd +import numpy as np +import logging + +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + +root_url = ('https://apache-mxnet.s3-accelerate.amazonaws.com/' + 'gluon/dataset/pikachu/') +data_dir = './data/pikachu/' +dataset = {'train.rec': 'e6bcb6ffba1ac04ff8a9b1115e650af56ee969c8', + 'train.idx': 'dcf7318b2602c06428b9988470c731621716c393', + 'val.rec': 'd6c33f799b4d058e82f2cb5bd9a976f69d72d520'} +for k, v in dataset.items(): + gluon.utils.download(root_url+k, data_dir+k, sha1_hash=v) + +T = 1 +devs = [mx.gpu(i) for i in range(4)] +data_shape = 224 * T +batch_size = 20 * len(devs) +rgb_mean = np.array([1,2,3]) + +class_names = ['pikachu'] +num_class = len(class_names) + +def get_iterators(data_shape, batch_size): + train_iter = image.ImageDetIter( + batch_size=batch_size, + data_shape=(3, data_shape, data_shape), + path_imgrec=data_dir+'train.rec', + path_imgidx=data_dir+'train.idx', + shuffle=True, + mean=True, + rand_crop=1, + min_object_covered=0.95, + max_attempts=200) + val_iter = image.ImageDetIter( + batch_size=batch_size, + data_shape=(3, data_shape, data_shape), + path_imgrec=data_dir+'val.rec', + shuffle=False, + mean=True) + return train_iter, val_iter, class_names, num_class + +train_data, test_data, class_names, num_class = get_iterators( + data_shape, batch_size) + + +class MyCustom(mx.operator.CustomOp): + def __init__(self): + super(MyCustom, self).__init__() + def forward(self, is_train, req, in_data, out_data, aux): + self.assign(out_data[0], req[0], 0) + def backward(self, req, out_grad, in_data, out_data, in_grad, aux): + self.assign(in_grad[0], req[0], 0) + self.assign(in_grad[1], req[1], 0) + +@mx.operator.register("MyCustom") +class MyCustomProp(mx.operator.CustomOpProp): + def __init__(self): + super(MyCustomProp, self).__init__(need_top_grad = False) + def list_arguments(self): + return ["data", "label"] + def list_outputs(self): + return ["loss"] + def infer_shape(self, in_shape): + return [in_shape[0], in_shape[1]], [(1, )], [] + def infer_type(self, in_type): + dtype = in_type[0] + return [dtype, dtype], [dtype], [] + def create_operator(self, ctx, shapes, dtypes): + return MyCustom() + +class MyMetric(mx.metric.EvalMetric): + def __init__(self): + super(MyMetric, self).__init__("MyMetric") + self.name = ['empty'] + def update(self, labels, preds): + pass + def get(self): + return self.name, [0] + +if __name__ == '__main__': + x = mx.sym.Variable("data") + label = mx.sym.Variable("label") + x = mx.sym.FullyConnected(data = x, num_hidden = 100) + label = mx.sym.Reshape(data = label, shape = (0, -1)) + sym = mx.sym.Custom(data = x, label = label, op_type = "MyCustom") + model = mx.module.Module(context = devs, symbol = sym, data_names = ('data',), label_names = ('label',)) + model.fit(train_data = train_data, begin_epoch = 0, num_epoch = 20, allow_missing = True, batch_end_callback = mx.callback.Speedometer(batch_size, 5), eval_metric = MyMetric()) diff --git a/tests/python/unittest/test_thread_local.py b/tests/python/unittest/test_thread_local.py new file mode 100644 index 000000000000..276e0034d4d8 --- /dev/null +++ b/tests/python/unittest/test_thread_local.py @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import threading +import mxnet as mx +from mxnet import context, attribute, name +from mxnet.gluon import block +from mxnet.context import Context +from mxnet.attribute import AttrScope +from mxnet.name import NameManager +from mxnet.test_utils import set_default_context + +def test_context(): + ctx_list = [] + ctx_list.append(Context.default_ctx) + def f(): + set_default_context(mx.gpu(11)) + ctx_list.append(Context.default_ctx) + thread = threading.Thread(target=f) + thread.start() + thread.join() + assert Context.devtype2str[ctx_list[0].device_typeid] == "cpu" + assert ctx_list[0].device_id == 0 + assert Context.devtype2str[ctx_list[1].device_typeid] == "gpu" + assert ctx_list[1].device_id == 11 + + condition = threading.Condition() + status = [False] + def g(): + condition.acquire() + with mx.cpu(10): + condition.wait() + if Context.default_ctx.device_id == 10: + status[0] = True + condition.release() + thread = threading.Thread(target=g) + thread.start() + condition.acquire() + Context.default_ctx = Context("cpu", 11) + condition.notify() + condition.release() + thread.join() + assert status[0], "Spawned thread didn't set the correct context" + +def test_attrscope(): + attrscope_list = [] + AttrScope.current = AttrScope(y="hi", z="hey") + attrscope_list.append(AttrScope.current) + def f(): + AttrScope.current = AttrScope(x="hello") + attrscope_list.append(AttrScope.current) + thread = threading.Thread(target=f) + thread.start() + thread.join() + assert len(attrscope_list[0]._attr) == 2 + assert attrscope_list[1]._attr["x"] == "hello" + + condition = threading.Condition() + status = [False] + def g(): + condition.acquire() + with mx.AttrScope(x="hello"): + condition.wait() + if "hello" in AttrScope.current._attr.values(): + status[0] = True + thread = threading.Thread(target=g) + thread.start() + condition.acquire() + AttrScope.current = AttrScope(x="hi") + condition.notify() + condition.release() + thread.join() + assert status[0], "Spawned thread didn't set the correct attr key values" + +def test_name(): + name_list = [] + NameManager.current = NameManager() + NameManager.current.get(None, "main_thread") + name_list.append(NameManager.current) + def f(): + NameManager.current = NameManager() + NameManager.current.get(None, "spawned_thread") + name_list.append(NameManager.current) + thread = threading.Thread(target=f) + thread.start() + thread.join() + assert "main_thread" in name_list[0]._counter, "cannot find the string `main thread` in name_list[0]._counter" + assert "spawned_thread" in name_list[1]._counter, "cannot find the string `spawned thread` in name_list[1]._counter" + + condition = threading.Condition() + status = [False] + def g(): + condition.acquire() + with NameManager(): + condition.wait() + if "main_thread" not in NameManager.current._counter: + status[0] = True + thread = threading.Thread(target=g) + thread.start() + condition.acquire() + NameManager.current = NameManager() + NameManager.current.get(None, "main_thread") + condition.notify() + condition.release() + thread.join() + assert status[0], "Spawned thread isn't using thread local NameManager" + +def test_blockscope(): + class dummy_block(object): + def __init__(self, prefix): + self.prefix = prefix + self._empty_prefix = False + blockscope_list = [] + status = [False] + condition = threading.Condition() + def f(): + with block._BlockScope(dummy_block("spawned_")): + x= NameManager.current.get(None, "hello") + if x == "spawned_hello0": + status[0] = True + thread = threading.Thread(target=f) + thread.start() + condition.acquire() + block._BlockScope.create("main_thread", None, "hi") + condition.notify() + condition.release() + thread.join() + assert status[0], "Spawned thread isn't using the correct blockscope namemanager" + +if __name__ == '__main__': + import nose + nose.runmodule() From 3d6501f7ab125f56dc6fa9fba95aba5ce36e5bbc Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Thu, 10 May 2018 01:34:20 +0000 Subject: [PATCH 08/13] Fix pylint --- python/mxnet/attribute.py | 2 ++ python/mxnet/base.py | 18 ++++++++++-------- python/mxnet/context.py | 2 ++ python/mxnet/name.py | 2 ++ 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/python/mxnet/attribute.py b/python/mxnet/attribute.py index 3a829eb2c370..17044ddaef06 100644 --- a/python/mxnet/attribute.py +++ b/python/mxnet/attribute.py @@ -79,6 +79,7 @@ def __exit__(self, ptype, value, trace): assert self._old_scope AttrScope._current.value = self._old_scope + #pylint: disable=no-self-argument @classproperty def current(cls): warnings.warn("AttrScope.current has been deprecated. " @@ -94,5 +95,6 @@ def current(cls, val): "It is advised to use the `with` statement with AttrScope.", DeprecationWarning) cls._current.value = val + #pylint: enable=no-self-argument AttrScope._current.value = AttrScope() diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 7b004df8d3f0..0fb73b3c7dda 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -16,7 +16,7 @@ # under the License. # coding: utf-8 -# pylint: disable=invalid-name, no-member, trailing-comma-tuple +# pylint: disable=invalid-name, no-member, trailing-comma-tuple, bad-mcs-classmethod-argument """ctypes library of mxnet and helper functions.""" from __future__ import absolute_import @@ -111,7 +111,7 @@ def __get__(self, obj, clas=None): def __set__(self, obj, value): if not self.fset: - raise MXNetError("cannot use the setter: %s to set attribute".format(obj.__name__)) + raise MXNetError("cannot use the setter: %s to set attribute" % obj.__name__) if inspect.isclass(obj): type_ = obj obj = None @@ -126,15 +126,16 @@ def setter(self, func): return self class _MXClassPropertyMetaClass(type): - def __setattr__(self, key, value): - if key in self.__dict__: - obj = self.__dict__.get(key) - if obj and type(obj) is _MXClassPropertyDescriptor: - return obj.__set__(self, value) + def __setattr__(cls, key, value): + if key in cls.__dict__: + obj = cls.__dict__.get(key) + if obj and isinstance(obj, _MXClassPropertyDescriptor): + return obj.__set__(cls, value) - return super(_MXClassPropertyMetaClass, self).__setattr__(key, value) + return super(_MXClassPropertyMetaClass, cls).__setattr__(key, value) # with_metaclass function obtained from: https://github.com/benjaminp/six/blob/master/six.py +#pylint: disable=unused-argument def with_metaclass(meta, *bases): """Create a base class with a metaclass.""" # This requires a bit of explanation: the basic idea is to make a dummy @@ -149,6 +150,7 @@ def __new__(cls, name, this_bases, d): def __prepare__(cls, name, this_bases): return meta.__prepare__(name, bases) return type.__new__(metaclass, 'temporary_class', (), {}) +#pylint: enable=unused-argument def classproperty(func): if not isinstance(func, (classmethod, staticmethod)): diff --git a/python/mxnet/context.py b/python/mxnet/context.py index 6dd39f6b6146..b9a7c652d254 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -121,6 +121,7 @@ def __enter__(self): def __exit__(self, ptype, value, trace): Context._default_ctx.value = self._old_ctx + #pylint: disable=no-self-argument @classproperty def default_ctx(cls): warnings.warn("Context.default_ctx has been deprecated. " @@ -138,6 +139,7 @@ def default_ctx(cls, val): "Please use test_utils.set_default_context to set a default context", DeprecationWarning) cls._default_ctx.value = val + #pylint: enable=no-self-argument # initialize the default context in Context Context._default_ctx.value = Context('cpu', 0) diff --git a/python/mxnet/name.py b/python/mxnet/name.py index 271f7e75cf26..4149d1db2731 100644 --- a/python/mxnet/name.py +++ b/python/mxnet/name.py @@ -75,6 +75,7 @@ def __exit__(self, ptype, value, trace): assert self._old_manager NameManager._current.value = self._old_manager + #pylint: disable=no-self-argument @classproperty def current(cls): warnings.warn("NameManager.current has been deprecated. " @@ -87,6 +88,7 @@ def current(cls): @current.setter def current(cls, val): cls._current.value = val + #pylint: enable=no-self-argument class Prefix(NameManager): """A name manager that attaches a prefix to all names. From 57af178061fe023c278ed277a7e83225e07d89e4 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Thu, 10 May 2018 06:57:51 +0000 Subject: [PATCH 09/13] Use current_context instead of _default_ctx --- python/mxnet/context.py | 2 ++ python/mxnet/ndarray/ndarray.py | 12 ++++++------ python/mxnet/ndarray/sparse.py | 10 +++++----- python/mxnet/symbol/symbol.py | 2 +- python/mxnet/test_utils.py | 2 +- 5 files changed, 15 insertions(+), 13 deletions(-) diff --git a/python/mxnet/context.py b/python/mxnet/context.py index b9a7c652d254..5861890f40c1 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -259,4 +259,6 @@ def current_context(): ------- default_ctx : Context """ + if not hasattr(Context._default_ctx, "value"): + Context._default_ctx.value = Context('cpu', 0) return Context._default_ctx.value diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 36825dd87f1d..e40b6fc35c1f 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -2261,7 +2261,7 @@ def ones(shape, ctx=None, dtype=None, **kwargs): The shape of the empty array. ctx : Context, optional An optional device context. - Defaults to the current default context (``mxnet.Context._default_ctx.value``). + Defaults to the current default context (``mxnet.Context.current_context()``). dtype : str or numpy.dtype, optional An optional value type (default is `float32`). out : NDArray, optional @@ -2283,7 +2283,7 @@ def ones(shape, ctx=None, dtype=None, **kwargs): """ # pylint: disable= unused-argument if ctx is None: - ctx = Context._default_ctx.value + ctx = Context.current_context() dtype = mx_real_t if dtype is None else dtype # pylint: disable= no-member, protected-access return _internal._ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs) @@ -2439,7 +2439,7 @@ def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t): array([2, 2, 2, 4, 4, 4], dtype=int32) """ if ctx is None: - ctx = Context._default_ctx.value + ctx = Context.current_context() return _internal._arange(start=start, stop=stop, step=step, repeat=repeat, dtype=dtype, ctx=str(ctx)) # pylint: enable= no-member, protected-access, too-many-arguments @@ -3666,7 +3666,7 @@ def zeros(shape, ctx=None, dtype=None, **kwargs): """ # pylint: disable= unused-argument if ctx is None: - ctx = Context._default_ctx.value + ctx = Context.current_context() dtype = mx_real_t if dtype is None else dtype # pylint: disable= no-member, protected-access return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs) @@ -3705,7 +3705,7 @@ def eye(N, M=0, k=0, ctx=None, dtype=None, **kwargs): """ # pylint: disable= unused-argument if ctx is None: - ctx = Context._default_ctx.value + ctx = Context.current_context() dtype = mx_real_t if dtype is None else dtype # pylint: disable= no-member, protected-access return _internal._eye(N=N, M=M, k=k, ctx=ctx, dtype=dtype, **kwargs) @@ -3733,7 +3733,7 @@ def empty(shape, ctx=None, dtype=None): if isinstance(shape, int): shape = (shape, ) if ctx is None: - ctx = Context._default_ctx.value + ctx = Context.current_context() if dtype is None: dtype = mx_real_t return NDArray(handle=_new_alloc_handle(shape, ctx, False, dtype)) diff --git a/python/mxnet/ndarray/sparse.py b/python/mxnet/ndarray/sparse.py index 524b87f3cd3d..e320271daa94 100644 --- a/python/mxnet/ndarray/sparse.py +++ b/python/mxnet/ndarray/sparse.py @@ -977,7 +977,7 @@ def _csr_matrix_from_definition(data, indices, indptr, shape=None, ctx=None, # pylint: disable= no-member, protected-access storage_type = 'csr' # context - ctx = Context._default_ctx.value if ctx is None else ctx + ctx = Context.current_context() if ctx is None else ctx # types dtype = _prepare_default_dtype(data, dtype) indptr_type = _STORAGE_AUX_TYPES[storage_type][0] if indptr_type is None else indptr_type @@ -1140,7 +1140,7 @@ def _row_sparse_ndarray_from_definition(data, indices, shape=None, ctx=None, """Create a `RowSparseNDArray` based on data and indices""" storage_type = 'row_sparse' # context - ctx = Context._default_ctx.value if ctx is None else ctx + ctx = Context.current_context() if ctx is None else ctx # types dtype = _prepare_default_dtype(data, dtype) indices_type = _STORAGE_AUX_TYPES[storage_type][0] if indices_type is None else indices_type @@ -1529,7 +1529,7 @@ def zeros(stype, shape, ctx=None, dtype=None, **kwargs): if stype == 'default': return _zeros_ndarray(shape, ctx=ctx, dtype=dtype, **kwargs) if ctx is None: - ctx = Context._default_ctx.value + ctx = Context.current_context() dtype = mx_real_t if dtype is None else dtype if stype == 'row_sparse' or stype == 'csr': aux_types = _STORAGE_AUX_TYPES[stype] @@ -1562,7 +1562,7 @@ def empty(stype, shape, ctx=None, dtype=None): if isinstance(shape, int): shape = (shape, ) if ctx is None: - ctx = Context._default_ctx.value + ctx = Context.current_context() if dtype is None: dtype = mx_real_t assert(stype is not None) @@ -1603,7 +1603,7 @@ def array(source_array, ctx=None, dtype=None): >>> mx.nd.sparse.array(mx.nd.sparse.zeros('row_sparse', (3, 2))) """ - ctx = Context._default_ctx.value if ctx is None else ctx + ctx = Context.current_context() if ctx is None else ctx if isinstance(source_array, NDArray): assert(source_array.stype != 'default'), \ "Please use `tostype` to create RowSparseNDArray or CSRNDArray from an NDArray" diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 757c31f98fa2..190e35396e0b 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -1767,7 +1767,7 @@ def eval(self, ctx=None, **kwargs): the result will be a list with one element. """ if ctx is None: - ctx = Context._default_ctx.value + ctx = Context.current_context() return self.bind(ctx, kwargs).forward() def reshape(self, *args, **kwargs): diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 2f0ae97e58c4..ca0748fd1fc3 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -54,7 +54,7 @@ def default_context(): """Get default context for regression test.""" # _TODO: get context from environment variable to support # testing with GPUs - return Context._default_ctx.value + return Context.current_context() def set_default_context(ctx): From d12e4a7402447f17fa3998bb8b73d1b73d0a4d55 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Thu, 10 May 2018 07:24:47 +0000 Subject: [PATCH 10/13] Use current_context --- python/mxnet/ndarray/ndarray.py | 14 +++++++------- python/mxnet/ndarray/sparse.py | 12 ++++++------ python/mxnet/test_utils.py | 4 ++-- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index e40b6fc35c1f..007b3c82def2 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -37,7 +37,7 @@ from ..base import c_array, c_array_buf, c_handle_array, mx_real_t from ..base import mx_uint, NDArrayHandle, check_call from ..base import ctypes2buffer -from ..context import Context +from ..context import Context, current_context from . import _internal from . import op from ._internal import NDArrayBase @@ -2261,7 +2261,7 @@ def ones(shape, ctx=None, dtype=None, **kwargs): The shape of the empty array. ctx : Context, optional An optional device context. - Defaults to the current default context (``mxnet.Context.current_context()``). + Defaults to the current default context (``mxnet.context.current_context()``). dtype : str or numpy.dtype, optional An optional value type (default is `float32`). out : NDArray, optional @@ -2283,7 +2283,7 @@ def ones(shape, ctx=None, dtype=None, **kwargs): """ # pylint: disable= unused-argument if ctx is None: - ctx = Context.current_context() + ctx = current_context() dtype = mx_real_t if dtype is None else dtype # pylint: disable= no-member, protected-access return _internal._ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs) @@ -2439,7 +2439,7 @@ def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t): array([2, 2, 2, 4, 4, 4], dtype=int32) """ if ctx is None: - ctx = Context.current_context() + ctx = current_context() return _internal._arange(start=start, stop=stop, step=step, repeat=repeat, dtype=dtype, ctx=str(ctx)) # pylint: enable= no-member, protected-access, too-many-arguments @@ -3666,7 +3666,7 @@ def zeros(shape, ctx=None, dtype=None, **kwargs): """ # pylint: disable= unused-argument if ctx is None: - ctx = Context.current_context() + ctx = current_context() dtype = mx_real_t if dtype is None else dtype # pylint: disable= no-member, protected-access return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs) @@ -3705,7 +3705,7 @@ def eye(N, M=0, k=0, ctx=None, dtype=None, **kwargs): """ # pylint: disable= unused-argument if ctx is None: - ctx = Context.current_context() + ctx = current_context() dtype = mx_real_t if dtype is None else dtype # pylint: disable= no-member, protected-access return _internal._eye(N=N, M=M, k=k, ctx=ctx, dtype=dtype, **kwargs) @@ -3733,7 +3733,7 @@ def empty(shape, ctx=None, dtype=None): if isinstance(shape, int): shape = (shape, ) if ctx is None: - ctx = Context.current_context() + ctx = current_context() if dtype is None: dtype = mx_real_t return NDArray(handle=_new_alloc_handle(shape, ctx, False, dtype)) diff --git a/python/mxnet/ndarray/sparse.py b/python/mxnet/ndarray/sparse.py index e320271daa94..9c02b8e2cf27 100644 --- a/python/mxnet/ndarray/sparse.py +++ b/python/mxnet/ndarray/sparse.py @@ -42,7 +42,7 @@ from ..base import _LIB, numeric_types from ..base import c_array_buf, mx_real_t, integer_types from ..base import mx_uint, NDArrayHandle, check_call -from ..context import Context +from ..context import Context, current_context from . import _internal from . import op try: @@ -977,7 +977,7 @@ def _csr_matrix_from_definition(data, indices, indptr, shape=None, ctx=None, # pylint: disable= no-member, protected-access storage_type = 'csr' # context - ctx = Context.current_context() if ctx is None else ctx + ctx = current_context() if ctx is None else ctx # types dtype = _prepare_default_dtype(data, dtype) indptr_type = _STORAGE_AUX_TYPES[storage_type][0] if indptr_type is None else indptr_type @@ -1140,7 +1140,7 @@ def _row_sparse_ndarray_from_definition(data, indices, shape=None, ctx=None, """Create a `RowSparseNDArray` based on data and indices""" storage_type = 'row_sparse' # context - ctx = Context.current_context() if ctx is None else ctx + ctx = current_context() if ctx is None else ctx # types dtype = _prepare_default_dtype(data, dtype) indices_type = _STORAGE_AUX_TYPES[storage_type][0] if indices_type is None else indices_type @@ -1529,7 +1529,7 @@ def zeros(stype, shape, ctx=None, dtype=None, **kwargs): if stype == 'default': return _zeros_ndarray(shape, ctx=ctx, dtype=dtype, **kwargs) if ctx is None: - ctx = Context.current_context() + ctx = current_context() dtype = mx_real_t if dtype is None else dtype if stype == 'row_sparse' or stype == 'csr': aux_types = _STORAGE_AUX_TYPES[stype] @@ -1562,7 +1562,7 @@ def empty(stype, shape, ctx=None, dtype=None): if isinstance(shape, int): shape = (shape, ) if ctx is None: - ctx = Context.current_context() + ctx = current_context() if dtype is None: dtype = mx_real_t assert(stype is not None) @@ -1603,7 +1603,7 @@ def array(source_array, ctx=None, dtype=None): >>> mx.nd.sparse.array(mx.nd.sparse.zeros('row_sparse', (3, 2))) """ - ctx = Context.current_context() if ctx is None else ctx + ctx = current_context() if ctx is None else ctx if isinstance(source_array, NDArray): assert(source_array.stype != 'default'), \ "Please use `tostype` to create RowSparseNDArray or CSRNDArray from an NDArray" diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index ca0748fd1fc3..bcdcc9c64080 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -44,7 +44,7 @@ # in rare cases requests may be not installed pass import mxnet as mx -from .context import Context +from .context import Context, current_context from .ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID from .ndarray import array from .symbol import Symbol @@ -54,7 +54,7 @@ def default_context(): """Get default context for regression test.""" # _TODO: get context from environment variable to support # testing with GPUs - return Context.current_context() + return current_context() def set_default_context(ctx): From 1a91314330c23f4f98b42d3a36e991c3f4753ca9 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Thu, 10 May 2018 21:59:35 +0000 Subject: [PATCH 11/13] Fix race condition --- tests/python/unittest/test_thread_local.py | 38 +++++++++------------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/tests/python/unittest/test_thread_local.py b/tests/python/unittest/test_thread_local.py index 276e0034d4d8..b709e87f45af 100644 --- a/tests/python/unittest/test_thread_local.py +++ b/tests/python/unittest/test_thread_local.py @@ -38,22 +38,19 @@ def f(): assert Context.devtype2str[ctx_list[1].device_typeid] == "gpu" assert ctx_list[1].device_id == 11 - condition = threading.Condition() + event = threading.Event() status = [False] def g(): - condition.acquire() with mx.cpu(10): - condition.wait() + event.wait() if Context.default_ctx.device_id == 10: status[0] = True - condition.release() thread = threading.Thread(target=g) thread.start() - condition.acquire() Context.default_ctx = Context("cpu", 11) - condition.notify() - condition.release() + event.set() thread.join() + event.clear() assert status[0], "Spawned thread didn't set the correct context" def test_attrscope(): @@ -69,21 +66,19 @@ def f(): assert len(attrscope_list[0]._attr) == 2 assert attrscope_list[1]._attr["x"] == "hello" - condition = threading.Condition() + event = threading.Event() status = [False] def g(): - condition.acquire() with mx.AttrScope(x="hello"): - condition.wait() + event.wait() if "hello" in AttrScope.current._attr.values(): status[0] = True thread = threading.Thread(target=g) thread.start() - condition.acquire() AttrScope.current = AttrScope(x="hi") - condition.notify() - condition.release() + event.set() thread.join() + event.clear() assert status[0], "Spawned thread didn't set the correct attr key values" def test_name(): @@ -101,22 +96,19 @@ def f(): assert "main_thread" in name_list[0]._counter, "cannot find the string `main thread` in name_list[0]._counter" assert "spawned_thread" in name_list[1]._counter, "cannot find the string `spawned thread` in name_list[1]._counter" - condition = threading.Condition() + event = threading.Event() status = [False] def g(): - condition.acquire() with NameManager(): - condition.wait() if "main_thread" not in NameManager.current._counter: status[0] = True thread = threading.Thread(target=g) thread.start() - condition.acquire() NameManager.current = NameManager() NameManager.current.get(None, "main_thread") - condition.notify() - condition.release() + event.set() thread.join() + event.clear() assert status[0], "Spawned thread isn't using thread local NameManager" def test_blockscope(): @@ -126,19 +118,19 @@ def __init__(self, prefix): self._empty_prefix = False blockscope_list = [] status = [False] - condition = threading.Condition() + event = threading.Event() def f(): with block._BlockScope(dummy_block("spawned_")): x= NameManager.current.get(None, "hello") + event.wait() if x == "spawned_hello0": status[0] = True thread = threading.Thread(target=f) thread.start() - condition.acquire() block._BlockScope.create("main_thread", None, "hi") - condition.notify() - condition.release() + event.set() thread.join() + event.clear() assert status[0], "Spawned thread isn't using the correct blockscope namemanager" if __name__ == '__main__': From c31cf5960552e0ed729c9e3d6b7770a95bcfae0b Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Thu, 10 May 2018 23:32:41 +0000 Subject: [PATCH 12/13] Fix thread local test --- tests/python/unittest/test_thread_local.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/unittest/test_thread_local.py b/tests/python/unittest/test_thread_local.py index b709e87f45af..a571a25ab2a6 100644 --- a/tests/python/unittest/test_thread_local.py +++ b/tests/python/unittest/test_thread_local.py @@ -78,6 +78,7 @@ def g(): AttrScope.current = AttrScope(x="hi") event.set() thread.join() + AttrScope.current = AttrScope() event.clear() assert status[0], "Spawned thread didn't set the correct attr key values" From 0af608b543e10ff18a2cc96067c1cf0ff834ec87 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Thu, 10 May 2018 23:36:41 +0000 Subject: [PATCH 13/13] Change to current_context --- python/mxnet/symbol/symbol.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 190e35396e0b..49023db2fe0c 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -37,7 +37,7 @@ from ..base import mx_uint, py_str, string_types from ..base import NDArrayHandle, ExecutorHandle, SymbolHandle from ..base import check_call, MXNetError, NotImplementedForSymbol -from ..context import Context +from ..context import Context, current_context from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP, _GRAD_REQ_MAP from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID from ..ndarray import _ndarray_cls @@ -1767,7 +1767,7 @@ def eval(self, ctx=None, **kwargs): the result will be a list with one element. """ if ctx is None: - ctx = Context.current_context() + ctx = current_context() return self.bind(ctx, kwargs).forward() def reshape(self, *args, **kwargs):