Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions python/mxnet/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
# coding: utf-8
"""Attribute scoping support for symbolic API."""
from __future__ import absolute_import
import threading
import warnings

from .base import string_types
from .base import string_types, classproperty, with_metaclass, _MXClassPropertyMetaClass

class AttrScope(object):
class AttrScope(with_metaclass(_MXClassPropertyMetaClass, object)):
"""Attribute manager for scoping.

User can also inherit this object to change naming behavior.
Expand All @@ -31,7 +33,7 @@ class AttrScope(object):
kwargs
The attributes to set for all symbol creations in the scope.
"""
current = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_current = threading.local()

_current = threading.local()

def __init__(self, **kwargs):
self._old_scope = None
Expand Down Expand Up @@ -64,15 +66,35 @@ def get(self, attr):

def __enter__(self):
# pylint: disable=protected-access
self._old_scope = AttrScope.current
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._old_scope = AttrScope.current.value

attr = AttrScope.current._attr.copy()
if not hasattr(AttrScope._current, "value"):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

AttrScope._current.value = AttrScope()
self._old_scope = AttrScope._current.value
attr = AttrScope._current.value._attr.copy()
attr.update(self._attr)
self._attr = attr
AttrScope.current = self
AttrScope._current.value = self
return self

def __exit__(self, ptype, value, trace):
assert self._old_scope
AttrScope.current = self._old_scope
AttrScope._current.value = self._old_scope

AttrScope.current = AttrScope()
#pylint: disable=no-self-argument
@classproperty
def current(cls):
warnings.warn("AttrScope.current has been deprecated. "
"It is advised to use the `with` statement with AttrScope.",
DeprecationWarning)
if not hasattr(AttrScope._current, "value"):
cls._current.value = AttrScope()
return cls._current.value

@current.setter
def current(cls, val):
warnings.warn("AttrScope.current has been deprecated. "
"It is advised to use the `with` statement with AttrScope.",
DeprecationWarning)
cls._current.value = val
#pylint: enable=no-self-argument

AttrScope._current.value = AttrScope()
64 changes: 63 additions & 1 deletion python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

# coding: utf-8
# pylint: disable=invalid-name, no-member, trailing-comma-tuple
# pylint: disable=invalid-name, no-member, trailing-comma-tuple, bad-mcs-classmethod-argument
"""ctypes library of mxnet and helper functions."""
from __future__ import absolute_import

Expand Down Expand Up @@ -98,6 +98,67 @@ class MXCallbackList(ctypes.Structure):
('contexts', ctypes.POINTER(ctypes.c_void_p))
]

# Please see: https://stackoverflow.com/questions/5189699/how-to-make-a-class-property
class _MXClassPropertyDescriptor(object):
def __init__(self, fget, fset=None):
self.fget = fget
self.fset = fset

def __get__(self, obj, clas=None):
if clas is None:
clas = type(obj)
return self.fget.__get__(obj, clas)()

def __set__(self, obj, value):
if not self.fset:
raise MXNetError("cannot use the setter: %s to set attribute" % obj.__name__)
if inspect.isclass(obj):
type_ = obj
obj = None
else:
type_ = type(obj)
return self.fset.__get__(obj, type_)(value)

def setter(self, func):
if not isinstance(func, (classmethod, staticmethod)):
func = classmethod(func)
self.fset = func
return self

class _MXClassPropertyMetaClass(type):
def __setattr__(cls, key, value):
if key in cls.__dict__:
obj = cls.__dict__.get(key)
if obj and isinstance(obj, _MXClassPropertyDescriptor):
return obj.__set__(cls, value)

return super(_MXClassPropertyMetaClass, cls).__setattr__(key, value)

# with_metaclass function obtained from: https://github.com/benjaminp/six/blob/master/six.py
#pylint: disable=unused-argument
def with_metaclass(meta, *bases):
"""Create a base class with a metaclass."""
# This requires a bit of explanation: the basic idea is to make a dummy
# metaclass for one level of class instantiation that replaces itself with
# the actual metaclass.
class metaclass(type):

def __new__(cls, name, this_bases, d):
return meta(name, bases, d)

@classmethod
def __prepare__(cls, name, this_bases):
return meta.__prepare__(name, bases)
return type.__new__(metaclass, 'temporary_class', (), {})
#pylint: enable=unused-argument

def classproperty(func):
if not isinstance(func, (classmethod, staticmethod)):
func = classmethod(func)

return _MXClassPropertyDescriptor(func)



def _load_lib():
"""Load library by searching possible path."""
Expand Down Expand Up @@ -227,6 +288,7 @@ def c_str_array(strings):
arr[:] = [s.encode('utf-8') for s in strings]
return arr


def c_array(ctype, values):
"""Create ctypes array from a Python array.

Expand Down
41 changes: 34 additions & 7 deletions python/mxnet/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
# coding: utf-8
"""Context management API of mxnet."""
from __future__ import absolute_import
import threading
import warnings
from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass

class Context(object):
class Context(with_metaclass(_MXClassPropertyMetaClass, object)):
"""Constructs a context.

MXNet can run operations on CPU and different GPUs.
Expand Down Expand Up @@ -61,7 +64,7 @@ class Context(object):
gpu(1)
"""
# static class variable
default_ctx = None
_default_ctx = threading.local()
devtype2str = {1: 'cpu', 2: 'gpu', 3: 'cpu_pinned', 5: 'cpu_shared'}
devstr2type = {'cpu': 1, 'gpu': 2, 'cpu_pinned': 3, 'cpu_shared': 5}
def __init__(self, device_type, device_id=0):
Expand Down Expand Up @@ -109,15 +112,37 @@ def __repr__(self):
return self.__str__()

def __enter__(self):
self._old_ctx = Context.default_ctx
Context.default_ctx = self
if not hasattr(Context._default_ctx, "value"):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about a situation like below:

import threading
import mxnet as mx

def test_context():
    ctx_list = []
    def f():
        mx.context.Context.default_ctx
        ctx_list.append(Context.default_ctx)
    thread = threading.Thread(target=f)
    thread.start()
    thread.join()

if __name__ == '__main__':
    test_context()

Without hasattr this doesn't work.

Context._default_ctx.value = Context('cpu', 0)
self._old_ctx = Context._default_ctx.value
Context._default_ctx.value = self
return self

def __exit__(self, ptype, value, trace):
Context.default_ctx = self._old_ctx
Context._default_ctx.value = self._old_ctx

#pylint: disable=no-self-argument
@classproperty
def default_ctx(cls):
warnings.warn("Context.default_ctx has been deprecated. "
"Please use Context.current_context() instead. "
"Please use test_utils.set_default_context to set a default context",
DeprecationWarning)
if not hasattr(Context._default_ctx, "value"):
cls._default_ctx.value = Context('cpu', 0)
return cls._default_ctx.value

@default_ctx.setter
def default_ctx(cls, val):
warnings.warn("Context.default_ctx has been deprecated. "
"Please use Context.current_context() instead. "
"Please use test_utils.set_default_context to set a default context",
DeprecationWarning)
cls._default_ctx.value = val
#pylint: enable=no-self-argument

# initialize the default context in Context
Context.default_ctx = Context('cpu', 0)
Context._default_ctx.value = Context('cpu', 0)


def cpu(device_id=0):
Expand Down Expand Up @@ -234,4 +259,6 @@ def current_context():
-------
default_ctx : Context
"""
return Context.default_ctx
if not hasattr(Context._default_ctx, "value"):
Context._default_ctx.value = Context('cpu', 0)
return Context._default_ctx.value
13 changes: 7 additions & 6 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""Base container class for all neural network models."""
__all__ = ['Block', 'HybridBlock', 'SymbolBlock']

import threading
import copy
import warnings
import re
Expand All @@ -35,7 +36,7 @@

class _BlockScope(object):
"""Scope for collecting child `Block` s."""
_current = None
_current = threading.local()

def __init__(self, block):
self._block = block
Expand All @@ -46,10 +47,10 @@ def __init__(self, block):
@staticmethod
def create(prefix, params, hint):
"""Creates prefix and params for new `Block`."""
current = _BlockScope._current
current = getattr(_BlockScope._current, "value", None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initialize _BlockScope._current.value to None globally and use _BlockScope._current.value directly here?

if current is None:
if prefix is None:
prefix = _name.NameManager.current.get(None, hint) + '_'
prefix = _name.NameManager._current.value.get(None, hint) + '_'
if params is None:
params = ParameterDict(prefix)
else:
Expand All @@ -70,8 +71,8 @@ def create(prefix, params, hint):
def __enter__(self):
if self._block._empty_prefix:
return self
self._old_scope = _BlockScope._current
_BlockScope._current = self
self._old_scope = getattr(_BlockScope._current, "value", None)
_BlockScope._current.value = self
self._name_scope = _name.Prefix(self._block.prefix)
self._name_scope.__enter__()
return self
Expand All @@ -81,7 +82,7 @@ def __exit__(self, ptype, value, trace):
return
self._name_scope.__exit__(ptype, value, trace)
self._name_scope = None
_BlockScope._current = self._old_scope
_BlockScope._current.value = self._old_scope


def _flatten(args, inout_str):
Expand Down
33 changes: 26 additions & 7 deletions python/mxnet/name.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
# coding: utf-8
"""Automatic naming support for symbolic API."""
from __future__ import absolute_import
import threading
import warnings
from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass

class NameManager(object):
class NameManager(with_metaclass(_MXClassPropertyMetaClass, object)):
"""NameManager to do automatic naming.

Developers can also inherit from this class to change naming behavior.
"""
current = None
_current = threading.local()

def __init__(self):
self._counter = {}
Expand Down Expand Up @@ -62,14 +65,30 @@ def get(self, name, hint):
return name

def __enter__(self):
self._old_manager = NameManager.current
NameManager.current = self
if not hasattr(NameManager._current, "value"):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

NameManager._current.value = NameManager()
self._old_manager = NameManager._current.value
NameManager._current.value = self
return self

def __exit__(self, ptype, value, trace):
assert self._old_manager
NameManager.current = self._old_manager

NameManager._current.value = self._old_manager

#pylint: disable=no-self-argument
@classproperty
def current(cls):
warnings.warn("NameManager.current has been deprecated. "
"It is advised to use the `with` statement with NameManager.",
DeprecationWarning)
if not hasattr(NameManager._current, "value"):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

cls._current.value = NameManager()
return cls._current.value

@current.setter
def current(cls, val):
cls._current.value = val
#pylint: enable=no-self-argument

class Prefix(NameManager):
"""A name manager that attaches a prefix to all names.
Expand All @@ -92,4 +111,4 @@ def get(self, name, hint):
return self._prefix + name

# initialize the default name manager
NameManager.current = NameManager()
NameManager._current.value = NameManager()
14 changes: 7 additions & 7 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from ..base import c_array, c_array_buf, c_handle_array, mx_real_t
from ..base import mx_uint, NDArrayHandle, check_call
from ..base import ctypes2buffer
from ..context import Context
from ..context import Context, current_context
from . import _internal
from . import op
from ._internal import NDArrayBase
Expand Down Expand Up @@ -2261,7 +2261,7 @@ def ones(shape, ctx=None, dtype=None, **kwargs):
The shape of the empty array.
ctx : Context, optional
An optional device context.
Defaults to the current default context (``mxnet.Context.default_ctx``).
Defaults to the current default context (``mxnet.context.current_context()``).
dtype : str or numpy.dtype, optional
An optional value type (default is `float32`).
out : NDArray, optional
Expand All @@ -2283,7 +2283,7 @@ def ones(shape, ctx=None, dtype=None, **kwargs):
"""
# pylint: disable= unused-argument
if ctx is None:
ctx = Context.default_ctx
ctx = current_context()
dtype = mx_real_t if dtype is None else dtype
# pylint: disable= no-member, protected-access
return _internal._ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
Expand Down Expand Up @@ -2439,7 +2439,7 @@ def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t):
array([2, 2, 2, 4, 4, 4], dtype=int32)
"""
if ctx is None:
ctx = Context.default_ctx
ctx = current_context()
return _internal._arange(start=start, stop=stop, step=step, repeat=repeat,
dtype=dtype, ctx=str(ctx))
# pylint: enable= no-member, protected-access, too-many-arguments
Expand Down Expand Up @@ -3666,7 +3666,7 @@ def zeros(shape, ctx=None, dtype=None, **kwargs):
"""
# pylint: disable= unused-argument
if ctx is None:
ctx = Context.default_ctx
ctx = current_context()
dtype = mx_real_t if dtype is None else dtype
# pylint: disable= no-member, protected-access
return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
Expand Down Expand Up @@ -3705,7 +3705,7 @@ def eye(N, M=0, k=0, ctx=None, dtype=None, **kwargs):
"""
# pylint: disable= unused-argument
if ctx is None:
ctx = Context.default_ctx
ctx = current_context()
dtype = mx_real_t if dtype is None else dtype
# pylint: disable= no-member, protected-access
return _internal._eye(N=N, M=M, k=k, ctx=ctx, dtype=dtype, **kwargs)
Expand Down Expand Up @@ -3733,7 +3733,7 @@ def empty(shape, ctx=None, dtype=None):
if isinstance(shape, int):
shape = (shape, )
if ctx is None:
ctx = Context.default_ctx
ctx = current_context()
if dtype is None:
dtype = mx_real_t
return NDArray(handle=_new_alloc_handle(shape, ctx, False, dtype))
Loading