diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index d963a5d160bf..8d437027e51a 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -816,10 +816,16 @@ def _convert_cropping( f"Operator {crop_type} is not supported for frontend Keras." ) int32_max = np.iinfo(np.int32).max + if data_layout == "NHWC": + begin = [0, crop_t, crop_l, 0] + end = [int32_max, in_h - crop_b, in_w - crop_r, int32_max] + else: + begin = [0, 0, crop_t, crop_l] + end = [int32_max, int32_max, in_h - crop_b, in_w - crop_r] return _op.strided_slice( inexpr, - begin=[0, 0, crop_t, crop_l], - end=[int32_max, int32_max, in_h - crop_b, in_w - crop_r], + begin=begin, + end=end, ) diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 45935f87f4f4..cc6421614e0a 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -449,7 +449,15 @@ def test_forward_crop(self, keras_mod): x = keras_mod.layers.Cropping2D(cropping=0)(x) x = keras_mod.layers.Add()([x, x]) keras_model = keras_mod.models.Model(data, x) - verify_keras_frontend(keras_model) + verify_keras_frontend(keras_model, layout="NHWC") + verify_keras_frontend(keras_model, layout="NHWC") + + data = keras_mod.layers.Input(shape=(32, 32, 3)) + x = keras_mod.layers.Cropping2D(cropping=(2, 1))(data) + x = keras_mod.layers.Cropping2D(cropping=(1, 2))(x) + keras_model = keras_mod.models.Model(data, x) + verify_keras_frontend(keras_model, layout="NHWC") + verify_keras_frontend(keras_model, layout="NCHW") def test_forward_multi_inputs(self, keras_mod): data1 = keras_mod.layers.Input(shape=(32, 32, 3))