diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 9374a2491280..30711de0a760 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -847,7 +847,7 @@ def new_zeros(self, inputs, input_types): dtype = _convert_dtype_value(inputs[2]) else: # if dtype is None, use the dtype of the input tensor - dtype = self.infer_type(inputs[0]) + dtype = self.infer_type(inputs[0]).dtype return self.full_impl(data, 0, dtype) def full(self, inputs, input_types): @@ -898,7 +898,7 @@ def new_full(self, inputs, input_types): dtype = _convert_dtype_value(inputs[3]) else: # if dtype is None, use the dtype of the input tensor - dtype = self.infer_type(inputs[0]) + dtype = self.infer_type(inputs[0]).dtype return self.full_impl(data, fill_value, dtype)