diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 1cf55f7145cd..4be504bb3577 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -20,13 +20,11 @@ import re from tvm import _ffi, ir, te, topi +from tvm.auto_scheduler import is_auto_scheduler_enabled +from tvm.meta_schedule import is_meta_schedule_enabled from tvm.target import generic_func, override_native_generic_func -from tvm.topi.utils import ( - get_const_float, - get_const_int, - get_const_tuple, - get_float_tuple, -) +from tvm.topi.utils import (get_const_float, get_const_int, get_const_tuple, + get_float_tuple) from .. import op as _op @@ -240,7 +238,6 @@ def wrap_compute_conv2d( need_meta_schedule_layout=False, ): """Wrap conv2d topi compute""" - def _compute_conv2d(attrs, inputs, out_type): padding = get_const_tuple(attrs.padding) strides = get_const_tuple(attrs.strides) @@ -854,8 +851,13 @@ def matmul_strategy(attrs, inputs, out_type, target): """matmul generic strategy""" logger.warning("matmul is not optimized for this platform.") strategy = _op.OpStrategy() + strategy.add_implementation( - wrap_compute_matmul(topi.nn.matmul), + wrap_compute_matmul( + topi.nn.matmul, + need_auto_scheduler_layout=is_auto_scheduler_enabled(), + need_meta_schedule_layout=is_meta_schedule_enabled(), + ), wrap_topi_schedule(topi.generic.schedule_matmul), name="matmul.generic", ) @@ -869,7 +871,6 @@ def wrap_compute_dense( need_meta_schedule_layout=False, ): """wrap dense topi compute""" - def _compute_dense(attrs, inputs, out_type): """Compute definition of dense""" out_dtype = attrs.out_dtype @@ -892,7 +893,11 @@ def dense_strategy(attrs, inputs, out_type, target): logger.warning("dense is not optimized for this platform.") strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_dense(topi.nn.dense), + wrap_compute_dense( + topi.nn.dense, + need_auto_scheduler_layout=is_auto_scheduler_enabled(), + need_meta_schedule_layout=is_meta_schedule_enabled(), + ), wrap_topi_schedule(topi.generic.schedule_dense), name="dense.generic", ) @@ -944,7 +949,11 @@ def batch_matmul_strategy(attrs, inputs, out_type, target): logger.warning("batch_matmul is not optimized for this platform.") strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_batch_matmul(topi.nn.batch_matmul), + wrap_compute_batch_matmul( + topi.nn.batch_matmul, + need_auto_scheduler_layout=is_auto_scheduler_enabled(), + need_meta_schedule_layout=is_meta_schedule_enabled(), + ), wrap_topi_schedule(topi.generic.schedule_batch_matmul), name="batch_matmul.generic", )