diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 4bdca2c4d533..eb16bf2a25b4 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -864,29 +864,14 @@ def _convert_reshape(inexpr, keras_layer, etab): _check_data_format(keras_layer) inshape = keras_layer.input_shape # includes batch tshape = keras_layer.target_shape # no batch - if len(inshape) == 3 and len(tshape) == 1: - # (?, a, b) -> (-1, ab) - shape = (-1, tshape[0]) - elif len(inshape) in [2, 3] and len(tshape) == 2: - # (?, cc) -> (-1, c, c) - # (?, a, b) -> (-1, c, c) - assert tshape[0] == tshape[1], "Only supports square target shapes, but got {}".format( - tshape - ) - shape = (-1,) + tshape - else: - # (?, h, w, c) -> (-1, c, H, W) - # (?, h, w, c) -> (-1, c, hw) - # (?, hw, c) -> (-1, c, h, w) - ch = inshape[-1] - assert ch == tshape[-1], ( - "Only supports last dimension in target shape being equal to " - "the channel number of input tensor." - ) - if etab.data_layout == "NCHW": - shape = (-1, ch) + tshape[:-1] - else: - shape = (-1,) + tshape[:-1] + (ch,) + shape = (-1,) + tshape + + if etab.data_layout == "NCHW" and (len(inshape) > 3 or len(tshape) > 2): + # Perform reshape in original NHWC format. + inexpr = _op.transpose(inexpr, [0] + list(range(2, len(inshape))) + [1]) + inexpr = _op.reshape(inexpr, newshape=shape) + return _op.transpose(inexpr, axes=[0, -1] + list(range(1, len(shape) - 1))) + return _op.reshape(inexpr, newshape=shape) diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 05d890419aa4..561e444f077f 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -350,6 +350,16 @@ def test_forward_reshape(self, keras): x = keras.layers.Reshape(target_shape=(4, 4))(data) keras_model = keras.models.Model(data, x) verify_keras_frontend(keras_model, need_transpose=False) + # "non-square" target shape + data = keras.layers.Input(shape=(15,)) + x = keras.layers.Reshape(target_shape=(5, 3))(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model, need_transpose=False) + # modify channel dim + data = keras.layers.Input(shape=(3, 2, 4)) + x = keras.layers.Reshape(target_shape=(3, 8))(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model) def test_forward_crop(self, keras): data = keras.layers.Input(shape=(32, 32, 3))