Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docker/python/ci-constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ flowvision = "==0.1.0"
#h5py = "==3.1.0"
keras = "==2.7"
jinja2 = "==3.0.3"
ml_dtypes = "==0.1.0"
mxnet = "==1.6.0"
mypy = "==0.902"
oneflow = "==0.7.0"
Expand Down
6 changes: 5 additions & 1 deletion include/tvm/driver/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,14 @@ using tvm::transform::Pass;
/*!
* \brief Configures and returns the composite Pass for the fused module (pre split) that contains
* device and host code.
*
* \param mixed_mod The original mixed module.
* \param target The device Target.
* \param apply_lower_passes Whether to apply lowering passes.
* \return The composite Pass for the fused module.
// */
TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target);
TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target,
bool apply_lower_passes);

/*!
* \brief Configures and returns the composite Pass for the device Target after device/host from
Expand Down Expand Up @@ -140,6 +143,7 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args,
IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
GlobalVarSupply global_var_supply);

/*!
* \brief Build a device and host module for a specific target from an IRModule.
* \param funcs The functions to be built.
Expand Down
4 changes: 3 additions & 1 deletion include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class DataType {
bool is_scalar() const { return lanes() == 1; }
/*! \return whether type is a scalar type. */
bool is_bool() const { return code() == DataType::kUInt && bits() == 1; }
/*! \return whether type is a float type. */
/*! \return whether type is a IEEE 754 standard float type. */
bool is_float() const { return code() == DataType::kFloat; }
/*! \return whether type is a float8 type. */
bool is_float8() const {
Expand All @@ -107,6 +107,8 @@ class DataType {
bool is_float16() const { return is_float() && bits() == 16; }
/*! \return whether type is a bfloat16 type. */
bool is_bfloat16() const { return code() == DataType::kBFloat && bits() == 16; }
/*! \return whether type is a general floating point data type. */
bool is_floating_point() const { return is_float() || is_float8() || is_bfloat16(); }
/*! \return whether type is an int type. */
bool is_int() const { return code() == DataType::kInt; }
/*! \return whether type is an uint type. */
Expand Down
3 changes: 1 addition & 2 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -939,8 +939,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span())
return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high), span);
}
}
if (t.is_float() || t.is_bfloat16() || t.is_float8())
return FloatImm(t, static_cast<double>(value), span);
if (t.is_floating_point()) return FloatImm(t, static_cast<double>(value), span);
// For now, we store const scalar values of custom datatypes within doubles; later, during the
// datatypes lowering pass, we will lower the value to its true representation in the format
// specified by the datatype.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class DataType(ctypes.Structure):
"uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1},
"e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8, "lanes": 1},
"e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8, "lanes": 1},
"bfloat16": {"type_code": DataTypeCode.BFLOAT, "bits": 16, "lanes": 1},
"float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1},
"float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1},
"float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1},
Expand Down
65 changes: 53 additions & 12 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def schedule_to_module(
"""According to the given schedule, form a function.

This is a low-level function intended for testing purposes, and
does not apply any optimization passes. In general, `tvm.lower`
and `tvm.build` should be used instead.
does not apply any optimization passes. In general, `tvm.build`
should be used instead.

Parameters
----------
Expand All @@ -91,6 +91,47 @@ def schedule_to_module(
return ffi.schedule_to_module(sch, args, name, binds)


def as_ir_module(
inp: Union[te.Schedule, PrimFunc, IRModule],
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
name: str = "main",
binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
) -> IRModule:
"""Convert input to IRModule.

Parameters
----------
inp : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule]
The input TE schedule or TensorIR PrimFunc/IRModule.

args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]]
The argument lists to the function for TE schedule.
It should be None if :attr:`inp` is a TensorIR PrimFunc/IRModule.

name : str
The name of the result function.

binds : Optional[Mapping[tensor.Tensor, tvm.tir.Buffer]]
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.

Returns
-------
m : IRModule
The result IRModule.
"""
if isinstance(inp, IRModule):
return inp
if isinstance(inp, PrimFunc):
return IRModule({name: inp.with_attr("global_symbol", name)})
if isinstance(inp, te.Schedule):
return schedule_to_module(inp, args, name, binds)
raise ValueError(
f"Expected input to be an IRModule, PrimFunc or te.Schedule, but got {type(inp)}"
)


def lower(
inp: Union[te.Schedule, PrimFunc, IRModule],
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
Expand All @@ -100,6 +141,9 @@ def lower(
) -> IRModule:
"""Lowering step before build into target.

(Warning) This function is obsolete and maintained for backward compatibility with
legacy TE Schedule, please use :func:`build` directly.

Parameters
----------
inp : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule]
Expand Down Expand Up @@ -199,8 +243,7 @@ def build(
B = te.placeholder((n,), name='B')
C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.te.create_schedule(C.op)
m = tvm.lower(s, [A, B, C], name="test_add")
rt_mod = tvm.build(m, target="llvm")
rt_mod = tvm.build(s, target="llvm")

2. it is a dict of compilation target to IRModule.

Expand All @@ -213,9 +256,7 @@ def build(
s1 = tvm.te.create_schedule(C.op)
with tvm.target.cuda() as cuda_tgt:
s2 = topi.cuda.schedule_injective(cuda_tgt, [C])
m1 = tvm.lower(s1, [A, B, C], name="test_add1")
m2 = tvm.lower(s2, [A, B, C], name="test_add2")
rt_mod = tvm.build({"llvm": m1, "cuda": m2})
rt_mod = tvm.build({"llvm": s1, "cuda": s2})

Note
----
Expand All @@ -224,16 +265,16 @@ def build(
if isinstance(inputs, te.Schedule):
if args is None:
raise ValueError("args must be given for build from schedule")
input_mod = lower(inputs, args, name=name, binds=binds)
input_mod = as_ir_module(inputs, args, name=name, binds=binds)
elif isinstance(inputs, (list, tuple, container.Array)):
merged_mod = tvm.IRModule({})
for x in inputs:
merged_mod.update(lower(x))
merged_mod.update(as_ir_module(x))
input_mod = merged_mod
elif isinstance(inputs, PrimFunc):
input_mod = lower(inputs, name=name)
input_mod = as_ir_module(inputs, name=name)
elif isinstance(inputs, tvm.IRModule):
input_mod = lower(inputs)
input_mod = as_ir_module(inputs)
elif not isinstance(inputs, (dict, container.Map)):
raise ValueError(
f"Inputs must be te.Schedule, IRModule, PrimFunc, "
Expand Down Expand Up @@ -278,7 +319,7 @@ def build(

annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host)

rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
rt_mod_host = _driver_ffi.ir_module_to_runtime_module(annotated_mods, target_host, True)

annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host)

Expand Down
25 changes: 14 additions & 11 deletions python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@
import ctypes
import warnings
import numpy as np
import ml_dtypes

try:
import ml_dtypes
except ImportError:
ml_dtypes = None
import tvm._ffi

from tvm._ffi.base import _LIB, check_call, c_array, string_types, _FFI_MODE
Expand Down Expand Up @@ -167,18 +164,19 @@ def copyfrom(self, source_array):
raise ValueError(
f"array shape do not match the shape of NDArray {source_array.shape} vs {shape}"
)

numpy_str_map = DataType.NUMPY2STR
np_dtype_str = (
numpy_str_map[source_array.dtype]
if source_array.dtype in numpy_str_map
else str(source_array.dtype)
)
if (not source_array.flags["C_CONTIGUOUS"]) or (
dtype == "bfloat16" or dtype != np_dtype_str
):
source_array = np.ascontiguousarray(
source_array, dtype="uint16" if dtype == "bfloat16" else dtype
)
if (not source_array.flags["C_CONTIGUOUS"]) or dtype != np_dtype_str:
if dtype == "e4m3_float8":
dtype = "float8_e4m3fn"
elif dtype == "e5m2_float8":
dtype = "float8_e5m2"
source_array = np.ascontiguousarray(source_array, dtype)
assert source_array.flags["C_CONTIGUOUS"]
data = source_array.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize)
Expand Down Expand Up @@ -221,7 +219,12 @@ def numpy(self):
if dtype == "int4":
dtype = "int8"
if dtype == "bfloat16":
dtype = "uint16"
if ml_dtypes is not None:
dtype = ml_dtypes.bfloat16
else:
raise RuntimeError(
"ml_dtypes is not installed, cannot convert bfloat16 array to numpy."
)
if dtype == "e4m3_float8":
if ml_dtypes is not None:
dtype = ml_dtypes.float8_e4m3fn
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,17 @@ def test_something():
)


def promote_bf16_to_fp32(x):
r"""Promote the data type of an array-like structure from bfloat16 to float32."""
if isinstance(x, list):
return [promote_bf16_to_fp32(y) for y in x]
else:
if isinstance(x, np.ndarray) and x.dtype == "bfloat16":
return x.astype("float32")
else:
return x


def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7):
"""Version of np.testing.assert_allclose with `atol` and `rtol` fields set
in reasonable defaults.
Expand All @@ -114,6 +125,11 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7):
compares the `abs(actual-desired)` with `atol+rtol*abs(desired)`. Since we
often allow `desired` to be close to zero, we generally want non-zero `atol`.
"""
# The ml_dtypes v0.2 is not compatible with numpy's asanyarray function, promote to
# float32 first.
actual = promote_bf16_to_fp32(actual)
desired = promote_bf16_to_fp32(desired)

actual = np.asanyarray(actual)
desired = np.asanyarray(desired)
np.testing.assert_allclose(actual.shape, desired.shape)
Expand Down
40 changes: 28 additions & 12 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
bool disable_cse_tir = pass_ctx->GetConfig<Bool>("tir.disable_cse_tir", Bool(false)).value();
bool enable_equiv_terms_in_cse_tir =
pass_ctx->GetConfig<Bool>("tir.enable_equiv_terms_in_cse_tir", Bool(false)).value();

bool ptx_ldg32 = pass_ctx->GetConfig<Bool>("tir.ptx_ldg32", Bool(false)).value();

// Get any user-added passes
Expand Down Expand Up @@ -405,13 +404,14 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule")
* device and host. Then it also applies transformations on the new splitted modules.
*/
std::pair<IRModule, IRModule> SplitMixedModule(IRModule mod_mixed, const Target& target_arg,
const Target& target_host_arg) {
const Target& target_host_arg,
bool apply_lower_passes) {
Target target = target_arg, target_host = target_host_arg;
CheckAndUpdateHostConsistency(&target, &target_host);

ICHECK(mod_mixed.defined()) << "This module must be defined";

mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target));
mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target, apply_lower_passes));

IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, target_host));

Expand All @@ -430,8 +430,8 @@ std::pair<IRModule, IRModule> SplitMixedModule(IRModule mod_mixed, const Target&
return {host_mod, device_mod};
}

runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
const Target& target_host_arg) {
runtime::Module IRModuleToRuntimeModule(const Map<Target, IRModule>& inputs_arg,
const Target& target_host_arg, bool apply_lower_passes) {
std::vector<runtime::Module> device_modules;
Map<Target, IRModule> inputs = inputs_arg;
Target target_host = target_host_arg;
Expand Down Expand Up @@ -467,7 +467,7 @@ runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
if (it.second.defined()) {
const Target& target = it.first;
const IRModule& ir_module = it.second;
auto pair = SplitMixedModule(ir_module, target, target_host);
auto pair = SplitMixedModule(ir_module, target, target_host, apply_lower_passes);
auto& host_mod = pair.first;
auto& device_mod = pair.second;

Expand Down Expand Up @@ -504,6 +504,17 @@ runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
return mhost;
}

TVM_REGISTER_GLOBAL("driver.ir_module_to_runtime_module")
.set_body_typed([](const Map<Target, IRModule>& inputs_arg, Target host_target,
bool apply_lower_passes) {
return IRModuleToRuntimeModule(inputs_arg, host_target, apply_lower_passes);
});

runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
const Target& target_host_arg) {
return IRModuleToRuntimeModule(inputs_arg, target_host_arg, false);
}

TVM_REGISTER_GLOBAL("driver.tir_to_runtime")
.set_body_typed([](const Map<Target, IRModule>& inputs_arg, Target host_target) {
return TIRToRuntime(inputs_arg, host_target);
Expand Down Expand Up @@ -542,18 +553,24 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg,
return TIRToRuntime(inputs, target_host);
}

transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) {
transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target,
bool apply_te_passes = false) {
transform::PassContext pass_ctx = transform::PassContext::Current();

Array<Pass> mixed_pass_list;

mixed_pass_list.push_back(tir::transform::BindTarget(target));
if (apply_te_passes) {
for (auto&& pass : CreatePassList(false)) {
mixed_pass_list.push_back(pass);
}
}

// VerifyVTCMLimit must occur before LowerVtcmAlloc
mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target));
// LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc());

mixed_pass_list.push_back(tir::transform::BindTarget(target));

mixed_pass_list.push_back(tir::transform::VerifyMemory());

mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc());
Expand Down Expand Up @@ -596,15 +613,14 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
}
mixed_pass_list.push_back(tir::transform::FP8StorageLegalize());
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());

mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());

return transform::Sequential(mixed_pass_list);
}

TVM_REGISTER_GLOBAL("driver.mixed_mod_passes")
.set_body_typed([](IRModule mixed_mod, Target target) {
return MixedModulePassManager(mixed_mod, target);
.set_body_typed([](IRModule mixed_mod, Target target, bool apply_lower_passes) {
return MixedModulePassManager(mixed_mod, target, apply_lower_passes);
});

transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,8 @@ std::vector<State> MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore(
const DataType& dtype = cache_read_buffer->dtype;
if (dtype.is_float16()) {
sch->StorageAlign(cache_read, 0, -2, 32, 8);
} else if (dtype.is_bfloat16()) {
sch->StorageAlign(cache_read, 0, -2, 32, 8);
} else if (dtype.is_int() && dtype.bits() == 8) {
sch->StorageAlign(cache_read, 0, -2, 32, 16);
} else {
Expand Down
Loading