Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit f6d03ae

Browse files
access2rohitRohit Kumar Srivastava
andcommitted
Improve performance of broadcast_axis on CPU (#17882)
* adding comments explaining code optimizations * fixing broadcast_axis kernel to int32 * fixing slice_axis kernel to int32 * combining CPU and GPU implementation method signatures and cleaned up code * adding new broadcast_axis to np_matmul Co-authored-by: Rohit Kumar Srivastava <srivastava.141@buckeyemail.osu.edu>
1 parent 5976f8b commit f6d03ae

File tree

2 files changed

+172
-14
lines changed

2 files changed

+172
-14
lines changed

src/operator/numpy/np_matmul_op-inl.h

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ inline void MatmulImpl(const OpContext& ctx,
138138
mshadow::Tensor<xpu, 1, DType*> workspace;
139139
mshadow::Tensor<xpu, 3, DType> ans, mlhs, mrhs;
140140
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
141+
bool isCPU = std::is_same<xpu, cpu>::value;
142+
// Is true if either a or b requires broadcast or not
141143
if (MatmulNeedBroadcast(a_shape, b_shape)) {
142144
// e.g. a.shape = (2, 3, 1, 4, 2)
143145
// b.shape = (5, 2, 4)
@@ -160,12 +162,35 @@ inline void MatmulImpl(const OpContext& ctx,
160162
struct ShapeAndStride aux_data_a, aux_data_b;
161163
PrepareAUXData(&aux_data_a, k_a_shape, k_a_shape_bc, ndim);
162164
PrepareAUXData(&aux_data_b, k_b_shape, k_b_shape_bc, ndim);
163-
Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
164-
s, bc_size_a, input_a.dptr<IType>(), bc_a_ptr,
165-
aux_data_a, OpReqType::kWriteTo, ndim);
166-
Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
167-
s, bc_size_b, input_b.dptr<IType>(), bc_b_ptr,
168-
aux_data_b, OpReqType::kWriteTo, ndim);
165+
if (isCPU) {
166+
if (!aux_data_a.shape_changed) {
167+
Kernel<direct_copy<mshadow_op::identity>, xpu>::Launch(
168+
s, bc_size_a, input_a.dptr<IType>(), bc_a_ptr, OpReqType::kWriteTo);
169+
Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
170+
s, input_b.Size(), input_b.dptr<IType>(), bc_b_ptr,
171+
aux_data_b, OpReqType::kWriteTo, ndim);
172+
} else if (!aux_data_b.shape_changed) {
173+
Kernel<direct_copy<mshadow_op::identity>, xpu>::Launch(
174+
s, bc_size_b, input_b.dptr<IType>(), bc_b_ptr, OpReqType::kWriteTo);
175+
Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
176+
s, input_a.Size(), input_a.dptr<IType>(), bc_a_ptr,
177+
aux_data_a, OpReqType::kWriteTo, ndim);
178+
} else {
179+
Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
180+
s, input_a.Size(), input_a.dptr<IType>(), bc_a_ptr,
181+
aux_data_a, OpReqType::kWriteTo, ndim);
182+
Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
183+
s, input_b.Size(), input_b.dptr<IType>(), bc_b_ptr,
184+
aux_data_b, OpReqType::kWriteTo, ndim);
185+
}
186+
} else {
187+
Kernel<broadcast_kernel_gpu<mshadow_op::identity>, xpu>::Launch(
188+
s, bc_size_a, input_a.dptr<IType>(), bc_a_ptr,
189+
aux_data_a, OpReqType::kWriteTo, ndim);
190+
Kernel<broadcast_kernel_gpu<mshadow_op::identity>, xpu>::Launch(
191+
s, bc_size_b, input_b.dptr<IType>(), bc_b_ptr,
192+
aux_data_b, OpReqType::kWriteTo, ndim);
193+
}
169194
});
170195
});
171196
ans = mshadow::Tensor<xpu, 3, DType>(output.dptr<DType>(),

src/operator/tensor/broadcast_reduce_op.h

Lines changed: 141 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#ifndef MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_OP_H_
2626
#define MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_OP_H_
2727

28+
#include <assert.h>
2829
#include <mxnet/operator_util.h>
2930
#include <string>
3031
#include <vector>
@@ -1043,7 +1044,12 @@ struct ShapeAndStride {
10431044
index_t out_stride[MXNET_SPECIAL_MAX_NDIM];
10441045
index_t input_shape[MXNET_SPECIAL_MAX_NDIM];
10451046
index_t output_shape[MXNET_SPECIAL_MAX_NDIM];
1047+
// axes: stores which axes in input is to broadcasted
1048+
index_t axes[MXNET_SPECIAL_MAX_NDIM];
1049+
int num_broadcast_axes = -1;
1050+
bool shape_changed = false;
10461051
};
1052+
} // unnamed namespace
10471053

10481054
/*!
10491055
* \brief Calculates Stride of input and output tensor dimesnions
@@ -1058,23 +1064,32 @@ inline void PrepareAUXData(ShapeAndStride *aux_data,
10581064
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> in_shape,
10591065
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> out_shape,
10601066
int ndim) {
1061-
int iter = ndim - 1;
1067+
int iter = ndim - 1, i = 0;
10621068
aux_data->out_stride[iter] = 1;
10631069
aux_data->in_stride[iter] = 1;
10641070
aux_data->input_shape[iter] = in_shape[iter];
10651071
aux_data->output_shape[iter] = out_shape[iter];
1072+
if (in_shape[iter] != out_shape[iter]) {
1073+
aux_data->axes[i++] = iter;
1074+
aux_data->shape_changed = true;
1075+
}
10661076
iter--;
10671077
for (; iter >= 0; --iter) {
10681078
aux_data->out_stride[iter] = aux_data->out_stride[iter + 1] * out_shape[iter + 1];
10691079
aux_data->in_stride[iter] = aux_data->in_stride[iter + 1] * in_shape[iter + 1];
10701080
aux_data->input_shape[iter] = in_shape[iter];
10711081
aux_data->output_shape[iter] = out_shape[iter];
1082+
if (in_shape[iter] != out_shape[iter]) {
1083+
aux_data->axes[i++] = iter;
1084+
aux_data->shape_changed = true;
1085+
}
10721086
}
1087+
aux_data->num_broadcast_axes = i;
1088+
assert(aux_data->num_broadcast_axes > -1 && aux_data->num_broadcast_axes < 4);
10731089
}
1074-
} // unnamed namespace
10751090

10761091
template<typename OP>
1077-
struct broadcast_kernel {
1092+
struct broadcast_kernel_gpu {
10781093
template<typename IType, typename OType>
10791094
MSHADOW_XINLINE static void Map(index_t i,
10801095
IType *input,
@@ -1102,6 +1117,103 @@ struct broadcast_kernel {
11021117
}
11031118
};
11041119

1120+
/**
1121+
* Changed the thread workload mapping from 1
1122+
* thread/output element to 1 thread/input to be broadcasted
1123+
* This approach leverages vectorization when fastest varying
1124+
* index(stride=1) of the tensor is to be broadcasted.
1125+
* In other cases it simply performs better by better load balancing.
1126+
*/
1127+
template<typename OP>
1128+
struct broadcast_kernel_cpu {
1129+
template<typename IType, typename OType>
1130+
MSHADOW_XINLINE static void Map(index_t i,
1131+
IType *input,
1132+
OType *output,
1133+
const ShapeAndStride& aux_data,
1134+
const OpReqType req,
1135+
const int ndim) {
1136+
index_t idx = i;
1137+
index_t init_off = 0;
1138+
for (int iter = ndim - 1; idx > 0 && iter >= 0; --iter) {
1139+
size_t dim_idx = idx % aux_data.input_shape[iter];
1140+
init_off += dim_idx * aux_data.out_stride[iter];
1141+
idx /= aux_data.input_shape[iter];
1142+
}
1143+
index_t stride_0, stride_1, stride_2;
1144+
// Each case is based on the number of axis to be broadcasted
1145+
// (1, 2 or 3) after merging axes.
1146+
switch (aux_data.num_broadcast_axes) {
1147+
// when input shape is one of the following forms
1148+
// (x_1,1) or (x_1,1,x_2) or (1,x_1)
1149+
// x_1, x_2 are size of the dimensions that are not to be broadcasted
1150+
// in case of (x_1,1) the system leverages vectorization but in other 2
1151+
// the performance is improved due avoidance of duplicate stride calculations
1152+
// for each output location input[i] needs to be written to.
1153+
case 1 :
1154+
stride_0 = aux_data.out_stride[aux_data.axes[0]];
1155+
for (index_t l = 0; l < aux_data.output_shape[aux_data.axes[0]]; l++) {
1156+
KERNEL_ASSIGN(output[init_off + l * stride_0],
1157+
req, OP::Map(input[i]));
1158+
}
1159+
break;
1160+
// when input shape is one of the follwing forms
1161+
// (x_1,1,x_2,1) or (1,x_1,1,x_2) or (x_1,1,x_2,1,x_3)
1162+
// x_1, x_2, x_3 are size of the dimensions that are not to be broadcasted
1163+
// in the inner most loop can be vectorized by compiler in outer loops
1164+
// the performance is improved due avoidance of duplicate stride calculations
1165+
// for each output location input[i] needs to be written to.
1166+
case 2:
1167+
stride_1 = aux_data.out_stride[aux_data.axes[1]];
1168+
stride_0 = aux_data.out_stride[aux_data.axes[0]];
1169+
for (index_t k = 0; k < aux_data.output_shape[aux_data.axes[1]]; k++) {
1170+
for (index_t l = 0; l < aux_data.output_shape[aux_data.axes[0]]; l++) {
1171+
KERNEL_ASSIGN(output[init_off + k * stride_1 + l * stride_0],
1172+
req, OP::Map(input[i]));
1173+
}
1174+
}
1175+
break;
1176+
// when input shape is of the form (1,x_1,1,x_2,1)
1177+
// x_1, x_2 are size of the dimensions that are not to be broadcasted
1178+
// here the last axis which is [4] is the one where compiler can vectorize
1179+
// the code the outer 2 loops improve preformance by avoiding
1180+
// duplicate stride calculations
1181+
// for each output location input[i] needs to be written to.
1182+
case 3:
1183+
stride_2 = aux_data.out_stride[aux_data.axes[2]];
1184+
stride_1 = aux_data.out_stride[aux_data.axes[1]];
1185+
stride_0 = aux_data.out_stride[aux_data.axes[0]];
1186+
for (index_t j = 0; j < aux_data.output_shape[aux_data.axes[2]]; j++) {
1187+
for (index_t k = 0; k < aux_data.output_shape[aux_data.axes[1]]; k++) {
1188+
for (index_t l = 0; l < aux_data.output_shape[aux_data.axes[0]]; l++) {
1189+
KERNEL_ASSIGN(output[init_off + j * stride_2 + k * stride_1 + l * stride_0],
1190+
req, OP::Map(input[i]));
1191+
}
1192+
}
1193+
}
1194+
break;
1195+
}
1196+
}
1197+
};
1198+
1199+
template<typename OP>
1200+
struct direct_copy {
1201+
template<typename IType, typename OType>
1202+
MSHADOW_XINLINE static void Map(index_t i,
1203+
IType *input,
1204+
OType *output,
1205+
const OpReqType req) {
1206+
KERNEL_ASSIGN(output[i], req, OP::Map(input[i]));
1207+
}
1208+
};
1209+
1210+
/**
1211+
* When CPU context is used the no. of kernel launches are equal to
1212+
* the no. of input elements, this helps leverage vectorization when possible
1213+
* When GPU context is used no. of kernel launches are equal to
1214+
* the no. of output elements, this ensures coalesced memory writes to output
1215+
* and improves coalesced memory reads.
1216+
*/
11051217
template<typename xpu>
11061218
inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs,
11071219
const OpContext& ctx,
@@ -1113,8 +1225,14 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs,
11131225
using namespace mshadow::expr;
11141226
using namespace mxnet_op;
11151227
mxnet::TShape src_shape, dst_shape;
1228+
// combines 2 or more consecutive broadcast/non-broadcast axes together
1229+
// e.g. (3,4,1,1,5,1,6,7) (2,3,5) (5,10,9) -> (3*4,1*1,5,1,6*7) (1,3) (5*10, 9)
1230+
// -> (12,1,5,1,42) (1,3) (50, 9)
1231+
// and this is the new input for broadcast_kernel whose total
1232+
// num of dimensions cannot be greater than 5(throws an error otherwise).
11161233
BroadcastReduceShapeCompact(outputs[0].shape_, small, &dst_shape, &src_shape);
11171234
Stream<xpu> *s = ctx.get_stream<xpu>();
1235+
bool isCPU = std::is_same<xpu, cpu>::value;
11181236
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, IType, {
11191237
MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, {
11201238
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> in_shape;
@@ -1130,21 +1248,36 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs,
11301248
}
11311249
struct ShapeAndStride aux_data;
11321250
PrepareAUXData(&aux_data, in_shape, out_shape, dst_shape.ndim());
1133-
if (dst_shape.ndim() == 2) {
1251+
if (!aux_data.shape_changed) {
1252+
// If no broadcast is required (i.e. input_shape == output_shape)
1253+
// then simply copy input to outout.
1254+
Kernel<direct_copy<mshadow_op::identity>, xpu>::Launch(
1255+
s, outputs[0].Size(), inputs[0].dptr<IType>(), outputs[0].dptr<OType>(), req[0]);
1256+
} else if (dst_shape.ndim() == 2) {
11341257
Tensor<xpu, 2, OType> out =
11351258
outputs[0].get_with_shape<xpu, 2, OType>(dst_shape.get<2>(), s);
11361259
Tensor<xpu, 2, IType> data =
11371260
inputs[0].get_with_shape<xpu, 2, IType>(src_shape.get<2>(), s);
1138-
Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
1139-
s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2);
1261+
if (isCPU) {
1262+
Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
1263+
s, data.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2);
1264+
} else {
1265+
Kernel<broadcast_kernel_gpu<mshadow_op::identity>, xpu>::Launch(
1266+
s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2);
1267+
}
11401268
} else {
11411269
const int ndim = MXNET_SPECIAL_MAX_NDIM;
11421270
Tensor<xpu, ndim, OType> out =
11431271
outputs[0].get_with_shape<xpu, ndim, OType>(dst_shape.get<ndim>(), s);
11441272
Tensor<xpu, ndim, IType> data =
11451273
inputs[0].get_with_shape<xpu, ndim, IType>(src_shape.get<ndim>(), s);
1146-
Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch(
1147-
s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], ndim);
1274+
if (isCPU) {
1275+
Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch(
1276+
s, data.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], ndim);
1277+
} else {
1278+
Kernel<broadcast_kernel_gpu<mshadow_op::identity>, xpu>::Launch(
1279+
s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], ndim);
1280+
}
11481281
}
11491282
});
11501283
});

0 commit comments

Comments
 (0)