2525#ifndef MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_OP_H_
2626#define MXNET_OPERATOR_TENSOR_BROADCAST_REDUCE_OP_H_
2727
28+ #include < assert.h>
2829#include < mxnet/operator_util.h>
2930#include < string>
3031#include < vector>
@@ -1043,7 +1044,12 @@ struct ShapeAndStride {
10431044 index_t out_stride[MXNET_SPECIAL_MAX_NDIM];
10441045 index_t input_shape[MXNET_SPECIAL_MAX_NDIM];
10451046 index_t output_shape[MXNET_SPECIAL_MAX_NDIM];
1047+ // axes: stores which axes in input is to broadcasted
1048+ index_t axes[MXNET_SPECIAL_MAX_NDIM];
1049+ int num_broadcast_axes = -1 ;
1050+ bool shape_changed = false ;
10461051};
1052+ } // unnamed namespace
10471053
10481054/* !
10491055 * \brief Calculates Stride of input and output tensor dimesnions
@@ -1058,23 +1064,32 @@ inline void PrepareAUXData(ShapeAndStride *aux_data,
10581064 mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> in_shape,
10591065 mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> out_shape,
10601066 int ndim) {
1061- int iter = ndim - 1 ;
1067+ int iter = ndim - 1 , i = 0 ;
10621068 aux_data->out_stride [iter] = 1 ;
10631069 aux_data->in_stride [iter] = 1 ;
10641070 aux_data->input_shape [iter] = in_shape[iter];
10651071 aux_data->output_shape [iter] = out_shape[iter];
1072+ if (in_shape[iter] != out_shape[iter]) {
1073+ aux_data->axes [i++] = iter;
1074+ aux_data->shape_changed = true ;
1075+ }
10661076 iter--;
10671077 for (; iter >= 0 ; --iter) {
10681078 aux_data->out_stride [iter] = aux_data->out_stride [iter + 1 ] * out_shape[iter + 1 ];
10691079 aux_data->in_stride [iter] = aux_data->in_stride [iter + 1 ] * in_shape[iter + 1 ];
10701080 aux_data->input_shape [iter] = in_shape[iter];
10711081 aux_data->output_shape [iter] = out_shape[iter];
1082+ if (in_shape[iter] != out_shape[iter]) {
1083+ aux_data->axes [i++] = iter;
1084+ aux_data->shape_changed = true ;
1085+ }
10721086 }
1087+ aux_data->num_broadcast_axes = i;
1088+ assert (aux_data->num_broadcast_axes > -1 && aux_data->num_broadcast_axes < 4 );
10731089}
1074- } // unnamed namespace
10751090
10761091template <typename OP>
1077- struct broadcast_kernel {
1092+ struct broadcast_kernel_gpu {
10781093 template <typename IType, typename OType>
10791094 MSHADOW_XINLINE static void Map (index_t i,
10801095 IType *input,
@@ -1102,6 +1117,103 @@ struct broadcast_kernel {
11021117 }
11031118};
11041119
1120+ /* *
1121+ * Changed the thread workload mapping from 1
1122+ * thread/output element to 1 thread/input to be broadcasted
1123+ * This approach leverages vectorization when fastest varying
1124+ * index(stride=1) of the tensor is to be broadcasted.
1125+ * In other cases it simply performs better by better load balancing.
1126+ */
1127+ template <typename OP>
1128+ struct broadcast_kernel_cpu {
1129+ template <typename IType, typename OType>
1130+ MSHADOW_XINLINE static void Map (index_t i,
1131+ IType *input,
1132+ OType *output,
1133+ const ShapeAndStride& aux_data,
1134+ const OpReqType req,
1135+ const int ndim) {
1136+ index_t idx = i;
1137+ index_t init_off = 0 ;
1138+ for (int iter = ndim - 1 ; idx > 0 && iter >= 0 ; --iter) {
1139+ size_t dim_idx = idx % aux_data.input_shape [iter];
1140+ init_off += dim_idx * aux_data.out_stride [iter];
1141+ idx /= aux_data.input_shape [iter];
1142+ }
1143+ index_t stride_0, stride_1, stride_2;
1144+ // Each case is based on the number of axis to be broadcasted
1145+ // (1, 2 or 3) after merging axes.
1146+ switch (aux_data.num_broadcast_axes ) {
1147+ // when input shape is one of the following forms
1148+ // (x_1,1) or (x_1,1,x_2) or (1,x_1)
1149+ // x_1, x_2 are size of the dimensions that are not to be broadcasted
1150+ // in case of (x_1,1) the system leverages vectorization but in other 2
1151+ // the performance is improved due avoidance of duplicate stride calculations
1152+ // for each output location input[i] needs to be written to.
1153+ case 1 :
1154+ stride_0 = aux_data.out_stride [aux_data.axes [0 ]];
1155+ for (index_t l = 0 ; l < aux_data.output_shape [aux_data.axes [0 ]]; l++) {
1156+ KERNEL_ASSIGN (output[init_off + l * stride_0],
1157+ req, OP::Map (input[i]));
1158+ }
1159+ break ;
1160+ // when input shape is one of the follwing forms
1161+ // (x_1,1,x_2,1) or (1,x_1,1,x_2) or (x_1,1,x_2,1,x_3)
1162+ // x_1, x_2, x_3 are size of the dimensions that are not to be broadcasted
1163+ // in the inner most loop can be vectorized by compiler in outer loops
1164+ // the performance is improved due avoidance of duplicate stride calculations
1165+ // for each output location input[i] needs to be written to.
1166+ case 2 :
1167+ stride_1 = aux_data.out_stride [aux_data.axes [1 ]];
1168+ stride_0 = aux_data.out_stride [aux_data.axes [0 ]];
1169+ for (index_t k = 0 ; k < aux_data.output_shape [aux_data.axes [1 ]]; k++) {
1170+ for (index_t l = 0 ; l < aux_data.output_shape [aux_data.axes [0 ]]; l++) {
1171+ KERNEL_ASSIGN (output[init_off + k * stride_1 + l * stride_0],
1172+ req, OP::Map (input[i]));
1173+ }
1174+ }
1175+ break ;
1176+ // when input shape is of the form (1,x_1,1,x_2,1)
1177+ // x_1, x_2 are size of the dimensions that are not to be broadcasted
1178+ // here the last axis which is [4] is the one where compiler can vectorize
1179+ // the code the outer 2 loops improve preformance by avoiding
1180+ // duplicate stride calculations
1181+ // for each output location input[i] needs to be written to.
1182+ case 3 :
1183+ stride_2 = aux_data.out_stride [aux_data.axes [2 ]];
1184+ stride_1 = aux_data.out_stride [aux_data.axes [1 ]];
1185+ stride_0 = aux_data.out_stride [aux_data.axes [0 ]];
1186+ for (index_t j = 0 ; j < aux_data.output_shape [aux_data.axes [2 ]]; j++) {
1187+ for (index_t k = 0 ; k < aux_data.output_shape [aux_data.axes [1 ]]; k++) {
1188+ for (index_t l = 0 ; l < aux_data.output_shape [aux_data.axes [0 ]]; l++) {
1189+ KERNEL_ASSIGN (output[init_off + j * stride_2 + k * stride_1 + l * stride_0],
1190+ req, OP::Map (input[i]));
1191+ }
1192+ }
1193+ }
1194+ break ;
1195+ }
1196+ }
1197+ };
1198+
1199+ template <typename OP>
1200+ struct direct_copy {
1201+ template <typename IType, typename OType>
1202+ MSHADOW_XINLINE static void Map (index_t i,
1203+ IType *input,
1204+ OType *output,
1205+ const OpReqType req) {
1206+ KERNEL_ASSIGN (output[i], req, OP::Map (input[i]));
1207+ }
1208+ };
1209+
1210+ /* *
1211+ * When CPU context is used the no. of kernel launches are equal to
1212+ * the no. of input elements, this helps leverage vectorization when possible
1213+ * When GPU context is used no. of kernel launches are equal to
1214+ * the no. of output elements, this ensures coalesced memory writes to output
1215+ * and improves coalesced memory reads.
1216+ */
11051217template <typename xpu>
11061218inline void BroadcastComputeImpl (const nnvm::NodeAttrs& attrs,
11071219 const OpContext& ctx,
@@ -1113,8 +1225,14 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs,
11131225 using namespace mshadow ::expr;
11141226 using namespace mxnet_op ;
11151227 mxnet::TShape src_shape, dst_shape;
1228+ // combines 2 or more consecutive broadcast/non-broadcast axes together
1229+ // e.g. (3,4,1,1,5,1,6,7) (2,3,5) (5,10,9) -> (3*4,1*1,5,1,6*7) (1,3) (5*10, 9)
1230+ // -> (12,1,5,1,42) (1,3) (50, 9)
1231+ // and this is the new input for broadcast_kernel whose total
1232+ // num of dimensions cannot be greater than 5(throws an error otherwise).
11161233 BroadcastReduceShapeCompact (outputs[0 ].shape_ , small, &dst_shape, &src_shape);
11171234 Stream<xpu> *s = ctx.get_stream <xpu>();
1235+ bool isCPU = std::is_same<xpu, cpu>::value;
11181236 MSHADOW_TYPE_SWITCH_WITH_BOOL (inputs[0 ].type_flag_ , IType, {
11191237 MSHADOW_TYPE_SWITCH_WITH_BOOL (outputs[0 ].type_flag_ , OType, {
11201238 mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> in_shape;
@@ -1130,21 +1248,36 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs,
11301248 }
11311249 struct ShapeAndStride aux_data;
11321250 PrepareAUXData (&aux_data, in_shape, out_shape, dst_shape.ndim ());
1133- if (dst_shape.ndim () == 2 ) {
1251+ if (!aux_data.shape_changed ) {
1252+ // If no broadcast is required (i.e. input_shape == output_shape)
1253+ // then simply copy input to outout.
1254+ Kernel<direct_copy<mshadow_op::identity>, xpu>::Launch (
1255+ s, outputs[0 ].Size (), inputs[0 ].dptr <IType>(), outputs[0 ].dptr <OType>(), req[0 ]);
1256+ } else if (dst_shape.ndim () == 2 ) {
11341257 Tensor<xpu, 2 , OType> out =
11351258 outputs[0 ].get_with_shape <xpu, 2 , OType>(dst_shape.get <2 >(), s);
11361259 Tensor<xpu, 2 , IType> data =
11371260 inputs[0 ].get_with_shape <xpu, 2 , IType>(src_shape.get <2 >(), s);
1138- Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch (
1139- s, out.shape_ .Size (), data.dptr_ , out.dptr_ , aux_data, req[0 ], 2 );
1261+ if (isCPU) {
1262+ Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch (
1263+ s, data.shape_ .Size (), data.dptr_ , out.dptr_ , aux_data, req[0 ], 2 );
1264+ } else {
1265+ Kernel<broadcast_kernel_gpu<mshadow_op::identity>, xpu>::Launch (
1266+ s, out.shape_ .Size (), data.dptr_ , out.dptr_ , aux_data, req[0 ], 2 );
1267+ }
11401268 } else {
11411269 const int ndim = MXNET_SPECIAL_MAX_NDIM;
11421270 Tensor<xpu, ndim, OType> out =
11431271 outputs[0 ].get_with_shape <xpu, ndim, OType>(dst_shape.get <ndim>(), s);
11441272 Tensor<xpu, ndim, IType> data =
11451273 inputs[0 ].get_with_shape <xpu, ndim, IType>(src_shape.get <ndim>(), s);
1146- Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch (
1147- s, out.shape_ .Size (), data.dptr_ , out.dptr_ , aux_data, req[0 ], ndim);
1274+ if (isCPU) {
1275+ Kernel<broadcast_kernel_cpu<mshadow_op::identity>, xpu>::Launch (
1276+ s, data.shape_ .Size (), data.dptr_ , out.dptr_ , aux_data, req[0 ], ndim);
1277+ } else {
1278+ Kernel<broadcast_kernel_gpu<mshadow_op::identity>, xpu>::Launch (
1279+ s, out.shape_ .Size (), data.dptr_ , out.dptr_ , aux_data, req[0 ], ndim);
1280+ }
11481281 }
11491282 });
11501283 });
0 commit comments