From 42b78dd7edc5382e444be701d2430c48084d85ba Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Tue, 10 Jan 2023 15:01:46 -0800 Subject: [PATCH 1/3] forward rewrite for generic --- python/tvm/relay/op/strategy/generic.py | 44 +++++++++++++++++++------ 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 1cf55f7145cd..b11e23466bcf 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -20,6 +20,8 @@ import re from tvm import _ffi, ir, te, topi +from tvm.auto_scheduler import is_auto_scheduler_enabled +from tvm.meta_schedule import is_meta_schedule_enabled from tvm.target import generic_func, override_native_generic_func from tvm.topi.utils import ( get_const_float, @@ -236,10 +238,14 @@ def wrap_compute_conv2d( need_kernel_layout=False, need_out_layout=False, has_groups=False, - need_auto_scheduler_layout=False, - need_meta_schedule_layout=False, + need_auto_scheduler_layout=None, + need_meta_schedule_layout=None, ): """Wrap conv2d topi compute""" + if need_auto_scheduler_layout is None: + need_auto_scheduler_layout = is_auto_scheduler_enabled() + if need_meta_schedule_layout is None: + need_meta_schedule_layout = is_meta_schedule_enabled() def _compute_conv2d(attrs, inputs, out_type): padding = get_const_tuple(attrs.padding) @@ -562,11 +568,16 @@ def conv3d_transpose_strategy(attrs, inputs, out_type, target): def wrap_compute_conv3d( topi_compute, need_layout=False, - need_auto_scheduler_layout=False, - need_meta_schedule_layout=False, + need_auto_scheduler_layout=None, + need_meta_schedule_layout=None, ): """wrap conv3d topi compute""" + if need_auto_scheduler_layout is None: + need_auto_scheduler_layout = is_auto_scheduler_enabled() + if need_meta_schedule_layout is None: + need_meta_schedule_layout = is_meta_schedule_enabled() + def _compute_conv3d(attrs, inputs, out_type): padding = get_const_tuple(attrs.padding) strides = get_const_tuple(attrs.strides) @@ -821,11 +832,16 @@ def copy_if_identical(tensor_a, tensor_b): # matmul def wrap_compute_matmul( topi_compute, - need_auto_scheduler_layout=False, - need_meta_schedule_layout=False, + need_auto_scheduler_layout=None, + need_meta_schedule_layout=None, ): """wrap matmul topi compute""" + if need_auto_scheduler_layout is None: + need_auto_scheduler_layout = is_auto_scheduler_enabled() + if need_meta_schedule_layout is None: + need_meta_schedule_layout = is_meta_schedule_enabled() + def _compute_matmul(attrs, inputs, out_type): """Compute definition of matmul""" out_dtype = attrs.out_dtype @@ -865,10 +881,14 @@ def matmul_strategy(attrs, inputs, out_type, target): # dense def wrap_compute_dense( topi_compute, - need_auto_scheduler_layout=False, - need_meta_schedule_layout=False, + need_auto_scheduler_layout=None, + need_meta_schedule_layout=None, ): """wrap dense topi compute""" + if need_auto_scheduler_layout is None: + need_auto_scheduler_layout = is_auto_scheduler_enabled() + if need_meta_schedule_layout is None: + need_meta_schedule_layout = is_meta_schedule_enabled() def _compute_dense(attrs, inputs, out_type): """Compute definition of dense""" @@ -916,11 +936,15 @@ def dense_pack_strategy(attrs, inputs, out_type, target): def wrap_compute_batch_matmul( topi_compute, *, - need_auto_scheduler_layout=False, - need_meta_schedule_layout=False, + need_auto_scheduler_layout=None, + need_meta_schedule_layout=None, need_out_dtype=False, ): """wrap batch_matmul topi compute""" + if need_auto_scheduler_layout is None: + need_auto_scheduler_layout = is_auto_scheduler_enabled() + if need_meta_schedule_layout is None: + need_meta_schedule_layout = is_meta_schedule_enabled() def _compute_batch_matmul(attrs, inputs, out_type): args = [inputs[0], inputs[1], out_type.shape] From f6fb10e8acd0489e653f80589bc91db41d4395fa Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Wed, 11 Jan 2023 10:08:04 -0800 Subject: [PATCH 2/3] move to layout to strategy --- python/tvm/relay/op/strategy/generic.py | 63 ++++++++++--------------- 1 file changed, 26 insertions(+), 37 deletions(-) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index b11e23466bcf..ff9788cafd54 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -238,15 +238,10 @@ def wrap_compute_conv2d( need_kernel_layout=False, need_out_layout=False, has_groups=False, - need_auto_scheduler_layout=None, - need_meta_schedule_layout=None, + need_auto_scheduler_layout=False, + need_meta_schedule_layout=False, ): """Wrap conv2d topi compute""" - if need_auto_scheduler_layout is None: - need_auto_scheduler_layout = is_auto_scheduler_enabled() - if need_meta_schedule_layout is None: - need_meta_schedule_layout = is_meta_schedule_enabled() - def _compute_conv2d(attrs, inputs, out_type): padding = get_const_tuple(attrs.padding) strides = get_const_tuple(attrs.strides) @@ -568,16 +563,11 @@ def conv3d_transpose_strategy(attrs, inputs, out_type, target): def wrap_compute_conv3d( topi_compute, need_layout=False, - need_auto_scheduler_layout=None, - need_meta_schedule_layout=None, + need_auto_scheduler_layout=False, + need_meta_schedule_layout=False, ): """wrap conv3d topi compute""" - if need_auto_scheduler_layout is None: - need_auto_scheduler_layout = is_auto_scheduler_enabled() - if need_meta_schedule_layout is None: - need_meta_schedule_layout = is_meta_schedule_enabled() - def _compute_conv3d(attrs, inputs, out_type): padding = get_const_tuple(attrs.padding) strides = get_const_tuple(attrs.strides) @@ -832,16 +822,11 @@ def copy_if_identical(tensor_a, tensor_b): # matmul def wrap_compute_matmul( topi_compute, - need_auto_scheduler_layout=None, - need_meta_schedule_layout=None, + need_auto_scheduler_layout=False, + need_meta_schedule_layout=False, ): """wrap matmul topi compute""" - if need_auto_scheduler_layout is None: - need_auto_scheduler_layout = is_auto_scheduler_enabled() - if need_meta_schedule_layout is None: - need_meta_schedule_layout = is_meta_schedule_enabled() - def _compute_matmul(attrs, inputs, out_type): """Compute definition of matmul""" out_dtype = attrs.out_dtype @@ -870,8 +855,13 @@ def matmul_strategy(attrs, inputs, out_type, target): """matmul generic strategy""" logger.warning("matmul is not optimized for this platform.") strategy = _op.OpStrategy() + strategy.add_implementation( - wrap_compute_matmul(topi.nn.matmul), + wrap_compute_matmul( + topi.nn.matmul, + need_auto_scheduler_layout=is_auto_scheduler_enabled(), + need_meta_schedule_layout=is_meta_schedule_enabled(), + ), wrap_topi_schedule(topi.generic.schedule_matmul), name="matmul.generic", ) @@ -881,15 +871,10 @@ def matmul_strategy(attrs, inputs, out_type, target): # dense def wrap_compute_dense( topi_compute, - need_auto_scheduler_layout=None, - need_meta_schedule_layout=None, + need_auto_scheduler_layout=False, + need_meta_schedule_layout=False, ): """wrap dense topi compute""" - if need_auto_scheduler_layout is None: - need_auto_scheduler_layout = is_auto_scheduler_enabled() - if need_meta_schedule_layout is None: - need_meta_schedule_layout = is_meta_schedule_enabled() - def _compute_dense(attrs, inputs, out_type): """Compute definition of dense""" out_dtype = attrs.out_dtype @@ -912,7 +897,11 @@ def dense_strategy(attrs, inputs, out_type, target): logger.warning("dense is not optimized for this platform.") strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_dense(topi.nn.dense), + wrap_compute_dense( + topi.nn.dense, + need_auto_scheduler_layout=is_auto_scheduler_enabled(), + need_meta_schedule_layout=is_meta_schedule_enabled(), + ), wrap_topi_schedule(topi.generic.schedule_dense), name="dense.generic", ) @@ -936,15 +925,11 @@ def dense_pack_strategy(attrs, inputs, out_type, target): def wrap_compute_batch_matmul( topi_compute, *, - need_auto_scheduler_layout=None, - need_meta_schedule_layout=None, + need_auto_scheduler_layout=False, + need_meta_schedule_layout=False, need_out_dtype=False, ): """wrap batch_matmul topi compute""" - if need_auto_scheduler_layout is None: - need_auto_scheduler_layout = is_auto_scheduler_enabled() - if need_meta_schedule_layout is None: - need_meta_schedule_layout = is_meta_schedule_enabled() def _compute_batch_matmul(attrs, inputs, out_type): args = [inputs[0], inputs[1], out_type.shape] @@ -968,7 +953,11 @@ def batch_matmul_strategy(attrs, inputs, out_type, target): logger.warning("batch_matmul is not optimized for this platform.") strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_batch_matmul(topi.nn.batch_matmul), + wrap_compute_batch_matmul( + topi.nn.batch_matmul, + need_auto_scheduler_layout=is_auto_scheduler_enabled(), + need_meta_schedule_layout=is_meta_schedule_enabled, + ), wrap_topi_schedule(topi.generic.schedule_batch_matmul), name="batch_matmul.generic", ) From 92661dc835229a23a586644b7c39e901f61a0096 Mon Sep 17 00:00:00 2001 From: Andrew Luo Date: Wed, 11 Jan 2023 10:44:18 -0800 Subject: [PATCH 3/3] missing () --- python/tvm/relay/op/strategy/generic.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index ff9788cafd54..4be504bb3577 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -23,12 +23,8 @@ from tvm.auto_scheduler import is_auto_scheduler_enabled from tvm.meta_schedule import is_meta_schedule_enabled from tvm.target import generic_func, override_native_generic_func -from tvm.topi.utils import ( - get_const_float, - get_const_int, - get_const_tuple, - get_float_tuple, -) +from tvm.topi.utils import (get_const_float, get_const_int, get_const_tuple, + get_float_tuple) from .. import op as _op @@ -956,7 +952,7 @@ def batch_matmul_strategy(attrs, inputs, out_type, target): wrap_compute_batch_matmul( topi.nn.batch_matmul, need_auto_scheduler_layout=is_auto_scheduler_enabled(), - need_meta_schedule_layout=is_meta_schedule_enabled, + need_meta_schedule_layout=is_meta_schedule_enabled(), ), wrap_topi_schedule(topi.generic.schedule_batch_matmul), name="batch_matmul.generic",