diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index d6902bc62574..300311d03b82 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -176,6 +176,8 @@ def copyfrom(self, source_array): if (not source_array.flags["C_CONTIGUOUS"]) or ( dtype == "bfloat16" or dtype != np_dtype_str ): + if dtype == "bfloat16": + source_array = np.frombuffer(source_array.tobytes(), "uint16") source_array = np.ascontiguousarray( source_array, dtype="uint16" if dtype == "bfloat16" else dtype )