diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 83a045ef54bd..d57b5caf9561 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1693,20 +1693,42 @@ def _index_tensor(self, node: fx.Node) -> relax.Var: axis, index_tensor = non_none_indices[0] return self.block_builder.emit(relax.op.take(data, index_tensor, axis=axis)) - # General case: multiple non-None indices require advanced indexing + # Check if all indices can be squeezed to 1D for sequential take + def is_squeezable(idx): + if idx.struct_info.ndim == 1: + return True + if idx.struct_info.ndim == 2: + shape = idx.struct_info.shape + for d in shape: + if isinstance(d, int) and d == 1: + return True + # Check for tir.IntImm + if hasattr(d, "value") and d.value == 1: + return True + return False + + all_squeezable = all(is_squeezable(idx) for _, idx in non_none_indices) + if all_squeezable: + result = data + for axis, idx in reversed(non_none_indices): + if idx.struct_info.ndim > 1: + idx = self.block_builder.emit(relax.op.squeeze(idx)) + result = self.block_builder.emit(relax.op.take(result, idx, axis=axis)) + return result + + # General case: replace None with arange, reshaped for broadcasting + max_ndim = max((idx.struct_info.ndim for _, idx in non_none_indices), default=1) processed_indices = [] data_shape = self.shape_of(data) for i, idx in enumerate(indices): if idx is None: - dim_size = data_shape[i] arange_idx = self.block_builder.emit( - relax.op.arange( - start=relax.PrimValue(0), - end=dim_size, - step=relax.PrimValue(1), - dtype="int64", - ) + relax.op.arange(relax.PrimValue(0), data_shape[i], relax.PrimValue(1), "int64") + ) + # Reshape to [dim_size, 1, 1, ...] for broadcasting + arange_idx = self.block_builder.emit( + relax.op.reshape(arange_idx, [data_shape[i]] + [1] * (max_ndim - 1)) ) processed_indices.append(arange_idx) else: diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 1b816432ce1f..92c5ab026750 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4073,29 +4073,437 @@ class expected_bicubic: def main( input: R.Tensor((1, 3, 112, 112), dtype="float32") ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")): - # block 0 with R.dataflow(): - lv: R.Tensor((1, 3, 224, 224), dtype="float32") = R.image.resize2d( - input, - R.shape([224, 224]), - roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)], - layout="NCHW", - method="cubic", - coordinate_transformation_mode="half_pixel", - rounding_method="round", - cubic_alpha=-0.75, - cubic_exclude=0, - extrapolation_value=0.0, - out_dtype="void", + lv: R.Tensor((1, 3, 112, 112), dtype="float32") = R.astype(input, dtype="float32") + lv1: R.Tensor((1, 3, 112, 112), dtype="float32") = R.astype(lv, dtype="float32") + lv2: R.Tensor((224,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(224), R.prim_value(1), dtype="int64" + ) + lv3: R.Tensor((224,), dtype="float32") = R.astype(lv2, dtype="float32") + lv4: R.Tensor((224,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(224), R.prim_value(1), dtype="int64" + ) + lv5: R.Tensor((224,), dtype="float32") = R.astype(lv4, dtype="float32") + lv6: R.Tensor((224,), dtype="float32") = R.add(lv5, R.const(0.5, "float32")) + lv7: R.Tensor((224,), dtype="float32") = R.multiply(lv6, R.const(0.5, "float32")) + lv8: R.Tensor((224,), dtype="float32") = R.subtract(lv7, R.const(0.5, "float32")) + lv9: R.Tensor((224,), dtype="float32") = R.add(lv3, R.const(0.5, "float32")) + lv10: R.Tensor((224,), dtype="float32") = R.multiply(lv9, R.const(0.5, "float32")) + lv11: R.Tensor((224,), dtype="float32") = R.subtract(lv10, R.const(0.5, "float32")) + lv12: R.Tensor((224, 1), dtype="float32") = R.expand_dims(lv11, axis=[-1]) + lv13: R.Tensor((224,), dtype="float32") = R.floor(lv8) + lv14: R.Tensor((224, 1), dtype="float32") = R.floor(lv12) + lv15: R.Tensor((224, 1), dtype="float32") = R.subtract(lv12, lv14) + lv16: R.Tensor((224, 1), dtype="float32") = R.clip( + lv15, R.prim_value(T.float64(0.0)), R.prim_value(T.float64(1.0)) + ) + lv17: R.Tensor((224,), dtype="float32") = R.subtract(lv8, lv13) + lv18: R.Tensor((224,), dtype="float32") = R.clip( + lv17, R.prim_value(T.float64(0.0)), R.prim_value(T.float64(1.0)) + ) + lv19: R.Tensor((224,), dtype="int64") = R.astype(lv13, dtype="int64") + lv20: R.Tensor((224, 1), dtype="int64") = R.astype(lv14, dtype="int64") + lv21: R.Tensor((224, 1), dtype="int64") = R.subtract(lv20, R.const(1, "int64")) + lv22: R.Tensor((224, 1), dtype="int64") = R.add(lv20, R.const(1, "int64")) + lv23: R.Tensor((224, 1), dtype="int64") = R.add(lv20, R.const(2, "int64")) + lv24: R.Tensor((224,), dtype="int64") = R.subtract(lv19, R.const(1, "int64")) + lv25: R.Tensor((224,), dtype="int64") = R.add(lv19, R.const(1, "int64")) + lv26: R.Tensor((224,), dtype="int64") = R.add(lv19, R.const(2, "int64")) + lv27: R.Tensor((224,), dtype="float32") = R.subtract(R.const(1.0, "float32"), lv18) + lv28: R.Tensor((448,), dtype="float32") = R.concat((lv18, lv27), axis=0) + lv29: R.Tensor((2, 224), dtype="float32") = R.reshape(lv28, R.shape([2, 224])) + lv30: R.Tensor((224,), dtype="float32") = R.add(lv18, R.const(1.0, "float32")) + lv31: R.Tensor((224,), dtype="float32") = R.subtract(R.const(2.0, "float32"), lv18) + lv32: R.Tensor((448,), dtype="float32") = R.concat((lv30, lv31), axis=0) + lv33: R.Tensor((2, 224), dtype="float32") = R.reshape(lv32, R.shape([2, 224])) + lv34: R.Tensor((2, 224), dtype="float32") = R.multiply( + lv33, R.const(-0.75, "float32") + ) + lv35: R.Tensor((2, 224), dtype="float32") = R.subtract( + lv34, R.const(-3.75, "float32") + ) + lv36: R.Tensor((2, 224), dtype="float32") = R.multiply(lv35, lv33) + lv37: R.Tensor((2, 224), dtype="float32") = R.add(lv36, R.const(-6.0, "float32")) + lv38: R.Tensor((2, 224), dtype="float32") = R.multiply(lv37, lv33) + lv39: R.Tensor((2, 224), dtype="float32") = R.subtract( + lv38, R.const(-3.0, "float32") + ) + lv40: R.Tensor((2, 224), dtype="float32") = R.multiply( + lv29, R.const(1.25, "float32") + ) + lv41: R.Tensor((2, 224), dtype="float32") = R.subtract( + lv40, R.const(2.25, "float32") + ) + lv42: R.Tensor((2, 224), dtype="float32") = R.multiply(lv41, lv29) + lv43: R.Tensor((2, 224), dtype="float32") = R.multiply(lv42, lv29) + lv44: R.Tensor((2, 224), dtype="float32") = R.add(lv43, R.const(1.0, "float32")) + lv45: R.Tensor((1, 224), dtype="float32") = R.strided_slice( + lv39, + (R.prim_value(0),), + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(1),), + assume_inbound=False, ) - gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) = (lv,) + lv46: R.Tensor((1, 224), dtype="float32") = R.strided_slice( + lv39, + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv47: R.Tensor((224,), dtype="float32") = R.squeeze(lv45, axis=[0]) + lv48: R.Tensor((224,), dtype="float32") = R.squeeze(lv46, axis=[0]) + lv49: R.Tensor((1, 224), dtype="float32") = R.strided_slice( + lv44, + (R.prim_value(0),), + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv50: R.Tensor((1, 224), dtype="float32") = R.strided_slice( + lv44, + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv51: R.Tensor((224,), dtype="float32") = R.squeeze(lv49, axis=[0]) + lv52: R.Tensor((224,), dtype="float32") = R.squeeze(lv50, axis=[0]) + lv53: R.Tensor((224, 1), dtype="float32") = R.subtract( + R.const(1.0, "float32"), lv16 + ) + lv54: R.Tensor((448, 1), dtype="float32") = R.concat((lv16, lv53), axis=0) + lv55: R.Tensor((2, 224, 1), dtype="float32") = R.reshape(lv54, R.shape([2, 224, 1])) + lv56: R.Tensor((224, 1), dtype="float32") = R.add(lv16, R.const(1.0, "float32")) + lv57: R.Tensor((224, 1), dtype="float32") = R.subtract( + R.const(2.0, "float32"), lv16 + ) + lv58: R.Tensor((448, 1), dtype="float32") = R.concat((lv56, lv57), axis=0) + lv59: R.Tensor((2, 224, 1), dtype="float32") = R.reshape(lv58, R.shape([2, 224, 1])) + lv60: R.Tensor((2, 224, 1), dtype="float32") = R.multiply( + lv59, R.const(-0.75, "float32") + ) + lv61: R.Tensor((2, 224, 1), dtype="float32") = R.subtract( + lv60, R.const(-3.75, "float32") + ) + lv62: R.Tensor((2, 224, 1), dtype="float32") = R.multiply(lv61, lv59) + lv63: R.Tensor((2, 224, 1), dtype="float32") = R.add(lv62, R.const(-6.0, "float32")) + lv64: R.Tensor((2, 224, 1), dtype="float32") = R.multiply(lv63, lv59) + lv65: R.Tensor((2, 224, 1), dtype="float32") = R.subtract( + lv64, R.const(-3.0, "float32") + ) + lv66: R.Tensor((2, 224, 1), dtype="float32") = R.multiply( + lv55, R.const(1.25, "float32") + ) + lv67: R.Tensor((2, 224, 1), dtype="float32") = R.subtract( + lv66, R.const(2.25, "float32") + ) + lv68: R.Tensor((2, 224, 1), dtype="float32") = R.multiply(lv67, lv55) + lv69: R.Tensor((2, 224, 1), dtype="float32") = R.multiply(lv68, lv55) + lv70: R.Tensor((2, 224, 1), dtype="float32") = R.add(lv69, R.const(1.0, "float32")) + lv71: R.Tensor((1, 224, 1), dtype="float32") = R.strided_slice( + lv65, + (R.prim_value(0),), + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv72: R.Tensor((1, 224, 1), dtype="float32") = R.strided_slice( + lv65, + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv73: R.Tensor((224, 1), dtype="float32") = R.squeeze(lv71, axis=[0]) + lv74: R.Tensor((224, 1), dtype="float32") = R.squeeze(lv72, axis=[0]) + lv75: R.Tensor((1, 224, 1), dtype="float32") = R.strided_slice( + lv70, + (R.prim_value(0),), + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv76: R.Tensor((1, 224, 1), dtype="float32") = R.strided_slice( + lv70, + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv77: R.Tensor((224, 1), dtype="float32") = R.squeeze(lv75, axis=[0]) + lv78: R.Tensor((224, 1), dtype="float32") = R.squeeze(lv76, axis=[0]) + lv79: R.Tensor((224, 1), dtype="int64") = R.clip( + lv21, R.prim_value(0), R.prim_value(111) + ) + lv80: R.Tensor((224,), dtype="int64") = R.clip( + lv24, R.prim_value(0), R.prim_value(111) + ) + lv81: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv80, axis=3, mode="fast" + ) + lv82: R.Tensor((224,), dtype="int64") = R.squeeze(lv79, axis=None) + lv83: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv81, lv82, axis=2, mode="fast" + ) + lv84: R.Tensor((224, 1), dtype="int64") = R.clip( + lv21, R.prim_value(0), R.prim_value(111) + ) + lv85: R.Tensor((224,), dtype="int64") = R.clip( + lv19, R.prim_value(0), R.prim_value(111) + ) + lv86: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv85, axis=3, mode="fast" + ) + lv87: R.Tensor((224,), dtype="int64") = R.squeeze(lv84, axis=None) + lv88: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv86, lv87, axis=2, mode="fast" + ) + lv89: R.Tensor((224, 1), dtype="int64") = R.clip( + lv21, R.prim_value(0), R.prim_value(111) + ) + lv90: R.Tensor((224,), dtype="int64") = R.clip( + lv25, R.prim_value(0), R.prim_value(111) + ) + lv91: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv90, axis=3, mode="fast" + ) + lv92: R.Tensor((224,), dtype="int64") = R.squeeze(lv89, axis=None) + lv93: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv91, lv92, axis=2, mode="fast" + ) + lv94: R.Tensor((224, 1), dtype="int64") = R.clip( + lv21, R.prim_value(0), R.prim_value(111) + ) + lv95: R.Tensor((224,), dtype="int64") = R.clip( + lv26, R.prim_value(0), R.prim_value(111) + ) + lv96: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv95, axis=3, mode="fast" + ) + lv97: R.Tensor((224,), dtype="int64") = R.squeeze(lv94, axis=None) + lv98: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv96, lv97, axis=2, mode="fast" + ) + lv99: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv83, lv47) + lv100: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv88, lv51) + lv101: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv99, lv100) + lv102: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv93, lv52) + lv103: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv101, lv102) + lv104: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv98, lv48) + lv105: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv103, lv104) + lv106: R.Tensor((224, 1), dtype="int64") = R.clip( + lv20, R.prim_value(0), R.prim_value(111) + ) + lv107: R.Tensor((224,), dtype="int64") = R.clip( + lv24, R.prim_value(0), R.prim_value(111) + ) + lv108: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv107, axis=3, mode="fast" + ) + lv109: R.Tensor((224,), dtype="int64") = R.squeeze(lv106, axis=None) + lv110: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv108, lv109, axis=2, mode="fast" + ) + lv111: R.Tensor((224, 1), dtype="int64") = R.clip( + lv20, R.prim_value(0), R.prim_value(111) + ) + lv112: R.Tensor((224,), dtype="int64") = R.clip( + lv19, R.prim_value(0), R.prim_value(111) + ) + lv113: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv112, axis=3, mode="fast" + ) + lv114: R.Tensor((224,), dtype="int64") = R.squeeze(lv111, axis=None) + lv115: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv113, lv114, axis=2, mode="fast" + ) + lv116: R.Tensor((224, 1), dtype="int64") = R.clip( + lv20, R.prim_value(0), R.prim_value(111) + ) + lv117: R.Tensor((224,), dtype="int64") = R.clip( + lv25, R.prim_value(0), R.prim_value(111) + ) + lv118: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv117, axis=3, mode="fast" + ) + lv119: R.Tensor((224,), dtype="int64") = R.squeeze(lv116, axis=None) + lv120: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv118, lv119, axis=2, mode="fast" + ) + lv121: R.Tensor((224, 1), dtype="int64") = R.clip( + lv20, R.prim_value(0), R.prim_value(111) + ) + lv122: R.Tensor((224,), dtype="int64") = R.clip( + lv26, R.prim_value(0), R.prim_value(111) + ) + lv123: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv122, axis=3, mode="fast" + ) + lv124: R.Tensor((224,), dtype="int64") = R.squeeze(lv121, axis=None) + lv125: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv123, lv124, axis=2, mode="fast" + ) + lv126: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv110, lv47) + lv127: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv115, lv51) + lv128: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv126, lv127) + lv129: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv120, lv52) + lv130: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv128, lv129) + lv131: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv125, lv48) + lv132: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv130, lv131) + lv133: R.Tensor((224, 1), dtype="int64") = R.clip( + lv22, R.prim_value(0), R.prim_value(111) + ) + lv134: R.Tensor((224,), dtype="int64") = R.clip( + lv24, R.prim_value(0), R.prim_value(111) + ) + lv135: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv134, axis=3, mode="fast" + ) + lv136: R.Tensor((224,), dtype="int64") = R.squeeze(lv133, axis=None) + lv137: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv135, lv136, axis=2, mode="fast" + ) + lv138: R.Tensor((224, 1), dtype="int64") = R.clip( + lv22, R.prim_value(0), R.prim_value(111) + ) + lv139: R.Tensor((224,), dtype="int64") = R.clip( + lv19, R.prim_value(0), R.prim_value(111) + ) + lv140: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv139, axis=3, mode="fast" + ) + lv141: R.Tensor((224,), dtype="int64") = R.squeeze(lv138, axis=None) + lv142: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv140, lv141, axis=2, mode="fast" + ) + lv143: R.Tensor((224, 1), dtype="int64") = R.clip( + lv22, R.prim_value(0), R.prim_value(111) + ) + lv144: R.Tensor((224,), dtype="int64") = R.clip( + lv25, R.prim_value(0), R.prim_value(111) + ) + lv145: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv144, axis=3, mode="fast" + ) + lv146: R.Tensor((224,), dtype="int64") = R.squeeze(lv143, axis=None) + lv147: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv145, lv146, axis=2, mode="fast" + ) + lv148: R.Tensor((224, 1), dtype="int64") = R.clip( + lv22, R.prim_value(0), R.prim_value(111) + ) + lv149: R.Tensor((224,), dtype="int64") = R.clip( + lv26, R.prim_value(0), R.prim_value(111) + ) + lv150: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv149, axis=3, mode="fast" + ) + lv151: R.Tensor((224,), dtype="int64") = R.squeeze(lv148, axis=None) + lv152: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv150, lv151, axis=2, mode="fast" + ) + lv153: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv137, lv47) + lv154: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv142, lv51) + lv155: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv153, lv154) + lv156: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv147, lv52) + lv157: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv155, lv156) + lv158: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv152, lv48) + lv159: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv157, lv158) + lv160: R.Tensor((224, 1), dtype="int64") = R.clip( + lv23, R.prim_value(0), R.prim_value(111) + ) + lv161: R.Tensor((224,), dtype="int64") = R.clip( + lv24, R.prim_value(0), R.prim_value(111) + ) + lv162: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv161, axis=3, mode="fast" + ) + lv163: R.Tensor((224,), dtype="int64") = R.squeeze(lv160, axis=None) + lv164: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv162, lv163, axis=2, mode="fast" + ) + lv165: R.Tensor((224, 1), dtype="int64") = R.clip( + lv23, R.prim_value(0), R.prim_value(111) + ) + lv166: R.Tensor((224,), dtype="int64") = R.clip( + lv19, R.prim_value(0), R.prim_value(111) + ) + lv167: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv166, axis=3, mode="fast" + ) + lv168: R.Tensor((224,), dtype="int64") = R.squeeze(lv165, axis=None) + lv169: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv167, lv168, axis=2, mode="fast" + ) + lv170: R.Tensor((224, 1), dtype="int64") = R.clip( + lv23, R.prim_value(0), R.prim_value(111) + ) + lv171: R.Tensor((224,), dtype="int64") = R.clip( + lv25, R.prim_value(0), R.prim_value(111) + ) + lv172: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv171, axis=3, mode="fast" + ) + lv173: R.Tensor((224,), dtype="int64") = R.squeeze(lv170, axis=None) + lv174: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv172, lv173, axis=2, mode="fast" + ) + lv175: R.Tensor((224, 1), dtype="int64") = R.clip( + lv23, R.prim_value(0), R.prim_value(111) + ) + lv176: R.Tensor((224,), dtype="int64") = R.clip( + lv26, R.prim_value(0), R.prim_value(111) + ) + lv177: R.Tensor((1, 3, 112, 224), dtype="float32") = R.take( + lv1, lv176, axis=3, mode="fast" + ) + lv178: R.Tensor((224,), dtype="int64") = R.squeeze(lv175, axis=None) + lv179: R.Tensor((1, 3, 224, 224), dtype="float32") = R.take( + lv177, lv178, axis=2, mode="fast" + ) + lv180: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv164, lv47) + lv181: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv169, lv51) + lv182: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv180, lv181) + lv183: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv174, lv52) + lv184: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv182, lv183) + lv185: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv179, lv48) + lv186: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv184, lv185) + lv187: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv105, lv73) + lv188: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv132, lv77) + lv189: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv187, lv188) + lv190: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv159, lv78) + lv191: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv189, lv190) + lv192: R.Tensor((1, 3, 224, 224), dtype="float32") = R.multiply(lv186, lv74) + lv193: R.Tensor((1, 3, 224, 224), dtype="float32") = R.add(lv191, lv192) + lv194: R.Tensor((1, 3, 224, 224), dtype="float32") = R.astype( + lv193, dtype="float32" + ) + lv195: R.Tensor((1, 3, 224, 224), dtype="float32") = R.astype( + lv194, dtype="float32" + ) + gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) = (lv195,) R.output(gv) return gv example_args = (torch.randn(1, 3, 112, 112, dtype=torch.float32),) - verify_model(InterpolateBilinear(), example_args, {}, expected_bilinear) - verify_model(InterpolateNearest(), example_args, {}, expected_nearest) - verify_model(InterpolateBicubic(), example_args, {}, expected_bicubic) + verify_model( + InterpolateBilinear(), example_args, {}, expected_bilinear, run_ep_decomposition=True + ) + verify_model( + InterpolateNearest(), example_args, {}, expected_nearest, run_ep_decomposition=True + ) + verify_model( + InterpolateBicubic(), example_args, {}, expected_bicubic, run_ep_decomposition=True + ) def test_mean():