Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions python/tvm/relay/backend/te_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,26 @@ def get():
The TE Compiler.
"""
return _backend._TECompilerGlobal()


def lower_to_primfunc(relay_func, target):
"""Lower Relay Function to TIR PrimFunc.

Parameters
----------
relay_func: relay.Function
The source primitive function, created by FuseOps.

target : Target
The compilation target.

Returns
-------
prim_func : tir.PrimFunc
The created prim func.
"""
f = tvm._ffi.get_global_func("relay.backend.LowerToPrimFunc")
assert f is not None, "relay.backend.LowerToPrimFunc does not exist. "

with target:
return f(relay_func, target)
6 changes: 2 additions & 4 deletions src/relay/backend/task_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ Array<meta_schedule::ExtractedTask> ExtractTask(IRModule mod, Target target,
using meta_schedule::ExtractedTask;
using meta_schedule::ModuleEqual;
using meta_schedule::ModuleHash;
backend::FTECompilerTIRConverter tir_converter = backend::GetTIRConverter();
backend::BindParamsInModule(mod, params);
// is_vm=true for backward compatibility
Array<Pass> pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true);
Expand All @@ -84,10 +83,9 @@ Array<meta_schedule::ExtractedTask> ExtractTask(IRModule mod, Target target,
if (!relay_func->HasNonzeroAttr(attr::kPrimitive)) {
return;
}
auto [inputs_outputs, constants, fused_name] =
tec::LowerTECompute(relay_func, target, constant_name_supply, /*return_inputs=*/true);

if (Optional<tir::PrimFunc> f = tir_converter(inputs_outputs, constants)) {
auto [f, fused_name] = tec::LowerToPrimFunc(relay_func, target, constant_name_supply);
if (f) {
IRModule tir_mod = PrimFuncToIRModule(f.value());
lower_results.push_back(std::make_tuple(fused_name, relay_func, tir_mod));
}
Expand Down
27 changes: 27 additions & 0 deletions src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,33 @@ std::tuple<Array<te::Tensor>, Array<runtime::NDArray>, std::string> LowerTECompu
return std::make_tuple(tensor_outs, constants, lower_te_compute.candidate_name_);
}

std::pair<Optional<tir::PrimFunc>, std::string> LowerToPrimFunc(const Function& relay_func,
Target target,
NameSupply constant_name_supply) {
ICHECK(relay_func->HasNonzeroAttr(attr::kPrimitive))
<< "The input must be a Relay primitive function.";

auto [inputs_outputs, constants, fused_name] =
tec::LowerTECompute(relay_func, target, constant_name_supply, /*return_inputs=*/true);
auto tir_converter = backend::GetTIRConverter();
return std::make_pair(tir_converter(inputs_outputs, constants), fused_name);
}

tir::PrimFunc LowerToPrimFunc(const Function& relay_func, Target target) {
auto [f_opt, _] = LowerToPrimFunc(relay_func, target, NameSupply(""));
(void)_; // to suppress -Werror=unused-variable warning
if (f_opt) {
return f_opt.value();
}
LOG(FATAL) << "Failed to convert the Relay function: " << AsText(relay_func, false);
return PrimFunc();
}

TVM_REGISTER_GLOBAL("relay.backend.LowerToPrimFunc")
.set_body_typed([](Function relay_func, Target target) {
return LowerToPrimFunc(relay_func, target);
});

TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) {
auto tgt = tvm::Target("ext_dev");
LowerToTECompute lower_te_compute(tgt, NameSupply(""));
Expand Down
20 changes: 16 additions & 4 deletions src/relay/backend/te_compiler_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,10 @@ class CCacheValue : public ObjectRef {
Array<IndexExpr> GetShape(const Array<IndexExpr>& shape);

/*!
* \brief Lowers Relay primitive Function to TE Compute
* \brief Lower Relay primitive Function to TE Compute
* \param source_func The primitive function to be lowered.
* \param target The target we want to create schedule for.
* \param constant_name_supply A name supplier for constants.
* \param target The compilation target.
* \param constant_name_supply A name supplier for constants
* across different invocations of this function.
* \param return_inputs If true, prepend input tensors to the output array of tensors.
* \return Tuple of the lowered TE compute, constant raw data, and fused function name.
Expand All @@ -224,10 +224,22 @@ std::tuple<Array<te::Tensor>, Array<runtime::NDArray>, std::string> LowerTECompu
const Function& source_func, Target target, NameSupply constant_name_supply,
bool return_inputs = true);

/*!
* \brief Lower Relay Function to TIR PrimFunc, by composing LowerTECompute and CreatePrimFunc.
* \param relay_func The primitive function to be lowered.
* \param target The compilation target.
* \param constant_name_supply A name supplier for constants
* across different invocations of this function.
* \return A pair of the created prim func and the name of the fused function.
*/
std::pair<Optional<tir::PrimFunc>, std::string> LowerToPrimFunc(const Function& relay_func,
Target target,
NameSupply constant_name_supply);

/*!
* \brief Create schedule for target.
* \param source_func The primitive function to be lowered.
* \param target The target we want to create schedule for.
* \param target The compilation target.
* \param global_var_supply A name supplier for global variables.
* \param constant_name_supply A name supplier for constants.
* \return Pair of schedule and cache.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np

import tvm
import tvm.testing
from tvm import te
from tvm.script import tir as T
from tvm import relay, tir
from tvm.relay.backend.te_compiler import lower_to_primfunc
from tvm.tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN


def _check(original, transformed):
Expand Down Expand Up @@ -360,5 +365,56 @@ def after(A: T.Buffer[(4, 16), "int32"], C: T.Buffer[(4, 8), "int32"]):
_check(before, after)


def test_allocate_const_after_tensorize():
i_size, o_size, h_size, w_size = 64, 64, 56, 56
k_height_size = k_width_size = 3
w_shape = (o_size, i_size, k_height_size, k_width_size)

data = relay.var("data", shape=(1, i_size, h_size, w_size), dtype="uint8")
weight = relay.var("weight", shape=w_shape, dtype="uint8")
conv2d = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=(k_height_size, k_width_size),
channels=o_size,
padding=(0, 0),
strides=(1, 1),
out_dtype="int32",
)
mod = tvm.IRModule.from_expr(conv2d)

executor = relay.backend.Executor("graph", {"link-params": True})
mod = mod.with_attr("executor", executor)

weight_np = np.random.uniform(1, 10, size=w_shape).astype("uint8")

target = tvm.target.Target("hexagon")

with tvm.transform.PassContext(opt_level=3):
opt_mod, _ = relay.optimize(mod, params={"weight": weight_np}, target=target)

conv2d_func = opt_mod["main"].body.args[0].op
prim_func = lower_to_primfunc(conv2d_func, target)

sch = tir.Schedule(prim_func)
block = sch.get_block("conv2d_NCHWc_int8")
loops = sch.get_loops(block)

sch.reorder(loops[8], loops[4], loops[-1])
sch.decompose_reduction(block, loops[1])
sch.tensorize(loops[4], VRMPY_u8u8i32_INTRIN)

seq = tvm.transform.Sequential(
[
tvm.tir.transform.LowerInitBlock(),
tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(),
]
)

# The following error is emitted if AllocateConst nodes are not correctly handled:
# Check failed: (buffer_data_to_buffer_.count(source_var)) is false:
_ = seq(sch.mod)


if __name__ == "__main__":
tvm.testing.main()