Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/transformers/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,12 @@ class WhisperGenerationConfigLike(Protocol):
"""Protocol for Whisper-specific generation config fields accessed in generation internals."""

no_timestamps_token_id: int


class TypedDictSchema(Protocol):
"""Protocol for TypedDict classes that expose their mutable keys."""

__mutable_keys__: set[str]


RequestSchema: TypeAlias = type[TypedDictSchema]
16 changes: 10 additions & 6 deletions src/transformers/cli/add_new_model_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections.abc import Callable
from datetime import date
from pathlib import Path
from typing import Annotated, Any
from typing import Annotated, Any, cast

import typer

Expand Down Expand Up @@ -57,10 +57,11 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine):
body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])]
)
if not self.is_in_class and m.matches(node, simple_top_level_assign_structure):
assigned_variable = node.body[0].targets[0].target.value
stmt = cast(cst.Assign, node.body[0])
assigned_variable = cast(cst.Name, stmt.targets[0].target).value
if assigned_variable == "__all__":
elements = node.body[0].value.elements
self.public_classes = [element.value.value for element in elements]
elements = cast(cst.Tuple, stmt.value).elements
self.public_classes = [cast(cst.SimpleString, element.value).value for element in elements]


CURRENT_YEAR = date.today().year
Expand Down Expand Up @@ -316,7 +317,10 @@ def insert_model_in_doc_toc(
with open(toc_file, "r") as f:
content = f.read()

old_model_toc = re.search(rf"- local: model_doc/{old_lowercase_name}\n {{8}}title: .*?\n", content).group(0)
old_model_toc_match = re.search(rf"- local: model_doc/{old_lowercase_name}\n {{8}}title: .*?\n", content)
if old_model_toc_match is None:
raise ValueError(f"Could not find toc entry for {old_lowercase_name}")
old_model_toc = old_model_toc_match.group(0)
new_toc = f" - local: model_doc/{new_lowercase_name}\n title: {new_model_paper_name}\n"
add_content_to_file(
repo_path / "docs" / "source" / "en" / "_toctree.yml", new_content=new_toc, add_after=old_model_toc
Expand Down Expand Up @@ -392,7 +396,7 @@ def find_modular_structure(
The new cased model name.
"""
all_classes, public_classes = find_all_classes_from_file(module_name)
import_location = ".".join(module_name.parts[-2:]).replace(".py", "")
import_location = ".".join(Path(module_name).parts[-2:]).replace(".py", "")
old_cased_name = old_model_infos.camelcase_name
imports = f"from ..{import_location} import {', '.join(class_ for class_ in all_classes)}"
modular_classes = "\n\n".join(
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/cli/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import re
import string
import time
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Awaitable
from typing import Annotated, Any
from urllib.parse import urljoin, urlparse

Expand Down Expand Up @@ -110,7 +110,9 @@ def __init__(self, model_id: str, user_id: str, base_url: str):
self.user_id = user_id
self.base_url = base_url

async def stream_output(self, stream: AsyncIterator[ChatCompletionStreamOutput]) -> tuple[str, str | Any | None]:
async def stream_output(
self, stream: Awaitable[AsyncIterator[ChatCompletionStreamOutput]]
) -> tuple[str, str | Any | None]:
self._console.print(f"[bold blue]<{self.model_id}>:")
with Live(console=self._console, refresh_per_second=4) as live:
text = ""
Expand Down
Loading
Loading