diff --git a/src/operator/l2_normalization-inl.h b/src/operator/l2_normalization-inl.h index c7e71424ada9..9dc23c051244 100644 --- a/src/operator/l2_normalization-inl.h +++ b/src/operator/l2_normalization-inl.h @@ -34,6 +34,7 @@ #include #include "./operator_common.h" #include "./mshadow_op.h" +#include "./tensor/broadcast_reduce_op.h" namespace mxnet { namespace op { @@ -87,6 +88,10 @@ class L2NormalizationOp : public Operator { Stream *s = ctx.get_stream(); TShape orig_shape = in_data[l2_normalization::kData].shape_; if (param_.mode == l2_normalization::kInstance) { + TShape small = out_data[1].shape_; + ReduceAxesComputeImpl(ctx, in_data, req, + { out_data[l2_normalization::kNorm] }, small); Shape<2> dshape = Shape2(orig_shape[0], orig_shape.ProdShape(1, orig_shape.ndim())); Tensor data = in_data[l2_normalization::kData] @@ -94,15 +99,13 @@ class L2NormalizationOp : public Operator { Tensor out = out_data[l2_normalization::kOut] .get_with_shape(dshape, s); Tensor norm = out_data[l2_normalization::kNorm].get(s); - norm = sumall_except_dim<0>(F(data)); - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - mxnet_op::Kernel, xpu>::Launch( - s, norm.size(0), norm.dptr_, norm.dptr_, DType(param_.eps)); - }); - norm = F(norm); - out = data / broadcast<0>(norm, out.shape_); + out = data / mshadow::expr::broadcast<0>(norm, out.shape_); } else if (param_.mode == l2_normalization::kChannel) { CHECK_GE(orig_shape.ndim(), 3U); + TShape small = out_data[1].shape_; + ReduceAxesComputeImpl(ctx, in_data, req, + { out_data[l2_normalization::kNorm] }, small); Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1], orig_shape.ProdShape(2, orig_shape.ndim())); Tensor data = in_data[l2_normalization::kData] @@ -112,15 +115,13 @@ class L2NormalizationOp : public Operator { Shape<2> norm_shape = Shape2(dshape[0], dshape[2]); Tensor norm = out_data[l2_normalization::kNorm] .get_with_shape(norm_shape, s); - norm = reduce_with_axis(F(data), 1); - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - mxnet_op::Kernel, xpu>::Launch( - s, norm.size(0) * norm.size(1), norm.dptr_, norm.dptr_, DType(param_.eps)); - }); - norm = F(norm); out = data / broadcast_with_axis(norm, 0, orig_shape[1]); } else if (param_.mode == l2_normalization::kSpatial) { CHECK_GE(orig_shape.ndim(), 3U); + TShape small = out_data[1].shape_; + ReduceAxesComputeImpl(ctx, in_data, req, + { out_data[l2_normalization::kNorm] }, small); Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1], orig_shape.ProdShape(2, orig_shape.ndim())); Tensor data = in_data[l2_normalization::kData] @@ -130,12 +131,6 @@ class L2NormalizationOp : public Operator { Shape<2> norm_shape = Shape2(dshape[0], dshape[1]); Tensor norm = out_data[l2_normalization::kNorm] .get_with_shape(norm_shape, s); - norm = reduce_with_axis(F(data), 2); - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - mxnet_op::Kernel, xpu>::Launch( - s, norm.size(0) * norm.size(1), norm.dptr_, norm.dptr_, DType(param_.eps)); - }); - norm = F(norm); out = data / broadcast_with_axis(norm, 1, dshape[2]); } else { LOG(FATAL) << "Unexpected mode in l2 normalization"; @@ -171,8 +166,8 @@ class L2NormalizationOp : public Operator { .get_space_typed(mshadow::Shape1(data.shape_[0]), s); temp = sumall_except_dim<0>(grad_out * data); Assign(grad_in, req[l2_normalization::kData], - (grad_out - data * broadcast<0>(temp, data.shape_)) / - broadcast<0>(norm, data.shape_)); + (grad_out - data * mshadow::expr::broadcast<0>(temp, data.shape_)) / + mshadow::expr::broadcast<0>(norm, data.shape_)); } else if (param_.mode == l2_normalization::kChannel) { CHECK_GE(orig_shape.ndim(), 3U); Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1], @@ -314,6 +309,11 @@ class L2NormalizationProp : public OperatorProperty { return {{out_grad[l2_normalization::kOut], in_grad[l2_normalization::kData]}}; } + std::vector ForwardResource( + const std::vector &in_shape) const override { + return{ ResourceRequest::kTempSpace }; + } + std::vector BackwardResource( const std::vector &in_shape) const override { return {ResourceRequest::kTempSpace};