Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@
import re

from tvm import _ffi, ir, te, topi
from tvm.auto_scheduler import is_auto_scheduler_enabled
from tvm.meta_schedule import is_meta_schedule_enabled
from tvm.target import generic_func, override_native_generic_func
from tvm.topi.utils import (
get_const_float,
get_const_int,
get_const_tuple,
get_float_tuple,
)
from tvm.topi.utils import (get_const_float, get_const_int, get_const_tuple,
get_float_tuple)

from .. import op as _op

Expand Down Expand Up @@ -240,7 +238,6 @@ def wrap_compute_conv2d(
need_meta_schedule_layout=False,
):
"""Wrap conv2d topi compute"""

def _compute_conv2d(attrs, inputs, out_type):
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
Expand Down Expand Up @@ -854,8 +851,13 @@ def matmul_strategy(attrs, inputs, out_type, target):
"""matmul generic strategy"""
logger.warning("matmul is not optimized for this platform.")
strategy = _op.OpStrategy()

strategy.add_implementation(
wrap_compute_matmul(topi.nn.matmul),
wrap_compute_matmul(
topi.nn.matmul,
need_auto_scheduler_layout=is_auto_scheduler_enabled(),
need_meta_schedule_layout=is_meta_schedule_enabled(),
),
wrap_topi_schedule(topi.generic.schedule_matmul),
name="matmul.generic",
)
Expand All @@ -869,7 +871,6 @@ def wrap_compute_dense(
need_meta_schedule_layout=False,
):
"""wrap dense topi compute"""

def _compute_dense(attrs, inputs, out_type):
"""Compute definition of dense"""
out_dtype = attrs.out_dtype
Expand All @@ -892,7 +893,11 @@ def dense_strategy(attrs, inputs, out_type, target):
logger.warning("dense is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_dense(topi.nn.dense),
wrap_compute_dense(
topi.nn.dense,
need_auto_scheduler_layout=is_auto_scheduler_enabled(),
need_meta_schedule_layout=is_meta_schedule_enabled(),
),
wrap_topi_schedule(topi.generic.schedule_dense),
name="dense.generic",
)
Expand Down Expand Up @@ -944,7 +949,11 @@ def batch_matmul_strategy(attrs, inputs, out_type, target):
logger.warning("batch_matmul is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_batch_matmul(topi.nn.batch_matmul),
wrap_compute_batch_matmul(
topi.nn.batch_matmul,
need_auto_scheduler_layout=is_auto_scheduler_enabled(),
need_meta_schedule_layout=is_meta_schedule_enabled(),
),
wrap_topi_schedule(topi.generic.schedule_batch_matmul),
name="batch_matmul.generic",
)
Expand Down