Skip to content

Export TensorFlow models to ONNX with dynamic input shapes#19255

Merged
sgugger merged 8 commits intohuggingface:mainfrom
dwyatte:tensorflow_onnx_dynamic_shape
Oct 7, 2022
Merged

Export TensorFlow models to ONNX with dynamic input shapes#19255
sgugger merged 8 commits intohuggingface:mainfrom
dwyatte:tensorflow_onnx_dynamic_shape

Conversation

@dwyatte
Copy link
Copy Markdown
Contributor

@dwyatte dwyatte commented Sep 30, 2022

What does this PR do?

This PR exports TensorFlow models to ONNX with dynamic input shapes. Previously they were being exported with static input shapes with a batch size of 2 and sequence length of 8. This should bring TensorFlow to ONNX export mostly into parity with PyTorch Models.

Fixes #19238

  • While fixing this, I noticed the TensorFlow to ONNX export tests weren't actually exporting TensorFlow models because FeaturesManager.get_model_class_for_feature returns a PyTorch model class by default. I've exposed a framework argument on these tests so that FeaturesManager.get_model_class_for_feature can return TensorFlow models. NOTE: Exporting TensorFlow to ONNX seems to be much slower than exporting PyTorch to ONNX so CI duration will increase
  • I've changed validate_model_outputs to check with a batch size/sequence length different than used during export (now 3 and 9 respectively). There was a TODO about this, but it surfaced an error for BERT, CamemBERT, and RoBERTa multiple-choice tasks onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Add node. Name:'tf_bert_for_multiple_choice/bert/encoder/layer_._0/attention/self/add_1' Status Message: /Users/runner/work/1/s/onnxruntime/core/providers/cpu/math/element_wise_ops.h:503 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1, I suspect due to the way these models are defined (tracing fails to properly infer shape somewhere). IMO this is still a net improvement since the ONNX models exported under TensorFlow were previously non-functional except with their static input shapes. I'm skipping these specific configurations during testing for now, but someone should look into this

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of who to tag.
Please tag fewer than 3 people.

@Rocketknight1, @LysandreJik, @lewtun

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

HuggingFaceDocBuilderDev commented Sep 30, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Copy Markdown
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for enabling dynamic input shapes and ensuring our tests actually test the TF exports @dwyatte !

Can you please confirm that the slow tests pass by running:

RUN_SLOW=1 pytest tests/onnx/test_onnx_v2.py

It would also be interesting to know how much slower the TF exports are compared to the PyTorch ones, e.g. can you share some timings for a few models?

reference_model_inputs = config.generate_dummy_inputs(preprocessor, framework=TensorType.PYTORCH)
reference_model_inputs = config.generate_dummy_inputs(
preprocessor,
batch_size=config.default_fixed_batch + 1,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, simple idea!

@dwyatte
Copy link
Copy Markdown
Contributor Author

dwyatte commented Oct 1, 2022

Can you please confirm that the slow tests pass by running RUN_SLOW=1 pytest tests/onnx/test_onnx_v2.py

There were a few failures here (16 failed, 400 passed, 16 skipped, 72972 warnings):

FAILED tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_029_clip_default - TypeError: generate_dummy_inputs() got an unexpected keyword argument 'batch_size'
FAILED tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_051_deberta_v2_question_answering - onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Deserialize tensor onn...
FAILED tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_050_deberta_v2_multiple_choice - AssertionError: deberta-v2, multiple-choice -> Outputs values doesn't match between reference model and ONN...
FAILED tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_078_groupvit_default - TypeError: generate_dummy_inputs() got an unexpected keyword argument 'batch_size'
FAILED tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_109_owlvit_default - TypeError: generate_dummy_inputs() got an unexpected keyword argument 'batch_size'
FAILED tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_110_perceiver_image_classification - onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTIO...
FAILED tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_111_perceiver_masked_lm - onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zer...
FAILED tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_112_perceiver_sequence_classification - onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEP...
FAILED tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_078_groupvit_default - TypeError: generate_dummy_inputs() got an unexpected keyword argument 'batch_size'
FAILED tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_125_roformer_multiple_choice - onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Got ...
FAILED tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_029_clip_default - TypeError: generate_dummy_inputs() got an unexpected keyword argument 'batch_size'
FAILED tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_125_roformer_multiple_choice - onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMEN...
FAILED tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_109_owlvit_default - TypeError: generate_dummy_inputs() got an unexpected keyword argument 'batch_size'
FAILED tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_110_perceiver_image_classification - onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_...
FAILED tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_111_perceiver_masked_lm - onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION :...
FAILED tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_112_perceiver_sequence_classification - onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTI...
  • clip, groupvit, and owlvit should be easy fixes to expose the relevant args (or consume via **kwargs) in their generate_dummy_inputs
  • deberta is failing with my environment on [49d62b0](https://github.com/dwyatte/transformers/commit/49d62b01783416a89acc0b865f7cb8dbab87cd6b) which I branched from
  • perceiver and roformer are real errors, but seem to be due to static input shapes e.g.,
E           onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'Reshape_57' Status Message: /Users/runner/work/1/s/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:41 onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime::TensorShape &, onnxruntime::TensorShapeVector &, bool) gsl::narrow_cast<int64_t>(input_shape.Size()) == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{3,256,256}, requested shape:{2,256,8,32}
E           onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Got invalid dimensions for input: token_type_ids for the following indices
E            index: 2 Got: 17 Expected: 15
E            Please fix either the inputs or the model.

A couple of options:

  • Disable tests for deberta, perceiver, and roformer for this PR while we figure out what's going on there
  • Don't include the code that automatically adds 1 to the batch size and sequence length during validation in this PR
  • Refactor the code to pass in batch_size and seq_length to validate_model_outputs to give more control over which models are tested with dynamic input shapes

What do you think / any other ideas?

@dwyatte
Copy link
Copy Markdown
Contributor Author

dwyatte commented Oct 1, 2022

It would also be interesting to know how much slower the TF exports are compared to the PyTorch ones, e.g. can you share some timings for a few models?

  • bert-base-cased

    • TensorFlow: 7 passed, 6 skipped, 19031 warnings in 525.40s (0:08:45)
    • PyTorch: 7 passed, 6 skipped, 9 warnings in 83.50s (0:01:23)
  • hf-internal-testing/tiny-albert

    • TensorFlow: 6 passed, 6 skipped, 4059 warnings in 28.33s
    • PyTorch: 6 passed, 6 skipped, 9 warnings in 14.30s
  • distilbert-base-cased

    • TensorFlow: 6 passed, 6 skipped, 10241 warnings in 293.32s (0:04:53)
    • PyTorch: 6 passed, 6 skipped, 15 warnings in 40.01s

So TF is around 2-8x slower on my machine (2.3 GHz 8-Core Intel Core i9). The warnings are mainly deprecation warnings from tf2onnx

@dwyatte
Copy link
Copy Markdown
Contributor Author

dwyatte commented Oct 3, 2022

@lewtun any further thoughts on this PR with the goal of supporting dynamic input shapes in ONNX models exported from TensorFlow?

It's not clear to me how tests/onnx/test_onnx_v2.py is used since it doesn't block checks here. Should we skip model/task/framework configurations known to fail a la

# ONNX inference fails on bert, camembert, and roberta multiple-choice when exported with TensorFlow.
# Skip for now
if name in ("bert", "camembert", "roberta") and feature == "multiple-choice" and framework == "tf":
return
Or is it ok to leave the failures if they don't block anything? Is the increase in test time for TF models a concern if it doesn't run regularly?

I suppose part of the answer is whether we want users to experience export failures related to dynamic shapes (which the current code in this PR would do) vs removing explicit dynamic shape validation from the user experience and limiting it to tests.

@lewtun
Copy link
Copy Markdown
Member

lewtun commented Oct 5, 2022

Hey @dwyatte, thanks for sharing the timings! I'm currently working on dramatically shrinking all the ONNX models we use for internal testing, so a 2-8x slowdown for some models is probably OK.

Regarding how to handle the model validation:

A couple of options:

  • Disable tests for deberta, perceiver, and roformer for this PR while we figure out what's going on there
  • Don't include the code that automatically adds 1 to the batch size and sequence length during validation in this PR
  • Refactor the code to pass in batch_size and seq_length to validate_model_outputs to give more control over which models are tested with dynamic input shapes

I am in favour of option (1) and creating a separate issue to figure out what's wrong in the ONNX export of these 3 models. You can skip these tests by following the same logic you linked to above :)

@dwyatte
Copy link
Copy Markdown
Contributor Author

dwyatte commented Oct 5, 2022

I am in favour of option (1) and creating a separate issue to figure out what's wrong in the ONNX export of these 3 models. You can skip these tests by following the same logic you linked to above

Created #19357 to track this. tests/onnx/test_onnx_v2.py should now be 100% passing/skipped (416 passed, 16 skipped in my env)

@dwyatte dwyatte requested a review from lewtun October 5, 2022 18:30
Copy link
Copy Markdown
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for opening an issue to track the problematic ONNX models @dwyatte 🔥 !

This PR LGTM, so gently pinging @sgugger for final approval.

For context in the review: @dwyatte uncovered some edge cases that our ONNX tests didn't cover. This PR currently skips the problematic model heads and we decided to tackle them in a separate issue, since this one is focused on enabling dynamic shapes for TF models

Copy link
Copy Markdown
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot for working on this!

@sgugger
Copy link
Copy Markdown
Collaborator

sgugger commented Oct 7, 2022

There was some problem with CircleCI which only ran part of the test suite (and I can't manually re-run it). Could you push an empty commit on your branch (git commit -m "Trigger CI" --allow-empty)?

@dwyatte
Copy link
Copy Markdown
Contributor Author

dwyatte commented Oct 7, 2022

There was some problem with CircleCI which only ran part of the test suite (and I can't manually re-run it). Could you push an empty commit on your branch (git commit -m "Trigger CI" --allow-empty)?

I think I was having the same problem described here #18351 (comment)

9496836 ran the CI under the huggingface org, so should be good to go now

@sgugger sgugger merged commit a26d71d into huggingface:main Oct 7, 2022
@sgugger
Copy link
Copy Markdown
Collaborator

sgugger commented Oct 7, 2022

Thanks!

@dwyatte dwyatte deleted the tensorflow_onnx_dynamic_shape branch January 31, 2023 18:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Exporting TensorFlow models to ONNX exports with a static batch size of 2 and sequence length of 8

4 participants