diff --git a/exir/tensor.py b/exir/tensor.py index ee2633654e8..ee074cf7119 100644 --- a/exir/tensor.py +++ b/exir/tensor.py @@ -37,7 +37,11 @@ def contiguous_stride_from_shape(shape: torch.Size) -> Tuple[int]: strides.append(accum) # For sizes[i] == 0, treat it as 1 to be consistent with core Pytorch # This preserves the PT equivalent behavior for dims with 0 elements - if sz != 0: + if isinstance(sz, int): + if sz != 0: + accum *= sz + else: + # Unbacked symints may error on the != 0 check accum *= sz return tuple(reversed(strides))