diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 21862089944e..d53647cc684c 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -266,11 +266,12 @@ def _convert_dense( # In case of RNN dense, input shape will be (1, 1, n) if input_dim > 2: input_shape = tuple(dim if dim else 1 for dim in _as_list(input_shape)[0]) - if input_dim != 3 or input_shape[0] != 1 or input_shape[1] != 1: - raise tvm.error.OpAttributeInvalid( - f"Input shape {input_shape} is not valid for operator Dense." - ) - inexpr = _op.squeeze(inexpr, axis=[0]) + # Keras has no limitations on the shape of the input tensor. But our + # dense op expects 2D input. All inputs with number of dimensions > 2 + # are reshaped all "batch" axes into one. + # For example: (N, d1, d2, d3) -> (N * d1 * d2, d3) + new_batch_size = np.prod(input_shape[:-1]) + inexpr = _op.reshape(inexpr, newshape=(new_batch_size, input_shape[-1])) out = _op.nn.dense(data=inexpr, **params) if keras_layer.use_bias: bias = etab.new_const(weightList[1]) @@ -283,7 +284,8 @@ def _convert_dense( if act_type != "linear": out = _convert_activation(out, act_type, etab, data_layout) if input_dim > 2: - out = _op.expand_dims(out, axis=0) + out_shape = (*input_shape[:-1], units) + out = _op.reshape(out, newshape=out_shape) return out diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index aef137e634a7..0d05e34a155b 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -285,6 +285,16 @@ def test_forward_dense(self, keras_mod): keras_model = keras_mod.models.Model(data, x) verify_keras_frontend(keras_model, need_transpose=False) + data = keras_mod.layers.Input(shape=(120, 2560), name="image_set") + x = keras_mod.layers.Dense(1, activation="linear", name="e")(data) + keras_model = keras_mod.models.Model(data, x) + verify_keras_frontend(keras_model, need_transpose=False) + + data = keras_mod.layers.Input(shape=(10, 12, 2560), name="image_set") + x = keras_mod.layers.Dense(32, activation="linear", name="e")(data) + keras_model = keras_mod.models.Model(data, x) + verify_keras_frontend(keras_model, need_transpose=False) + def test_forward_permute(self, keras_mod): data = keras_mod.layers.Input(shape=(2, 3, 4)) x = keras_mod.layers.Permute([2, 3, 1])(data)