@@ -99,6 +99,8 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
9999 // 1d conv
100100 CHECK_EQ (dshp.ndim (), 3U ) << " Input data should be 3D in batch-num_filter-x" ;
101101 Shape<3 > dshape = ConvertLayout (dshp.get <3 >(), param_.layout .value (), kNCW );
102+ CHECK_NE (param_.num_group , 0U ) \
103+ << " num_group must be non-zero" ;
102104 Shape<3 > wshape = Shape3 (param_.num_filter / param_.num_group ,
103105 mxnet::dim_size_is_known (dshape, 1 ) ? dshape[1 ] / param_.num_group : -1 ,
104106 param_.kernel [0 ]);
@@ -149,6 +151,8 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
149151 CHECK_EQ (dshp.ndim (), 4U ) \
150152 << " Input data should be 4D in batch-num_filter-y-x" ;
151153 Shape<4 > dshape = ConvertLayout (dshp.get <4 >(), param_.layout .value (), kNCHW );
154+ CHECK_NE (param_.num_group , 0U ) \
155+ << " num_group must be non-zero" ;
152156 Shape<4 > wshape = Shape4 (param_.num_filter / param_.num_group ,
153157 mxnet::dim_size_is_known (dshape, 1 ) ? dshape[1 ] / param_.num_group : -1 ,
154158 param_.kernel [0 ], param_.kernel [1 ]);
@@ -208,6 +212,8 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
208212 CHECK_EQ (dshp.ndim (), 5U ) \
209213 << " Input data should be 5D in batch-num_filter-depth-y-x" ;
210214 Shape<5 > dshape = ConvertLayout (dshp.get <5 >(), param_.layout .value (), kNCDHW );
215+ CHECK_NE (param_.num_group , 0U ) \
216+ << " num_group must be non-zero" ;
211217 Shape<5 > wshape = Shape5 (param_.num_filter / param_.num_group ,
212218 mxnet::dim_size_is_known (dshape, 1 ) ? dshape[1 ] / param_.num_group : -1 ,
213219 param_.kernel [0 ], param_.kernel [1 ], param_.kernel [2 ]);
0 commit comments