-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Change class variables to thread local variables #10833
Changes from all commits
1a8dda6
9c5c903
d4cc2af
4623ce5
d8b1b2f
40fc6bb
6b79d89
df057e2
3d6501f
b4996ac
57af178
d12e4a7
1a91314
c31cf59
0af608b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,10 +18,12 @@ | |
| # coding: utf-8 | ||
| """Attribute scoping support for symbolic API.""" | ||
| from __future__ import absolute_import | ||
| import threading | ||
| import warnings | ||
|
|
||
| from .base import string_types | ||
| from .base import string_types, classproperty, with_metaclass, _MXClassPropertyMetaClass | ||
|
|
||
| class AttrScope(object): | ||
| class AttrScope(with_metaclass(_MXClassPropertyMetaClass, object)): | ||
| """Attribute manager for scoping. | ||
|
|
||
| User can also inherit this object to change naming behavior. | ||
|
|
@@ -31,7 +33,7 @@ class AttrScope(object): | |
| kwargs | ||
| The attributes to set for all symbol creations in the scope. | ||
| """ | ||
| current = None | ||
| _current = threading.local() | ||
|
|
||
| def __init__(self, **kwargs): | ||
| self._old_scope = None | ||
|
|
@@ -64,15 +66,35 @@ def get(self, attr): | |
|
|
||
| def __enter__(self): | ||
| # pylint: disable=protected-access | ||
| self._old_scope = AttrScope.current | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. self._old_scope = AttrScope.current.value |
||
| attr = AttrScope.current._attr.copy() | ||
| if not hasattr(AttrScope._current, "value"): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same |
||
| AttrScope._current.value = AttrScope() | ||
| self._old_scope = AttrScope._current.value | ||
| attr = AttrScope._current.value._attr.copy() | ||
| attr.update(self._attr) | ||
| self._attr = attr | ||
| AttrScope.current = self | ||
| AttrScope._current.value = self | ||
| return self | ||
|
|
||
| def __exit__(self, ptype, value, trace): | ||
| assert self._old_scope | ||
| AttrScope.current = self._old_scope | ||
| AttrScope._current.value = self._old_scope | ||
|
|
||
| AttrScope.current = AttrScope() | ||
| #pylint: disable=no-self-argument | ||
| @classproperty | ||
| def current(cls): | ||
| warnings.warn("AttrScope.current has been deprecated. " | ||
| "It is advised to use the `with` statement with AttrScope.", | ||
| DeprecationWarning) | ||
| if not hasattr(AttrScope._current, "value"): | ||
| cls._current.value = AttrScope() | ||
| return cls._current.value | ||
|
|
||
| @current.setter | ||
| def current(cls, val): | ||
| warnings.warn("AttrScope.current has been deprecated. " | ||
| "It is advised to use the `with` statement with AttrScope.", | ||
| DeprecationWarning) | ||
| cls._current.value = val | ||
| #pylint: enable=no-self-argument | ||
|
|
||
| AttrScope._current.value = AttrScope() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,8 +18,11 @@ | |
| # coding: utf-8 | ||
| """Context management API of mxnet.""" | ||
| from __future__ import absolute_import | ||
| import threading | ||
| import warnings | ||
| from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass | ||
|
|
||
| class Context(object): | ||
| class Context(with_metaclass(_MXClassPropertyMetaClass, object)): | ||
| """Constructs a context. | ||
|
|
||
| MXNet can run operations on CPU and different GPUs. | ||
|
|
@@ -61,7 +64,7 @@ class Context(object): | |
| gpu(1) | ||
| """ | ||
| # static class variable | ||
| default_ctx = None | ||
| _default_ctx = threading.local() | ||
| devtype2str = {1: 'cpu', 2: 'gpu', 3: 'cpu_pinned', 5: 'cpu_shared'} | ||
| devstr2type = {'cpu': 1, 'gpu': 2, 'cpu_pinned': 3, 'cpu_shared': 5} | ||
| def __init__(self, device_type, device_id=0): | ||
|
|
@@ -109,15 +112,37 @@ def __repr__(self): | |
| return self.__str__() | ||
|
|
||
| def __enter__(self): | ||
| self._old_ctx = Context.default_ctx | ||
| Context.default_ctx = self | ||
| if not hasattr(Context._default_ctx, "value"): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do you need this? Isn't it set with https://github.com/apache/incubator-mxnet/pull/10833/files#diff-1a5e06031378f44204fd1da1fffc0b07R145
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what about a situation like below: Without hasattr this doesn't work. |
||
| Context._default_ctx.value = Context('cpu', 0) | ||
| self._old_ctx = Context._default_ctx.value | ||
| Context._default_ctx.value = self | ||
| return self | ||
|
|
||
| def __exit__(self, ptype, value, trace): | ||
| Context.default_ctx = self._old_ctx | ||
| Context._default_ctx.value = self._old_ctx | ||
|
|
||
| #pylint: disable=no-self-argument | ||
| @classproperty | ||
| def default_ctx(cls): | ||
| warnings.warn("Context.default_ctx has been deprecated. " | ||
| "Please use Context.current_context() instead. " | ||
| "Please use test_utils.set_default_context to set a default context", | ||
| DeprecationWarning) | ||
| if not hasattr(Context._default_ctx, "value"): | ||
| cls._default_ctx.value = Context('cpu', 0) | ||
| return cls._default_ctx.value | ||
|
|
||
| @default_ctx.setter | ||
| def default_ctx(cls, val): | ||
| warnings.warn("Context.default_ctx has been deprecated. " | ||
| "Please use Context.current_context() instead. " | ||
| "Please use test_utils.set_default_context to set a default context", | ||
| DeprecationWarning) | ||
| cls._default_ctx.value = val | ||
| #pylint: enable=no-self-argument | ||
|
|
||
| # initialize the default context in Context | ||
| Context.default_ctx = Context('cpu', 0) | ||
| Context._default_ctx.value = Context('cpu', 0) | ||
|
|
||
|
|
||
| def cpu(device_id=0): | ||
|
|
@@ -234,4 +259,6 @@ def current_context(): | |
| ------- | ||
| default_ctx : Context | ||
| """ | ||
| return Context.default_ctx | ||
| if not hasattr(Context._default_ctx, "value"): | ||
| Context._default_ctx.value = Context('cpu', 0) | ||
| return Context._default_ctx.value | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |
| """Base container class for all neural network models.""" | ||
| __all__ = ['Block', 'HybridBlock', 'SymbolBlock'] | ||
|
|
||
| import threading | ||
| import copy | ||
| import warnings | ||
| import re | ||
|
|
@@ -35,7 +36,7 @@ | |
|
|
||
| class _BlockScope(object): | ||
| """Scope for collecting child `Block` s.""" | ||
| _current = None | ||
| _current = threading.local() | ||
|
|
||
| def __init__(self, block): | ||
| self._block = block | ||
|
|
@@ -46,10 +47,10 @@ def __init__(self, block): | |
| @staticmethod | ||
| def create(prefix, params, hint): | ||
| """Creates prefix and params for new `Block`.""" | ||
| current = _BlockScope._current | ||
| current = getattr(_BlockScope._current, "value", None) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. initialize _BlockScope._current.value to None globally and use _BlockScope._current.value directly here? |
||
| if current is None: | ||
| if prefix is None: | ||
| prefix = _name.NameManager.current.get(None, hint) + '_' | ||
| prefix = _name.NameManager._current.value.get(None, hint) + '_' | ||
| if params is None: | ||
| params = ParameterDict(prefix) | ||
| else: | ||
|
|
@@ -70,8 +71,8 @@ def create(prefix, params, hint): | |
| def __enter__(self): | ||
| if self._block._empty_prefix: | ||
| return self | ||
| self._old_scope = _BlockScope._current | ||
| _BlockScope._current = self | ||
| self._old_scope = getattr(_BlockScope._current, "value", None) | ||
| _BlockScope._current.value = self | ||
| self._name_scope = _name.Prefix(self._block.prefix) | ||
| self._name_scope.__enter__() | ||
| return self | ||
|
|
@@ -81,7 +82,7 @@ def __exit__(self, ptype, value, trace): | |
| return | ||
| self._name_scope.__exit__(ptype, value, trace) | ||
| self._name_scope = None | ||
| _BlockScope._current = self._old_scope | ||
| _BlockScope._current.value = self._old_scope | ||
|
|
||
|
|
||
| def _flatten(args, inout_str): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,13 +18,16 @@ | |
| # coding: utf-8 | ||
| """Automatic naming support for symbolic API.""" | ||
| from __future__ import absolute_import | ||
| import threading | ||
| import warnings | ||
| from .base import classproperty, with_metaclass, _MXClassPropertyMetaClass | ||
|
|
||
| class NameManager(object): | ||
| class NameManager(with_metaclass(_MXClassPropertyMetaClass, object)): | ||
| """NameManager to do automatic naming. | ||
|
|
||
| Developers can also inherit from this class to change naming behavior. | ||
| """ | ||
| current = None | ||
| _current = threading.local() | ||
|
|
||
| def __init__(self): | ||
| self._counter = {} | ||
|
|
@@ -62,14 +65,30 @@ def get(self, name, hint): | |
| return name | ||
|
|
||
| def __enter__(self): | ||
| self._old_manager = NameManager.current | ||
| NameManager.current = self | ||
| if not hasattr(NameManager._current, "value"): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same |
||
| NameManager._current.value = NameManager() | ||
| self._old_manager = NameManager._current.value | ||
| NameManager._current.value = self | ||
| return self | ||
|
|
||
| def __exit__(self, ptype, value, trace): | ||
| assert self._old_manager | ||
| NameManager.current = self._old_manager | ||
|
|
||
| NameManager._current.value = self._old_manager | ||
|
|
||
| #pylint: disable=no-self-argument | ||
| @classproperty | ||
| def current(cls): | ||
| warnings.warn("NameManager.current has been deprecated. " | ||
| "It is advised to use the `with` statement with NameManager.", | ||
| DeprecationWarning) | ||
| if not hasattr(NameManager._current, "value"): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same |
||
| cls._current.value = NameManager() | ||
| return cls._current.value | ||
|
|
||
| @current.setter | ||
| def current(cls, val): | ||
| cls._current.value = val | ||
| #pylint: enable=no-self-argument | ||
|
|
||
| class Prefix(NameManager): | ||
| """A name manager that attaches a prefix to all names. | ||
|
|
@@ -92,4 +111,4 @@ def get(self, name, hint): | |
| return self._prefix + name | ||
|
|
||
| # initialize the default name manager | ||
| NameManager.current = NameManager() | ||
| NameManager._current.value = NameManager() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_current = threading.local()