Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
a941951
upd
yzh119 Jun 3, 2023
3284151
wip
yzh119 Jun 3, 2023
e753640
upd
yzh119 Jun 10, 2023
7f31636
upd
yzh119 Jun 29, 2023
0f6b858
upd
yzh119 Jun 29, 2023
edc938a
rm redundant files
yzh119 Jun 29, 2023
4ac20f1
upd
yzh119 Jun 29, 2023
a971ee0
Merge branch 'main' into native-bf16
yzh119 Jun 29, 2023
de6c5b2
upd
yzh119 Jun 29, 2023
9483524
do not change nvcc
yzh119 Jun 29, 2023
fd0ff79
fix
yzh119 Jun 29, 2023
a0525b9
remove empty line
yzh119 Jun 29, 2023
9c6d639
fix
yzh119 Jun 29, 2023
bdc3382
lint
yzh119 Jun 29, 2023
a15b600
c++ lint
yzh119 Jun 30, 2023
c33423a
use ml_dtypes for llvm codegen test
yzh119 Jun 30, 2023
6d3b6c3
add ml_dtypes to ci-constraints
yzh119 Jul 1, 2023
40f088c
alphabetical
yzh119 Jul 1, 2023
259f278
pylint
yzh119 Jul 1, 2023
aff8f61
lint
yzh119 Jul 1, 2023
d883e33
upd
yzh119 Jul 1, 2023
487c15f
improve comments
yzh119 Jul 1, 2023
73d6361
improved code comment
yzh119 Jul 1, 2023
0d7d8ba
upd
yzh119 Jul 1, 2023
4f4e8a6
bugfix
yzh119 Jul 1, 2023
2162eba
bugfix
yzh119 Jul 2, 2023
31355fa
lint
yzh119 Jul 2, 2023
daaae71
refactor buildprocess
yzh119 Jul 3, 2023
af72e1e
remove unused functions
yzh119 Jul 3, 2023
98cb6e4
pylint
yzh119 Jul 3, 2023
aae8208
import error
yzh119 Jul 3, 2023
4fc069e
add ml_dtypes to build-environment.yaml
yzh119 Jul 3, 2023
3eae70d
update docker scripts
yzh119 Jul 3, 2023
f1b24f9
bugfix
yzh119 Jul 4, 2023
bd8dd9d
import ml_dtypes for all
yzh119 Jul 4, 2023
6e1c7cb
fix bug
yzh119 Jul 4, 2023
c3415ce
fix
yzh119 Jul 8, 2023
b79c9d7
resolve issues
yzh119 Jul 16, 2023
e079c28
fix tests
yzh119 Jul 17, 2023
81ee32e
revert changes in nvcc.py
yzh119 Jul 17, 2023
524191f
fix lint
yzh119 Jul 17, 2023
abbe2ef
fix copyfrom semantics
yzh119 Jul 17, 2023
1f7c54a
use numpy's impl for assert equal
yzh119 Jul 18, 2023
32298ea
test on tlcpack-staging docker images
yzh119 Jul 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions ci/jenkins/docker-images.ini
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ ci_arm: tlcpack/ci-arm:20230615-060132-62a5e7acf
ci_cortexm: tlcpack/ci-cortexm:20230613-060122-21361a63a
ci_cpu: tlcpack/ci-cpu:20230604-060130-0af9ff90e
ci_gpu: tlcpack/ci-gpu:20230504-142417-4d37a0a0
ci_hexagon: tlcpack/ci-hexagon:20230504-142417-4d37a0a0
ci_i386: tlcpack/ci-i386:20230504-142417-4d37a0a0
ci_hexagon: tlcpackstaging/ci_hexagon:20230724-060135-684689e92
ci_i386: tlcpackstaging/ci_i386:20230724-060135-684689e92
ci_lint: tlcpack/ci-lint:20230504-142417-4d37a0a0
ci_minimal: tlcpack/ci-minimal:20230504-142417-4d37a0a0
ci_riscv: tlcpack/ci-riscv:20230504-142417-4d37a0a0
ci_wasm: tlcpack/ci-wasm:20230504-142417-4d37a0a0
ci_minimal: tlcpackstaging/ci_minimal:20230724-060135-684689e92
ci_riscv: tlcpackstaging/ci_riscv:20230724-060135-684689e92
ci_wasm: tlcpackstaging/ci_wasm:20230724-060135-684689e92
3 changes: 3 additions & 0 deletions conda/build-environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,6 @@ dependencies:
- make
- scipy
- pillow
- pip
- pip:
- ml_dtypes
1 change: 1 addition & 0 deletions docker/python/ci-constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ flowvision = "==0.1.0"
#h5py = "==3.1.0"
keras = "==2.7"
jinja2 = "==3.0.3"
ml_dtypes = "==0.1.0"
mxnet = "==1.6.0"
mypy = "==0.902"
oneflow = "==0.7.0"
Expand Down
6 changes: 5 additions & 1 deletion include/tvm/driver/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,14 @@ using tvm::transform::Pass;
/*!
* \brief Configures and returns the composite Pass for the fused module (pre split) that contains
* device and host code.
*
* \param mixed_mod The original mixed module.
* \param target The device Target.
* \param apply_lower_passes Whether to apply lowering passes.
* \return The composite Pass for the fused module.
// */
TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target);
TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target,
bool apply_lower_passes);

/*!
* \brief Configures and returns the composite Pass for the device Target after device/host from
Expand Down Expand Up @@ -140,6 +143,7 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args,
IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
GlobalVarSupply global_var_supply);

/*!
* \brief Build a device and host module for a specific target from an IRModule.
* \param funcs The functions to be built.
Expand Down
4 changes: 3 additions & 1 deletion include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class DataType {
bool is_scalar() const { return lanes() == 1; }
/*! \return whether type is a scalar type. */
bool is_bool() const { return code() == DataType::kUInt && bits() == 1; }
/*! \return whether type is a float type. */
/*! \return whether type is a IEEE 754 standard float type. */
bool is_float() const { return code() == DataType::kFloat; }
/*! \return whether type is a float8 type. */
bool is_float8() const {
Expand All @@ -107,6 +107,8 @@ class DataType {
bool is_float16() const { return is_float() && bits() == 16; }
/*! \return whether type is a bfloat16 type. */
bool is_bfloat16() const { return code() == DataType::kBFloat && bits() == 16; }
/*! \return whether type is a general floating point data type. */
bool is_floating_point() const { return is_float() || is_float8() || is_bfloat16(); }
/*! \return whether type is an int type. */
bool is_int() const { return code() == DataType::kInt; }
/*! \return whether type is an uint type. */
Expand Down
3 changes: 1 addition & 2 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -939,8 +939,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span())
return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high), span);
}
}
if (t.is_float() || t.is_bfloat16() || t.is_float8())
return FloatImm(t, static_cast<double>(value), span);
if (t.is_floating_point()) return FloatImm(t, static_cast<double>(value), span);
// For now, we store const scalar values of custom datatypes within doubles; later, during the
// datatypes lowering pass, we will lower the value to its true representation in the format
// specified by the datatype.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import sys
import os
import traceback
import ml_dtypes

# top-level alias
# tvm._ffi
Expand Down
13 changes: 4 additions & 9 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@

import numpy as np

try:
import ml_dtypes
except ImportError:
ml_dtypes = None
from .base import _LIB, check_call

tvm_shape_index_t = ctypes.c_int64
Expand Down Expand Up @@ -96,6 +92,9 @@ class DataType(ctypes.Structure):
np.dtype(np.float32): "float32",
np.dtype(np.float64): "float64",
np.dtype(np.float_): "float64",
np.dtype("bfloat16"): "bfloat16",
np.dtype("float8_e4m3fn"): "e4m3_float8",
np.dtype("float8_e5m2"): "e5m2_float8",
}
STR2DTYPE = {
"void": {"type_code": DataTypeCode.HANDLE, "bits": 0, "lanes": 0},
Expand All @@ -110,6 +109,7 @@ class DataType(ctypes.Structure):
"uint64": {"type_code": DataTypeCode.UINT, "bits": 64, "lanes": 1},
"e4m3_float8": {"type_code": DataTypeCode.E4M3Float, "bits": 8, "lanes": 1},
"e5m2_float8": {"type_code": DataTypeCode.E5M2Float, "bits": 8, "lanes": 1},
"bfloat16": {"type_code": DataTypeCode.BFLOAT, "bits": 16, "lanes": 1},
"float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1},
"float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1},
"float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1},
Expand Down Expand Up @@ -203,11 +203,6 @@ def __ne__(self, other):
return not self.__eq__(other)


if ml_dtypes is not None:
DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e4m3fn)] = "e4m3_float8"
DataType.NUMPY2STR[np.dtype(ml_dtypes.float8_e5m2)] = "e5m2_float8"

RPC_SESS_MASK = 128


Expand Down
65 changes: 53 additions & 12 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def schedule_to_module(
"""According to the given schedule, form a function.

This is a low-level function intended for testing purposes, and
does not apply any optimization passes. In general, `tvm.lower`
and `tvm.build` should be used instead.
does not apply any optimization passes. In general, `tvm.build`
should be used instead.

Parameters
----------
Expand All @@ -91,6 +91,47 @@ def schedule_to_module(
return ffi.schedule_to_module(sch, args, name, binds)


def as_ir_module(
inp: Union[te.Schedule, PrimFunc, IRModule],
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
name: str = "main",
binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
) -> IRModule:
"""Convert input to IRModule.

Parameters
----------
inp : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule]
The input TE schedule or TensorIR PrimFunc/IRModule.

args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]]
The argument lists to the function for TE schedule.
It should be None if :attr:`inp` is a TensorIR PrimFunc/IRModule.

name : str
The name of the result function.

binds : Optional[Mapping[tensor.Tensor, tvm.tir.Buffer]]
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.

Returns
-------
m : IRModule
The result IRModule.
"""
if isinstance(inp, IRModule):
return inp
if isinstance(inp, PrimFunc):
return IRModule({name: inp.with_attr("global_symbol", name)})
if isinstance(inp, te.Schedule):
return schedule_to_module(inp, args, name, binds)
raise ValueError(
f"Expected input to be an IRModule, PrimFunc or te.Schedule, but got {type(inp)}"
)


def lower(
inp: Union[te.Schedule, PrimFunc, IRModule],
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
Expand All @@ -100,6 +141,9 @@ def lower(
) -> IRModule:
"""Lowering step before build into target.

(Warning) This function is obsolete and maintained for backward compatibility with
legacy TE Schedule, please use :func:`build` directly.

Parameters
----------
inp : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule]
Expand Down Expand Up @@ -199,8 +243,7 @@ def build(
B = te.placeholder((n,), name='B')
C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.te.create_schedule(C.op)
m = tvm.lower(s, [A, B, C], name="test_add")
rt_mod = tvm.build(m, target="llvm")
rt_mod = tvm.build(s, target="llvm")

2. it is a dict of compilation target to IRModule.

Expand All @@ -213,9 +256,7 @@ def build(
s1 = tvm.te.create_schedule(C.op)
with tvm.target.cuda() as cuda_tgt:
s2 = topi.cuda.schedule_injective(cuda_tgt, [C])
m1 = tvm.lower(s1, [A, B, C], name="test_add1")
m2 = tvm.lower(s2, [A, B, C], name="test_add2")
rt_mod = tvm.build({"llvm": m1, "cuda": m2})
rt_mod = tvm.build({"llvm": s1, "cuda": s2})

Note
----
Expand All @@ -224,16 +265,16 @@ def build(
if isinstance(inputs, te.Schedule):
if args is None:
raise ValueError("args must be given for build from schedule")
input_mod = lower(inputs, args, name=name, binds=binds)
input_mod = as_ir_module(inputs, args, name=name, binds=binds)
elif isinstance(inputs, (list, tuple, container.Array)):
merged_mod = tvm.IRModule({})
for x in inputs:
merged_mod.update(lower(x))
merged_mod.update(as_ir_module(x))
input_mod = merged_mod
elif isinstance(inputs, PrimFunc):
input_mod = lower(inputs, name=name)
input_mod = as_ir_module(inputs, name=name)
elif isinstance(inputs, tvm.IRModule):
input_mod = lower(inputs)
input_mod = as_ir_module(inputs)
elif not isinstance(inputs, (dict, container.Map)):
raise ValueError(
f"Inputs must be te.Schedule, IRModule, PrimFunc, "
Expand Down Expand Up @@ -278,7 +319,7 @@ def build(

annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host)

rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
rt_mod_host = _driver_ffi.ir_module_to_runtime_module(annotated_mods, target_host, True)

annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host)

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def get_type(elem_type):
raise ImportError(f"Unable to import TensorProto from onnx {e}")

# Onnx mapping converts bfloat16 to float16 because
# numpy does not have a bfloat16 data type. However,
# onnx does not have a bfloat16 data type. However,
# tvm has one, so we force the return type to be bfloat16
if elem_type == int(TensorProto.BFLOAT16):
return "bfloat16"
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4375,6 +4375,8 @@ def _convert_data_type(input_type, default_dtype=None):
return "float32"
elif input_type in ["half", "float16", "torch.float16"]:
return "float16"
elif input_type in ["bfloat16", "torch.bfloat16"]:
return "bfloat16"
elif input_type in ["long", "int64", "torch.int64"]:
return "int64"
elif input_type in ["int", "int32", "torch.int32"]:
Expand Down
33 changes: 7 additions & 26 deletions python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@
import warnings
import numpy as np

try:
import ml_dtypes
except ImportError:
ml_dtypes = None
import tvm._ffi

from tvm._ffi.base import _LIB, check_call, c_array, string_types, _FFI_MODE
Expand Down Expand Up @@ -167,18 +163,19 @@ def copyfrom(self, source_array):
raise ValueError(
f"array shape do not match the shape of NDArray {source_array.shape} vs {shape}"
)

numpy_str_map = DataType.NUMPY2STR
np_dtype_str = (
numpy_str_map[source_array.dtype]
if source_array.dtype in numpy_str_map
else str(source_array.dtype)
)
if (not source_array.flags["C_CONTIGUOUS"]) or (
dtype == "bfloat16" or dtype != np_dtype_str
):
source_array = np.ascontiguousarray(
source_array, dtype="uint16" if dtype == "bfloat16" else dtype
)
if (not source_array.flags["C_CONTIGUOUS"]) or dtype != np_dtype_str:
if dtype == "e4m3_float8":
dtype = "float8_e4m3fn"
elif dtype == "e5m2_float8":
dtype = "float8_e5m2"
source_array = np.ascontiguousarray(source_array, dtype)
assert source_array.flags["C_CONTIGUOUS"]
data = source_array.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize)
Expand Down Expand Up @@ -220,22 +217,6 @@ def numpy(self):
dtype = str(t)
if dtype == "int4":
dtype = "int8"
if dtype == "bfloat16":
dtype = "uint16"
if dtype == "e4m3_float8":
if ml_dtypes is not None:
dtype = ml_dtypes.float8_e4m3fn
else:
raise RuntimeError(
"ml_dtypes is not installed, cannot convert e4m3_float8 array to numpy."
)
if dtype == "e5m2_float8":
if ml_dtypes is not None:
dtype = ml_dtypes.float8_e5m2
else:
raise RuntimeError(
"ml_dtypes is not installed, cannot convert e5m2_float8 array to numpy."
)
np_arr = np.empty(shape, dtype=dtype)
assert np_arr.flags["C_CONTIGUOUS"]
data = np_arr.ctypes.data_as(ctypes.c_void_p)
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,17 @@ def test_something():
)


def promote_bf16_to_fp32(x):
r"""Promote the data type of an array-like structure from bfloat16 to float32."""
if isinstance(x, list):
return [promote_bf16_to_fp32(y) for y in x]
else:
if isinstance(x, np.ndarray) and x.dtype == "bfloat16":
return x.astype("float32")
else:
return x


def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7):
"""Version of np.testing.assert_allclose with `atol` and `rtol` fields set
in reasonable defaults.
Expand All @@ -114,6 +125,11 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=1e-7):
compares the `abs(actual-desired)` with `atol+rtol*abs(desired)`. Since we
often allow `desired` to be close to zero, we generally want non-zero `atol`.
"""
# The ml_dtypes v0.2 is not compatible with numpy's asanyarray function, promote to
# float32 first.
actual = promote_bf16_to_fp32(actual)
desired = promote_bf16_to_fp32(desired)

actual = np.asanyarray(actual)
desired = np.asanyarray(desired)
np.testing.assert_allclose(actual.shape, desired.shape)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/arm_cpu/injective.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def schedule_injective(outs):

if list(s[x].op.axis):
# do not vectorize for broadcast
dtype = "uint16" if x.dtype == "bfloat16" else x.dtype
dtype = x.dtype
(io, ii) = s[x].split(list(s[x].op.axis)[-1], 16 // np.dtype(dtype).itemsize)
s[x].vectorize(ii)
tvm.te.schedule.AutoInlineInjective(s)
Expand Down
1 change: 0 additions & 1 deletion python/tvm/topi/nn/winograd_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def winograd_transform_matrices(tile_size, kernel_size, out_dtype):
intp_pts = _interpolation_points(degree)
A_data, B_data, G_data = _cook_toom_convolution(intp_pts, tile_size, kernel_size)

out_dtype = "uint16" if out_dtype == "bfloat16" else out_dtype
return (
const_matrix(A_data.astype(out_dtype), "A"),
const_matrix(B_data.astype(out_dtype), "B"),
Expand Down
2 changes: 1 addition & 1 deletion src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) {

// x / 2.0 = x * 0.5
if (const FloatImmNode* ptr = op->b.as<FloatImmNode>()) {
ICHECK(op->dtype.is_float() || op->dtype.is_bfloat16() ||
ICHECK(op->dtype.is_floating_point() ||
datatype::Registry::Global()->GetTypeRegistered(op->dtype.code()));
return op->a * make_const(op->b.dtype(), 1.0 / ptr->value);
}
Expand Down
Loading