Skip to content

Adding support for fp16 for asr pipeline.#20864

Merged
Narsil merged 10 commits intohuggingface:mainfrom
Narsil:support_fp16_asr
Dec 23, 2022
Merged

Adding support for fp16 for asr pipeline.#20864
Narsil merged 10 commits intohuggingface:mainfrom
Narsil:support_fp16_asr

Conversation

@Narsil
Copy link
Copy Markdown
Contributor

@Narsil Narsil commented Dec 21, 2022

What does this PR do?

Fixes #20862

Many things were considered before settling for this design.

  • feature_extractor(return_tensors="pt¨, torch_dtype=torch_dtype) . This would have the advantage of being consistent, but not all feature extractors to define this, so it would affect all of them. Then why would we use torch_dtype instead of the more common place dtype which could be applied to TF and flax as well. Also it feels a bit redundant to specify both return_tensors and torch_dtype, it would be a good candidate to fuse both parameters (but outisde the scope of this PR).
  • AutoFeatureExtractor.from_pretrained(..., torch_dtype=torch_dtype). This would have the advantage of being overall so users don't need to respecify on each call. However we can't specifiy return_tensors="pt" in there either, so for consistency I didn't try to put it there.
  • ffmpeg_read(..., dtype=dtype) This would be nice to load directly the waveform into fp16 and just let fp16 flow through the feature_extractor. However, whisper in particular uses mel_spectrogram, so using f16 sound might actually damage performance.

In the end, this solution is the simplement I could come up with. Let torch_dtype flow to the pipeline, use it as a regular parameter and convert the output of the feature_extractor after.

This does incur a potentially extra copy but there's no risk of damaging quality of the input.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

HuggingFaceDocBuilderDev commented Dec 21, 2022

The documentation is not available anymore as the PR was closed or merged.

accept torch_dtype=fp16).

Also we need to use a GPU to actually compute on fp16.
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
if dtype is not None:
processed = {k: v.to(dtype=dtype) for k, v in processed.items()}
Copy link
Copy Markdown
Contributor

@bofenghuang bofenghuang Dec 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Narsil,

I think this works fine for whisper models because they only have a single value input_features.

But in case of other models like wav2vec2, the model have multiple values of different dtypes, input_values which need to be casted from float32 to float16, and attention_mask I'm not sure to keep as int32 or cast to int16

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. And as above, if you directly use the to method on processed, it will take care of that for you.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thanks, TIL

Copy link
Copy Markdown
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! My only comment is to make sure to leverage the to method on BatchFeature (if the feature extractor here returns another type, maybe make sure its to method handles dtype arguments) so that checks like not converting int inputs are applied for free.

Otherwise LGTM!

chunk = inputs[i : i + chunk_len]
processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
if dtype is not None:
processed = {k: v.to(dtype=dtype) for k, v in processed.items()}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you can call the to directly on processed, which is a BatchFeature and handles dtype in its to method thanks to #20536 (was designed for vision but I think it will apply here too).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done !


def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False):
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False, dtype=None):
print(f"Running with dtype {dtype}")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be cleaned up ;-)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops

inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
if dtype is not None:
processed = {k: v.to(dtype=dtype) for k, v in processed.items()}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. And as above, if you directly use the to method on processed, it will take care of that for you.

@Narsil Narsil merged commit f7f0ec2 into huggingface:main Dec 23, 2022
@Narsil Narsil deleted the support_fp16_asr branch December 23, 2022 09:18
MKhalusova pushed a commit to MKhalusova/transformers that referenced this pull request Dec 28, 2022
* Supporting `fp16` for asr pipeline

* Adding test.

* Style.

* Oops.

* Flake8 update ?

* Fixing flake8 ?

* Revert "Flake8 update ?"

This reverts commit 0b917fc.

* Style (acctidentally deleted flake8 F401.)

* Move to a bigger test (no small whisper model, and s2t doesn't seem to
accept torch_dtype=fp16).

Also we need to use a GPU to actually compute on fp16.

* Using BatchFeature capability.
silverriver pushed a commit to silverriver/transformers that referenced this pull request Jan 6, 2023
* Supporting `fp16` for asr pipeline

* Adding test.

* Style.

* Oops.

* Flake8 update ?

* Fixing flake8 ?

* Revert "Flake8 update ?"

This reverts commit 0b917fc.

* Style (acctidentally deleted flake8 F401.)

* Move to a bigger test (no small whisper model, and s2t doesn't seem to
accept torch_dtype=fp16).

Also we need to use a GPU to actually compute on fp16.

* Using BatchFeature capability.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Run AutomaticSpeechRecognitionPipeline with FP16

4 participants