Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/relax/vm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def foo(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")):

lowering_passes = tvm.transform.Sequential(
[
relax.transform.LegalizeOps(),
relax.transform.RewriteDataflowReshape(),
relax.transform.ToNonDataflow(),
relax.transform.RemovePurityChecking(),
Expand Down
9 changes: 8 additions & 1 deletion src/relax/transform/legalize_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
namespace tvm {
namespace relax {

TVM_REGISTER_PASS_CONFIG_OPTION("relax.transform.apply_legalize_ops", Bool);

/*!
* \brief Check if a given Tensor/Shape/TupleStructInfo contains shapes whose
* values are all known.
Expand Down Expand Up @@ -206,7 +208,12 @@ namespace transform {
Pass LegalizeOps(Optional<Map<String, PackedFunc>> cmap, bool enable_warning) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod,
PassContext pc) {
return LegalizeMutator(mod, cmap, enable_warning).Transform();
bool apply_legalize_ops =
pc->GetConfig<Bool>("relax.transform.apply_legalize_ops").value_or(Bool(true))->value;
if (apply_legalize_ops) {
mod = LegalizeMutator(mod, cmap, enable_warning).Transform();
}
return mod;
};
return CreateModulePass(/*pass_function=*/pass_func,
/*opt_level=*/0,
Expand Down
10 changes: 6 additions & 4 deletions tests/python/relax/test_codegen_cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ def reset_seed():


def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False):
if legalize:
mod = relax.transform.LegalizeOps()(mod)

dev = tvm.device(target, 0)
with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": cuda_graph}):
with tvm.transform.PassContext(
config={
"relax.backend.use_cuda_graph": cuda_graph,
"relax.transform.apply_legalize_ops": legalize,
}
):
ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, dev)
f = vm["main"]
Expand Down
10 changes: 6 additions & 4 deletions tests/python/relax/test_codegen_cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,13 @@ def get_result_with_relax_cudnn_offload(mod, np_inputs, cuda_graph=False):


def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False):
if legalize:
mod = relax.transform.LegalizeOps()(mod)

dev = tvm.device(target, 0)
with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": cuda_graph}):
with tvm.transform.PassContext(
config={
"relax.backend.use_cuda_graph": cuda_graph,
"relax.transform.apply_legalize_ops": legalize,
}
):
ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, dev)
f = vm["main"]
Expand Down
10 changes: 6 additions & 4 deletions tests/python/relax/test_codegen_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,12 @@ def main(


def build_and_run(mod, inputs_np, target, legalize=True, cuda_graph=False):
if legalize:
mod = relax.transform.LegalizeOps()(mod) # For cpu reference, nop for cutlass.

with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": cuda_graph}):
with tvm.transform.PassContext(
config={
"relax.backend.use_cuda_graph": cuda_graph,
"relax.transform.apply_legalize_ops": legalize,
}
):
ex = relax.build(mod, target)

dev = tvm.device(target, 0)
Expand Down
6 changes: 2 additions & 4 deletions tests/python/relax/test_codegen_dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,12 @@ def main(


def build_and_run(mod, inputs, legalize=False):
if legalize:
mod = relax.transform.LegalizeOps()(mod)

target = tvm.target.Target("llvm")
dev = tvm.cpu()
inputs = [tvm.nd.array(inp, dev) for inp in inputs]

ex = relax.build(mod, target)
with tvm.transform.PassContext(config={"relax.transform.apply_legalize_ops": legalize}):
ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, dev)
f = vm["main"]
return f(*inputs).numpy()
Expand Down
6 changes: 2 additions & 4 deletions tests/python/relax/test_codegen_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,9 @@ def main(


def build_and_run(mod, inputs_np, target, legalize=False):
if legalize:
mod = relax.transform.LegalizeOps()(mod)

dev = tvm.device(target, 0)
ex = relax.build(mod, target)
with tvm.transform.PassContext(config={"relax.transform.apply_legalize_ops": legalize}):
ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, dev)
f = vm["main"]
inputs = [tvm.nd.array(inp, dev) for inp in inputs_np]
Expand Down
1 change: 0 additions & 1 deletion tests/python/relax/test_codegen_tir_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def build(mod):


def build_and_run_reference(mod, inputs_np):
mod = relax.transform.LegalizeOps()(mod)
dev = tvm.device("llvm", 0)
ex = relax.build(mod, "llvm")
vm = relax.VirtualMachine(ex, dev)
Expand Down
2 changes: 0 additions & 2 deletions tests/python/relax/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,7 +1181,6 @@ def expected(
# make sure it builds
mod = tvm.IRModule()
mod["main"] = rewritten
mod = rx.transform.LegalizeOps()(mod)

rx.build(mod, target="llvm")

Expand Down Expand Up @@ -1279,7 +1278,6 @@ def rewriter(matchings, _):
# make sure it builds
mod = tvm.IRModule()
mod["main"] = rewritten
mod = rx.transform.LegalizeOps()(mod)

rx.build(mod, target="llvm")

Expand Down
6 changes: 2 additions & 4 deletions tests/python/relax/test_e2e_op_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def main(x: R.Tensor((8, 9, 10, 10), "float32"), begin: R.Tensor((4,),"int64"),
gv: R.Tensor("float32", ndim=4) = R.dynamic_strided_slice(x, begin, end, strides)
return gv
# fmt: on
mod = LegalizeOps()(DynamicStridedSlice)
vm = build(mod)
vm = build(DynamicStridedSlice)

x_np = np.random.rand(8, 9, 10, 10).astype(np.float32)
data_nd = tvm.nd.array(x_np, dev)
Expand Down Expand Up @@ -83,8 +82,7 @@ def main(x: R.Tensor(("m", "n", 10, 10), "float32"), begin: R.Tensor((4,),"int64
gv: R.Tensor("float32", ndim=4) = R.dynamic_strided_slice(x, begin, end, strides)
return gv
# fmt: on
mod = LegalizeOps()(DynamicStridedSlice)
vm = build(mod)
vm = build(DynamicStridedSlice)

x_np = np.random.rand(8, 9, 10, 10).astype(np.float32)
data_nd = tvm.nd.array(x_np, dev)
Expand Down
4 changes: 0 additions & 4 deletions tests/python/relax/test_frontend_stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,6 @@ def check_correctness(
# Run the jax jitted model with the input jax numpy data
jax_output = jax_jit_mod(*inputs_jnp)

# Legalize the Relax Operators into TensorIR
# TODO (relax-team): add LegalizeOps in default seq in vm_build
ir_mod = relax.transform.LegalizeOps()(ir_mod)
# TODO (yongwww): support multiple targets,
# "llvm" should be good for this check
target = tvm.target.Target("llvm", host="llvm")
Expand Down Expand Up @@ -157,7 +154,6 @@ def get_vm_res(
out: Union[tvm.nd.NDArray, List[tvm.nd.NDArray]]
inference result
"""
ir_mod = relax.transform.LegalizeOps()(ir_mod)
target = tvm.target.Target("llvm", host="llvm")
# Compile and run
ex = relax.build(ir_mod, target)
Expand Down
6 changes: 2 additions & 4 deletions tests/python/relax/test_op_gradient_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ def _is_call_no_grad(expr):
out = forward_bb.emit_output(call)
forward_bb.emit_func_output(out)
forward_mod = forward_bb.get()
forward_lower_mod = LegalizeOps()(forward_mod)
forward_ex = relax.build(forward_lower_mod, target)
forward_ex = relax.build(forward_mod, target)
forward_vm = relax.VirtualMachine(forward_ex, dev)

# Generate weights
Expand Down Expand Up @@ -187,8 +186,7 @@ def forward(*inputs):
grad_bb.emit_func_output(out)

grad_mod = grad_bb.get()
grad_lower_mod = LegalizeOps()(grad_mod)
grad_ex = relax.build(grad_lower_mod, target)
grad_ex = relax.build(grad_mod, target)
grad_vm = relax.VirtualMachine(grad_ex, dev)

# tvm.runtime.NDArray inputs
Expand Down
4 changes: 1 addition & 3 deletions tests/python/relax/test_training_optimizer_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,13 @@
from tvm import relax
from tvm import IRModule
from tvm.relax.training.optimizer import Adam, SGD, MomentumSGD
from tvm.relax.transform import LegalizeOps
from tvm.script.parser import relax as R
from tvm.runtime.relax_vm import VirtualMachine
from tvm.testing import assert_allclose


def _legalize_and_build(mod: IRModule, target, dev):
lowered_mod = LegalizeOps()(mod)
ex = relax.build(lowered_mod, target)
ex = relax.build(mod, target)
vm = VirtualMachine(ex, dev)
return vm

Expand Down
4 changes: 1 addition & 3 deletions tests/python/relax/test_transform_gradient_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@
from tvm.testing import assert_allclose
from tvm.testing.utils import check_numerical_grads
from tvm.script.parser import ir as I, relax as R
from tvm.relax.transform import LegalizeOps


def _legalize_and_build(mod, target, dev):
lowered_mod = LegalizeOps()(mod)
ex = relax.build(lowered_mod, target)
ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, dev)
return vm

Expand Down
3 changes: 1 addition & 2 deletions tests/python/relax/test_vm_execbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,7 @@ def main(inp: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="
R.output(gv)
return gv

mod = relax.transform.LegalizeOps()(Module)
ex = relax.build(mod, "llvm")
ex = relax.build(Module, "llvm")
vm = relax.VirtualMachine(ex, tvm.cpu())

correct_input = tvm.nd.array(np.random.normal(size=(10, 10)).astype("float32"))
Expand Down