Adding support for fp16 for asr pipeline.#20864
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
accept torch_dtype=fp16). Also we need to use a GPU to actually compute on fp16.
| inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" | ||
| ) | ||
| if dtype is not None: | ||
| processed = {k: v.to(dtype=dtype) for k, v in processed.items()} |
There was a problem hiding this comment.
Hi @Narsil,
I think this works fine for whisper models because they only have a single value input_features.
But in case of other models like wav2vec2, the model have multiple values of different dtypes, input_values which need to be casted from float32 to float16, and attention_mask I'm not sure to keep as int32 or cast to int16
There was a problem hiding this comment.
Yes. And as above, if you directly use the to method on processed, it will take care of that for you.
sgugger
left a comment
There was a problem hiding this comment.
Thanks for working on this! My only comment is to make sure to leverage the to method on BatchFeature (if the feature extractor here returns another type, maybe make sure its to method handles dtype arguments) so that checks like not converting int inputs are applied for free.
Otherwise LGTM!
| chunk = inputs[i : i + chunk_len] | ||
| processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt") | ||
| if dtype is not None: | ||
| processed = {k: v.to(dtype=dtype) for k, v in processed.items()} |
There was a problem hiding this comment.
I believe you can call the to directly on processed, which is a BatchFeature and handles dtype in its to method thanks to #20536 (was designed for vision but I think it will apply here too).
|
|
||
| def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False): | ||
| def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False, dtype=None): | ||
| print(f"Running with dtype {dtype}") |
| inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" | ||
| ) | ||
| if dtype is not None: | ||
| processed = {k: v.to(dtype=dtype) for k, v in processed.items()} |
There was a problem hiding this comment.
Yes. And as above, if you directly use the to method on processed, it will take care of that for you.
* Supporting `fp16` for asr pipeline * Adding test. * Style. * Oops. * Flake8 update ? * Fixing flake8 ? * Revert "Flake8 update ?" This reverts commit 0b917fc. * Style (acctidentally deleted flake8 F401.) * Move to a bigger test (no small whisper model, and s2t doesn't seem to accept torch_dtype=fp16). Also we need to use a GPU to actually compute on fp16. * Using BatchFeature capability.
* Supporting `fp16` for asr pipeline * Adding test. * Style. * Oops. * Flake8 update ? * Fixing flake8 ? * Revert "Flake8 update ?" This reverts commit 0b917fc. * Style (acctidentally deleted flake8 F401.) * Move to a bigger test (no small whisper model, and s2t doesn't seem to accept torch_dtype=fp16). Also we need to use a GPU to actually compute on fp16. * Using BatchFeature capability.
What does this PR do?
Fixes #20862
Many things were considered before settling for this design.
feature_extractor(return_tensors="pt¨, torch_dtype=torch_dtype). This would have the advantage of being consistent, but not all feature extractors to define this, so it would affect all of them. Then why would we usetorch_dtypeinstead of the more common placedtypewhich could be applied to TF and flax as well. Also it feels a bit redundant to specify bothreturn_tensorsandtorch_dtype, it would be a good candidate to fuse both parameters (but outisde the scope of this PR).AutoFeatureExtractor.from_pretrained(..., torch_dtype=torch_dtype). This would have the advantage of being overall so users don't need to respecify on each call. However we can't specifiyreturn_tensors="pt"in there either, so for consistency I didn't try to put it there.ffmpeg_read(..., dtype=dtype)This would be nice to load directly the waveform into fp16 and just let fp16 flow through the feature_extractor. However, whisper in particular uses mel_spectrogram, so using f16 sound might actually damage performance.In the end, this solution is the simplement I could come up with. Let
torch_dtypeflow to the pipeline, use it as a regular parameter and convert the output of the feature_extractor after.This does incur a potentially extra copy but there's no risk of damaging quality of the input.
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.