-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[MXNET-72] Improve sparse sgd on GPU #10293
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -98,6 +98,38 @@ Where the parameter ``momentum`` is the decay rate of momentum estimates at each | |
| .add_argument("mom", "NDArray-or-Symbol", "Momentum") | ||
| .add_arguments(SignumParam::__FIELDS__()); | ||
|
|
||
| template<int req> | ||
| struct SGDMomStdDnsRspDnsKernel<req, cpu> { | ||
| template<typename DType, typename IType, typename RType> | ||
| MSHADOW_XINLINE static void Map(int i, index_t row_length, DType* out_data, | ||
| DType* mom_data, const DType* weight_data, const IType* grad_idx, | ||
| const DType* grad_data, const RType* prefix_sum, const DType clip_gradient, | ||
| const DType momentum, const DType lr, const DType wd, const DType rescale_grad) { | ||
| const DType rate = lr * wd; | ||
| const bool non_zero = (i == 0) ? prefix_sum[0] > 0 | ||
| : prefix_sum[i] > prefix_sum[i-1]; | ||
|
|
||
| const index_t row_i = i * row_length; | ||
| const RType grad_i = (prefix_sum[i]-1) * row_length; | ||
| for (index_t j = 0; j < row_length; j++) { | ||
| const index_t data_i = row_i + j; | ||
| const DType grad = non_zero ? grad_data[grad_i + j] | ||
| : static_cast<DType>(0); | ||
| if (clip_gradient >= 0.0f) { | ||
| mom_data[data_i] = momentum * mom_data[data_i] | ||
| - rate * weight_data[data_i] | ||
| - lr * | ||
| mshadow_op::clip::Map(rescale_grad * grad, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why line break here?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no particular reason. No new line will make it a 200 character line.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean the extra line break for lr * mshadow_op::clip::Map, these two places are inconsistent with what you have on line 52 of optimizer_op.cu below.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't think it's necessary to add/remove that extra line break. Please provide constructive feedbacks/review comments |
||
| clip_gradient); | ||
| } else { | ||
| mom_data[data_i] = momentum * mom_data[data_i] | ||
| - rate * weight_data[data_i] | ||
| - lr * rescale_grad * grad; | ||
| } | ||
| KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] + mom_data[data_i]); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| template<> | ||
| void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param, | ||
|
|
@@ -139,7 +171,7 @@ void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param, | |
| prefix_sum[i] += prefix_sum[i - 1]; | ||
| } | ||
| } | ||
| Kernel<SGDMomStdDnsRspDnsKernel<req_type>, cpu>::Launch(s, num_rows, row_length, | ||
| Kernel<SGDMomStdDnsRspDnsKernel<req_type, cpu>, cpu>::Launch(s, num_rows, row_length, | ||
| out_data, mom_data, weight_data, grad_idx, grad_val, prefix_sum, | ||
| static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum), | ||
| static_cast<DType>(param.lr), static_cast<DType>(param.wd), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why a line break here?