-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Relax][PyTorch] Support gru op for ExportedProgram importer #18360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @tlopex, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the PyTorch ExportedProgram importer by adding comprehensive support for the GRU (Gated Recurrent Unit) operator. This enables the successful translation and execution of PyTorch models that incorporate GRU layers, thereby broadening the range of supported recurrent neural network architectures within the system. The implementation meticulously reconstructs the GRU's internal logic using Relax operations and includes dedicated tests to ensure functional correctness. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for the GRU operation in the PyTorch ExportedProgram importer. The implementation correctly follows the GRU equations and includes basic tests. However, I've identified a significant performance issue in the GRU implementation due to redundant computations within a loop. Additionally, the tests could be more comprehensive to cover multi-layer GRUs and other configurations. I've provided specific suggestions to address these points.
| # Reset gate weights | ||
| weight_ih_r = self.block_builder.emit( | ||
| relax.op.strided_slice(weight_ih, axes=[0], begin=[0], end=[gate_size]) | ||
| ) | ||
| weight_hh_r = self.block_builder.emit( | ||
| relax.op.strided_slice(weight_hh, axes=[0], begin=[0], end=[gate_size]) | ||
| ) | ||
|
|
||
| # Update gate weights | ||
| weight_ih_z = self.block_builder.emit( | ||
| relax.op.strided_slice( | ||
| weight_ih, axes=[0], begin=[gate_size], end=[2 * gate_size] | ||
| ) | ||
| ) | ||
| weight_hh_z = self.block_builder.emit( | ||
| relax.op.strided_slice( | ||
| weight_hh, axes=[0], begin=[gate_size], end=[2 * gate_size] | ||
| ) | ||
| ) | ||
|
|
||
| # New gate weights | ||
| weight_ih_n = self.block_builder.emit( | ||
| relax.op.strided_slice( | ||
| weight_ih, axes=[0], begin=[2 * gate_size], end=[3 * gate_size] | ||
| ) | ||
| ) | ||
| weight_hh_n = self.block_builder.emit( | ||
| relax.op.strided_slice( | ||
| weight_hh, axes=[0], begin=[2 * gate_size], end=[3 * gate_size] | ||
| ) | ||
| ) | ||
|
|
||
| # Transpose weights for matmul | ||
| weight_ih_r_t = self.block_builder.emit( | ||
| relax.op.permute_dims(weight_ih_r, axes=[1, 0]) | ||
| ) | ||
| weight_hh_r_t = self.block_builder.emit( | ||
| relax.op.permute_dims(weight_hh_r, axes=[1, 0]) | ||
| ) | ||
| weight_ih_z_t = self.block_builder.emit( | ||
| relax.op.permute_dims(weight_ih_z, axes=[1, 0]) | ||
| ) | ||
| weight_hh_z_t = self.block_builder.emit( | ||
| relax.op.permute_dims(weight_hh_z, axes=[1, 0]) | ||
| ) | ||
| weight_ih_n_t = self.block_builder.emit( | ||
| relax.op.permute_dims(weight_ih_n, axes=[1, 0]) | ||
| ) | ||
| weight_hh_n_t = self.block_builder.emit( | ||
| relax.op.permute_dims(weight_hh_n, axes=[1, 0]) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The weight slicing and transposition operations are performed inside the time-step loop (for t in range(seq_len)). Since these weights do not depend on the time step t, these computations are redundant and highly inefficient, especially for long sequences. They should be hoisted out of the time-step loop and computed only once per layer. The same applies to bias slicing (e.g., lines 583-588, 607-616, 635-644). This will result in a much smaller and more efficient computation graph.
| if num_layers > 1: | ||
| # Multi-layer: params[0] is first layer's weight_ih | ||
| weight_ih = params[0] | ||
| else: | ||
| # Single layer: params[0] is weight_ih | ||
| weight_ih = params[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if/else block is redundant as both branches execute the same code (weight_ih = params[0]). This can be simplified to improve code clarity and maintainability.
# For multi-layer, we need to extract the first layer's weights
# to determine hidden size. params[0] is the first layer's weight_ih
# for both single and multi-layer cases.
weight_ih = params[0]| def test_gru(): | ||
| class BasicGRU(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.gru = nn.GRU( | ||
| input_size=4, | ||
| hidden_size=8, | ||
| num_layers=1, | ||
| batch_first=True, | ||
| bidirectional=False, | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| y, _ = self.gru(x) | ||
| return y | ||
|
|
||
| torch.manual_seed(42) | ||
| x = torch.randn(2, 3, 4, dtype=torch.float32) | ||
| model = BasicGRU() | ||
| with torch.no_grad(): | ||
| pytorch_output = model(x) | ||
| exported_program = export(model, args=(x,)) | ||
| mod = from_exported_program(exported_program) | ||
| target = tvm.target.Target("llvm") | ||
| ex = relax.build(mod, target) | ||
| vm = relax.VirtualMachine(ex, tvm.cpu()) | ||
| x_tvm = tvm.runtime.tensor(x.numpy()) | ||
| tvm_output = vm["main"](x_tvm) | ||
| if hasattr(tvm_output, "numpy"): | ||
| tvm_output_np = tvm_output.numpy() | ||
| else: | ||
| tvm_output_np = tvm_output[0].numpy() | ||
| assert ( | ||
| pytorch_output.shape == tvm_output_np.shape | ||
| ), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM {tvm_output_np.shape}" | ||
| np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np, rtol=1e-4, atol=1e-5) | ||
|
|
||
| class SeqFirstGRU(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.gru = nn.GRU( | ||
| input_size=3, | ||
| hidden_size=6, | ||
| num_layers=1, | ||
| batch_first=False, | ||
| bidirectional=False, | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| y, _ = self.gru(x) | ||
| return y | ||
|
|
||
| torch.manual_seed(43) | ||
| x2 = torch.randn(4, 2, 3, dtype=torch.float32) | ||
| model2 = SeqFirstGRU() | ||
| with torch.no_grad(): | ||
| pytorch_output2 = model2(x2) | ||
| exported_program2 = export(model2, args=(x2,)) | ||
| mod2 = from_exported_program(exported_program2) | ||
| ex2 = relax.build(mod2, target) | ||
| vm2 = relax.VirtualMachine(ex2, tvm.cpu()) | ||
| x2_tvm = tvm.runtime.tensor(x2.numpy()) | ||
| tvm_output2 = vm2["main"](x2_tvm) | ||
| if hasattr(tvm_output2, "numpy"): | ||
| tvm_output2_np = tvm_output2.numpy() | ||
| else: | ||
| tvm_output2_np = tvm_output2[0].numpy() | ||
| assert pytorch_output2.shape == tvm_output2_np.shape | ||
| np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test function test_gru contains two very similar blocks of code for testing BasicGRU (batch_first=True) and SeqFirstGRU (batch_first=False). This code duplication makes the test harder to read and maintain. Consider refactoring the common testing logic into a helper function that can be called for both GRU configurations. This helper could take the model class, input data, and other relevant parameters as arguments.
| def test_gru(): | ||
| class BasicGRU(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.gru = nn.GRU( | ||
| input_size=4, | ||
| hidden_size=8, | ||
| num_layers=1, | ||
| batch_first=True, | ||
| bidirectional=False, | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| y, _ = self.gru(x) | ||
| return y | ||
|
|
||
| torch.manual_seed(42) | ||
| x = torch.randn(2, 3, 4, dtype=torch.float32) | ||
| model = BasicGRU() | ||
| with torch.no_grad(): | ||
| pytorch_output = model(x) | ||
| exported_program = export(model, args=(x,)) | ||
| mod = from_exported_program(exported_program) | ||
| target = tvm.target.Target("llvm") | ||
| ex = relax.build(mod, target) | ||
| vm = relax.VirtualMachine(ex, tvm.cpu()) | ||
| x_tvm = tvm.runtime.tensor(x.numpy()) | ||
| tvm_output = vm["main"](x_tvm) | ||
| if hasattr(tvm_output, "numpy"): | ||
| tvm_output_np = tvm_output.numpy() | ||
| else: | ||
| tvm_output_np = tvm_output[0].numpy() | ||
| assert ( | ||
| pytorch_output.shape == tvm_output_np.shape | ||
| ), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM {tvm_output_np.shape}" | ||
| np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np, rtol=1e-4, atol=1e-5) | ||
|
|
||
| class SeqFirstGRU(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.gru = nn.GRU( | ||
| input_size=3, | ||
| hidden_size=6, | ||
| num_layers=1, | ||
| batch_first=False, | ||
| bidirectional=False, | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| y, _ = self.gru(x) | ||
| return y | ||
|
|
||
| torch.manual_seed(43) | ||
| x2 = torch.randn(4, 2, 3, dtype=torch.float32) | ||
| model2 = SeqFirstGRU() | ||
| with torch.no_grad(): | ||
| pytorch_output2 = model2(x2) | ||
| exported_program2 = export(model2, args=(x2,)) | ||
| mod2 = from_exported_program(exported_program2) | ||
| ex2 = relax.build(mod2, target) | ||
| vm2 = relax.VirtualMachine(ex2, tvm.cpu()) | ||
| x2_tvm = tvm.runtime.tensor(x2.numpy()) | ||
| tvm_output2 = vm2["main"](x2_tvm) | ||
| if hasattr(tvm_output2, "numpy"): | ||
| tvm_output2_np = tvm_output2.numpy() | ||
| else: | ||
| tvm_output2_np = tvm_output2[0].numpy() | ||
| assert pytorch_output2.shape == tvm_output2_np.shape | ||
| np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _gru implementation supports multi-layer GRUs, GRUs with an initial hidden state (hx), and GRUs with/without biases. However, the tests only cover single-layer GRUs without an initial hidden state and with biases. To ensure the implementation is robust and prevent future regressions, please add test cases for:
- Multi-layer GRU (
num_layers > 1). - GRU with a provided initial hidden state (
hx). - GRU without biases (
bias=Falseinnn.GRU).
|
@tlopex
run_decompositions() isn't an in-place operation. We need to change to something like this so that the decomposed ep is used.
decomposed_ep = exported_program.run_decompositions()
return ExportedProgramImporter().from_exported_program(
decomposed_ep,
keep_params_as_input,
unwrap_unit_return_tuple,
no_bind_return_tuple,
)I wrote the original code and I think it was wrong. |
|
With proper decomposition, almost 40% (75 failed, 101 passed, 2 warnings) of the exported program frontend tests will fail so I think it might be easier to migrate gradually. |
|
@mshr-h Got it! I'll consider how to update it. Shall we merge this pr first and then get on it? |
|
@tlopex Okay! |
|
@mshr-h Sorry for late reply. I checked the code and found there will be 40% of the exported program frontend tests will fail if modified correctly. Maybe the first step is try to support unsupported ops because of decomposition step by step? But during the fixing, the tests may remain fail. If you think this is good to do, I can get on it this week |
|
@tlopex Thanks. |
|
Get it! Let me start fixing it @mshr-h |
This pr supports
gru.inputfor ExportedProgram importer.This links to issue #18356