This repository was archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[MXNET-374] handle row_sparse weight in parameter and trainer #11001
Merged
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
2863a1f
+ rsp parameter
eric-haibin-lin e3d20c7
draft
eric-haibin-lin ad672a7
Fix optimizer pickle
eric-haibin-lin 674d374
refactor and document
eric-haibin-lin 6db6e29
add test for save load with cast_stype
eric-haibin-lin 6f0f403
refactor trainer tests
eric-haibin-lin 8db0499
add test
eric-haibin-lin 60d9f16
merge
eric-haibin-lin 83009bc
add back test
eric-haibin-lin cf006c8
raise error for load params
eric-haibin-lin 4e9ab9c
add comment
eric-haibin-lin a991e98
remove print
eric-haibin-lin 468b599
fix doc
eric-haibin-lin 0f70344
CR comments
eric-haibin-lin ff9bf84
CR comments
eric-haibin-lin bee6774
change error
eric-haibin-lin 077b7a5
remove cast stype
eric-haibin-lin 6038fe9
fix test
eric-haibin-lin 70de567
add reset kvstore to trainer
eric-haibin-lin 12a8b59
lint
eric-haibin-lin 2a06884
add test to CI
eric-haibin-lin fbcf15d
Merge remote-tracking branch 'upstream/master' into sparse-block
eric-haibin-lin 01b3e4d
add more checks
eric-haibin-lin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -81,6 +81,8 @@ class Parameter(object): | |
| Weight decay multiplier (L2 regularizer coefficient). Works similar to lr_mult. | ||
| init : Initializer, default None | ||
| Initializer of this parameter. Will use the global initializer by default. | ||
| stype: {'default', 'row_sparse', 'csr'}, defaults to 'default'. | ||
| The storage type of the parameter. | ||
| grad_stype: {'default', 'row_sparse', 'csr'}, defaults to 'default'. | ||
| The storage type of the parameter's gradient. | ||
|
|
||
|
|
@@ -99,12 +101,13 @@ class Parameter(object): | |
| """ | ||
| def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t, | ||
| lr_mult=1.0, wd_mult=1.0, init=None, allow_deferred_init=False, | ||
| differentiable=True, grad_stype='default'): | ||
| differentiable=True, stype='default', grad_stype='default'): | ||
| self._var = None | ||
| self._data = None | ||
| self._grad = None | ||
| self._ctx_list = None | ||
| self._ctx_map = None | ||
| self._trainer = None | ||
| self._deferred_init = () | ||
| self._differentiable = differentiable | ||
| self._allow_deferred_init = allow_deferred_init | ||
|
|
@@ -116,10 +119,14 @@ def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t, | |
| self.wd_mult = wd_mult | ||
| self.grad_req = grad_req | ||
| self.init = init | ||
| assert grad_stype in ['default', 'row_sparse', 'csr'], \ | ||
| "grad_stype for Parameter '%s' must be one of 'default', 'row_sparse', or 'csr'," \ | ||
| " but got '%s'" % (name, grad_stype) | ||
| # sparse related storage type information | ||
| valid_stypes = ['default', 'row_sparse', 'csr'] | ||
| assert grad_stype in valid_stypes, "grad_stype for Parameter '%s' must be " \ | ||
| "one of 'default', 'row_sparse', or 'csr', but got '%s'" % (name, grad_stype) | ||
| assert stype in valid_stypes, "stype for Parameter '%s' must be " \ | ||
| "one of 'default', 'row_sparse', or 'csr', but got '%s'" % (name, stype) | ||
| self._grad_stype = grad_stype | ||
| self._stype = stype | ||
|
|
||
|
|
||
| def __repr__(self): | ||
|
|
@@ -162,6 +169,16 @@ def shape(self, new_shape): | |
|
|
||
| self._shape = new_shape | ||
|
|
||
| def _set_trainer(self, trainer): | ||
| """ Set the trainer this parameter is associated with. """ | ||
| # trainer cannot be replaced for sparse params | ||
| if self._stype != 'default' and self._trainer and trainer and self._trainer is not trainer: | ||
| raise RuntimeError( | ||
| "Failed to set the trainer for Parameter '%s' because it was already set. " \ | ||
| "More than one trainers for a %s Parameter is not supported." \ | ||
| %(self.name, self._stype)) | ||
| self._trainer = trainer | ||
|
|
||
| def _check_and_get(self, arr_list, ctx): | ||
| if arr_list is not None: | ||
| if ctx is list: | ||
|
|
@@ -194,6 +211,20 @@ def _check_and_get(self, arr_list, ctx): | |
| "because the later does not include Parameters of " \ | ||
| "nested child Blocks"%(self.name)) | ||
|
|
||
| def _get_row_sparse(self, arr_list, ctx, row_id): | ||
| """ Get row_sparse data from row_sparse parameters based on row_id. """ | ||
| # get row sparse params based on row ids | ||
| if not isinstance(row_id, ndarray.NDArray): | ||
| raise TypeError("row_id must have NDArray type, but %s is given"%(type(row_id))) | ||
| if not self._trainer: | ||
| raise RuntimeError("Cannot get row_sparse data for Parameter '%s' when no " \ | ||
| "Trainer is created with it."%self.name) | ||
| results = self._check_and_get(arr_list, ctx) | ||
|
|
||
| # fetch row sparse params from the trainer | ||
| self._trainer._row_sparse_pull(self, results, row_id) | ||
| return results | ||
|
|
||
| def _load_init(self, data, ctx): | ||
| """(Re)initializes by loading from data.""" | ||
| if self.shape: | ||
|
|
@@ -208,6 +239,8 @@ def _load_init(self, data, ctx): | |
| "Failed loading Parameter '%s' from saved params: " \ | ||
| "dtype incompatible expected %s vs saved %s"%( | ||
| self.name, str(self.dtype), str(data.dtype)) | ||
| if self._stype != data.stype: | ||
| data = data.tostype(self._stype) | ||
| if isinstance(ctx, Context): | ||
| ctx = [ctx] | ||
| if self._data is None: | ||
|
|
@@ -243,7 +276,7 @@ def _finish_deferred_init(self): | |
| with autograd.pause(): | ||
| if data is None: | ||
| data = ndarray.zeros(shape=self.shape, dtype=self.dtype, | ||
| ctx=context.cpu()) | ||
| ctx=context.cpu(), stype=self._stype) | ||
| initializer.create(default_init)( | ||
| initializer.InitDesc(self.name, {'__init__': init}), data) | ||
|
|
||
|
|
@@ -271,12 +304,18 @@ def _init_grad(self): | |
| self._grad = [ndarray.zeros(shape=i.shape, dtype=i.dtype, ctx=i.context, | ||
| stype=self._grad_stype) for i in self._data] | ||
|
|
||
| autograd.mark_variables(self.list_data(), self.list_grad(), self.grad_req) | ||
| autograd.mark_variables(self._check_and_get(self._data, list), | ||
| self._grad, self.grad_req) | ||
|
|
||
| def _reduce(self): | ||
| """Reduce data from multiple context.""" | ||
| block = self.list_data() | ||
| data = ndarray.add_n(*(w.copyto(context.cpu()) for w in block)) / len(block) | ||
| if self._stype == 'default': | ||
| block = self.list_data() | ||
| data = ndarray.add_n(*(w.copyto(context.cpu()) for w in block)) / len(block) | ||
| else: | ||
| # fetch all rows for 'row_sparse' param | ||
| all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=context.cpu()) | ||
| data = self.row_sparse_data(all_row_ids) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to have row_sparse but update_on_kvstore=false?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently when gluon sees rowsparse weight, it always creates a kvstore and set update to kvstore=True. |
||
| return data | ||
|
|
||
| def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(), | ||
|
|
@@ -380,12 +419,58 @@ def set_data(self, data): | |
| self._deferred_init = self._deferred_init[:3] + (data,) | ||
| return | ||
|
|
||
| for arr in self.list_data(): | ||
| # if update_on_kvstore, we need to make sure the copy stored in kvstore is in sync | ||
| if self._trainer and self._trainer._kv_initialized and self._trainer._update_on_kvstore: | ||
| if self not in self._trainer._params_to_init: | ||
| self._trainer._reset_kvstore() | ||
|
|
||
| for arr in self._check_and_get(self._data, list): | ||
| arr[:] = data | ||
|
|
||
| def row_sparse_data(self, row_id): | ||
| """Returns a copy of the 'row_sparse' parameter on the same context as row_id's. | ||
| The copy only retains rows whose ids occur in provided row ids. | ||
| The parameter must have been initialized on this context before. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| row_id: NDArray | ||
| Row ids to retain for the 'row_sparse' parameter. | ||
|
|
||
| Returns | ||
| ------- | ||
| NDArray on row_id's context | ||
| """ | ||
| if self._stype != 'row_sparse': | ||
| raise RuntimeError("Cannot return a copy of Parameter %s via row_sparse_data() " \ | ||
| "because its storage type is %s. Please use data() instead." \ | ||
| %(self.name, self._stype)) | ||
| return self._get_row_sparse(self._data, row_id.context, row_id) | ||
|
|
||
| def list_row_sparse_data(self, row_id): | ||
| """Returns copies of the 'row_sparse' parameter on all contexts, in the same order | ||
| as creation. The copy only retains rows whose ids occur in provided row ids. | ||
| The parameter must have been initialized before. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| row_id: NDArray | ||
| Row ids to retain for the 'row_sparse' parameter. | ||
|
|
||
| Returns | ||
| ------- | ||
| list of NDArrays | ||
| """ | ||
| if self._stype != 'row_sparse': | ||
| raise RuntimeError("Cannot return copies of Parameter '%s' on all contexts via " \ | ||
| "list_row_sparse_data() because its storage type is %s. Please " \ | ||
| "use data() instead." % (self.name, self._stype)) | ||
| return self._get_row_sparse(self._data, list, row_id) | ||
|
|
||
| def data(self, ctx=None): | ||
| """Returns a copy of this parameter on one context. Must have been | ||
| initialized on this context before. | ||
| initialized on this context before. For sparse parameters, use | ||
| :py:meth:`Parameter.row_sparse_data` instead. | ||
|
|
||
| Parameters | ||
| ---------- | ||
|
|
@@ -396,11 +481,25 @@ def data(self, ctx=None): | |
| ------- | ||
| NDArray on ctx | ||
| """ | ||
| if self._stype != 'default': | ||
| raise RuntimeError("Cannot return a copy of Parameter '%s' on ctx %s via data() " \ | ||
| "because its storage type is %s. Please use row_sparse_data() " \ | ||
| "instead." % (self.name, str(ctx), self._stype)) | ||
| return self._check_and_get(self._data, ctx) | ||
|
|
||
| def list_data(self): | ||
| """Returns copies of this parameter on all contexts, in the same order | ||
| as creation.""" | ||
| as creation. For sparse parameters, use :py:meth:`Parameter.list_row_sparse_data` | ||
| instead. | ||
|
|
||
| Returns | ||
| ------- | ||
| list of NDArrays | ||
| """ | ||
| if self._stype != 'default': | ||
| raise RuntimeError("Cannot return copies of Parameter '%s' on all contexts via " \ | ||
| "list_data() because its storage type is %s. Please use " \ | ||
| "row_sparse_data() instead." % (self.name, self._stype)) | ||
| return self._check_and_get(self._data, list) | ||
|
|
||
| def grad(self, ctx=None): | ||
|
|
@@ -447,7 +546,7 @@ def var(self): | |
| if self._var is None: | ||
| self._var = symbol.var(self.name, shape=self.shape, dtype=self.dtype, | ||
| lr_mult=self.lr_mult, wd_mult=self.wd_mult, | ||
| init=self.init) | ||
| init=self.init, stype=self._stype) | ||
| return self._var | ||
|
|
||
| def cast(self, dtype): | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might as well make it a set.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only has 3 elements. I don't think this makes any real difference