Skip to content

Allow user-managed Pool in Wav2Vec2ProcessorWithLM.batch_decode#18351

Merged
sgugger merged 4 commits intohuggingface:mainfrom
falcaopetri:w2v2-lm-pool
Oct 18, 2022
Merged

Allow user-managed Pool in Wav2Vec2ProcessorWithLM.batch_decode#18351
sgugger merged 4 commits intohuggingface:mainfrom
falcaopetri:w2v2-lm-pool

Conversation

@falcaopetri
Copy link
Copy Markdown
Contributor

@falcaopetri falcaopetri commented Jul 29, 2022

What does this PR do?

There are two issues being attacked:

  1. Always creating a fresh pool within Wav2Vec2ProcessorWithLM.batch_decode generates a big overhead if it's called multiple times (this PR fixes Wav2Vec2ProcessorWithLM degraded performance when transcribing multiple files #17879)
  2. pyctcdecode can't use spawn Pools (this PR supersedes fix: 🐛 changing context of multiprocessing while decoding for Windows #17070)

Changes:

  • adds a pool argument to Wav2Vec2ProcessorWithLM.batch_decode. This allows a user-managed multiprocessing.Pool to be (re)used across multiple calls to batch_decode.
  • updates pyctcdecode version requirement. The new version contains code to handle invalid pools.
  • adds some tests for TF's, torch's and flax's Wav2Vec2ProcessorWithLM.
  • adds usage example in Wav2Vec2ProcessorWithLM.batch_decode's docs.

An important implementation reference is multiprocessing's Contexts and start methods. Basically, batch_decode's multiprocessing capabilities are useful only in Unix, which uses fork contexts. This PR introduces some checks in this regard. They can be removed once kensho-technologies/pyctcdecode#65 is resolved.

Breaking change

The new pool argument can break currently valid codes like: processor.batch_decode(logits, 5). Previously, the second argument meant num_processes. If that's an issue, some considerations are:

  • pool and num_processes are mutually exclusive, but a unique arg like num_processes_or_pool seemed weird
  • we could add pool as last argument
  • we could force kwargs-only

Checklist

I couldn't install all deps, so I didn't execute some tests and I didn't build the documentation changes. Let's see what CI shows about:

  • Test test_decoder_batch_1 from test_processor_wav2vec2_with_lm.py (it uses fork, but it fails in my Mac, probably because of the OS platform behavior) Fixed in 17efdddbeeac75eeacddff3bcc51ced70ec19217
  • All other tests
  • Wav2Vec2ProcessorWithLM.batch_decode's new Tips
  • Wav2Vec2ProcessorWithLM.batch_decode's new usage example

Also, the new tests are a copy and paste of previous tests. I couldn't figure out how to factor out the duplicated code (I'm more used to pytest's fixtures). Ideas to cut down the duplicated code are welcomed.

After merge

Once merged, it could be nice to adapt and re-run some scripts such as evaluation of patrickvonplaten/wav2vec2-large-960h-lv60-self-4-gram. In this case, for example, there are two improvements:

  • Current evaluation creates a fresh Pool for each inferred audio, yielding a huge overhead when compared to a sequential decoding or to a user-managed pool as introduced in this PR.
  • Current evaluation uses Wav2Vec2ProcessorWithLM.batch_decode without batched=True, which probably means that there's no advantage in using parallel batch decoding. The usage example from this PR allows proper parallel batch decoding.

Before submitting

Who can review?

@patrickvonplaten

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

HuggingFaceDocBuilderDev commented Jul 29, 2022

The documentation is not available anymore as the PR was closed or merged.

@sanchit-gandhi
Copy link
Copy Markdown
Contributor

@anton-l are you looking into this? I'm not super familiar with the pyctcdecode backend, but can get up to speed and take a look for @falcaopetri if need be :)

@anton-l
Copy link
Copy Markdown
Member

anton-l commented Aug 9, 2022

Hi @falcaopetri! Looks like test_decoder_batch_1 indeed fails on Linux too, maybe you could take a look and find a workaround?

@falcaopetri
Copy link
Copy Markdown
Contributor Author

falcaopetri commented Aug 13, 2022

Hi everyone. Test setup was wrong, but new commit's diff should be self-explanatory. Let me know if there's something still missing.

You can also check this colab. 34s -> 5s speed-up by just removing the overhead, even without using >1 processes, which should have a somewhat linear improvement when decoding multiple batches.

@falcaopetri
Copy link
Copy Markdown
Contributor Author

Hi @anton-l, @sanchit-gandhi. I've merged main trying to fix the Add model like runner CI step, but with no success. Could you take a look on that, and the overall PR please?

@github-actions github-actions Bot closed this Sep 25, 2022
@huggingface huggingface deleted a comment from github-actions Bot Sep 27, 2022
@patrickvonplaten
Copy link
Copy Markdown
Contributor

This PR looks very nice to me! Let's try to get it merged

@patrickvonplaten
Copy link
Copy Markdown
Contributor

@falcaopetri let's see what the tests say and then IMO we can merge if they are all green

@falcaopetri
Copy link
Copy Markdown
Contributor Author

Thanks for re-opening it @patrickvonplaten! I've just fixed a code quality issue that made check_code_quality fail. But now CircleCI seems to be having some (internal?) problem.

And should I rebase everything to get a nice and clean history or it's not necessary?

@falcaopetri
Copy link
Copy Markdown
Contributor Author

@patrickvonplaten, I've rebased everything so we (i) get a cleaner merge and (ii) re-trigger CircleCI. Unfortunately, CircleCI setup is still failing.

@patrickvonplaten
Copy link
Copy Markdown
Contributor

It seems there is an issue with your CircleCI permissions, the tests won't run.
Could you try refreshing your permissions as shown here

@falcaopetri
Copy link
Copy Markdown
Contributor Author

Oh I see, I didn't realize CircleCI was talking about my own user's credential. Thanks for the the heads up.
I've refreshed my credentials and force pushed things again to trigger CI steps.

For future reference, I did make a mistake in the process: after refreshing credentials, CircleCI prompted me to set up a project within my own organization, which then made the run_test* steps fail with "Resource class docker for xlarge is not available for your project, or is not a valid resource class. This message will often appear if the pricing plan for this project does not support docker use." Deleting my own project and force pushing made the huggingface org run the steps.

@falcaopetri
Copy link
Copy Markdown
Contributor Author

I just realized that my Wav2Vec2ProcessorWithLM.batch_decode's Example triggers:
WARNING:datasets.fingerprint:Parameter 'fn_kwargs'={'pool': <multiprocessing.pool.Pool object at 0x7f64ecc54990>} of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.

@patrickvonplaten, is there a way to indicate that a map arg shouldn't be taken into account when hashing the transformation? Or maybe pool could be a global variable?

@patrickvonplaten
Copy link
Copy Markdown
Contributor

Uff yeah good question. Gently pinging @lhoestq and/or @mariosasko here as this seems to be datasets related

@patrickvonplaten
Copy link
Copy Markdown
Contributor

On the other hand, I don't think it's a must that results are cached with datasets here - think we can merge this PR without.

Copy link
Copy Markdown
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR @falcaopetri - it's good to merge for me. Would love to get a review from @sanchit-gandhi here before merging :-)

@lhoestq
Copy link
Copy Markdown
Member

lhoestq commented Oct 5, 2022

@patrickvonplaten, is there a way to indicate that a map arg shouldn't be taken into account when hashing the transformation? Or maybe pool could be a global variable?

You can pass new_fingerprint= in map with a unique fingerprint (string) that depends on the parameters of your transform.

Alternatively you can just disable caching in datasets to remove the warning

Copy link
Copy Markdown
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your PR! Just left a couple of nits for styling but otherwise LGTM!

Comment on lines 365 to 367
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example is a bit too long for a docstring IMO, can you mvoe it to the doc page instead? (in docs/source/em/models/wav2vec2_with_lm.mdx)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I understood your suggestion @sgugger. Wav2Vec2ProcessorWithLM seems to be documented within docs/source/en/model_doc/wav2vec2.mdx. Should I add the example as a custom section (after Overview?), similarly to Speech2Text#Inference?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add the example after the documentation of the processor class.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to bother you, @sgugger, but I still can't grasp how to proceed. I've tried things like below, but doc-builder renders the <div>s incorrectly.

[[autodoc]] Wav2Vec2ProcessorWithLM
    ...
    - batch_decode
Example:

```python
>>> ...
```
    - decode

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example should be after the whole autodoc block. You can introduce it with a small sentence.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation @sgugger. I've refactor the example. FYI, I added a link from batch_decode to the example section (see here). I hope it's ok to have this cross-reference (.py's docstring referencing a section from .mdx).

Now I've also realized that [~model.wav2vec2...] were not rendered properly. Are they available only inside an autodoc context?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment below, the links are not right (you don't need the full path but you do need the class name!)

Comment thread tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py Outdated
Copy link
Copy Markdown
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice PR, thank you @falcaopetri!

Also, the new tests are a copy and paste of previous tests (...) Ideas to cut down the duplicated code are welcomed

There's no issue with having duplicated code for the different tests - it's very readable and easy to understand, which is exactly what we want in Transformers 🤗

Comment thread src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py Outdated
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Comment thread docs/source/en/model_doc/wav2vec2.mdx Outdated
Comment on lines +78 to +79
If you are planning to decode multiple batches of audios, you should consider using [`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.batch_decode`] and passing an instantiated `multiprocessing.Pool`.
Otherwise, [`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.batch_decode`] performance will be slower than calling [`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] for each audio individually, as it internally instantiates a new `Pool` for every call. See the example below:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If you are planning to decode multiple batches of audios, you should consider using [`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.batch_decode`] and passing an instantiated `multiprocessing.Pool`.
Otherwise, [`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.batch_decode`] performance will be slower than calling [`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] for each audio individually, as it internally instantiates a new `Pool` for every call. See the example below:
If you are planning to decode multiple batches of audios, you should consider using [`~Wav2Vec2ProcessorWithLM.batch_decode`] and passing an instantiated `multiprocessing.Pool`.
Otherwise, [`~Wav2Vec2ProcessorWithLM.batch_decode`] performance will be slower than calling [`~Wav2Vec2ProcessorWithLM.decode`] for each audio individually, as it internally instantiates a new `Pool` for every call. See the example below:

@sgugger sgugger merged commit af150e4 into huggingface:main Oct 18, 2022
@sgugger
Copy link
Copy Markdown
Collaborator

sgugger commented Oct 18, 2022

Thanks again for all your work on this!

@ydshieh
Copy link
Copy Markdown
Collaborator

ydshieh commented Oct 20, 2022

Hi @falcaopetri Thank you for adding this! After merging to the main branch, we have 3 test failures (when running on GPU)

tests/models/wav2vec2/test_modeling_tf_wav2vec2.py::TFWav2Vec2ModelIntegrationTest::test_wav2vec2_with_lm_invalid_pool
(line 243)  RuntimeError: context has already been set
tests/models/wav2vec2/test_modeling_wav2vec2.py::Wav2Vec2ModelIntegrationTest::test_wav2vec2_with_lm_invalid_pool
(line 243)  RuntimeError: context has already been set
tests/models/wav2vec2/test_modeling_wav2vec2.py::Wav2Vec2ModelIntegrationTest::test_wav2vec2_with_lm_pool
(line 1639)  TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

More detailed information could be found here and its raw logs.

Would you like to take a look here 🙏? Thank you 🤗

@falcaopetri
Copy link
Copy Markdown
Contributor Author

Hi @ydshieh. I'm really sorry about that. Both errors are my fault while setting up the tests.

RuntimeError: context has already been set emerged only during GPU tests probably because tests are executed differently: a PR's pytest is invoked as with -n 8, while your tests here aren't. I.e., RuntimeError was indeed an issue, but it was hidden during the PR because tests were being executed in different processes.

TypeError: can't convert cuda:0 device type tensor to numpy was caused by a strange change I did at the time after copy and pasting other tests.

This should fix both issues: main...falcaopetri:transformers:fix-w2v2-lm-pool. What is your workflow in this case? Should I open a new PR?

@ydshieh
Copy link
Copy Markdown
Collaborator

ydshieh commented Oct 21, 2022

Thank you so much @falcaopetri. Yes, it would be nice if you can open a PR. Otherwise, we can do it ourselves by looking your branch above ❤️ .

[Context of our CI] Our CircleCI tests run on CPU, so we use -n 8. After merging into main, we run (more) tests and on GPU machines, so we set -n 1 to avoid multiple processes to access GPU at the same time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Wav2Vec2ProcessorWithLM degraded performance when transcribing multiple files

8 participants