diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 104f20a61eeb..c611a2b745d8 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -1528,8 +1528,11 @@ inline bool AdagradStorageType(const nnvm::NodeAttrs& attrs, return dispatched; } +template +struct AdagradDnsRspDnsKernel; -struct AdagradDnsRspDnsKernel { +template<> +struct AdagradDnsRspDnsKernel { template MSHADOW_XINLINE static void Map(int i, index_t row_length, DType* out_data, DType* state_data, const DType* weight_data, const IType* grad_idx, @@ -1555,6 +1558,30 @@ struct AdagradDnsRspDnsKernel { } }; +template<> +struct AdagradDnsRspDnsKernel { + template + MSHADOW_XINLINE static void Map(int i, index_t row_length, DType* out_data, + DType* state_data, const DType* weight_data, const IType* grad_idx, + const DType* grad_data, const DType clip_gradient, const DType epsilon, + const DType lr, const DType rescale_grad) { + using nnvm::dim_t; + using namespace mshadow_op; + const dim_t row_id = i / row_length; + const dim_t col_id = i % row_length; + const dim_t data_i = grad_idx[row_id] * row_length + col_id; + DType grad_rescaled = grad_data[i] * rescale_grad; + if (clip_gradient >= 0.0f) { + grad_rescaled = clip::Map(grad_rescaled, clip_gradient); + } + const DType grad_squared = grad_rescaled * grad_rescaled; + state_data[data_i] += grad_squared; + const DType div = grad_rescaled / square_root::Map(state_data[data_i] + epsilon); + // No need to use KERNEL_ASSIGN, as we already checked req is kWriteInplace + out_data[data_i] = weight_data[data_i] - div * lr; + } +}; + template void AdagradUpdateDnsRspDnsImpl(const AdagradParam& param, const OpContext& ctx, @@ -1582,7 +1609,11 @@ void AdagradUpdateDnsRspDnsImpl(const AdagradParam& param, DType* out_data = out->dptr(); const nnvm::dim_t nnr = grad.storage_shape()[0]; const auto row_length = weight.shape_.ProdShape(1, weight.ndim()); - Kernel::Launch(s, nnr, row_length, + size_t num_threads = nnr; + if (std::is_same::value) { + num_threads = nnr * row_length; + } + Kernel, xpu>::Launch(s, num_threads, row_length, out_data, state_data, weight_data, grad_idx, grad_val, static_cast(param.clip_gradient), static_cast(param.epsilon), static_cast(param.lr), static_cast(param.rescale_grad));