From 7ba1970546bf7e974d501d936ce0e3a83ce1d0c2 Mon Sep 17 00:00:00 2001 From: Bin Li Date: Fri, 19 Jan 2024 09:08:22 -0800 Subject: [PATCH] [BugFix]Ensure that bf16 arrays are created as expected --- python/tvm/runtime/ndarray.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index d6902bc62574..300311d03b82 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -176,6 +176,8 @@ def copyfrom(self, source_array): if (not source_array.flags["C_CONTIGUOUS"]) or ( dtype == "bfloat16" or dtype != np_dtype_str ): + if dtype == "bfloat16": + source_array = np.frombuffer(source_array.tobytes(), "uint16") source_array = np.ascontiguousarray( source_array, dtype="uint16" if dtype == "bfloat16" else dtype )