diff --git a/python/tvm/relay/op/strategy/bifrost.py b/python/tvm/relay/op/strategy/bifrost.py index 24e68a47bbeb..8008391fe86c 100644 --- a/python/tvm/relay/op/strategy/bifrost.py +++ b/python/tvm/relay/op/strategy/bifrost.py @@ -65,6 +65,14 @@ def conv2d_strategy_bifrost(attrs, inputs, out_type, target): wrap_topi_schedule(topi.bifrost.schedule_conv2d_nchw_spatial_pack), name="conv2d_nchw_spatial_pack.bifrost", ) + elif layout == "NHWC": + assert kernel_layout == "HWIO" + # For now just reuse general Mali strategy. + strategy.add_implementation( + wrap_compute_conv2d(topi.mali.conv2d_nhwc_spatial_pack), + wrap_topi_schedule(topi.mali.schedule_conv2d_nhwc_spatial_pack), + name="conv2d_nhwc_spatial_pack.bifrost", + ) else: raise RuntimeError("Unsupported conv2d layout {} for Mali(Bifrost)".format(layout)) elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): diff --git a/tests/python/topi/python/test_topi_conv2d_nhwc.py b/tests/python/topi/python/test_topi_conv2d_nhwc.py index 1a80b8e50e0f..eb4c5a343b58 100644 --- a/tests/python/topi/python/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/python/test_topi_conv2d_nhwc.py @@ -38,6 +38,10 @@ topi.mali.conv2d_nhwc_spatial_pack, topi.mali.schedule_conv2d_nhwc_spatial_pack, ), + "bifrost": ( + topi.mali.conv2d_nhwc_spatial_pack, + topi.mali.schedule_conv2d_nhwc_spatial_pack, + ), "hls": (topi.nn.conv2d_nhwc, topi.hls.schedule_conv2d_nhwc), }