diff --git a/src/operator/numpy/np_matmul_op-inl.h b/src/operator/numpy/np_matmul_op-inl.h index 89560f64d8c0..8f1b4f9f3a30 100644 --- a/src/operator/numpy/np_matmul_op-inl.h +++ b/src/operator/numpy/np_matmul_op-inl.h @@ -138,6 +138,8 @@ inline void MatmulImpl(const OpContext& ctx, mshadow::Tensor workspace; mshadow::Tensor ans, mlhs, mrhs; mshadow::Stream *s = ctx.get_stream(); + bool isCPU = std::is_same::value; + // Is true if either a or b requires broadcast or not if (MatmulNeedBroadcast(a_shape, b_shape)) { // e.g. a.shape = (2, 3, 1, 4, 2) // b.shape = (5, 2, 4) @@ -157,12 +159,38 @@ inline void MatmulImpl(const OpContext& ctx, DType* bc_b_ptr = bc_a_ptr + bc_size_a; MSHADOW_TYPE_SWITCH_WITH_BOOL(input_a.type_flag_, IType, { MSHADOW_TYPE_SWITCH_WITH_BOOL(input_b.type_flag_, OType, { - Kernel, xpu>::Launch( - s, bc_size_a, input_a.dptr(), bc_a_ptr, - k_a_shape, k_a_shape_bc, OpReqType::kWriteTo, ndim); - Kernel, xpu>::Launch( - s, bc_size_b, input_b.dptr(), bc_b_ptr, - k_b_shape, k_b_shape_bc, OpReqType::kWriteTo, ndim); + struct ShapeAndStride aux_data_a, aux_data_b; + PrepareAUXData(&aux_data_a, k_a_shape, k_a_shape_bc, ndim); + PrepareAUXData(&aux_data_b, k_b_shape, k_b_shape_bc, ndim); + if (isCPU) { + if (!aux_data_a.shape_changed) { + Kernel, xpu>::Launch( + s, bc_size_a, input_a.dptr(), bc_a_ptr, OpReqType::kWriteTo); + Kernel, xpu>::Launch( + s, input_b.Size(), input_b.dptr(), bc_b_ptr, + aux_data_b, OpReqType::kWriteTo, ndim); + } else if (!aux_data_b.shape_changed) { + Kernel, xpu>::Launch( + s, bc_size_b, input_b.dptr(), bc_b_ptr, OpReqType::kWriteTo); + Kernel, xpu>::Launch( + s, input_a.Size(), input_a.dptr(), bc_a_ptr, + aux_data_a, OpReqType::kWriteTo, ndim); + } else { + Kernel, xpu>::Launch( + s, input_a.Size(), input_a.dptr(), bc_a_ptr, + aux_data_a, OpReqType::kWriteTo, ndim); + Kernel, xpu>::Launch( + s, input_b.Size(), input_b.dptr(), bc_b_ptr, + aux_data_b, OpReqType::kWriteTo, ndim); + } + } else { + Kernel, xpu>::Launch( + s, bc_size_a, input_a.dptr(), bc_a_ptr, + aux_data_a, OpReqType::kWriteTo, ndim); + Kernel, xpu>::Launch( + s, bc_size_b, input_b.dptr(), bc_b_ptr, + aux_data_b, OpReqType::kWriteTo, ndim); + } }); }); ans = mshadow::Tensor(output.dptr(), diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index 5eb0c41aa36c..82b4f7d1f43a 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -25,6 +25,7 @@ #ifndef MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_OP_H_ #define MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_OP_H_ +#include #include #include #include @@ -1037,34 +1038,182 @@ void ReduceAxesBackwardUseInOut(const nnvm::NodeAttrs& attrs, ReduceAxesBackwardUseInOutImpl(ctx, small, inputs, req, outputs); } +namespace { // unnamed namespace to keep scope of the struct within the file +struct ShapeAndStride { + index_t in_stride[MXNET_SPECIAL_MAX_NDIM]; + index_t out_stride[MXNET_SPECIAL_MAX_NDIM]; + index_t input_shape[MXNET_SPECIAL_MAX_NDIM]; + index_t output_shape[MXNET_SPECIAL_MAX_NDIM]; + // axes: stores which axes in input is to broadcasted + index_t axes[MXNET_SPECIAL_MAX_NDIM]; + int num_broadcast_axes = -1; + bool shape_changed = false; +}; +} // unnamed namespace + +/*! + * \brief Calculates Stride of input and output tensor dimesnions + And saves mshadow::Shape data in an integer array for + faster access. + * \param *aux_data to hold stride and shape data. + * \param in_shape input shape + * \param out_shape output shape + * \param ndim no of dimensions in output + */ +inline void PrepareAUXData(ShapeAndStride *aux_data, + mshadow::Shape in_shape, + mshadow::Shape out_shape, + int ndim) { + int iter = ndim - 1, i = 0; + aux_data->out_stride[iter] = 1; + aux_data->in_stride[iter] = 1; + aux_data->input_shape[iter] = in_shape[iter]; + aux_data->output_shape[iter] = out_shape[iter]; + if (in_shape[iter] != out_shape[iter]) { + aux_data->axes[i++] = iter; + aux_data->shape_changed = true; + } + iter--; + for (; iter >= 0; --iter) { + aux_data->out_stride[iter] = aux_data->out_stride[iter + 1] * out_shape[iter + 1]; + aux_data->in_stride[iter] = aux_data->in_stride[iter + 1] * in_shape[iter + 1]; + aux_data->input_shape[iter] = in_shape[iter]; + aux_data->output_shape[iter] = out_shape[iter]; + if (in_shape[iter] != out_shape[iter]) { + aux_data->axes[i++] = iter; + aux_data->shape_changed = true; + } + } + aux_data->num_broadcast_axes = i; + assert(aux_data->num_broadcast_axes > -1 && aux_data->num_broadcast_axes < 4); +} + template -struct broadcast_kernel { +struct broadcast_kernel_gpu { template MSHADOW_XINLINE static void Map(index_t i, IType *input, OType *output, - mshadow::Shape in_shape, - mshadow::Shape out_shape, + const ShapeAndStride& aux_data, const OpReqType req, - const uint32_t ndim) { - size_t in_stride = 1; - size_t out_stride = 1; + const int ndim) { index_t idx = i; index_t in_idx = i; +#pragma unroll 4 for (int iter = ndim - 1; iter >= 0; --iter) { - size_t dim_idx = idx % out_shape[iter]; - in_idx -= dim_idx * out_stride; - if (in_shape[iter] != 1) { - in_idx += dim_idx * in_stride; + index_t out_dim_shape = aux_data.output_shape[iter]; + index_t out_dim_stride = aux_data.out_stride[iter]; + // x % y = x - (x / y) * y + // speeds up modulo(%) operation in GPU + index_t dim_idx = idx - (idx / out_dim_shape) * out_dim_shape; + if (aux_data.input_shape[iter] != 1) { + in_idx += dim_idx * (aux_data.in_stride[iter] - out_dim_stride); + } else { + in_idx -= dim_idx * out_dim_stride; } - idx /= out_shape[iter]; - in_stride *= in_shape[iter]; - out_stride *= out_shape[iter]; + idx /= out_dim_shape; } KERNEL_ASSIGN(output[i], req, OP::Map(input[in_idx])); } }; +/** + * Changed the thread workload mapping from 1 + * thread/output element to 1 thread/input to be broadcasted + * This approach leverages vectorization when fastest varying + * index(stride=1) of the tensor is to be broadcasted. + * In other cases it simply performs better by better load balancing. + */ +template +struct broadcast_kernel_cpu { + template + MSHADOW_XINLINE static void Map(index_t i, + IType *input, + OType *output, + const ShapeAndStride& aux_data, + const OpReqType req, + const int ndim) { + index_t idx = i; + index_t init_off = 0; + for (int iter = ndim - 1; idx > 0 && iter >= 0; --iter) { + size_t dim_idx = idx % aux_data.input_shape[iter]; + init_off += dim_idx * aux_data.out_stride[iter]; + idx /= aux_data.input_shape[iter]; + } + index_t stride_0, stride_1, stride_2; + // Each case is based on the number of axis to be broadcasted + // (1, 2 or 3) after merging axes. + switch (aux_data.num_broadcast_axes) { + // when input shape is one of the following forms + // (x_1,1) or (x_1,1,x_2) or (1,x_1) + // x_1, x_2 are size of the dimensions that are not to be broadcasted + // in case of (x_1,1) the system leverages vectorization but in other 2 + // the performance is improved due avoidance of duplicate stride calculations + // for each output location input[i] needs to be written to. + case 1 : + stride_0 = aux_data.out_stride[aux_data.axes[0]]; + for (index_t l = 0; l < aux_data.output_shape[aux_data.axes[0]]; l++) { + KERNEL_ASSIGN(output[init_off + l * stride_0], + req, OP::Map(input[i])); + } + break; + // when input shape is one of the follwing forms + // (x_1,1,x_2,1) or (1,x_1,1,x_2) or (x_1,1,x_2,1,x_3) + // x_1, x_2, x_3 are size of the dimensions that are not to be broadcasted + // in the inner most loop can be vectorized by compiler in outer loops + // the performance is improved due avoidance of duplicate stride calculations + // for each output location input[i] needs to be written to. + case 2: + stride_1 = aux_data.out_stride[aux_data.axes[1]]; + stride_0 = aux_data.out_stride[aux_data.axes[0]]; + for (index_t k = 0; k < aux_data.output_shape[aux_data.axes[1]]; k++) { + for (index_t l = 0; l < aux_data.output_shape[aux_data.axes[0]]; l++) { + KERNEL_ASSIGN(output[init_off + k * stride_1 + l * stride_0], + req, OP::Map(input[i])); + } + } + break; + // when input shape is of the form (1,x_1,1,x_2,1) + // x_1, x_2 are size of the dimensions that are not to be broadcasted + // here the last axis which is [4] is the one where compiler can vectorize + // the code the outer 2 loops improve preformance by avoiding + // duplicate stride calculations + // for each output location input[i] needs to be written to. + case 3: + stride_2 = aux_data.out_stride[aux_data.axes[2]]; + stride_1 = aux_data.out_stride[aux_data.axes[1]]; + stride_0 = aux_data.out_stride[aux_data.axes[0]]; + for (index_t j = 0; j < aux_data.output_shape[aux_data.axes[2]]; j++) { + for (index_t k = 0; k < aux_data.output_shape[aux_data.axes[1]]; k++) { + for (index_t l = 0; l < aux_data.output_shape[aux_data.axes[0]]; l++) { + KERNEL_ASSIGN(output[init_off + j * stride_2 + k * stride_1 + l * stride_0], + req, OP::Map(input[i])); + } + } + } + break; + } + } +}; + +template +struct direct_copy { + template + MSHADOW_XINLINE static void Map(index_t i, + IType *input, + OType *output, + const OpReqType req) { + KERNEL_ASSIGN(output[i], req, OP::Map(input[i])); + } +}; + +/** + * When CPU context is used the no. of kernel launches are equal to + * the no. of input elements, this helps leverage vectorization when possible + * When GPU context is used no. of kernel launches are equal to + * the no. of output elements, this ensures coalesced memory writes to output + * and improves coalesced memory reads. + */ template inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -1076,8 +1225,14 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs, using namespace mshadow::expr; using namespace mxnet_op; mxnet::TShape src_shape, dst_shape; + // combines 2 or more consecutive broadcast/non-broadcast axes together + // e.g. (3,4,1,1,5,1,6,7) (2,3,5) (5,10,9) -> (3*4,1*1,5,1,6*7) (1,3) (5*10, 9) + // -> (12,1,5,1,42) (1,3) (50, 9) + // and this is the new input for broadcast_kernel whose total + // num of dimensions cannot be greater than 5(throws an error otherwise). BroadcastReduceShapeCompact(outputs[0].shape_, small, &dst_shape, &src_shape); Stream *s = ctx.get_stream(); + bool isCPU = std::is_same::value; MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, IType, { MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, { mshadow::Shape in_shape; @@ -1091,21 +1246,38 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs, out_shape[i] = 1; } } - if (dst_shape.ndim() == 2) { + struct ShapeAndStride aux_data; + PrepareAUXData(&aux_data, in_shape, out_shape, dst_shape.ndim()); + if (!aux_data.shape_changed) { + // If no broadcast is required (i.e. input_shape == output_shape) + // then simply copy input to outout. + Kernel, xpu>::Launch( + s, outputs[0].Size(), inputs[0].dptr(), outputs[0].dptr(), req[0]); + } else if (dst_shape.ndim() == 2) { Tensor out = outputs[0].get_with_shape(dst_shape.get<2>(), s); Tensor data = inputs[0].get_with_shape(src_shape.get<2>(), s); - Kernel, xpu>::Launch( - s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, req[0], 2); + if (isCPU) { + Kernel, xpu>::Launch( + s, data.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2); + } else { + Kernel, xpu>::Launch( + s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2); + } } else { const int ndim = MXNET_SPECIAL_MAX_NDIM; Tensor out = outputs[0].get_with_shape(dst_shape.get(), s); Tensor data = inputs[0].get_with_shape(src_shape.get(), s); - Kernel, xpu>::Launch( - s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, req[0], ndim); + if (isCPU) { + Kernel, xpu>::Launch( + s, data.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], ndim); + } else { + Kernel, xpu>::Launch( + s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], ndim); + } } }); });