-
Notifications
You must be signed in to change notification settings - Fork 33.1k
Parakeet tdt #44171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Parakeet tdt #44171
Changes from all commits
fa7d6e0
f2b4938
fa36657
05e2e34
bb5ff33
9ec79b0
33f128e
760b4b6
b33002f
48b39dd
e9f23ab
e2b97aa
6b9fc73
6c879bc
149e17f
36bfa63
2df0ccc
388c6d3
08b2b55
0c4e05a
1ddd804
f512670
07d8e35
fab050a
c438565
86d980c
895c4a0
ab21380
d0141d5
f7529d4
77b95d7
94eae66
f7d4067
f75c17b
5a49b65
881233f
b41a8ee
897753a
6c914db
756cee1
f30c536
fa95fc8
5df7f28
cd706d4
a47ed8a
13b68ce
72c1ad0
8e23b3d
1cc39fd
531f297
2c0f23a
cef6639
fd3cf9b
e63a5bf
f9d1a4f
43ee7cd
c2a0f78
1fd7ed7
7cc9d2e
e753eab
ed3fa4d
ab66b23
a5ba0c6
48279a6
1d7680d
59ddced
31490d1
d8eb1b6
1f1b912
9ab08d1
fd9f8b1
136f676
833d289
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1459,6 +1459,13 @@ def compute_transition_scores( | |
| def _validate_generation_mode( | ||
| self: "GenerativePreTrainedModel", generation_mode, generation_config, generation_mode_kwargs | ||
| ): | ||
| supported_modes = getattr(self, "_supported_generation_modes", None) | ||
| if supported_modes is not None and generation_mode not in supported_modes: | ||
| raise ValueError( | ||
| f"{self.__class__.__name__} only supports {supported_modes}, but got " | ||
| f"generation mode '{generation_mode}'." | ||
| ) | ||
|
|
||
|
Comment on lines
+1462
to
+1468
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added this to be able to do class ParakeetForTDT(ParakeetPreTrainedModel, ParakeetTDTGenerationMixin):
_supported_generation_modes = [GenerationMode.GREEDY_SEARCH] |
||
| if generation_mode == GenerationMode.BEAM_SEARCH and "streamer" in generation_mode_kwargs: | ||
| raise ValueError( | ||
| "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -286,6 +286,7 @@ def register_kernel_mapping_transformers(*args, **kwargs): | |
| "falcon_mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "version": 1}, | ||
| "finegrained-fp8": {"repo_id": "kernels-community/finegrained-fp8", "version": 1}, | ||
| "deep-gemm": {"repo_id": "kernels-community/deep-gemm", "version": 1}, | ||
| "tdt-loss": {"repo_id": "eustlb/tdt-loss", "revision": "v1"}, | ||
| } | ||
|
Comment on lines
288
to
290
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ErikKaum pinging you here because your YouTube kernel tutorial helped a lot for this 😊 What are the next steps to move my tdt kernel from my repo to kernels-community and compile for other environments?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @eustlb thanks for creating the kernel! btw I changed from And maybe we need to also add the source to the I guess @ErikKaum will have have best practice tips!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here I just used the same convention as for other hub kernels: |
||
|
|
||
| _KERNEL_MODULE_MAPPING: dict[str, ModuleType | None] = {} | ||
|
|
@@ -372,10 +373,12 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _ | |
| repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"] | ||
| revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None) | ||
| version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None) | ||
| kernel = get_kernel(repo_id, revision=revision, version=version) | ||
| # Since we only read from `_HUB_KERNEL_MAPPING`, we can allow all kernels | ||
| kernel = get_kernel(repo_id, revision=revision, version=version, allow_all_kernels=True) | ||
|
Comment on lines
+376
to
+377
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we hardcode |
||
| mapping[kernel_name] = kernel | ||
| except FileNotFoundError: | ||
| except FileNotFoundError as e: | ||
| mapping[kernel_name] = None | ||
| logger.warning_once(f"Failed to load kernel {kernel_name}: {e}") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding a helpful error message, otherwise kernel may not load without notifying the user! E.g. due to different Torch. For example it will now print: |
||
| except AssertionError: | ||
| # Happens when torch is built without an accelerator backend; fall back to slow path. | ||
| mapping[kernel_name] = None | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.