Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 295 additions & 0 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,300 @@ def _lstm(self, node: fx.Node) -> relax.Var:
output = self.block_builder.emit(relax.op.permute_dims(output, axes=[1, 0, 2]))
return output

def _gru(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
input_tensor = args[0]
hx = args[1] if len(args) > 1 else None
params = args[2] if len(args) > 2 else None
has_biases = args[3] if len(args) > 3 else True
num_layers = args[4] if len(args) > 4 else 1
_dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference
_train = args[6] if len(args) > 6 else False # Not used in inference
bidirectional = args[7] if len(args) > 7 else False
batch_first = args[8] if len(args) > 8 else False

if bidirectional:
raise NotImplementedError("Bidirectional GRU is not yet supported")

input_shape = self.shape_of(input_tensor)
if batch_first:
batch_size, seq_len, input_size = input_shape
else:
seq_len, batch_size, input_size = input_shape

if isinstance(seq_len, tvm.tir.IntImm):
seq_len = seq_len.value
if isinstance(batch_size, tvm.tir.IntImm):
batch_size = batch_size.value
if isinstance(input_size, tvm.tir.IntImm):
input_size = input_size.value

if params and len(params) >= 2:
# For multi-layer, we need to extract the first layer's weights
# to determine hidden size
if num_layers > 1:
# Multi-layer: params[0] is first layer's weight_ih
weight_ih = params[0]
else:
# Single layer: params[0] is weight_ih
weight_ih = params[0]
Comment on lines +425 to +430
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if/else block is redundant as both branches execute the same code (weight_ih = params[0]). This can be simplified to improve code clarity and maintainability.

            # For multi-layer, we need to extract the first layer's weights
            # to determine hidden size. params[0] is the first layer's weight_ih
            # for both single and multi-layer cases.
            weight_ih = params[0]

# Extract hidden size from weight dimensions
# weight_ih has shape (3 * hidden_size, input_size)
weight_ih_shape = self.shape_of(weight_ih)
hidden_size = weight_ih_shape[0] // 3 # 3 gates: reset, update, new
else:
# Fallback to a default hidden size
hidden_size = 16

# Implement actual GRU computation using Relax operations
# GRU equations:
# r_t = sigmoid(W_ir * x_t + b_ir + W_hr * h_{t-1} + b_hr)
# z_t = sigmoid(W_iz * x_t + b_iz + W_hz * h_{t-1} + b_hz)
# n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * h_{t-1} + b_hn))
# h_t = (1 - z_t) * n_t + z_t * h_{t-1}
dtype = input_tensor.struct_info.dtype

# Reshape input for processing
if batch_first:
# Input: (batch, seq_len, input_size) -> (seq_len, batch, input_size)
input_reshaped = self.block_builder.emit(
relax.op.permute_dims(input_tensor, axes=[1, 0, 2])
)
else:
input_reshaped = input_tensor

# Initialize hidden states for all layers
if hx is not None:
# hx shape: (num_layers, batch_size, hidden_size)
h_states = []
for layer in range(num_layers):
h_layer = self.block_builder.emit(
relax.op.take(hx, relax.const(layer, "int64"), axis=0, mode="clip")
)
h_states.append(h_layer)
else:
h_states = []
for layer in range(num_layers):
h_layer = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype)
)
h_states.append(h_layer)

outputs = []

for t in range(seq_len):
# Get input at time t: (batch_size, input_size)
x_t = self.block_builder.emit(
relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, mode="clip")
)

# Process through each layer
current_input = x_t
new_h_states = []

for layer in range(num_layers):
# Get layer parameters
if params and len(params) >= 4 * num_layers:
# Multi-layer case: params are organized as
# [layer0_ih, layer0_hh, layer0_bias_ih, layer0_bias_hh, layer1_ih, ...]
param_offset = layer * 4
weight_ih = params[param_offset]
weight_hh = params[param_offset + 1]
bias_ih = params[param_offset + 2] if has_biases else None
bias_hh = params[param_offset + 3] if has_biases else None
elif params and len(params) >= 4:
# Single layer case
weight_ih = params[0]
weight_hh = params[1]
bias_ih = params[2] if has_biases else None
bias_hh = params[3] if has_biases else None
else:
# Fallback: create zero weights
weight_ih = self.block_builder.emit(
relax.op.zeros(
relax.ShapeExpr(
(3 * hidden_size, input_size if layer == 0 else hidden_size)
),
dtype,
)
)
weight_hh = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((3 * hidden_size, hidden_size)), dtype)
)
bias_ih = None
bias_hh = None

# Get previous hidden state for this layer
h_prev = h_states[layer]

# Split weights by gates: PyTorch GRU gate order: reset, update, new (r, z, n)
gate_size = hidden_size

# Reset gate weights
weight_ih_r = self.block_builder.emit(
relax.op.strided_slice(weight_ih, axes=[0], begin=[0], end=[gate_size])
)
weight_hh_r = self.block_builder.emit(
relax.op.strided_slice(weight_hh, axes=[0], begin=[0], end=[gate_size])
)

# Update gate weights
weight_ih_z = self.block_builder.emit(
relax.op.strided_slice(
weight_ih, axes=[0], begin=[gate_size], end=[2 * gate_size]
)
)
weight_hh_z = self.block_builder.emit(
relax.op.strided_slice(
weight_hh, axes=[0], begin=[gate_size], end=[2 * gate_size]
)
)

# New gate weights
weight_ih_n = self.block_builder.emit(
relax.op.strided_slice(
weight_ih, axes=[0], begin=[2 * gate_size], end=[3 * gate_size]
)
)
weight_hh_n = self.block_builder.emit(
relax.op.strided_slice(
weight_hh, axes=[0], begin=[2 * gate_size], end=[3 * gate_size]
)
)

# Transpose weights for matmul
weight_ih_r_t = self.block_builder.emit(
relax.op.permute_dims(weight_ih_r, axes=[1, 0])
)
weight_hh_r_t = self.block_builder.emit(
relax.op.permute_dims(weight_hh_r, axes=[1, 0])
)
weight_ih_z_t = self.block_builder.emit(
relax.op.permute_dims(weight_ih_z, axes=[1, 0])
)
weight_hh_z_t = self.block_builder.emit(
relax.op.permute_dims(weight_hh_z, axes=[1, 0])
)
weight_ih_n_t = self.block_builder.emit(
relax.op.permute_dims(weight_ih_n, axes=[1, 0])
)
weight_hh_n_t = self.block_builder.emit(
relax.op.permute_dims(weight_hh_n, axes=[1, 0])
)
Comment on lines +523 to +573
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The weight slicing and transposition operations are performed inside the time-step loop (for t in range(seq_len)). Since these weights do not depend on the time step t, these computations are redundant and highly inefficient, especially for long sequences. They should be hoisted out of the time-step loop and computed only once per layer. The same applies to bias slicing (e.g., lines 583-588, 607-616, 635-644). This will result in a much smaller and more efficient computation graph.


# Compute reset gate: r_t = sigmoid(W_ir * x_t + b_ir + W_hr * h_{t-1} + b_hr)
r_ih = self.block_builder.emit(
relax.op.linear_algebra.matmul(current_input, weight_ih_r_t)
)
r_hh = self.block_builder.emit(
relax.op.linear_algebra.matmul(h_prev, weight_hh_r_t)
)
if bias_ih is not None and bias_hh is not None:
bias_ih_r = self.block_builder.emit(
relax.op.strided_slice(bias_ih, axes=[0], begin=[0], end=[gate_size])
)
bias_hh_r = self.block_builder.emit(
relax.op.strided_slice(bias_hh, axes=[0], begin=[0], end=[gate_size])
)
r_t = self.block_builder.emit(
relax.op.sigmoid(
relax.op.add(
relax.op.add(relax.op.add(r_ih, bias_ih_r), r_hh), bias_hh_r
)
)
)
else:
r_t = self.block_builder.emit(relax.op.sigmoid(relax.op.add(r_ih, r_hh)))

# Compute update gate: z_t = sigmoid(W_iz * x_t + b_iz + W_hz * h_{t-1} + b_hz)
z_ih = self.block_builder.emit(
relax.op.linear_algebra.matmul(current_input, weight_ih_z_t)
)
z_hh = self.block_builder.emit(
relax.op.linear_algebra.matmul(h_prev, weight_hh_z_t)
)
if bias_ih is not None and bias_hh is not None:
bias_ih_z = self.block_builder.emit(
relax.op.strided_slice(
bias_ih, axes=[0], begin=[gate_size], end=[2 * gate_size]
)
)
bias_hh_z = self.block_builder.emit(
relax.op.strided_slice(
bias_hh, axes=[0], begin=[gate_size], end=[2 * gate_size]
)
)
z_t = self.block_builder.emit(
relax.op.sigmoid(
relax.op.add(
relax.op.add(relax.op.add(z_ih, bias_ih_z), z_hh), bias_hh_z
)
)
)
else:
z_t = self.block_builder.emit(relax.op.sigmoid(relax.op.add(z_ih, z_hh)))

# Compute new gate: n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * h_{t-1} + b_hn))
n_ih = self.block_builder.emit(
relax.op.linear_algebra.matmul(current_input, weight_ih_n_t)
)
n_hh = self.block_builder.emit(
relax.op.linear_algebra.matmul(h_prev, weight_hh_n_t)
)
if bias_ih is not None and bias_hh is not None:
bias_ih_n = self.block_builder.emit(
relax.op.strided_slice(
bias_ih, axes=[0], begin=[2 * gate_size], end=[3 * gate_size]
)
)
bias_hh_n = self.block_builder.emit(
relax.op.strided_slice(
bias_hh, axes=[0], begin=[2 * gate_size], end=[3 * gate_size]
)
)
n_t = self.block_builder.emit(
relax.op.tanh(
relax.op.add(
relax.op.add(n_ih, bias_ih_n),
relax.op.multiply(r_t, relax.op.add(n_hh, bias_hh_n)),
)
)
)
else:
n_t = self.block_builder.emit(
relax.op.tanh(relax.op.add(n_ih, relax.op.multiply(r_t, n_hh)))
)

# Update hidden state: h_t = (1 - z_t) * n_t + z_t * h_{t-1}
one_minus_z = self.block_builder.emit(
relax.op.subtract(relax.const(1.0, dtype), z_t)
)
h_t = self.block_builder.emit(
relax.op.add(
relax.op.multiply(one_minus_z, n_t), relax.op.multiply(z_t, h_prev)
)
)

new_h_states.append(h_t)

current_input = h_t

# Update hidden states for next time step
h_states = new_h_states

# Store output (from the last layer)
outputs.append(h_states[-1])

# Stack outputs: (seq_len, batch_size, hidden_size)
output = self.block_builder.emit(relax.op.stack(outputs, axis=0))

# Reshape back to batch_first if needed
if batch_first:
# (seq_len, batch_size, hidden_size) -> (batch_size, seq_len, hidden_size)
output = self.block_builder.emit(relax.op.permute_dims(output, axes=[1, 0, 2]))

return output

########## Manipulation ##########

def _narrow(self, node: fx.Node) -> relax.Var:
Expand Down Expand Up @@ -652,6 +946,7 @@ def create_convert_map(
"layer_norm.default": self._layer_norm,
"linear.default": self._linear,
"lstm.input": self._lstm,
"gru.input": self._gru,
"max_pool1d.default": self._max_pool1d,
"max_pool2d.default": self._max_pool2d,
"max_pool3d.default": self._max_pool3d,
Expand Down
71 changes: 71 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -6050,5 +6050,76 @@ def main(
verify_model(TensorNoneModel(), example_args, {}, Expected)


def test_gru():
class BasicGRU(nn.Module):
def __init__(self):
super().__init__()
self.gru = nn.GRU(
input_size=4,
hidden_size=8,
num_layers=1,
batch_first=True,
bidirectional=False,
)

def forward(self, x):
y, _ = self.gru(x)
return y

torch.manual_seed(42)
x = torch.randn(2, 3, 4, dtype=torch.float32)
model = BasicGRU()
with torch.no_grad():
pytorch_output = model(x)
exported_program = export(model, args=(x,))
mod = from_exported_program(exported_program)
target = tvm.target.Target("llvm")
ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
x_tvm = tvm.runtime.tensor(x.numpy())
tvm_output = vm["main"](x_tvm)
if hasattr(tvm_output, "numpy"):
tvm_output_np = tvm_output.numpy()
else:
tvm_output_np = tvm_output[0].numpy()
assert (
pytorch_output.shape == tvm_output_np.shape
), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM {tvm_output_np.shape}"
np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np, rtol=1e-4, atol=1e-5)

class SeqFirstGRU(nn.Module):
def __init__(self):
super().__init__()
self.gru = nn.GRU(
input_size=3,
hidden_size=6,
num_layers=1,
batch_first=False,
bidirectional=False,
)

def forward(self, x):
y, _ = self.gru(x)
return y

torch.manual_seed(43)
x2 = torch.randn(4, 2, 3, dtype=torch.float32)
model2 = SeqFirstGRU()
with torch.no_grad():
pytorch_output2 = model2(x2)
exported_program2 = export(model2, args=(x2,))
mod2 = from_exported_program(exported_program2)
ex2 = relax.build(mod2, target)
vm2 = relax.VirtualMachine(ex2, tvm.cpu())
x2_tvm = tvm.runtime.tensor(x2.numpy())
tvm_output2 = vm2["main"](x2_tvm)
if hasattr(tvm_output2, "numpy"):
tvm_output2_np = tvm_output2.numpy()
else:
tvm_output2_np = tvm_output2[0].numpy()
assert pytorch_output2.shape == tvm_output2_np.shape
np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5)
Comment on lines +6053 to +6121
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test function test_gru contains two very similar blocks of code for testing BasicGRU (batch_first=True) and SeqFirstGRU (batch_first=False). This code duplication makes the test harder to read and maintain. Consider refactoring the common testing logic into a helper function that can be called for both GRU configurations. This helper could take the model class, input data, and other relevant parameters as arguments.

Comment on lines +6053 to +6121
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _gru implementation supports multi-layer GRUs, GRUs with an initial hidden state (hx), and GRUs with/without biases. However, the tests only cover single-layer GRUs without an initial hidden state and with biases. To ensure the implementation is robust and prevent future regressions, please add test cases for:

  • Multi-layer GRU (num_layers > 1).
  • GRU with a provided initial hidden state (hx).
  • GRU without biases (bias=False in nn.GRU).



if __name__ == "__main__":
tvm.testing.main()
Loading