Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ class Target : public ObjectRef {
*/
static Target WithHost(const Target& target, const Target& host);

/*! \return The target with the host stripped out */
Target WithoutHost() const;

/*!
* \brief Returns true if \p this target represents an external codegen. If so,
* \p this->kind->name can be used as the "Compiler" attribute on partitioned functions,
Expand Down
38 changes: 38 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,51 @@ TVM_DLL Pass LowerCustomDatatypes();
*/
TVM_DLL Pass DecorateDeviceScope();

/*!
* \brief Annotate locations that should be run on the device
*
* Insert `AttrStmt` nodes specifying a target on which regions within
* the PrimFunc should be executed. Only modifies functions that have
* a `tvm::attr::kTarget` attribute, and where that target defines a
* host.
*
* \return The pass.
*/
TVM_DLL Pass AnnotateDeviceRegions();

/*!
* \brief Split the function into a host function and device functions.
*
* The resulting host-side function will keep the same
* `tvm::attr::kTarget` attribute (e.g. `T.target("cuda",
* host=T.target("llvm"))`). This ensures that `MakePackedAPI` knows
* which device type should be used for the input buffers.
*
* The resulting device-side function will
* have the host stripped from its target attribute
* (e.g. `T.target("cuda")`).
*
* \return The pass.
*/
TVM_DLL Pass SplitHostDevice();

/*!
* \brief Lower cross-device function calls.
*
* Prior to this pass, host to device calls are represented as
* subroutine calls, with environment parameters (e.g. env_thread)
* specified internally. The device function is an internal function,
* without a `tvm::attr::kGlobalSymbol` attribute.
*
* After this pass, host to device calls are represented as
* tvm_call_packed built-in. The device function is an
* externally-exposed function, with a non-empty
* `tvm::attr::kGlobalSymbol` attribute.
*
* \return The pass.
*/
TVM_DLL Pass LowerDeviceKernelLaunch();

/*!
* \brief skip assert stmt.
*
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def call_tir(global_var: tvm.ir.GlobalVar, *args):
The call expression.
"""
assert isinstance(global_var, tvm.ir.GlobalVar)
return Call(dtype="handle", op=global_var, args=args)
return Call(dtype="void", op=global_var, args=args)


def start_profile_intrinsic(id):
Expand Down
38 changes: 38 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,22 @@ def MakeUnpackedAPI():
return _ffi_api.MakeUnpackedAPI() # type: ignore


def AnnotateDeviceRegions():
"""Annotate locations that should be run on the device

Insert `AttrStmt` nodes specifying a target on which regions
within the PrimFunc should be executed. Only modifies functions
that have a `tvm::attr::kTarget` attribute, and where that target
defines a host.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.AnnotateDeviceRegions() # type: ignore


def SplitHostDevice():
"""Split the function into a host function and device functions.

Expand All @@ -446,6 +462,28 @@ def SplitHostDevice():
return _ffi_api.SplitHostDevice() # type: ignore


def LowerDeviceKernelLaunch():
"""Lower cross-device function calls.

Prior to this pass, host to device calls are represented as
subroutine calls, with environment parameters (e.g. env_thread)
specified internally. The device function is an internal
function, without a `tvm::attr::kGlobalSymbol` attribute.

After this pass, host to device calls are represented as
tvm_call_packed built-in. The device function is an
externally-exposed function, with a non-empty
`tvm::attr::kGlobalSymbol` attribute.


Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerDeviceKernelLaunch() # type: ignore


def DecorateDeviceScope():
"""Decorate all the function's body as device function.

Expand Down
3 changes: 3 additions & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,10 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::MakePackedAPI());
}
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());

mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
mixed_pass_list.push_back(tir::transform::SplitHostDevice());
mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());

return transform::Sequential(mixed_pass_list);
}
Expand Down
10 changes: 10 additions & 0 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,16 @@ Map<String, ObjectRef> TargetNode::Export() const {

Optional<Target> TargetNode::GetHost() const { return this->host.as<Target>(); }

Target Target::WithoutHost() const {
if ((*this)->GetHost()) {
auto output = make_object<TargetNode>(*get());
output->host = NullOpt;
return Target(output);
} else {
return *this;
}
}

int TargetNode::GetTargetDeviceType() const {
if (Optional<Integer> device_type = GetAttr<Integer>("target_device_type")) {
return Downcast<Integer>(device_type)->value;
Expand Down
81 changes: 81 additions & 0 deletions src/tir/transforms/annotate_device_regions.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file annotate_device_regions.cc
* \brief Split device function from host.
*/
#include <tvm/ir/transform.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

namespace tvm {
namespace tir {

class DeviceRegionAnnotater : public StmtMutator {
public:
explicit DeviceRegionAnnotater(Target device_target) : device_target_(device_target) {}

Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tvm::attr::kTarget) {
// If a target attribute already exists, use it as-is.
return GetRef<Stmt>(op);
} else if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope ||
op->attr_key == attr::device_scope) {
// These attributes are only allowed in device-side code, so
// they should be annotated with the function's default target.
Stmt body = GetRef<Stmt>(op);
return AttrStmt(device_target_, tvm::attr::kTarget, 0, body);
} else {
// All other annotations are ignored
return StmtMutator::VisitStmt_(op);
}
}

private:
Target device_target_;
};

namespace transform {

Pass AnnotateDeviceRegions() {
auto pass_func = [](PrimFunc func, IRModule mod, PassContext ctx) -> PrimFunc {
auto opt_target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(opt_target) << "AnnotateDeviceRegions: Require the target attribute";
Target target = opt_target.value();

if (target->GetHost()) {
DeviceRegionAnnotater mutator(target.WithoutHost());
func.CopyOnWrite()->body = mutator(func->body);
}
return func;
};

return CreatePrimFuncPass(pass_func, 0, "tir.AnnotateDeviceRegions", {});
}

TVM_REGISTER_GLOBAL("tir.transform.AnnotateDeviceRegions").set_body_typed(AnnotateDeviceRegions);

} // namespace transform
} // namespace tir
} // namespace tvm
Loading