-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Improve sparse pull performance for gluon trainer #11429
Changes from all commits
f0f7bd6
3e6c4c2
62b2c6a
1548d03
762f816
4cafce1
932bf49
1a6dc36
6f38f75
0b9be78
a533c23
7281797
a834826
2622924
46ff1a0
47b143d
3d8d666
a2b1cc9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -69,7 +69,8 @@ def __init__(self, params, optimizer, optimizer_params=None, kvstore='device', | |
| "got %s."%(type(params))) | ||
| self._params = [] | ||
| # parameters to initialize on the kvstore | ||
| self._contains_sparse = False | ||
| self._contains_sparse_weight = False | ||
| self._contains_sparse_grad = False | ||
| self._param2idx = {} | ||
| for i, param in enumerate(params): | ||
| if not isinstance(param, Parameter): | ||
|
|
@@ -80,7 +81,9 @@ def __init__(self, params, optimizer, optimizer_params=None, kvstore='device', | |
| self._params.append(param) | ||
| param._set_trainer(self) | ||
| if param._stype != 'default': | ||
| self._contains_sparse = True | ||
| self._contains_sparse_weight = True | ||
| if param._grad_stype != 'default': | ||
| self._contains_sparse_grad = True | ||
| self._compression_params = compression_params | ||
| optimizer_params = optimizer_params if optimizer_params else {} | ||
| self._scale = float(optimizer_params.get('rescale_grad', 1.0)) | ||
|
|
@@ -153,13 +156,31 @@ def _reset_kvstore(self): | |
| def _init_kvstore(self): | ||
| """Create kvstore.""" | ||
| config = self._kvstore_params | ||
| if self._contains_sparse: | ||
| # if weight is sparse, the weight must be updated on KVStore. | ||
| # training loop contains: | ||
| # - row_sparse_pull(sparse_weight) | ||
| # - forward() | ||
| # - backward() | ||
| # - push(sparse_grad), push(dense_grad) | ||
| # - pull(dense_weight) | ||
| if self._contains_sparse_weight: | ||
| kvstore, update_on_kvstore = _create_sparse_kvstore(config['kvstore']) | ||
| # update_on_kvstore is set to False by the user | ||
| # raise Error if update_on_kvstore is set to False by the user | ||
| if config['update_on_kvstore'] is False: | ||
| raise RuntimeError("Cannot set update_on_kvstore to False when sparse " | ||
| "gradients and/or sparse weights are present for " | ||
| "Parameter '%s'."%param.name) | ||
| raise RuntimeError("Cannot set update_on_kvstore to False when sparse weights " | ||
| "are present.") | ||
| # if weight is dense and grad is sparse, the weight better not be updated on KVStore. | ||
| # training loop contains: | ||
| # - forward() | ||
| # - backward() | ||
| # - push(grad) | ||
| # - pull(grad) | ||
| # - update(grad, weight) | ||
| elif self._contains_sparse_grad: | ||
| arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params} | ||
| kvstore, _ = _create_kvstore(config['kvstore'], len(self._contexts), arg_arrays) | ||
| update_on_kvstore = False | ||
| # normal case | ||
| else: | ||
| arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params} | ||
| kvstore, update_on_kvstore = _create_kvstore(config['kvstore'], len(self._contexts), | ||
|
|
@@ -169,9 +190,9 @@ def _init_kvstore(self): | |
| if kvstore: | ||
| if self._compression_params: | ||
| kvstore.set_gradient_compression(self._compression_params) | ||
| # kv.pull(row_sparse_grad) is not supported | ||
| if 'dist' in kvstore.type and not self._contains_sparse: | ||
| update_on_kvstore = False | ||
| if 'dist' in kvstore.type: | ||
| # kv.pull(row_sparse_grad) is not supported for dist kvstore | ||
| update_on_kvstore = self._contains_sparse_weight or self._contains_sparse_grad | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. from the comment I'm guessing you meant
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is intended. kv.pull(row_sparse_grad) is not supported for dist kvstore, so we want to set update_on_kvstore = True if there's sparse grad. |
||
| if update_on_kvstore: | ||
| # optimizer preferably needs to be set before init for multiprecision | ||
| kvstore.set_optimizer(self._optimizer) | ||
|
|
@@ -211,8 +232,8 @@ def _row_sparse_pull(self, parameter, out, row_id): | |
| self._init_kvstore() | ||
| if self._params_to_init: | ||
| self._init_params() | ||
| self._kvstore.row_sparse_pull(self._param2idx[parameter.name], \ | ||
| out=out, row_ids=row_id) | ||
| idx = self._param2idx[parameter.name] | ||
| self._kvstore.row_sparse_pull(idx, out=out, row_ids=row_id, priority=-idx) | ||
|
|
||
| def step(self, batch_size, ignore_stale_grad=False): | ||
| """Makes one step of parameter update. Should be called after | ||
|
|
@@ -272,7 +293,7 @@ def _allreduce_grads(self): | |
| self._kvstore.push(i, param.list_grad(), priority=-i) | ||
|
|
||
| if not self._update_on_kvstore: | ||
| self._kvstore.pull(i, param.list_grad(), priority=-i) | ||
| self._kvstore.pull(i, param.list_grad(), priority=-i, ignore_sparse=False) | ||
|
|
||
| def update(self, batch_size, ignore_stale_grad=False): | ||
| """Makes one step of parameter update. | ||
|
|
@@ -327,7 +348,7 @@ def _update(self, ignore_stale_grad=False): | |
| if self._kvstore and self._update_on_kvstore: | ||
| if param._stype == 'default': | ||
| # 'row_sparse' parameters are not pulled immediately - they're pulled | ||
| # in `SparseBlock.sparse_forward` | ||
| # in `Block.forward` | ||
| self._kvstore.pull(i, param.list_data(), priority=-i) | ||
| continue | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -235,7 +235,7 @@ def push(self, key, value, priority=0): | |
| self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority))) | ||
|
|
||
|
|
||
| def pull(self, key, out=None, priority=0): | ||
| def pull(self, key, out=None, priority=0, ignore_sparse=True): | ||
| """ Pulls a single value or a sequence of values from the store. | ||
|
|
||
| This function returns immediately after adding an operator to the engine. | ||
|
|
@@ -247,8 +247,8 @@ def pull(self, key, out=None, priority=0): | |
|
|
||
| The returned values are guaranteed to be the latest values in the store. | ||
|
|
||
| For `RowSparseNDArray` values, this call is ignored, | ||
| please use ``row_sparse_pull`` instead. | ||
| pull with `RowSparseNDArray` is not supported for dist kvstore. | ||
| Please use ``row_sparse_pull`` instead. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should ignore_sparse be defaulted to false to be consistent with previous behavior?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Previous behavior is to always ignore sparse. So it's consistent |
||
|
|
||
| Parameters | ||
| ---------- | ||
|
|
@@ -263,6 +263,9 @@ def pull(self, key, out=None, priority=0): | |
| Higher priority pull operations are likely to be executed before | ||
| other pull actions. | ||
|
|
||
| ignore_sparse: bool, optional, default True | ||
| Whether to ignore sparse arrays in the request. | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> # pull a single key-value pair | ||
|
|
@@ -298,11 +301,13 @@ def pull(self, key, out=None, priority=0): | |
| assert(out is not None) | ||
| ckeys, cvals, use_str_keys = _ctype_key_value(key, out) | ||
| if use_str_keys: | ||
| check_call(_LIB.MXKVStorePullEx( | ||
| self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority))) | ||
| check_call(_LIB.MXKVStorePullWithSparseEx(self.handle, mx_uint(len(ckeys)), ckeys, | ||
| cvals, ctypes.c_int(priority), | ||
| ctypes.c_bool(ignore_sparse))) | ||
| else: | ||
| check_call(_LIB.MXKVStorePull( | ||
| self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority))) | ||
| check_call(_LIB.MXKVStorePullWithSparse(self.handle, mx_uint(len(ckeys)), ckeys, | ||
| cvals, ctypes.c_int(priority), | ||
| ctypes.c_bool(ignore_sparse))) | ||
|
|
||
| def row_sparse_pull(self, key, out=None, priority=0, row_ids=None): | ||
| """ Pulls a single RowSparseNDArray value or a sequence of RowSparseNDArray values \ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this have to be a error, or can it be a warning and automatically use update_on_kvstore ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, shouldn't this be outside the if contains_sparse_weight condition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By default update_on_kvstore is None. It's only set if user provides a value on purpose. I think an explicit err is better, since we cannot satisfy user's original intent.
If user set update_on_kvstore to False and the model contains no sparse weight, it's totally fine. Why should this be outside the if condition?