diff --git a/src/transformers/_typing.py b/src/transformers/_typing.py index 193ca14b503c..cea3656f043b 100644 --- a/src/transformers/_typing.py +++ b/src/transformers/_typing.py @@ -170,3 +170,12 @@ class WhisperGenerationConfigLike(Protocol): """Protocol for Whisper-specific generation config fields accessed in generation internals.""" no_timestamps_token_id: int + + +class TypedDictSchema(Protocol): + """Protocol for TypedDict classes that expose their mutable keys.""" + + __mutable_keys__: set[str] + + +RequestSchema: TypeAlias = type[TypedDictSchema] diff --git a/src/transformers/cli/add_new_model_like.py b/src/transformers/cli/add_new_model_like.py index c3d4c694ba6c..5d5464f63130 100644 --- a/src/transformers/cli/add_new_model_like.py +++ b/src/transformers/cli/add_new_model_like.py @@ -19,7 +19,7 @@ from collections.abc import Callable from datetime import date from pathlib import Path -from typing import Annotated, Any +from typing import Annotated, Any, cast import typer @@ -57,10 +57,11 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine): body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])] ) if not self.is_in_class and m.matches(node, simple_top_level_assign_structure): - assigned_variable = node.body[0].targets[0].target.value + stmt = cast(cst.Assign, node.body[0]) + assigned_variable = cast(cst.Name, stmt.targets[0].target).value if assigned_variable == "__all__": - elements = node.body[0].value.elements - self.public_classes = [element.value.value for element in elements] + elements = cast(cst.Tuple, stmt.value).elements + self.public_classes = [cast(cst.SimpleString, element.value).value for element in elements] CURRENT_YEAR = date.today().year @@ -316,7 +317,10 @@ def insert_model_in_doc_toc( with open(toc_file, "r") as f: content = f.read() - old_model_toc = re.search(rf"- local: model_doc/{old_lowercase_name}\n {{8}}title: .*?\n", content).group(0) + old_model_toc_match = re.search(rf"- local: model_doc/{old_lowercase_name}\n {{8}}title: .*?\n", content) + if old_model_toc_match is None: + raise ValueError(f"Could not find toc entry for {old_lowercase_name}") + old_model_toc = old_model_toc_match.group(0) new_toc = f" - local: model_doc/{new_lowercase_name}\n title: {new_model_paper_name}\n" add_content_to_file( repo_path / "docs" / "source" / "en" / "_toctree.yml", new_content=new_toc, add_after=old_model_toc @@ -392,7 +396,7 @@ def find_modular_structure( The new cased model name. """ all_classes, public_classes = find_all_classes_from_file(module_name) - import_location = ".".join(module_name.parts[-2:]).replace(".py", "") + import_location = ".".join(Path(module_name).parts[-2:]).replace(".py", "") old_cased_name = old_model_infos.camelcase_name imports = f"from ..{import_location} import {', '.join(class_ for class_ in all_classes)}" modular_classes = "\n\n".join( diff --git a/src/transformers/cli/chat.py b/src/transformers/cli/chat.py index c6e434b07481..b2ebcdd3fc91 100644 --- a/src/transformers/cli/chat.py +++ b/src/transformers/cli/chat.py @@ -18,7 +18,7 @@ import re import string import time -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Awaitable from typing import Annotated, Any from urllib.parse import urljoin, urlparse @@ -110,7 +110,9 @@ def __init__(self, model_id: str, user_id: str, base_url: str): self.user_id = user_id self.base_url = base_url - async def stream_output(self, stream: AsyncIterator[ChatCompletionStreamOutput]) -> tuple[str, str | Any | None]: + async def stream_output( + self, stream: Awaitable[AsyncIterator[ChatCompletionStreamOutput]] + ) -> tuple[str, str | Any | None]: self._console.print(f"[bold blue]<{self.model_id}>:") with Live(console=self._console, refresh_per_second=4) as live: text = "" diff --git a/src/transformers/cli/serve.py b/src/transformers/cli/serve.py index 7337eb305b61..5233289c572c 100644 --- a/src/transformers/cli/serve.py +++ b/src/transformers/cli/serve.py @@ -23,18 +23,18 @@ import threading import time import uuid -from collections.abc import Callable, Generator, Iterable +from collections.abc import Callable, Generator, Iterable, Set from contextlib import asynccontextmanager from functools import lru_cache from io import BytesIO from threading import Thread -from typing import TYPE_CHECKING, Annotated, Optional, TypedDict, Union +from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast import typer from huggingface_hub import scan_cache_dir from tokenizers.decoders import DecodeStream from tqdm import tqdm -from tqdm.auto import tqdm as base_tqdm +from typing_extensions import NotRequired import transformers from transformers import AutoTokenizer, BitsAndBytesConfig, GenerationConfig, PreTrainedTokenizerBase @@ -52,6 +52,9 @@ is_vision_available, ) +from .._typing import RequestSchema +from ..tokenization_utils_base import BatchEncoding + if TYPE_CHECKING: from transformers import ( @@ -74,7 +77,7 @@ ) if serve_dependencies_available: import uvicorn - from fastapi import FastAPI, HTTPException + from fastapi import FastAPI, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from openai.types.audio.transcription import Transcription @@ -109,8 +112,30 @@ ResponseTextDoneEvent, ) from openai.types.responses.response_create_params import ResponseCreateParamsStreaming + from openai.types.responses.response_output_text import Annotation from pydantic import BaseModel, TypeAdapter, ValidationError + def make_response_output_text(text: str) -> ResponseOutputText: + return ResponseOutputText( + type="output_text", + text=text, + annotations=cast(list[Annotation], []), + ) + + def make_completed_response_output_message( + message_id: str, content: list[ResponseOutputText] + ) -> ResponseOutputMessage: + return ResponseOutputMessage.model_validate( + { + "id": message_id, + "type": "message", + "status": "completed", + "role": "assistant", + "content": content, + "annotations": [], + } + ) + # Expand OpenAI's request input types with an optional `generation_config` field class TransformersResponseCreateParamsStreaming(ResponseCreateParamsStreaming, total=False): """ @@ -133,7 +158,7 @@ class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total file: bytes # Overwritten -- pydantic isn't happy with `typing.IO[bytes]`, present in the original type generation_config: str - stream: bool = False + stream: NotRequired[bool] # Contrarily to OpenAI's output types, input types are `TypedDict`, which don't have built-in validation. response_validator = TypeAdapter(TransformersResponseCreateParamsStreaming) @@ -232,6 +257,12 @@ class Modality(enum.Enum): TTS = "TTS" +def require_batch_encoding(obj: object) -> BatchEncoding: + if not isinstance(obj, BatchEncoding): + raise TypeError("Expected BatchEncoding from `apply_chat_template` with `return_dict=True`") + return obj + + def create_generation_config_from_req( req: dict, model_generation_config: GenerationConfig, @@ -442,7 +473,7 @@ def __iter__(self): def set_tqdm_class(callback, mid): download_aggregator = DownloadAggregator(callback, mid) - class ProgressTqdm(base_tqdm): + class ProgressTqdm(tqdm): """tqdm subclass that routes progress to the correct SSE stage. Bars with ``unit="B"`` are download bars (one per file shard) — they are @@ -664,6 +695,8 @@ def __init__( self.input_validation = input_validation self.force_model = force_model self.non_blocking = non_blocking + self._loop: asyncio.AbstractEventLoop | None = None + self._thread: threading.Thread | None = None # Continuous batching configuration arguments self.cb_block_size = cb_block_size @@ -765,14 +798,20 @@ def responses(request: dict): async def audio_transcriptions(request: Request): # Parses the multipart/form-data request into the request format used by other endpoints async with request.form() as form: + file_field = form["file"] + if not isinstance(file_field, UploadFile): + raise TypeError("Expected uploaded file") + model_field = form["model"] + if not isinstance(model_field, str): + raise TypeError("Expected model name") parsed_request = TransformersTranscriptionCreateParams( - file=await form["file"].read(), - model=form["model"], + file=await file_field.read(), + model=model_field, # TODO: add other fields ) logger.debug( - f"Received file: {form['file'].filename}; MIME type: {form['file'].content_type}; " - f"size: {form['file'].size / 1024:.2f} KiB" + f"Received file: {file_field.filename}; MIME type: {file_field.content_type}; " + f"size: {file_field.size / 1024:.2f} KiB" ) self.validate_transcription_request(request=parsed_request) @@ -930,9 +969,9 @@ def reset_loaded_models(self): def _validate_request( self, request: dict, - schema: TypedDict, + schema: RequestSchema, validator: "TypeAdapter", - unused_fields: set, + unused_fields: Set[str], ): """ Validates the request against the schema, and checks for unexpected keys. @@ -940,11 +979,11 @@ def _validate_request( Args: request (`dict`): The request to validate. - schema (`TypedDict`): - The schema of the request to validate. It is a `TypedDict` definition. + schema (`RequestSchema`): + The `TypedDict` class used as the request schema. validator (`TypeAdapter`): The validator to use to validate the request. Built from `schema`. - unused_fields (`set`): + unused_fields (`Set[str]`): Fields accepted by `schema`, but not used in `transformers serve`. Raises: @@ -1080,7 +1119,7 @@ def chunk_to_sse_element(chunk: "ChatCompletionChunk | BaseModel") -> str: @staticmethod @lru_cache - def get_gen_models(cache_dir: str | None = None) -> list[dict[str, any]]: + def get_gen_models(cache_dir: str | None = None) -> list[dict[str, Any]]: """ List LLMs and VLMs in the cache. """ @@ -1198,19 +1237,25 @@ def continuous_batching_chat_completion(self, req: dict, request_id: str) -> "St self.running_continuous_batching_manager.logit_processor = LogitsProcessorList() self.running_continuous_batching_manager.start() + cb_manager = self.running_continuous_batching_manager + if cb_manager is None: + raise RuntimeError("Continuous batching manager failed to initialize") + # TODO (Joao, Lysandre): this should also work with tool support modality = self.get_model_modality(model, processor=processor) processor_inputs = self.get_processor_inputs_from_inbound_messages(req["messages"], modality) - inputs = processor.apply_chat_template( - processor_inputs, - add_generation_prompt=True, - tools=req.get("tools"), - return_tensors="pt", - return_dict=True, - tokenize=True, + chat_inputs = require_batch_encoding( + processor.apply_chat_template( + processor_inputs, + add_generation_prompt=True, + tools=req.get("tools"), + return_tensors="pt", + return_dict=True, + tokenize=True, + ) ) - inputs = inputs["input_ids"][0].to(model.device) + inputs = chat_inputs.to(model.device)["input_ids"][0] def stream_chat_completion(request_id, decode_stream): from ..generation.continuous_batching import RequestStatus @@ -1221,7 +1266,7 @@ def stream_chat_completion(request_id, decode_stream): yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision) n_tokens_generated = 0 - for result in self.running_continuous_batching_manager.request_id_iter(request_id): + for result in cb_manager.request_id_iter(request_id): n_tokens_generated += 1 # Always yield the token content (even for the final FINISHED token) @@ -1257,17 +1302,16 @@ def stream_chat_completion(request_id, decode_stream): except Exception as e: logger.error(str(e)) - self.running_continuous_batching_manager.cancel_request(request_id) + cb_manager.cancel_request(request_id) yield f'data: {{"error": "{str(e)}"}}' def buffer_chat_completion(_request_id): result = None - while self.running_continuous_batching_manager.is_running() and result is None: - result = self.running_continuous_batching_manager.get_result(request_id=_request_id, timeout=1) + while cb_manager.is_running() and result is None: + result = cb_manager.get_result(request_id=_request_id, timeout=1) if result is None: raise RuntimeError(f"Request {_request_id} failed: generation loop stopped before producing a result.") - content = tokenizer.decode(result.generated_tokens) chat_completion_result = ChatCompletion( @@ -1297,7 +1341,7 @@ async def cancellation_wrapper_stream(_request_id): yield self.chunk_to_sse_element(_chunk) await asyncio.sleep(0) except asyncio.CancelledError: - self.running_continuous_batching_manager.cancel_request(_request_id) + cb_manager.cancel_request(_request_id) logger.warning(f"Request {_request_id} was cancelled.") def cancellation_wrapper_buffer(_request_id): @@ -1305,14 +1349,15 @@ def cancellation_wrapper_buffer(_request_id): try: return buffer_chat_completion(_request_id) except asyncio.CancelledError: - self.running_continuous_batching_manager.cancel_request(_request_id) + cb_manager.cancel_request(_request_id) logger.warning(f"Request {_request_id} was cancelled.") - request_id = self.running_continuous_batching_manager.add_request( - inputs, request_id=request_id, max_new_tokens=generation_config.max_new_tokens, streaming=req.get("stream") + stream = req.get("stream", False) + request_id = cb_manager.add_request( + inputs, request_id=request_id, max_new_tokens=generation_config.max_new_tokens, streaming=stream ) - if req.get("stream"): + if stream: return StreamingResponse(cancellation_wrapper_stream(request_id), media_type="text/event-stream") else: chunk = cancellation_wrapper_buffer(request_id) @@ -1428,13 +1473,15 @@ def generate_chat_completion(self, req: dict) -> "StreamingResponse | JSONRespon # 2. force generation to pick from that tool's arguments # ====== END OF TOOL PREPROCESSING LOGIC ====== - inputs = processor.apply_chat_template( - processor_inputs, - add_generation_prompt=True, - tools=req.get("tools"), - return_tensors="pt", - return_dict=True, - tokenize=True, + inputs = require_batch_encoding( + processor.apply_chat_template( + processor_inputs, + add_generation_prompt=True, + tools=req.get("tools"), + return_tensors="pt", + return_dict=True, + tokenize=True, + ) ) inputs = inputs.to(model.device) request_id = req.get("request_id", "req_0") @@ -1451,8 +1498,11 @@ def generate_chat_completion(self, req: dict) -> "StreamingResponse | JSONRespon ) if self.is_continuation(req) and not must_discard_cache: + if self.last_kv_cache is None: + raise RuntimeError("Expected last_kv_cache for continuation request") seq_len = self.last_kv_cache.get_seq_length() - if inputs["input_ids"].shape[-1] > seq_len: + input_ids = inputs["input_ids"] + if input_ids.shape[-1] > seq_len: last_kv_cache = self.last_kv_cache else: last_kv_cache = None @@ -1616,7 +1666,8 @@ def generate_with_cache(**kwargs): finally: thread.join() - if req.get("stream"): + stream = req.get("stream", False) + if stream: return StreamingResponse( map(self.chunk_to_sse_element, stream_chat_completion(generation_streamer, request_id)), media_type="text/event-stream", @@ -1692,8 +1743,12 @@ def generate_response(self, req: dict) -> Generator[str, None, None]: else: raise TypeError("inputs should be a list, dict, or str") - inputs = processor.apply_chat_template(inputs, tokenize=True, add_generation_prompt=True, return_tensors="pt") - inputs = inputs.to(model.device) + chat_inputs = processor.apply_chat_template( + inputs, tokenize=True, add_generation_prompt=True, return_tensors="pt" + ) + if not hasattr(chat_inputs, "to"): + raise TypeError("Expected tensor-like output from apply_chat_template") + inputs = chat_inputs.to(model.device) request_id = req.get("previous_response_id", "req_0") # Temporary hack for GPT-OSS 1: don't filter special tokens @@ -1710,6 +1765,8 @@ def generate_response(self, req: dict) -> Generator[str, None, None]: last_kv_cache = None if self.is_continuation(req) and not must_discard_cache: + if self.last_kv_cache is None: + raise RuntimeError("Expected last_kv_cache for continuation request") seq_len = self.last_kv_cache.get_seq_length() if inputs.shape[-1] > seq_len: last_kv_cache = self.last_kv_cache @@ -1810,7 +1867,7 @@ def generate_with_cache(**kwargs): sequence_number=sequence_number, output_index=output_index, content_index=content_index, - part=ResponseOutputText(type="output_text", text="", annotations=[]), + part=make_response_output_text(""), ) sequence_number += 1 yield self.chunk_to_sse_element(response_content_part_added) @@ -1876,7 +1933,7 @@ def generate_with_cache(**kwargs): sequence_number=sequence_number, output_index=output_index, content_index=content_index, - part=ResponseOutputText(type="output_text", text=response_output_text_done.text, annotations=[]), + part=make_response_output_text(response_output_text_done.text), ) sequence_number += 1 content_index += 1 @@ -1887,13 +1944,9 @@ def generate_with_cache(**kwargs): type="response.output_item.done", sequence_number=sequence_number, output_index=output_index, - item=ResponseOutputMessage( - id=f"msg_{request_id}", - type="message", - status="completed", - role="assistant", + item=make_completed_response_output_message( + message_id=f"msg_{request_id}", content=[response_content_part_done.part], - annotations=[], ), ) sequence_number += 1 @@ -1996,10 +2049,10 @@ def generate_response_non_streaming(self, req: dict) -> dict: else: raise ValueError("inputs should be a list, dict, or str") - inputs = processor.apply_chat_template( - inputs, add_generation_prompt=True, return_tensors="pt", return_dict=True - )["input_ids"] - inputs = inputs.to(model.device) + chat_result = require_batch_encoding( + processor.apply_chat_template(inputs, add_generation_prompt=True, return_tensors="pt", return_dict=True) + ) + input_ids = chat_result.to(model.device)["input_ids"] request_id = req.get("previous_response_id", "req_0") # Temporary hack for GPTOSS 1: don't filter special tokens @@ -2011,13 +2064,15 @@ def generate_response_non_streaming(self, req: dict) -> dict: last_kv_cache = None if self.is_continuation(req) and not must_discard_cache: + if self.last_kv_cache is None: + raise RuntimeError("Expected last_kv_cache for continuation request") seq_len = self.last_kv_cache.get_seq_length() - if inputs.shape[-1] > seq_len: + if input_ids.shape[-1] > seq_len: last_kv_cache = self.last_kv_cache generate_output = model.generate( - inputs=inputs, - attention_mask=torch_ones_like(inputs), + inputs=input_ids, + attention_mask=torch_ones_like(input_ids), generation_config=generation_config, return_dict_in_generate=True, past_key_values=last_kv_cache, @@ -2029,13 +2084,9 @@ def generate_response_non_streaming(self, req: dict) -> dict: full_text = processor.batch_decode(generate_output.sequences, skip_special_tokens=skip_special_tokens)[0] created_at = time.time() - response_output_item = ResponseOutputMessage( - id=f"msg_{request_id}", - type="message", - status="completed", - role="assistant", - content=[ResponseOutputText(type="output_text", text=full_text, annotations=[])], - annotations=[], + response_output_item = make_completed_response_output_message( + message_id=f"msg_{request_id}", + content=[make_response_output_text(full_text)], ) response_completed = Response( id=f"resp_{request_id}", @@ -2072,7 +2123,9 @@ def generate_transcription(self, req: dict) -> Generator[str, None, None]: audio_model, audio_processor = self.load_audio_model_and_processor(model_id_and_revision) generation_streamer = TextIteratorStreamer( - audio_processor.tokenizer, skip_special_tokens=True, skip_prompt=True + audio_processor.tokenizer, + skip_special_tokens=True, + skip_prompt=True, ) generation_config = create_generation_config_from_req( req, model_generation_config=audio_model.generation_config diff --git a/src/transformers/cli/system.py b/src/transformers/cli/system.py index 65180c464724..413e465ec6c5 100644 --- a/src/transformers/cli/system.py +++ b/src/transformers/cli/system.py @@ -47,7 +47,7 @@ def env( """Print information about the environment.""" import safetensors - safetensors_version = safetensors.__version__ + safetensors_version = getattr(safetensors, "__version__", "unknown") accelerate_version = "not installed" accelerate_config = accelerate_config_str = "not found" @@ -114,13 +114,14 @@ def env( elif pt_xpu_available: info["Using XPU in script?"] = "" info["XPU type"] = torch.xpu.get_device_name() - elif pt_hpu_available: + elif pt_hpu_available and hasattr(torch, "hpu"): info["Using HPU in script?"] = "" info["HPU type"] = torch.hpu.get_device_name() - elif pt_npu_available: + elif pt_npu_available and hasattr(torch, "npu"): info["Using NPU in script?"] = "" info["NPU type"] = torch.npu.get_device_name() - info["CANN version"] = torch.version.cann + if hasattr(torch.version, "cann"): + info["CANN version"] = torch.version.cann print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n") print(_format_dict(info)) diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index b238b8b17031..7a75260e98e3 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -580,6 +580,14 @@ class ProcessorMixin(PushToHubMixin): _auto_class = None valid_processor_kwargs = ProcessingKwargs + # Dynamically set sub-processor attributes. Not every processor has all of these; + # they are populated via setattr in __init__ based on each subclass's `attributes`. + tokenizer: Any + feature_extractor: Any + image_processor: Any + video_processor: Any + chat_template: str | dict[str, str] | None + # args have to match the attributes class attribute def __init__(self, *args, **kwargs): # First, extract chat template from kwargs. It can never be a positional arg diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index e61b20c3a4cf..04f2e5972304 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -29,11 +29,12 @@ from collections.abc import Callable, Collection, Mapping, Sequence, Sized from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, NamedTuple, Union +from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Union, overload import numpy as np from huggingface_hub import create_repo, is_offline_mode, list_repo_files from packaging import version +from typing_extensions import TypeVar from . import __version__ from .dynamic_module_utils import custom_object_save @@ -188,7 +189,10 @@ class TokenSpan(NamedTuple): end: int -class BatchEncoding(UserDict): +_V = TypeVar("_V", default=Any) + + +class BatchEncoding(UserDict, Generic[_V]): """ Holds the output of the [`~tokenization_utils_base.PreTrainedTokenizerBase.__call__`], [`~tokenization_utils_base.PreTrainedTokenizerBase.encode_plus`] and @@ -248,7 +252,16 @@ def n_sequences(self) -> int | None: """ return self._n_sequences - def __getitem__(self, item: int | str) -> Any | EncodingFast: + @overload + def __getitem__(self, item: str) -> _V: ... + + @overload + def __getitem__(self, item: int) -> EncodingFast: ... + + @overload + def __getitem__(self, item: slice) -> dict[str, _V]: ... + + def __getitem__(self, item: int | str | slice) -> _V | EncodingFast | dict[str, _V]: """ If the key is a string, returns the value of the dict associated to `key` ('input_ids', 'attention_mask', etc.). @@ -753,7 +766,7 @@ def as_tensor(value, dtype=None): return self - def to(self, device: str | torch.device, *, non_blocking: bool = False) -> BatchEncoding: + def to(self, device: str | torch.device, *, non_blocking: bool = False) -> BatchEncoding[torch.Tensor]: """ Send all values to device by calling `v.to(device, non_blocking=non_blocking)` (PyTorch only). @@ -1291,7 +1304,7 @@ def __getattr__(self, key): if key not in self.__dict__: raise AttributeError(f"{self.__class__.__name__} has no attribute {key}") - return super().__getattr__(key) + return self.__dict__[key] def get_special_tokens_mask( self, token_ids_0: list[int], token_ids_1: list[int] | None = None, already_has_special_tokens: bool = False @@ -2003,7 +2016,7 @@ def save_pretrained( if push_to_hub: commit_message = kwargs.pop("commit_message", None) - repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = kwargs.pop("repo_id", str(save_directory).split(os.path.sep)[-1]) repo_id = create_repo(repo_id, exist_ok=True, **kwargs).repo_id files_timestamps = self._get_files_timestamps(save_directory) @@ -2629,6 +2642,9 @@ def pad( # Call .keys() explicitly for compatibility with TensorDict and other Mapping subclasses encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()} + if isinstance(encoded_inputs, (list, tuple)): + raise TypeError("encoded_inputs should be a mapping or a list of mappings") + # The model's main input name, usually `input_ids`, has been passed for padding if self.model_input_names[0] not in encoded_inputs: raise ValueError( diff --git a/utils/check_types.py b/utils/check_types.py index 2400863bc2d3..513f9f53ed3d 100644 --- a/utils/check_types.py +++ b/utils/check_types.py @@ -17,12 +17,14 @@ "src/transformers/_typing.py", "src/transformers/utils/**/*.py", "src/transformers/generation/**/*.py", + "src/transformers/cli/**/*.py", "src/transformers/quantizers/**/*.py", ], "check_args": [ "src/transformers/_typing.py", "src/transformers/utils", "src/transformers/generation", + "src/transformers/cli", "src/transformers/quantizers", ], "fix_args": None,