diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 13c7f74c7ecd..75e28c8d078e 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -460,8 +460,33 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, if (param->kernel_size.defined() && param->channels.defined()) { ICHECK_EQ(param->kernel_size.size(), 3); ICHECK_EQ(param->dilation.size(), 3); - Array wshape({param->channels, indexdiv(dshape_ncdhw[1], param->groups), - param->kernel_size[0], param->kernel_size[1], param->kernel_size[2]}); + + bool is_depthwise = false; + if (param->groups > 1) { + if (!(weight && weight->shape.defined())) { + reporter->GetDiagCtx().Emit( + Diagnostic::Error(reporter->GetSpan()) + << "Weight shape must be specified when groups is greater than 1."); + return false; + } + + Array wshape_oidhw = trans_kernel_layout.ForwardShape(weight->shape); + if (tvm::tir::ExprDeepEqual()(param->groups, dshape_ncdhw[1]) && + tvm::tir::ExprDeepEqual()(param->groups, wshape_oidhw[0])) { + is_depthwise = true; + } + } + + Array wshape; + if (is_depthwise) { + auto channel_multiplier = indexdiv(param->channels, dshape_ncdhw[1]); + wshape = {dshape_ncdhw[1], channel_multiplier, param->kernel_size[0], param->kernel_size[1], + param->kernel_size[2]}; + } else { + wshape = {param->channels, indexdiv(dshape_ncdhw[1], param->groups), param->kernel_size[0], + param->kernel_size[1], param->kernel_size[2]}; + } + wshape = trans_kernel_layout.BackwardShape(wshape); channels = param->channels; dilated_ksize_z = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index cb785021783d..399f8556e09e 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -823,6 +823,30 @@ def test_conv3d_transpose_ncdhw_run(): tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5) +def test_compile_depthwise_conv3d(): + dshape = [1, 16, 10, 10, 10] + wshape = [16, 2, 1, 1, 1] + params = {} + data = relay.var("data", shape=dshape, dtype="float32") + kernel = relay.const(tvm.nd.array(np.ones(shape=wshape).astype(dtype="float32"))) + mod = tvm.IRModule() + res = relay.nn.conv3d( + data, + kernel, + kernel_size=[1, 1, 1], + padding=[0] * 3, + channels=32, + groups=16, + data_layout="NCDHW", + kernel_layout="OIDHW", + ) + func = relay.Function([data], res) + mod = tvm.IRModule.from_expr(func) + + target = "llvm" + _ = relay.build(mod, tvm.target.Target(target, host=target)) + + @tvm.testing.uses_gpu def test_conv2d_transpose_infer_type(): # symbolic in batch dimension