diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 86e0808d885f..2d224888da55 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -899,6 +899,8 @@ title: DAB-DETR - local: model_doc/deformable_detr title: Deformable DETR + - local: model_doc/deimv2 + title: DEIMv2 - local: model_doc/deit title: DeiT - local: model_doc/depth_anything diff --git a/docs/source/en/model_doc/deimv2.md b/docs/source/en/model_doc/deimv2.md new file mode 100644 index 000000000000..f19aafad2158 --- /dev/null +++ b/docs/source/en/model_doc/deimv2.md @@ -0,0 +1,65 @@ + +*This model was released on 2025-09-25 and added to Hugging Face Transformers on 2026-04-22.* + +# DEIMv2 + +## Overview + +DEIMv2 (DETR with Improved Matching v2) was proposed in [DEIMv2: Real-Time Object Detection Meets DINOv3](https://huggingface.co/papers/2509.20787) by Shihua Huang, Yongjie Hou, Longfei Liu, Xuanlong Yu, and Xi Shen. + +The abstract from the paper is the following: + +*Driven by the simple and effective Dense O2O, DEIM demonstrates faster convergence and enhanced performance. In this work, we extend it with DINOv3 features, resulting in DEIMv2. DEIMv2 spans eight model sizes from X to Atto, covering GPU, edge, and mobile deployment. For the X, L, M, and S variants, we adopt DINOv3-pretrained / distilled backbones and introduce a Spatial Tuning Adapter (STA), which efficiently converts DINOv3's single-scale output into multi-scale features and complements strong semantics with fine-grained details to enhance detection. For ultra-lightweight models (Nano, Pico, Femto, and Atto), we employ HGNetv2 with depth and width pruning to meet strict resource budgets. Together with a simplified decoder and an upgraded Dense O2O, this unified design enables DEIMv2 to achieve a superior performance-cost trade-off across diverse scenarios, establishing new state-of-the-art results. Notably, our largest model, DEIMv2-X, achieves 57.8 AP with only 50.3M parameters, surpassing prior X-scale models that require over 60M parameters for just 56.5 AP. On the compact side, DEIMv2-S is the first sub-10M model (9.71M) to exceed the 50 AP milestone on COCO, reaching 50.9 AP. Even the ultra-lightweight DEIMv2-Pico, with just 1.5M parameters, delivers 38.5 AP-matching YOLOv10-Nano (2.3M) with ~50% fewer parameters.* + +## Usage + +```python +from transformers import AutoImageProcessor, AutoModelForObjectDetection +from transformers.image_utils import load_image + +url = "http://images.cocodataset.org/val2017/000000039769.jpg" +image = load_image(url) + +image_processor = AutoImageProcessor.from_pretrained("harshaljanjani/DEIMv2_HGNetv2_N_COCO_Transformers") +model = AutoModelForObjectDetection.from_pretrained("harshaljanjani/DEIMv2_HGNetv2_N_COCO_Transformers", device_map="auto") + +inputs = image_processor(images=image, return_tensors="pt").to(model.device) +outputs = model(**inputs) + +results = image_processor.post_process_object_detection( + outputs, threshold=0.5, target_sizes=[image.size[::-1]] +) + +for result in results: + for score, label, box in zip(result["scores"], result["labels"], result["boxes"]): + box = [round(i, 2) for i in box.tolist()] + print(f"Detected {model.config.id2label[label.item()]} with confidence {round(score.item(), 3)} at location {box}") +``` + +## Deimv2Config + +[[autodoc]] Deimv2Config + +## Deimv2Model + +[[autodoc]] Deimv2Model + - forward + +## Deimv2ForObjectDetection + +[[autodoc]] Deimv2ForObjectDetection + - forward diff --git a/docs/source/en/modeling_rules.md b/docs/source/en/modeling_rules.md index d3b6e48bd7c4..0591a79f89b3 100644 --- a/docs/source/en/modeling_rules.md +++ b/docs/source/en/modeling_rules.md @@ -13,22 +13,22 @@ specific language governing permissions and limitations under the License. # Model structure rules -Transformers enforces a set of static rules on every `modeling_*.py`, `modular_*.py`, and `configuration_*.py` file. The [mlinter](https://github.com/huggingface/transformers-mlinter) tool checks them as part of `make typing` and errors out if violations are found. +Transformers enforces a set of static rules on every `modeling_*.py`, `modular_*.py`, and `configuration_*.py` file. The [mlinter](https://github.com/huggingface/transformers-mlinter) package provides the checker engine, and the repository keeps its active rule set in `utils/rules.toml`. That local TOML lets us enable, disable, or tweak rules quickly without waiting for a new `transformers-mlinter` release. These are the expected model conventions for adding or changing modeling code. They keep the codebase consistent and ensure compatibility with features like pipeline parallelism, device maps, and weight tying. ## Running the checker -`make typing` runs `mlinter` alongside the `ty` type checker. Run `mlinter` on its own with the following commands. +`make typing` runs `mlinter` alongside the `ty` type checker through the repo wrapper, so it picks up `utils/rules.toml`. Run the same wrapper directly with the following commands. ```bash -mlinter # check all modeling files -mlinter --changed-only # check only files changed vs origin/main -mlinter --list-rules # list all rules and their enabled status -mlinter --rule TRF001 # show built-in docs for a specific rule +python utils/check_modeling_structure.py # check all modeling files +python utils/check_modeling_structure.py --changed-only # check only files changed vs origin/main +python utils/check_modeling_structure.py --list-rules # list all rules and their enabled status +python utils/check_modeling_structure.py --rule TRF001 # show built-in docs for a specific rule ``` -The `--changed-only` flag is the fastest option during development. It only checks the files you've modified relative to the main branch. +The `--changed-only` flag is the fastest option during development. It only checks the files you've modified relative to the main branch. If you invoke `mlinter` directly instead of the wrapper, pass `--rules-toml utils/rules.toml` so local overrides are applied. ## Fixing a violation @@ -52,7 +52,7 @@ Use the rule ID to look up the fix in the [rules reference](#rules-reference). T ## Rules reference -Each rule below lists what it enforces and a diff showing the fix. Run `mlinter --rule TRF001` to see the built-in docs for any rule. +Each rule below lists what it enforces and a diff showing the fix. Run `python utils/check_modeling_structure.py --rule TRF001` to see the built-in docs for any rule with the repo's current rule set. diff --git a/docs/source/en/serve-cli/serving.md b/docs/source/en/serve-cli/serving.md index 783eb0c8dd87..83dcb9e88d9a 100644 --- a/docs/source/en/serve-cli/serving.md +++ b/docs/source/en/serve-cli/serving.md @@ -456,7 +456,7 @@ data: {"id":"f47ac10b-58cc-4372-a567-0e02b2c3d479","choices":[{"delta":{"content ### Audio-based completions -Multimodal models like [Gemma 4](https://huggingface.co/google/gemma-4-E2B-it) and [Qwen2.5-Omni](https://huggingface.co/Qwen/Qwen2.5-Omni-3B) accept audio input using the OpenAI `input_audio` content type. The audio must be base64-encoded and the format (`mp3` or `wav`) must be specified. +Multimodal models like [Gemma 4](https://huggingface.co/google/gemma-4-E2B-it) and [Qwen2.5-Omni](https://huggingface.co/Qwen/Qwen2.5-Omni-3B) accept audio input through the OpenAI `input_audio` content type. Base64-encode the audio and specify the format (`mp3` or `wav`). @@ -695,7 +695,7 @@ data: {"id":"cb997e1d-98b9-414a-be89-1880288610ef","choices":[{"delta":{"content > [!WARNING] > The `audio_url` content type is an extension not part of the OpenAI standard and may change in future versions. -As a convenience, audio can also be passed by URL using the `audio_url` content type, avoiding the need for base64 encoding. +You can also pass audio by URL with the `audio_url` content type to skip base64 encoding. ```python completion = client.chat.completions.create( @@ -717,7 +717,7 @@ completion = client.chat.completions.create( > [!WARNING] > The `video_url` content type is an extension not part of the OpenAI standard and may change in future versions. -Video input is supported using the `video_url` content type. If the model supports audio (e.g. Gemma 4, Qwen2.5-Omni), the audio track is automatically extracted from the video and processed alongside the visual frames. +Use the `video_url` content type for video input. If the model supports audio (e.g. Gemma 4, Qwen2.5-Omni), the server extracts the audio track from the video and processes it with the visual frames. > [!TIP] > Video processing requires [torchcodec](https://github.com/pytorch/torchcodec). Install it with `pip install torchcodec`. @@ -934,7 +934,7 @@ data: {"id":"cb997e1d-98b9-414a-be89-1880288610ef","choices":[{"delta":{"content -### Multi-turn conversations +### Multi-turn conversations[[completions]] To have a multi-turn conversation, include the full conversation history in the `messages` list with alternating `user` and `assistant` roles. Like all OpenAI-compatible servers, the API is stateless, so every request must contain the complete conversation history. @@ -954,7 +954,7 @@ completion = client.chat.completions.create( print(completion.choices[0].message.content) ``` -The follow-up question "How many people live there?" relies on the prior context, and the model answers about Paris accordingly. +The follow-up question "How many people live there?" relies on the prior context, so the model answers about Paris. ``` As of 2021, the population of Paris is approximately 2.2 million people. @@ -1466,7 +1466,7 @@ data: {"content_index":0,"delta":"This ","item_id":"msg_a1b2c3d4","output_index" > [!WARNING] > The `audio_url` content type is an extension not part of the OpenAI standard and may change in future versions. -As a convenience, audio can also be passed by URL using the `audio_url` content type, avoiding the need for base64 encoding. +You can also pass audio by URL with the `audio_url` content type to skip base64 encoding. ```python response = client.responses.create( @@ -1621,7 +1621,7 @@ data: {"content_index":0,"delta":"Based ","item_id":"msg_b2c3d4e5","output_index -### Multi-turn conversations +### Multi-turn conversations[[responses]] For multi-turn conversations, pass a list of messages with `role` keys in the `input` field. Like all OpenAI-compatible servers, the API is stateless, so every request must contain the complete conversation history. @@ -1643,7 +1643,7 @@ response = client.responses.create( print(response.output[0].content[0].text) ``` -The follow-up question "How many people live there?" relies on the prior context, and the model answers about Paris accordingly. +The follow-up question "How many people live there?" relies on the prior context, so the model answers about Paris. ``` As of 2021, Paris has a population of approximately 2.8 million people. @@ -1734,7 +1734,7 @@ The stream ends with exactly one terminal event, `ready` (success) or `error` (f ## Timeout -`transformers serve` supports different requests by different models. Each model loads on demand and stays in GPU memory. Models unload automatically after 300 seconds of inactivity to free up GPU memory. Set `--model-timeout` to a different value in seconds, or `-1` to disable unloading entirely. +`transformers serve` handles requests for any model. Each model loads on demand and stays in GPU memory. Models unload automatically after 300 seconds of inactivity to free GPU memory. Set `--model-timeout` to a different value in seconds, or `-1` to disable unloading. ```shell transformers serve --model-timeout 400 @@ -1742,7 +1742,7 @@ transformers serve --model-timeout 400 ### Loading examples -See the example responses below for a freshly downloaded model, a model loaded from your local cache (skips the download stage), and a model that already exists in memory. +The examples below show responses for a freshly downloaded model, a model loaded from your local cache (skips the download stage), and a model already in memory. @@ -1784,7 +1784,7 @@ data: {"status": "ready", "model": "org/model@main", "cached": true} The `transformers serve` server supports OpenAI-style function calling. Models trained for tool-use generate structured function calls that your application executes. > [!NOTE] -> Tool calling is currently limited to the Qwen model family. +> Tool calling works with any model whose tokenizer declares tool call tokens. Qwen and Gemma 4 work out of the box. Open an [issue](https://github.com/huggingface/transformers/issues/new/choose) to request support for a specific model. Define tools as a list of function specifications following the OpenAI format. @@ -1846,6 +1846,79 @@ for event in response: print(event) ``` +### Multi-turn tool calling + +After the model returns a tool call, execute the function locally, then send the result back in a follow-up request to get the model's final answer. The pattern differs slightly between the two APIs. See the [OpenAI function calling guide](https://developers.openai.com/api/docs/guides/function-calling?api-mode=chat) for the full spec. + +The examples below reuse the `tools` list defined above. + + + + +Pass the tool result as a `role: "tool"` message with the matching `tool_call_id`. + +```py +# Model returns a tool call +messages = [{"role": "user", "content": "What's the weather like in San Francisco?"}] +response = client.chat.completions.create( + model="Qwen/Qwen2.5-7B-Instruct", + messages=messages, + tools=tools, +) +assistant_message = response.choices[0].message + +# Execute the tool locally +tool_call = assistant_message.tool_calls[0] +result = {"temperature": 22, "condition": "sunny"} # your actual function call here + +# Send the tool result back +messages.append(assistant_message) +messages.append({ + "role": "tool", + "tool_call_id": tool_call.id, + "content": json.dumps(result), +}) +final_response = client.chat.completions.create( + model="Qwen/Qwen2.5-7B-Instruct", + messages=messages, + tools=tools, +) +print(final_response.choices[0].message.content) +``` + + + + +Pass the tool result as a `function_call_output` item in the `input` list of the follow-up request. + +```py +user_message = {"role": "user", "content": "What's the weather like in San Francisco?"} +response = client.responses.create( + model="Qwen/Qwen2.5-7B-Instruct", + input=[user_message], + tools=tools, + stream=False, +) +tool_call = next(item for item in response.output if item.type == "function_call") + +result = {"temperature": 22, "condition": "sunny"} + +final_response = client.responses.create( + model="Qwen/Qwen2.5-7B-Instruct", + input=[ + user_message, + tool_call.model_dump(exclude_none=True), + {"type": "function_call_output", "call_id": tool_call.call_id, "output": json.dumps(result)}, + ], + tools=tools, + stream=False, +) +print(final_response.output_text) +``` + + + + ## Port forwarding Port forwarding lets you serve models from a remote server. Make sure you have SSH access to the server, then run this command on your local machine. diff --git a/setup.py b/setup.py index 2e6adca0315c..42c865b1b9ba 100644 --- a/setup.py +++ b/setup.py @@ -124,7 +124,9 @@ "rjieba", "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", "ruff==0.14.10", - "transformers-mlinter==0.1.0", + # When bumping `transformers-mlinter`, sync repo-local rule overrides from + # `utils/rules.toml` back into the released package. + "transformers-mlinter==0.1.1", "ty==0.0.20", # `sacrebleu` not used in `transformers`. However, it is needed in several tests, when a test calls # `evaluate.load("sacrebleu")`. This metric is used in the examples that we use to test the `Trainer` with, in the diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 399b0be222e9..1a721ca2a82a 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -56,7 +56,7 @@ "rjieba": "rjieba", "rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", "ruff": "ruff==0.14.10", - "transformers-mlinter": "transformers-mlinter==0.1.0", + "transformers-mlinter": "transformers-mlinter==0.1.1", "ty": "ty==0.0.20", "sacrebleu": "sacrebleu>=1.4.12,<2.0.0", "sacremoses": "sacremoses", diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 308c42564295..f601a97959c6 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -1556,8 +1556,10 @@ class ContinuousBatchingConfig: Number of blocks in the KV cache. Auto-inferred from GPU memory when `None`. max_batch_tokens (`int`, *optional*): Maximum number of tokens in a batch. Auto-inferred from GPU memory when `None`. - max_memory_percent (`float`, *optional*, defaults to 0.8): - Maximum percentage of free GPU memory (after the model is loaded) to use for the KV cache. + max_memory_percent (`float`, *optional*): + Maximum percentage of free GPU memory (after the model is loaded) to use for the KV cache. When `None`, + resolved at runtime to 0.9 if there is no logit processing and 0.8 if there is, to leave headroom for + vocabulary-sized temporary tensors. max_blocks_per_request (`int`, *optional*, defaults to 0): Maximum blocks per request, used in the `flash_attn_with_kvcache` fast decode path to dimension the block table. Setting this to 0 disables the fast decode path. @@ -1607,8 +1609,9 @@ class ContinuousBatchingConfig: num_blocks: int | None = None max_batch_tokens: int | None = None - # The max percentage of free GPU memory (after the model is loaded) to use for the KV cache. - max_memory_percent: float = 0.8 + # The max percentage of free GPU memory (after the model is loaded) to use for the KV cache. If None, auto resolved + # to 0.9 (no logit processing) or 0.8 (logit processing) to leave headroom for temporary tensors. + max_memory_percent: float | None = None # This is only used in the flash_attn_with_kvcache fast decode path to dimension the block table. If it is set to 0, # the fast decode path will not be used. Currently turned off by default. @@ -1773,6 +1776,13 @@ def decide_use_async_batching(self, is_attn_mask_needed: bool) -> bool: ) return self.use_async_batching + def resolve_max_memory_percent(self, has_logit_processors: bool) -> None: + """Resolves `max_memory_percent` when unset: 0.9 without logit processors, 0.8 with them. Active processors + materialize `[N, V]` intermediates (e.g. top-p sort, softmax) that get captured into the CUDA graph pool, so + the cache has to cede some budget to that pool.""" + if self.max_memory_percent is None: + self.max_memory_percent = 0.8 if has_logit_processors else 0.9 + def resolve_sentinel_values(self) -> None: """For some parameters (padding intervals and max cached graphs), the default is a sentinel value of 0: that way, if the user specifies a value for those parameters, we know they want it used, ie. we turn on cuda graphs. diff --git a/src/transformers/generation/continuous_batching/cache.py b/src/transformers/generation/continuous_batching/cache.py index 9fd0d3afba11..59de60bc957c 100644 --- a/src/transformers/generation/continuous_batching/cache.py +++ b/src/transformers/generation/continuous_batching/cache.py @@ -182,15 +182,30 @@ def __init__( else: num_attention_masks = 1 + # Peak activations coefficients (for number of blocks and number of batch tokens) + q_bytes_per_token = config.num_attention_heads * self.head_dim + lm_head_peak = ( + 0, # number of blocks does not affect the LM head peak activation + config.hidden_size + 2 * config.vocab_size, # hidden states + logits + ) + attention_peak = ( + 2 * page_size, # old K and V, read from cache (in the worst case scenario: whole cache is read) + config.hidden_size + q_bytes_per_token + 2 * page_size, # hidden state + Q + new K and V + ) + memory_handler = PagedAttentionMemoryHandler( - block_size=self.block_size, + continuous_batching_config=continuous_batching_config, page_size=page_size, num_groups=self.num_groups, group_size=group_size, - peak_activation_per_token=(config.hidden_size + config.vocab_size), + activation_peaks=[lm_head_peak, attention_peak], num_attention_masks=num_attention_masks, - continuous_batching_config=continuous_batching_config, ) + + # If somehow the max memory percent is not yet resolved, resolve it conservatively + if continuous_batching_config.max_memory_percent is None: + continuous_batching_config.resolve_max_memory_percent(has_logit_processors=True) + num_blocks, max_batch_tokens = memory_handler.infer_num_blocks_and_max_batch_tokens( num_blocks=continuous_batching_config.num_blocks, max_batch_tokens=continuous_batching_config.max_batch_tokens, @@ -316,17 +331,20 @@ def extend_read_and_write_indices( request_id: str, past_length: int, query_length: int, - read_index: list[list[int]], + read_index: list[list[int]] | None, write_index: list[list[int]], ) -> None: """Retrieve physical cache indices for reading KV states in the cache across all layer groups. This method coordinates with all cache managers to build the complete set of read indices needed for attention computation. + When read_index is None, the batch has no cache reads and we only compute the write indices. """ - for cm, read_indices, write_indices in zip(self.group_cache_managers, read_index, write_index): - indices = cm.get_read_indices(request_id, past_length, query_length) - read_indices.extend(indices) - indices = cm.get_write_indices(request_id, past_length, query_length) - write_indices.extend(indices) + # Write indices are always computed + for cm, write_indices in zip(self.group_cache_managers, write_index): + write_indices.extend(cm.get_write_indices(request_id, past_length, query_length)) + # Read indices are only computed if there are cache indices + if read_index is not None: + for cm, read_indices in zip(self.group_cache_managers, read_index): + read_indices.extend(cm.get_read_indices(request_id, past_length, query_length)) def fill_block_table( self, request_id: str, past_length: int, query_length: int, block_table: torch.Tensor @@ -355,26 +373,34 @@ def update( read_index: list[torch.Tensor], # shape [num_layer_groups, seqlen_kv + past_length] write_index: list[torch.Tensor], # shape [num_layer_groups, seqlen_q] ) -> tuple[torch.Tensor, torch.Tensor]: # shape [seqlen_kv + past_length, num_kv_heads, head_dim] - """Update the cache with new key-value states for a specific layer. This method writes new KV states to the - appropriate cache locations. The behavior differs based on the layer's attention type: + """Update the cache with new key-value states for a specific layer, and retrieves the relevant KV states from + the cache for attention computation. The behavior differs based on the layer's attention type: - Full attention: New KV states are written to cache, then complete sequence is read from cache - Sliding window: Old KV is read from cache along with extra spaces for the new KV, then new KV is written to cache. This is because new KV might overwrite the old KV, so we need to read the old KV first. + When the layer's read index is empty, the batch has no cache reads (all requests are non-chunked prefills): we + only write to the cache and return the input KV states directly, skipping the index_select read-back. + Returns the complete KV states (cached + new) for attention computation. """ - # Retrieve the layer read and write indices + # Retrieve the layer write index and the relevant cache tensors group_idx, layer_idx_in_group = self.layer_index_to_group_indices[layer_idx] layer_read_index = read_index[group_idx] layer_write_index = write_index[group_idx] - # Select the correct cache k_cache = self.key_cache[layer_idx_in_group] v_cache = self.value_cache[layer_idx_in_group] # Transpose the key and value states to match the cache shape, after which shape is [seqlen_kv, num_kv_heads, head_dim] key_states = key_states.transpose(1, 2).squeeze(0) value_states = value_states.transpose(1, 2).squeeze(0) + # Case: write-only, no cache read. The input KV states already contain everything the attention needs. + if layer_read_index.numel() == 0: + k_cache.index_copy_(0, layer_write_index, key_states) + v_cache.index_copy_(0, layer_write_index, value_states) + return key_states, value_states + # Case: full attention sliding_window = self.sliding_windows[layer_idx] if sliding_window == 1: @@ -509,25 +535,26 @@ class PagedAttentionMemoryHandler: _activation_dtype = torch.bfloat16 _input_dtype = torch.int32 - _upper_bound_max_batch_tokens = 256 + _upper_bound_max_batch_tokens = 1024 _upper_bound_num_blocks = 4096 def __init__( self, - block_size: int, + continuous_batching_config: ContinuousBatchingConfig, page_size: int, num_groups: int, group_size: int, - peak_activation_per_token: int, + activation_peaks: list[tuple[int, int]], num_attention_masks: int, - continuous_batching_config: ContinuousBatchingConfig, ) -> None: - """Initialize the memory handler.""" - self.block_size = block_size + """Initialize the memory handler. `activation_peaks` is a list of `(Δcn, Δcm)` pairs giving the activation memory + contributions proportional to N (pages) and M (batch tokens) for each peak. Memory must satisfy the constraint + at every peak, so we solve each polynomial independently and take the most restrictive result.""" + self.block_size = continuous_batching_config.block_size self.page_size = page_size self.num_groups = num_groups self.group_size = group_size - self.peak_activation_per_token = peak_activation_per_token + self.activation_peaks = activation_peaks self.num_attention_masks = num_attention_masks self.max_blocks_per_request = continuous_batching_config.max_blocks_per_request or 0 # This is the number of output rows for the output_ids tensor @@ -545,23 +572,29 @@ def get_available_memory(max_memory_percent: float = 1.0) -> int: # Formatting is disabled because of comment indentation, which improves readability. # fmt: off - def _equation_coefficients(self, cache_dtype: torch.dtype) -> tuple[int, int, int, int]: - """Returns (coeff_n, coeff_m, coeff_nm, coeff_mm) for the memory polynomial. Each addend is annotated with - the tensor it corresponds to in `ContinuousBatchingIOs._setup_static_tensors`. + def _equation_coefficients( + self, peak: tuple[int, int], cache_dtype: torch.dtype + ) -> tuple[int, int, int, int]: + """Returns `(coeff_n, coeff_m, coeff_nm, coeff_mm)` for the memory polynomial of a single activation peak. + `peak = (Δcn, Δcm)` is the peak-specific activation contribution; the rest of the coefficients are shared + across peaks. Each addend is annotated with the tensor it corresponds to in + `ContinuousBatchingIOs._setup_static_tensors` (or the forward pass, for activation terms). """ i = self._input_dtype.itemsize # int32 a = self._activation_dtype.itemsize # bfloat16 c = cache_dtype.itemsize k = self.io_multiplier # 1 sync, 2 async (IO tensors only) + delta_n, delta_m = peak # -- N terms: cost per cache page -------------------------------------------------- coeff_n = ( 2 * self.group_size * self.page_size * c # kv_cache: 2 * group_size * [N, page_size] * cache_dtype + k * self.num_groups * 8 # read_index: [num_groups, N + M] (N part only, int64) + + delta_n * a # activation peak: N-proportional part ) # -- M terms: cost per batch token ------------------------------------------------- coeff_m = ( - self.peak_activation_per_token * a # activation peak (largest hidden state per token) + delta_m * a # activation peak: M-proportional part + k * 7 * i # bulk_input: [7, M] int32, packed as 7 rows + k * self.num_output_rows * i # output_ids: [num_output_rows, M] int32 + k * self.num_groups # block_table: [bt_groups, M, max_blocks_per_req] int32 @@ -569,9 +602,9 @@ def _equation_coefficients(self, cache_dtype: torch.dtype) -> tuple[int, int, in + k * self.num_groups * 8 # write_index: [num_groups, M] int64 + k * self.num_groups * 8 # read_index: [num_groups, N + M] (M part only, int64) ) - # -- N·M terms: cost per (page × batch token) ------------------------------------- + # -- N·M terms: cost per (page × batch token) -------------------------------------- coeff_nm = k * self.num_attention_masks * a # attention_mask: [1, 1, M, N + M] (N·M part only) - # -- M² terms: cost per (batch token squared) ------------------------------------- + # -- M² terms: cost per (batch token squared) -------------------------------------- coeff_mm = k * self.num_attention_masks * a # attention_mask: [1, 1, M, N + M] (M² part only) return coeff_n, coeff_m, coeff_nm, coeff_mm @@ -590,55 +623,80 @@ def _solve_quadratic(a: float, b: float, c: float) -> float: raise ValueError(f"No positive solution (root = {root})") return root - def infer_num_blocks_and_max_batch_tokens( + def _solve_for_peak( self, - num_blocks: int | None = None, - max_batch_tokens: int | None = None, - max_memory_percent: float = 0.8, # FIXME: it seems we overcommit memory, was changed from 0.9 which caused OOMs in our benchmarking CI - cache_dtype: torch.dtype = torch.float16, + peak: tuple[int, int], + available: int, + num_blocks: int | None, + max_batch_tokens: int | None, + cache_dtype: torch.dtype, ) -> tuple[int, int]: - """Solve for the missing variable(s) in the memory polynomial (see ``_equation_coefficients``). When both - are unknown, assumes M = m·N (m = 0.01, i.e. one batch fills ~1 % of the cache) and solves the resulting - quadratic in N. - """ - available = self.get_available_memory(max_memory_percent) - coeff_n, coeff_m, coeff_nm, coeff_mm = self._equation_coefficients(cache_dtype) - logger.info(f"Cache memory: {available}") + """Solve for `(num_blocks, max_batch_tokens)` against one activation peak's memory polynomial. Clamps to upper + bounds. Either input may be None; whichever is None is solved for.""" + cn, cm, cnm, cmm = self._equation_coefficients(peak, cache_dtype) if num_blocks is None and max_batch_tokens is None: # Substitute M = m·N → (coeff_nm·m + coeff_mm·m²)·N² + (coeff_n + coeff_m·m)·N − avail = 0 m = 0.01 - num_pages = self._solve_quadratic( - coeff_nm * m + coeff_mm * m**2, - coeff_n + coeff_m * m, - -available, - ) - num_blocks = min(floor(num_pages) // self.block_size, self._upper_bound_num_blocks) - max_batch_tokens = min(int(num_pages * m), self._upper_bound_max_batch_tokens) - - elif num_blocks is None: + num_pages = self._solve_quadratic(cnm * m + cmm * m**2, cn + cm * m, -available) + max_batch_tokens = int(num_pages * m) + if max_batch_tokens > self._upper_bound_max_batch_tokens: + max_batch_tokens = self._upper_bound_max_batch_tokens + # If max_batch_tokens is clamped, we recompute num_blocks below to get a higher value + num_blocks = None + else: + num_blocks = min(floor(num_pages) // self.block_size, self._upper_bound_num_blocks) + + if num_blocks is None: # M given → linear in N: (coeff_n + coeff_nm·M)·N = avail − coeff_m·M − coeff_mm·M² M = max_batch_tokens - num_pages = floor((available - coeff_m * M - coeff_mm * M**2) / (coeff_n + coeff_nm * M)) + num_pages = floor((available - cm * M - cmm * M**2) / (cn + cnm * M)) num_blocks = min(num_pages // self.block_size, self._upper_bound_num_blocks) - elif max_batch_tokens is None: # N given → quadratic in M: coeff_mm·M² + (coeff_m + coeff_nm·N)·M + (coeff_n·N − avail) = 0 N = num_blocks * self.block_size - M = self._solve_quadratic(coeff_mm, coeff_m + coeff_nm * N, coeff_n * N - available) + M = self._solve_quadratic(cmm, cm + cnm * N, cn * N - available) max_batch_tokens = min(floor(M), self._upper_bound_max_batch_tokens) + return num_blocks, max_batch_tokens + + def infer_num_blocks_and_max_batch_tokens( + self, + num_blocks: int | None = None, + max_batch_tokens: int | None = None, + max_memory_percent: float = 0.9, + cache_dtype: torch.dtype = torch.float16, + ) -> tuple[int, int]: + """Solve for the missing variable(s) in the memory polynomial (see ``_equation_coefficients``). There is one + polynomial per activation peak; we solve each independently and take the most restrictive (smallest) result. + When both `N` and `M` are unknown, assumes `M = m·N` (m = 0.01, i.e. one batch fills ~1 % of the cache) and + solves the resulting quadratic in N. + """ + available = self.get_available_memory(max_memory_percent) + logger.info(f"Cache memory: {available}") + # Solve each peak independently, then take the element-wise min (tightest constraint wins) + acc_num_blocks = float("inf") + acc_max_batch_tokens = float("inf") + for peak in self.activation_peaks: + n_blocks, m_batch_tokens = self._solve_for_peak(peak, available, num_blocks, max_batch_tokens, cache_dtype) + acc_num_blocks = min(acc_num_blocks, n_blocks) + acc_max_batch_tokens = min(acc_max_batch_tokens, m_batch_tokens) + # Now update the value (cannot update in loop, it would overwrite the user-passed values) + num_blocks, max_batch_tokens = acc_num_blocks, acc_max_batch_tokens # Validate - memory_footprint = self.compute_memory_footprint( - max_batch_tokens=max_batch_tokens, num_blocks=num_blocks, cache_dtype=cache_dtype - ) + memory_footprint = self.compute_memory_footprint(num_blocks, max_batch_tokens, cache_dtype) if memory_footprint > available: raise MemoryError(f"Memory footprint {memory_footprint} is more than available memory {available}") return num_blocks, max_batch_tokens def compute_memory_footprint(self, num_blocks: int, max_batch_tokens: int, cache_dtype: torch.dtype) -> int: - """Evaluate the memory polynomial at concrete (N, M) values.""" + """Evaluate the memory polynomial at concrete (N, M) values, taking the max across activation peaks.""" N = num_blocks * self.block_size M = max_batch_tokens - cn, cm, cnm, cmm = self._equation_coefficients(cache_dtype) - return cn * N + cm * M + cnm * N * M + cmm * M * M + + max_memory_footprint = 0 + for peak in self.activation_peaks: + cn, cm, cnm, cmm = self._equation_coefficients(peak, cache_dtype) + memory_footprint = cn * N + cm * M + cnm * N * M + cmm * M * M + max_memory_footprint = max(max_memory_footprint, memory_footprint) + return max_memory_footprint diff --git a/src/transformers/generation/continuous_batching/cb_logits_processors.py b/src/transformers/generation/continuous_batching/cb_logits_processors.py index 3a5f7eb8df26..619d9fefea5e 100644 --- a/src/transformers/generation/continuous_batching/cb_logits_processors.py +++ b/src/transformers/generation/continuous_batching/cb_logits_processors.py @@ -319,6 +319,8 @@ def __call__(self, scores: torch.FloatTensor, tensor_arg: torch.Tensor) -> torch return scores.masked_fill(indices_to_remove, self.filter_value) +# TODO: add non-per-request CB variants so the memory-efficient warpers work when `per_request_processors=False`. +# TODO: fuse temperature + top-k + top-p into a single pass to reuse the softmax/sort and cut activation peak. CLASSIC_TO_CB_PROCESSORS_MAP = { TemperatureLogitsWarper: ContinuousBatchingTemperatureLogitsWarper, TopKLogitsWarper: ContinuousBatchingTopKLogitsWarper, diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 47290b9d70b6..0521c6402ca9 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -623,26 +623,18 @@ def _sample(self, scores: torch.Tensor, logits_indices: torch.Tensor, output_ids output_ids[1, :tokens].copy_(logprobs.view(dtype=torch.int32)) @torch.inference_mode() - def warmup( - self, - model: nn.Module, - logit_processor: LogitsProcessorList, - num_query_tokens: int = 0, - num_cache_tokens: int = 0, - ) -> None: + def warmup(self, model: nn.Module) -> None: """Pre-capture CUDA graphs (or trigger compile warmup) for varlen and decode paths. In async mode, both IO - pairs are warmed up since each has its own graph buffer and static tensors.""" + pairs are warmed up since each has its own graph buffer and static tensors. The varlen path is warmed up at + the largest possible `(q, kv)` sizes so subsequent captures fit inside it without growing the pool.""" if not self._pad_inputs: logger.info("CUDA graphs and compile are disabled, skipping warmup.") return None - num_query_tokens = num_query_tokens if num_query_tokens > 0 else self.max_batch_tokens - num_query_tokens = min(num_query_tokens, self.max_batch_tokens) - num_cache_tokens = num_cache_tokens if num_cache_tokens > 0 else self.cache.block_size * num_query_tokens - num_cache_tokens = min(num_cache_tokens, self.cache.num_blocks * self.cache.block_size) - + num_query_tokens = self.max_batch_tokens num_pages = self.cache.num_blocks * self.cache.block_size + num_cache_tokens = num_pages - num_query_tokens compute_stream = self.inputs_and_outputs.compute_stream # In async mode, each IO pair has its own graph buffer and static tensors, so we warm up both @@ -677,7 +669,7 @@ def warmup( forward_fn(*forward_fn_args) logger.info(f"Varlen warmup completed in {perf_counter() - start:.2f}s") except Exception as e: - logger.warning(f"Failed to warm up varlen path: {e}") + logger.warning(f"Failed to warm up varlen path: {e}. Graph pool may fragment and OOM under load.") finally: for fs in future_states: self.cache.free_blocks(fs.state.request_id) @@ -811,12 +803,12 @@ def is_running(self) -> bool: """Check if the background generation thread is running.""" return self._generation_thread is not None and self._generation_thread.is_alive() - def warmup(self, num_query_tokens: int = 0, num_cache_tokens: int = 0) -> None: + def warmup(self) -> None: """Pre-capture CUDA graphs for varlen and decode paths by running dummy batches. Initializes the batch processor if not already done.""" if self.batch_processor is None: self.batch_processor = self._create_batch_processor() - self.batch_processor.warmup(self.model, self.logit_processor, num_query_tokens, num_cache_tokens) + self.batch_processor.warmup(self.model) self.warmed_up = True # NOTE: don't forget to update `continuous_batching_context_manager` when changing this method's definition @@ -1040,6 +1032,8 @@ def _generation_step(self) -> None: self.batch_processor._generation_step(self.model) def _create_batch_processor(self) -> ContinuousBatchProcessor: + # Resolve max_memory_percent now that we know whether any logit processors are active. + self.continuous_batching_config.resolve_max_memory_percent(self.logit_processor.do_processing) # Create the PagedAttentionCache paged_attention_cache = PagedAttentionCache( self.model.config, @@ -1225,25 +1219,25 @@ def continuous_batching_context_manager( timeout: float | None = None, continuous_batching_config: ContinuousBatchingConfig | None = None, persistent_manager: bool = False, - warmup_requests: int | None = 0, + warmup: bool = True, **deprecated_kwargs, ) -> Generator[ContinuousBatchingManager]: """A context manager to safely use the continuous batching manager. Arguments are similar to the ones of `init_continuous_batching`, except for: - block: whether to block the thread when stopping the manager. Default is True. - timeout: maximum time to wait for the thread to stop. Default is None (no timeout). - - warmup_query_tokens: the number of expected requests for which to warmup. 0 is auto, None is no warmup. + - warmup: whether to pre-capture CUDA graphs at the largest sizes before running. Default is True. """ manager = self.init_continuous_batching( generation_config=generation_config, continuous_batching_config=continuous_batching_config, **deprecated_kwargs, ) - if not (warmup_requests is None or manager.warmed_up): + if warmup and not manager.warmed_up: # Warmup is long (~30 sec): best to signal the user it's happening than let them think the manager is stuck - logger.warning("Warming up for coninuous batching...") + logger.warning("Warming up for continuous batching...") start = perf_counter() - manager.warmup(num_query_tokens=warmup_requests, num_cache_tokens=0) + manager.warmup() logger.warning(f"Warming up completed in {perf_counter() - start:.2f}s.") manager.start() try: @@ -1320,7 +1314,7 @@ def generate_batch( block=True, timeout=5, persistent_manager=persistent_manager, - warmup_requests=len(inputs) if warmup else None, + warmup=warmup, **deprecated_kwargs, ) logging_cm = logging_redirect_tqdm([logger]) diff --git a/src/transformers/generation/continuous_batching/input_outputs.py b/src/transformers/generation/continuous_batching/input_outputs.py index 134941c2526f..fbe7890a15b9 100644 --- a/src/transformers/generation/continuous_batching/input_outputs.py +++ b/src/transformers/generation/continuous_batching/input_outputs.py @@ -14,7 +14,6 @@ from contextlib import nullcontext from dataclasses import dataclass from functools import partial -from itertools import count from typing import Any import torch @@ -250,10 +249,11 @@ def _transfer_inputs( # Only transfer block_table for decode-only batches (when it's actually used) if self.use_block_table: other.block_table.copy_(self.block_table, non_blocking=non_blocking) - # Otherwise, we transfer the read and write indices + # Otherwise, we transfer the write indices (and read indices if the batch uses any cache reads) else: other.write_index_storage.copy_(self.write_index_storage, non_blocking=non_blocking) - other.read_index_storage.copy_(self.read_index_storage, non_blocking=non_blocking) + if self.max_kv_read > 0: + other.read_index_storage.copy_(self.read_index_storage, non_blocking=non_blocking) # Transfer the attention masks if needed if self.attention_mask is not None and other.attention_mask is not None: for layer_type in self.attention_mask.keys(): @@ -373,14 +373,15 @@ def prepare_batch_tensors( self.requests_in_batch = [] self.req_id_to_new_token_position = {} - # Prepare accumulators + # Prepare accumulators. For batches with no past cache to read, we leave read_index empty: the cache.update + # will detect the 0-size indices and skip the read. input_ids = [] position_ids = [] cumulative_seqlens_q = [0] logits_indices = [] cumulative_seqlens_k = {layer_type: [0] for layer_type in self.cumulative_seqlens_k.keys()} - read_index = [[] for _ in range(self.cache.num_groups)] write_index = [[] for _ in range(self.cache.num_groups)] + read_index = None if self.max_kv_read == 0 else [[] for _ in range(self.cache.num_groups)] # Go through all the requests in the batch for i, future_state in enumerate(requests_in_batch): @@ -448,14 +449,16 @@ def prepare_batch_tensors( sliding_window=self.sliding_window if layer_type == "sliding_attention" else 1, ) - # If we are not using the block table, we populate the read and write indices + # If we are not using the block table, we populate the write indices (and maybe the read indices) if not self.use_block_table: to_index_tensor = partial(torch.tensor, dtype=torch.int64, device=self.device) - for i, group_read_indices, group_write_indices in zip(count(), read_index, write_index): - self.read_index_storage[i, : len(group_read_indices)] = to_index_tensor(group_read_indices) + for i, group_write_indices in enumerate(write_index): self.write_index_storage[i, : len(group_write_indices)] = to_index_tensor(group_write_indices) - self.true_read_sizes[i] = len(group_read_indices) self.true_write_sizes[i] = len(group_write_indices) + if read_index is not None: + for i, group_read_indices in enumerate(read_index): + self.read_index_storage[i, : len(group_read_indices)] = to_index_tensor(group_read_indices) + self.true_read_sizes[i] = len(group_read_indices) def get_model_kwargs(self, use_padding: bool = False) -> dict[str, Any]: """Get model keyword arguments for the current batch, eventually padding the query dimension and KV dimensions @@ -500,10 +503,14 @@ def get_model_kwargs(self, use_padding: bool = False) -> dict[str, Any]: # For the attributes that are lists of tensors, we construct list of tensor references for i in range(self.cache.num_groups): - read_index_size = kv_size if use_padding else self.true_read_sizes[i] write_index_size = q_size if use_padding else self.true_write_sizes[i] - kwargs.read_index.append(self.read_index_storage[i, :read_index_size]) kwargs.write_index.append(self.write_index_storage[i, :write_index_size]) + # If there is no cache to read, pass a list of empty tensors so `cache.update` uses the write-only fast path + if self.max_kv_read == 0: + read_index_size = 0 + else: + read_index_size = kv_size if use_padding else self.true_read_sizes[i] + kwargs.read_index.append(self.read_index_storage[i, :read_index_size]) # For the attributes that are dict of tensors, we first fill the dict with the actual values for layer_type, seqlens_k in self.cumulative_seqlens_k.items(): @@ -531,11 +538,11 @@ def get_cb_kwargs(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return self.carry_over_ids, self.output_ids, self.output_ids def _get_graph_key(self) -> tuple[int, ...]: - # Keys for varlen path - if self.max_kv_read > 0: - return (self.num_q_tokens, self.max_kv_read, *self.max_seqlen_k.values()) # Keys for decode fast path - return (self.num_q_tokens,) + if self.use_block_table: + return (self.num_q_tokens,) + # Keys for varlen path + return (self.num_q_tokens, self.max_kv_read, *self.max_seqlen_k.values()) def get_graph(self) -> torch.cuda.CUDAGraph | None: key = self._get_graph_key() diff --git a/src/transformers/generation/continuous_batching/requests.py b/src/transformers/generation/continuous_batching/requests.py index 05bf65725c5a..381c94bc2dc9 100644 --- a/src/transformers/generation/continuous_batching/requests.py +++ b/src/transformers/generation/continuous_batching/requests.py @@ -27,6 +27,7 @@ import psutil # This is a temporary token ID used to represent a token that is not yet generated +# TODO: update this to 0 and check it breaks nothing + simplify carry over and time new logic TMP_TOKEN_ID = -1 @@ -45,9 +46,11 @@ def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]: device = torch.device("cuda") torch.cuda.empty_cache() torch.cuda.synchronize() - total_memory = torch.cuda.get_device_properties(device).total_memory + # Use mem_get_info to get actual free memory: device_properties().total_memory returns the physical device + # total which ignores CUDA context and driver overhead (~0.5 GiB), leading to overcommit. + free_memory, total_memory = torch.cuda.mem_get_info(device) reserved_memory = torch.cuda.memory_reserved(device) - allocated_memory = torch.cuda.memory_allocated(device) + allocated_memory = total_memory - free_memory elif is_torch_xpu_available(): device = torch.device("xpu") torch.xpu.empty_cache() diff --git a/src/transformers/generation/continuous_batching/scheduler.py b/src/transformers/generation/continuous_batching/scheduler.py index f35d2e968342..284c202267c5 100644 --- a/src/transformers/generation/continuous_batching/scheduler.py +++ b/src/transformers/generation/continuous_batching/scheduler.py @@ -205,7 +205,7 @@ def _process_candidates( """ scheduled_requests = [] one_allocation_failed = False - decode_fast_path = True + decode_fast_path = self.cache.max_blocks_per_request > 0 # best way to check if decode fast path availability safety_margins = safety_margin * self.cache.num_blocks original_token_budget, original_cache_budget = token_budget, cache_budget @@ -219,17 +219,22 @@ def _process_candidates( ) break - # Check cache budget + # Infer the tokens that will be present in the batch if token budget is enough + request_tokens = self._infer_request_tokens(state, request_ids_to_remove_from_waiting) + # Account for token budget + request_len = min(len(request_tokens), token_budget) + + # This block checks cache budget: decode batches have infinite budget, but varlen batches don't, because KV + # cache is read through a fixed-sized index tensor. We keep track of the current budget in case the batch + # goes from decode to varlen + is_decode_eligible = request_len == 1 and state.position_offset < self.max_decode_fast_path_length read_cache_needed = state.current_len() if self.read_cache_limit is not None: read_cache_needed = min(read_cache_needed, self.read_cache_limit) - if cache_budget < read_cache_needed: + # A request that would change the batch from decode to varlen is rejected if the cache budget is too low + if not (decode_fast_path and is_decode_eligible) and cache_budget < read_cache_needed: continue - # Infer the tokens that will be present in the batch if token budget is enough - request_tokens = self._infer_request_tokens(state, request_ids_to_remove_from_waiting) - # Account for token budget - request_len = min(len(request_tokens), token_budget) # Check there will be enough cache for the new tokens allocation_successful = self._allocate_blocks_if_needed(state, request_len) @@ -273,7 +278,7 @@ def _process_candidates( request_ids_to_remove_from_waiting.add(req_id) # Early exit of the loop if we have no budget left - if token_budget == 0 or cache_budget == 0: + if token_budget == 0 or (cache_budget <= 0 and not decode_fast_path): break num_q_tokens = original_token_budget - token_budget diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 213b91e3a115..a6b9a517b20d 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -13,7 +13,6 @@ # limitations under the License. import torch import torch.nn as nn -import triton from torch.nn import functional as F from ..activations import ACT2FN @@ -159,6 +158,11 @@ def _load_deepgemm_kernel(): _deepgemm_available = True +def _cdiv(a: int, b: int) -> int: + """Ceiling division.""" + return (a + b - 1) // b + + def w8a8_fp8_matmul( A: torch.Tensor, B: torch.Tensor, @@ -603,8 +607,8 @@ def __init__( if self.has_gate: gu_proj_out, gu_proj_in = 2 * self.intermediate_dim, self.hidden_dim self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, gu_proj_out, gu_proj_in, dtype=dtype)) - gu_scale_out = triton.cdiv(gu_proj_out, self.block_size[0]) if self.block_size is not None else 1 - gu_scale_in = triton.cdiv(gu_proj_in, self.block_size[1]) if self.block_size is not None else 1 + gu_scale_out = _cdiv(gu_proj_out, self.block_size[0]) if self.block_size is not None else 1 + gu_scale_in = _cdiv(gu_proj_in, self.block_size[1]) if self.block_size is not None else 1 self.gate_up_proj_scale_inv = nn.Parameter( torch.empty(self.num_experts, gu_scale_out, gu_scale_in, dtype=torch.float32) ) @@ -612,8 +616,8 @@ def __init__( else: u_proj_out, u_proj_in = self.intermediate_dim, self.hidden_dim self.up_proj = nn.Parameter(torch.empty(self.num_experts, u_proj_out, u_proj_in, dtype=dtype)) - u_scale_out = triton.cdiv(u_proj_out, self.block_size[0]) if self.block_size is not None else 1 - u_scale_in = triton.cdiv(u_proj_in, self.block_size[1]) if self.block_size is not None else 1 + u_scale_out = _cdiv(u_proj_out, self.block_size[0]) if self.block_size is not None else 1 + u_scale_in = _cdiv(u_proj_in, self.block_size[1]) if self.block_size is not None else 1 self.up_proj_scale_inv = nn.Parameter( torch.empty(self.num_experts, u_scale_out, u_scale_in, dtype=torch.float32) ) @@ -621,8 +625,8 @@ def __init__( d_proj_out, d_proj_in = self.hidden_dim, self.intermediate_dim self.down_proj = nn.Parameter(torch.empty(self.num_experts, d_proj_out, d_proj_in, dtype=dtype)) - d_scale_out = triton.cdiv(d_proj_out, self.block_size[0]) if self.block_size is not None else 1 - d_scale_in = triton.cdiv(d_proj_in, self.block_size[1]) if self.block_size is not None else 1 + d_scale_out = _cdiv(d_proj_out, self.block_size[0]) if self.block_size is not None else 1 + d_scale_in = _cdiv(d_proj_in, self.block_size[1]) if self.block_size is not None else 1 self.down_proj_scale_inv = nn.Parameter( torch.empty(self.num_experts, d_scale_out, d_scale_in, dtype=torch.float32) ) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index b1e6c74ddf10..70a343424aa8 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -289,6 +289,7 @@ def register_kernel_mapping_transformers(*args, **kwargs): "falcon_mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "version": 1}, "finegrained-fp8": {"repo_id": "kernels-community/finegrained-fp8", "version": 1}, "deep-gemm": {"repo_id": "kernels-community/deep-gemm", "version": 1}, + "sonic-moe": {"repo_id": "kernels-community/sonic-moe", "version": 1}, } _KERNEL_MODULE_MAPPING: dict[str, ModuleType | None] = {} diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index d17522d26daa..c8a8e87f3621 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -23,6 +23,7 @@ is_torch_less_or_equal, is_torchdynamo_compiling, ) +from .sonicmoe import sonicmoe_experts_forward if is_torch_available(): @@ -31,6 +32,7 @@ logger = logging.get_logger(__name__) + # Examples of experts class with its eager mm implementation # class Experts(torch.nn.Module): # """Collection of expert weights stored as 3D tensors.""" @@ -458,6 +460,7 @@ class ExpertsInterface(GeneralInterface): """Interface for registering custom experts forward functions.""" _global_mapping = { + "sonicmoe": sonicmoe_experts_forward, "batched_mm": batched_mm_experts_forward, "grouped_mm": grouped_mm_experts_forward, } @@ -498,6 +501,7 @@ def use_experts_implementation( experts_class: type[torch.nn.Module] | None = None, *, experts_interface: ExpertsInterface = ALL_EXPERTS_FUNCTIONS, + is_concatenated: bool = True, is_transposed: bool = False, has_bias: bool = False, has_gate: bool = True, @@ -509,10 +513,16 @@ def use_experts_implementation( The experts class to modify. If not provided, returns a decorator that can be applied to the class. experts_interface (`ExpertsInterface`, *optional*, defaults to `ALL_EXPERTS_FUNCTIONS`): The experts interface to use for dispatching the forward method. + is_concatenated (`bool`, *optional*, defaults to `True`): + Whether the expert weights are stored in concatenated layout [gate;up] + or interleaved layout [gate0, up0, gate1, up1, ...]. is_transposed (`bool`, *optional*, defaults to `False`): Whether the expert weights are stored in transposed format. has_bias (`bool`, *optional*, defaults to `False`): - Whether the expert layers include bias terms. + Whether the expert layers include bias terms or not. + has_gate (`bool`, *optional*, defaults to `True`): + Whether the experts use a gating mechanism or not. + Whether it has gate_up_proj weights or just up_proj weights. Returns: `type[torch.nn.Module]`: The modified experts class. @@ -529,6 +539,7 @@ def __init__(self, config, *args, **kwargs): self.has_gate = has_gate self.has_bias = has_bias self.is_transposed = is_transposed + self.is_concatenated = is_concatenated @wraps(original_forward) def forward(self, *args, **kwargs): diff --git a/src/transformers/integrations/sonicmoe.py b/src/transformers/integrations/sonicmoe.py new file mode 100644 index 000000000000..e322bb4bc061 --- /dev/null +++ b/src/transformers/integrations/sonicmoe.py @@ -0,0 +1,124 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SonicMoE integration: fused MoE using CuteDSL kernels from `kernels-community/sonic-moe`. + +Provides `sonicmoe_experts_forward` registered as "sonicmoe" in the ExpertsInterface. +Requirements: CUDA, `kernels`, `nvidia-cutlass-dsl`, has_gate=True. +""" + +import functools + +import torch + +from ..utils import logging +from .hub_kernels import lazy_load_kernel + + +logger = logging.get_logger(__name__) + +# Map activation function names from HF config to SonicMoE epilogue names +ACT_MAP = {"silu": "swiglu", "gelu": "geglu", "relu": "reglu"} + + +@functools.cache +def _load_sonic_kernel(): + """ + Load sonic-moe once and return its required symbols. + + Raises: + ImportError if the kernel or required symbols are not found. + + Returns: + Tuple of (ActivationType, moe_general_routing_inputs function) from the sonic-moe kernel. + """ + + kernel = lazy_load_kernel("sonic-moe") + if kernel is None: + raise ImportError( + "sonic-moe kernel not found. Make sure you have the `kernels` and `nvidia-cutlass-dsl` packages installed." + ) + + ActivationType = getattr(getattr(kernel, "enums", None), "ActivationType", None) + moe_general_routing_inputs = getattr(kernel, "moe_general_routing_inputs", None) + + missing = [ + name + for name, attr in [ + ("enums.ActivationType", ActivationType), + ("moe_general_routing_inputs", moe_general_routing_inputs), + ] + if attr is None + ] + if missing: + raise ImportError( + f"sonic-moe kernel is missing required symbols: {', '.join(missing)}. " + "Make sure you have the `kernels` package and `nvidia-cutlass-dsl` installed." + ) + + return ActivationType, moe_general_routing_inputs + + +def sonicmoe_experts_forward( + self: torch.nn.Module, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, +) -> torch.Tensor: + if not self.has_gate: + raise ValueError("sonicmoe requires gated experts (has_gate=True)") + if hidden_states.device.type != "cuda": + raise ValueError("sonicmoe requires CUDA device") + + ActivationType, moe_general_routing_inputs = _load_sonic_kernel() + + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_tokens = hidden_states.size(0) + + # Flatten — token_indices must be int32, sorted ascending (required by sonic-moe) + token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1).int() + router_scores = top_k_weights.reshape(-1).to(hidden_states.dtype) + expert_ids = top_k_index.reshape(-1).int() + + # Map activation function + act_name = getattr(self.config, "hidden_act", "silu").lower() + activation_type = getattr(ActivationType, ACT_MAP.get(act_name, "swiglu").upper(), ActivationType.SWIGLU) + + # Permute weights as expected by sonic-moe (E=num_experts, H=hidden_size, I=intermediate_size). + # Non-transposed: gate_up_proj is (E, 2*I, H), down_proj is (E, H, I) -> permute(1, 2, 0). + # Transposed: gate_up_proj is (E, H, 2*I), down_proj is (E, I, H) -> permute(2, 1, 0). + perm = (2, 1, 0) if self.is_transposed else (1, 2, 0) + w1 = self.gate_up_proj.permute(*perm) # (2*I, H, E) + w2 = self.down_proj.permute(*perm) # (I, H, E) + b1 = self.gate_up_proj_bias if self.has_bias else None + b2 = self.down_proj_bias if self.has_bias else None + + output, _ = moe_general_routing_inputs( + hidden_states, + router_scores, + token_idx, + expert_ids, + w1, + b1, + w2, + b2, + E=self.num_experts, + activation_type=activation_type, + stream_id=torch.cuda.current_stream(device).cuda_stream, + is_inference_mode_enabled=not torch.is_grad_enabled(), + concat_layout=self.is_concatenated, + ) + + return output diff --git a/src/transformers/loss/loss_deimv2.py b/src/transformers/loss/loss_deimv2.py new file mode 100644 index 000000000000..5c8573f7da44 --- /dev/null +++ b/src/transformers/loss/loss_deimv2.py @@ -0,0 +1,261 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn.functional as F + +from ..utils import is_vision_available +from .loss_d_fine import DFineLoss, _set_aux_loss, _set_aux_loss2 +from .loss_for_object_detection import box_iou + + +if is_vision_available(): + from transformers.image_transforms import center_to_corners_format + + +class Deimv2Loss(DFineLoss): + def __init__(self, config): + super().__init__(config) + self.weight_dict = { + "loss_mal": config.weight_loss_mal, + "loss_bbox": config.weight_loss_bbox, + "loss_giou": config.weight_loss_giou, + "loss_fgl": config.weight_loss_fgl, + "loss_ddf": config.weight_loss_ddf, + } + self.losses = ["mal", "boxes", "local"] + self.mal_alpha = config.mal_alpha + self.use_dense_one_to_one = config.use_dense_one_to_one + + def loss_labels_mal(self, outputs, targets, indices, num_boxes): + """Compute the Matching Aware Loss (MAL), which uses IoU-weighted soft labels + instead of hard one-hot targets, with focal-style weighting controlled by `mal_alpha`. + """ + idx = self._get_source_permutation_idx(indices) + + src_boxes = outputs["pred_boxes"][idx] + target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) + ious, _ = box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes)) + ious = torch.diag(ious).detach() + + src_logits = outputs["logits"] + target_classes_original = torch.cat([t["class_labels"][i] for t, (_, i) in zip(targets, indices)]) + target_classes = torch.full( + src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device + ) + target_classes[idx] = target_classes_original + target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1] + + target_score_original = torch.zeros_like(target_classes, dtype=src_logits.dtype) + target_score_original[idx] = ious.to(target_score_original.dtype) + target_score = target_score_original.unsqueeze(-1) * target + + pred_score = F.sigmoid(src_logits).detach() + target_score = target_score.pow(self.gamma) + if self.mal_alpha is not None: + weight = self.mal_alpha * pred_score.pow(self.gamma) * (1 - target) + target + else: + weight = pred_score.pow(self.gamma) * (1 - target) + target + + loss = F.binary_cross_entropy_with_logits(src_logits, target_score, weight=weight, reduction="none") + loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes + return {"loss_mal": loss} + + def _get_dense_o2o_indices(self, indices, indices_aux_list): + results = [] + for indices_aux in indices_aux_list: + indices = [ + (torch.cat([idx1[0], idx2[0]]), torch.cat([idx1[1], idx2[1]])) + for idx1, idx2 in zip(indices.copy(), indices_aux.copy()) + ] + + for index in [torch.cat([idx[0][:, None], idx[1][:, None]], 1) for idx in indices]: + unique, counts = torch.unique(index, return_counts=True, dim=0) + count_sort_indices = torch.argsort(counts, descending=True) + unique_sorted = unique[count_sort_indices] + column_to_row = {} + for idx_pair in unique_sorted: + row_idx, col_idx = idx_pair[0].item(), idx_pair[1].item() + if row_idx not in column_to_row: + column_to_row[row_idx] = col_idx + final_rows = torch.tensor(list(column_to_row.keys()), device=index.device) + final_cols = torch.tensor(list(column_to_row.values()), device=index.device) + results.append((final_rows.long(), final_cols.long())) + return results + + def get_loss(self, loss, outputs, targets, indices, num_boxes): + loss_map = { + "cardinality": self.loss_cardinality, + "local": self.loss_local, + "boxes": self.loss_boxes, + "focal": self.loss_labels_focal, + "vfl": self.loss_labels_vfl, + "mal": self.loss_labels_mal, + } + if loss not in loss_map: + raise ValueError(f"Loss {loss} not supported") + return loss_map[loss](outputs, targets, indices, num_boxes) + + def forward(self, outputs, targets): + """ + This performs the loss computation. + + Args: + outputs (`dict`, *optional*): + Dictionary of tensors, see the output specification of the model for the format. + targets (`list[dict]`, *optional*): + List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the + losses applied, see each loss' doc. + """ + if not self.use_dense_one_to_one: + return super().forward(outputs, targets) + + # Retrieve the matching between the outputs of the last layer and the targets + outputs_without_aux = {k: v for k, v in outputs.items() if "auxiliary_outputs" not in k} + indices = self.matcher(outputs_without_aux, targets) + + # Compute the average number of target boxes across all nodes, for normalization purposes + num_boxes = sum(len(t["class_labels"]) for t in targets) + num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + num_boxes = torch.clamp(num_boxes, min=1).item() + + # Handle auxiliary outputs matching + cached_indices = [] + indices_aux_list = [] + if "auxiliary_outputs" in outputs: + for auxiliary_outputs in outputs["auxiliary_outputs"]: + aux_indices = self.matcher(auxiliary_outputs, targets) + cached_indices.append(aux_indices) + indices_aux_list.append(aux_indices) + + # Dense one-to-one matching + indices_go = self._get_dense_o2o_indices(indices, indices_aux_list) + num_boxes_go = sum(len(x[0]) for x in indices_go) + num_boxes_go = torch.as_tensor([num_boxes_go], dtype=torch.float, device=next(iter(outputs.values())).device) + num_boxes_go = torch.clamp(num_boxes_go, min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + use_union = loss in ("boxes", "local") + indices_in = indices_go if use_union else indices + num_boxes_in = num_boxes_go if use_union else num_boxes + l_dict = self.get_loss(loss, outputs, targets, indices_in, num_boxes_in) + l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict} + losses.update(l_dict) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if "auxiliary_outputs" in outputs: + for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]): + for loss in self.losses: + use_union = loss in ("boxes", "local") + indices_in = indices_go if use_union else cached_indices[i] + num_boxes_in = num_boxes_go if use_union else num_boxes + l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices_in, num_boxes_in) + l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict} + l_dict = {k + f"_aux_{i}": v for k, v in l_dict.items()} + losses.update(l_dict) + + # In case of cdn auxiliary losses. For deimv2 + if "dn_auxiliary_outputs" in outputs: + if "denoising_meta_values" not in outputs: + raise ValueError( + "The output must have the 'denoising_meta_values` key. " + "Please, ensure that 'outputs' includes a 'denoising_meta_values' entry." + ) + dn_indices = self.get_cdn_matched_indices(outputs["denoising_meta_values"], targets) + dn_num_boxes = num_boxes * outputs["denoising_meta_values"]["dn_num_group"] + for i, auxiliary_outputs in enumerate(outputs["dn_auxiliary_outputs"]): + for loss in self.losses: + l_dict = self.get_loss(loss, auxiliary_outputs, targets, dn_indices, dn_num_boxes) + l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict} + l_dict = {k + f"_dn_{i}": v for k, v in l_dict.items()} + losses.update(l_dict) + + return losses + + +def Deimv2ForObjectDetectionLoss( + logits, + labels, + device, + pred_boxes, + config, + outputs_class=None, + outputs_coord=None, + enc_topk_logits=None, + enc_topk_bboxes=None, + denoising_meta_values=None, + predicted_corners=None, + initial_reference_points=None, + **kwargs, +): + criterion = Deimv2Loss(config) + criterion.to(device) + + outputs_loss = {"logits": logits, "pred_boxes": pred_boxes.clamp(min=0, max=1)} + auxiliary_outputs = None + + if config.auxiliary_loss: + if denoising_meta_values is not None: + dn_out_coord, normal_out_coord = torch.split( + outputs_coord.clamp(min=0, max=1), denoising_meta_values["dn_num_split"], dim=2 + ) + dn_out_class, normal_out_class = torch.split(outputs_class, denoising_meta_values["dn_num_split"], dim=2) + # https://github.com/Intellindust-AI-Lab/DEIMv2/blob/main/engine/deim/deim_decoder.py#L562-L571 + # The original splits denoising queries in the decoder; here it happens in the loss since the decoder returns unsplit tensors. + _, normal_logits = torch.split(logits, denoising_meta_values["dn_num_split"], dim=1) + _, normal_pred_boxes = torch.split(pred_boxes, denoising_meta_values["dn_num_split"], dim=1) + dn_out_corners, out_corners = torch.split(predicted_corners, denoising_meta_values["dn_num_split"], dim=2) + dn_out_refs, out_refs = torch.split(initial_reference_points, denoising_meta_values["dn_num_split"], dim=2) + + outputs_loss["logits"] = normal_logits + outputs_loss["pred_boxes"] = normal_pred_boxes.clamp(min=0, max=1) + else: + normal_out_coord = outputs_coord.clamp(min=0, max=1) + normal_out_class = outputs_class + out_corners = predicted_corners + out_refs = initial_reference_points + + auxiliary_outputs = _set_aux_loss2( + normal_out_class[:, :-1].transpose(0, 1), + normal_out_coord[:, :-1].transpose(0, 1), + out_corners[:, :-1].transpose(0, 1), + out_refs[:, :-1].transpose(0, 1), + out_corners[:, -1], + normal_out_class[:, -1], + ) + + outputs_loss["auxiliary_outputs"] = auxiliary_outputs + outputs_loss["auxiliary_outputs"].extend( + _set_aux_loss([enc_topk_logits], [enc_topk_bboxes.clamp(min=0, max=1)]) + ) + + if denoising_meta_values is not None: + dn_auxiliary_outputs = _set_aux_loss2( + dn_out_class.transpose(0, 1), + dn_out_coord.transpose(0, 1), + dn_out_corners.transpose(0, 1), + dn_out_refs.transpose(0, 1), + dn_out_corners[:, -1], + dn_out_class[:, -1], + ) + outputs_loss["dn_auxiliary_outputs"] = dn_auxiliary_outputs + outputs_loss["denoising_meta_values"] = denoising_meta_values + + loss_dict = criterion(outputs_loss, labels) + + loss = sum(loss_dict.values()) + return loss, loss_dict, auxiliary_outputs diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index df269477e9ec..51564d299e55 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -19,6 +19,7 @@ from .loss_d_fine import DFineForObjectDetectionLoss from .loss_deformable_detr import DeformableDetrForObjectDetectionLoss, DeformableDetrForSegmentationLoss +from .loss_deimv2 import Deimv2ForObjectDetectionLoss from .loss_for_object_detection import ForObjectDetectionLoss, ForSegmentationLoss from .loss_grounding_dino import GroundingDinoForObjectDetectionLoss from .loss_lw_detr import LwDetrForObjectDetectionLoss @@ -163,6 +164,7 @@ def ForTokenClassification(logits: torch.Tensor, labels, config, **kwargs): "RTDetrForObjectDetection": RTDetrForObjectDetectionLoss, "RTDetrV2ForObjectDetection": RTDetrForObjectDetectionLoss, "DFineForObjectDetection": DFineForObjectDetectionLoss, + "Deimv2ForObjectDetection": Deimv2ForObjectDetectionLoss, "CsmForConditionalGeneration": ForCausalLMLoss, "LwDetrForObjectDetection": LwDetrForObjectDetectionLoss, } diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index eb092019b678..d58c9a52fd33 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -66,10 +66,12 @@ ) from .integrations.deepspeed import _load_state_dict_into_zero3_model from .integrations.eager_paged import eager_paged_attention_forward +from .integrations.finegrained_fp8 import ALL_FP8_EXPERTS_FUNCTIONS from .integrations.flash_attention import flash_attention_forward from .integrations.flash_paged import paged_attention_forward from .integrations.flex_attention import flex_attention_forward from .integrations.hub_kernels import allow_all_hub_kernels, is_kernel +from .integrations.moe import ALL_EXPERTS_FUNCTIONS from .integrations.peft import maybe_load_adapters from .integrations.sdpa_attention import sdpa_attention_forward from .integrations.sdpa_paged import sdpa_attention_paged_forward @@ -1969,11 +1971,14 @@ def get_correct_attn_implementation(self, requested_attention: str | None, is_in def get_correct_experts_implementation(self, requested_experts: str | None) -> str: applicable_experts = "grouped_mm" if requested_experts is None else requested_experts - if applicable_experts not in ["eager", "grouped_mm", "batched_mm", "deepgemm"]: + base_experts_fns = ["eager"] + list(set(ALL_EXPERTS_FUNCTIONS.keys()) | set(ALL_FP8_EXPERTS_FUNCTIONS.keys())) + valid_experts_str_list = [f'`experts_implementation="{fn}"`' for fn in base_experts_fns] + valid_experts_str_list[-1] = "and " + valid_experts_str_list[-1] + valid_experts_str = ", ".join(valid_experts_str_list) + if applicable_experts not in base_experts_fns: message = ( f'Specified `experts_implementation="{applicable_experts}"` is not supported. The only possible arguments are ' - '`experts_implementation="eager"`, `"experts_implementation=grouped_mm"`, `"experts_implementation=batched_mm"` ' - 'and `"experts_implementation=deepgemm"`.' + f"{valid_experts_str}." ) raise ValueError(message) diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 3bf3878ea229..eae18ac34b7c 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -96,6 +96,7 @@ from .deepseek_vl import * from .deepseek_vl_hybrid import * from .deformable_detr import * + from .deimv2 import * from .deit import * from .deprecated import * from .depth_anything import * diff --git a/src/transformers/models/auto/auto_mappings.py b/src/transformers/models/auto/auto_mappings.py index d1d331a0d42f..f443a8db642c 100644 --- a/src/transformers/models/auto/auto_mappings.py +++ b/src/transformers/models/auto/auto_mappings.py @@ -125,6 +125,7 @@ ("deepseek_vl", "DeepseekVLConfig"), ("deepseek_vl_hybrid", "DeepseekVLHybridConfig"), ("deformable_detr", "DeformableDetrConfig"), + ("deimv2", "Deimv2Config"), ("deit", "DeiTConfig"), ("depth_anything", "DepthAnythingConfig"), ("depth_pro", "DepthProConfig"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index c74ee27519ff..c624f49083d2 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -76,6 +76,7 @@ ("convnextv2", {"torchvision": "ConvNextImageProcessor", "pil": "ConvNextImageProcessorPil"}), ("cvt", {"torchvision": "ConvNextImageProcessor", "pil": "ConvNextImageProcessorPil"}), ("data2vec-vision", {"torchvision": "BeitImageProcessor", "pil": "BeitImageProcessorPil"}), + ("deimv2", {"torchvision": "RTDetrImageProcessor", "pil": "RTDetrImageProcessorPil"}), ("depth_anything", {"torchvision": "DPTImageProcessor", "pil": "DPTImageProcessorPil"}), ("dinat", {"torchvision": "ViTImageProcessor", "pil": "ViTImageProcessorPil"}), ("dinov2", {"torchvision": "BitImageProcessor", "pil": "BitImageProcessorPil"}), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 3250eba7ba68..b4d928647561 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -116,6 +116,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("deepseek_vl", "DeepseekVLModel"), ("deepseek_vl_hybrid", "DeepseekVLHybridModel"), ("deformable_detr", "DeformableDetrModel"), + ("deimv2", "Deimv2Model"), ("deit", "DeiTModel"), ("depth_pro", "DepthProModel"), ("detr", "DetrModel"), @@ -1113,6 +1114,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("d_fine", "DFineForObjectDetection"), ("dab-detr", "DabDetrForObjectDetection"), ("deformable_detr", "DeformableDetrForObjectDetection"), + ("deimv2", "Deimv2ForObjectDetection"), ("detr", "DetrForObjectDetection"), ("lw_detr", "LwDetrForObjectDetection"), ("pp_doclayout_v2", "PPDocLayoutV2ForObjectDetection"), diff --git a/src/transformers/models/d_fine/modeling_d_fine.py b/src/transformers/models/d_fine/modeling_d_fine.py index 1c758f8b1dcd..f1d23356fb2b 100644 --- a/src/transformers/models/d_fine/modeling_d_fine.py +++ b/src/transformers/models/d_fine/modeling_d_fine.py @@ -874,6 +874,7 @@ class DFinePreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" + super()._init_weights(module) # initialize linear layer bias value according to a given probability value. if isinstance(module, (DFineForObjectDetection, DFineDecoder)): if module.class_embed is not None: @@ -919,15 +920,6 @@ def _init_weights(self, module): init.xavier_uniform_(module.enc_score_head.weight) init.constant_(module.enc_score_head.bias, bias) - if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - init.zeros_(module.bias) - if getattr(module, "running_mean", None) is not None: - init.zeros_(module.running_mean) - init.ones_(module.running_var) - init.zeros_(module.num_batches_tracked) - if isinstance(module, DFineGate): bias = float(-math.log((1 - 0.5) / 0.5)) init.constant_(module.gate.bias, bias) @@ -937,10 +929,6 @@ def _init_weights(self, module): init.constant_(module.reg_conf.layers[-1].bias, 0) init.constant_(module.reg_conf.layers[-1].weight, 0) - if isinstance(module, nn.LayerNorm): - init.ones_(module.weight) - init.zeros_(module.bias) - if hasattr(module, "weight_embedding") and self.config.learn_initial_query: init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: diff --git a/src/transformers/models/d_fine/modular_d_fine.py b/src/transformers/models/d_fine/modular_d_fine.py index ba5798ad93cb..49289f075037 100644 --- a/src/transformers/models/d_fine/modular_d_fine.py +++ b/src/transformers/models/d_fine/modular_d_fine.py @@ -23,6 +23,7 @@ from ...backbone_utils import consolidate_backbone_kwargs_to_config from ...configuration_utils import PreTrainedConfig from ...image_transforms import corners_to_center_format +from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, logging, torch_compilable_check from ..auto import AutoConfig @@ -678,6 +679,7 @@ class DFinePreTrainedModel(RTDetrPreTrainedModel): @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" + PreTrainedModel._init_weights(self, module) # initialize linear layer bias value according to a given probability value. if isinstance(module, (DFineForObjectDetection, DFineDecoder)): if module.class_embed is not None: @@ -723,15 +725,6 @@ def _init_weights(self, module): init.xavier_uniform_(module.enc_score_head.weight) init.constant_(module.enc_score_head.bias, bias) - if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - init.zeros_(module.bias) - if getattr(module, "running_mean", None) is not None: - init.zeros_(module.running_mean) - init.ones_(module.running_var) - init.zeros_(module.num_batches_tracked) - if isinstance(module, DFineGate): bias = float(-math.log((1 - 0.5) / 0.5)) init.constant_(module.gate.bias, bias) @@ -741,10 +734,6 @@ def _init_weights(self, module): init.constant_(module.reg_conf.layers[-1].bias, 0) init.constant_(module.reg_conf.layers[-1].weight, 0) - if isinstance(module, nn.LayerNorm): - init.ones_(module.weight) - init.zeros_(module.bias) - if hasattr(module, "weight_embedding") and self.config.learn_initial_query: init.xavier_uniform_(module.weight_embedding.weight) if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: diff --git a/src/transformers/models/deimv2/__init__.py b/src/transformers/models/deimv2/__init__.py new file mode 100644 index 000000000000..2140a69a54f2 --- /dev/null +++ b/src/transformers/models/deimv2/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_deimv2 import * + from .modeling_deimv2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/deimv2/configuration_deimv2.py b/src/transformers/models/deimv2/configuration_deimv2.py new file mode 100644 index 000000000000..f307202b0985 --- /dev/null +++ b/src/transformers/models/deimv2/configuration_deimv2.py @@ -0,0 +1,266 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/deimv2/modular_deimv2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_deimv2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from huggingface_hub.dataclasses import strict + +from ...backbone_utils import consolidate_backbone_kwargs_to_config +from ...configuration_utils import PreTrainedConfig +from ...utils import auto_docstring +from ..auto import AutoConfig + + +# TODO: Attribute map assignment logic should be fixed in modular +# as well as super() call parsing because otherwise we cannot re-write args after initialization +@auto_docstring(checkpoint="Intellindust/DEIMv2_HGNetv2_N_COCO") +@strict +class Deimv2Config(PreTrainedConfig): + r""" + initializer_bias_prior_prob (`float`, *optional*): + The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`. + If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights. + freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`): + Whether to freeze the batch normalization layers in the backbone. + encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`): + Multi level features input for encoder. + feat_strides (`list[int]`, *optional*, defaults to `[8, 16, 32]`): + Strides used in each feature map. + encode_proj_layers (`list[int]`, *optional*, defaults to `[2]`): + Indexes of the projected layers to be used in the encoder. + positional_encoding_temperature (`int`, *optional*, defaults to 10000): + The temperature parameter used to create the positional encodings. + encoder_activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. + eval_size (`list[int]` or `tuple[int, int]`, *optional*): + Height and width used to computes the effective height and width of the position embeddings after taking + into account the stride. + normalize_before (`bool`, *optional*, defaults to `False`): + Determine whether to apply layer normalization in the transformer encoder layer before self-attention and + feed-forward modules. + hidden_expansion (`float`, *optional*, defaults to 1.0): + Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer. + num_queries (`int`, *optional*, defaults to 300): + Number of object queries. + decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`): + Multi level features dimension for decoder. + num_feature_levels (`int`, *optional*, defaults to 3): + The number of input feature levels. + decoder_n_points (`int`, *optional*, defaults to 4): + The number of sampled keys in each feature level for each attention head in the decoder. + decoder_activation_function (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the decoder. + num_denoising (`int`, *optional*, defaults to 100): + The total number of denoising tasks or queries to be used for contrastive denoising. + label_noise_ratio (`float`, *optional*, defaults to 0.5): + The fraction of denoising labels to which random noise should be added. + box_noise_scale (`float`, *optional*, defaults to 1.0): + Scale or magnitude of noise to be added to the bounding boxes. + learn_initial_query (`bool`, *optional*, defaults to `False`): + Indicates whether the initial query embeddings for the decoder should be learned during training. + anchor_image_size (`tuple[int, int]`, *optional*): + Height and width of the input image used during evaluation to generate the bounding box anchors. + with_box_refine (`bool`, *optional*, defaults to `True`): + Whether to apply iterative bounding box refinement. + matcher_alpha (`float`, *optional*, defaults to 0.25): + Parameter alpha used by the Hungarian Matcher. + matcher_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used by the Hungarian Matcher. + matcher_class_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the class loss used by the Hungarian Matcher. + matcher_bbox_cost (`float`, *optional*, defaults to 5.0): + The relative weight of the bounding box loss used by the Hungarian Matcher. + matcher_giou_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the giou loss of used by the Hungarian Matcher. + use_focal_loss (`bool`, *optional*, defaults to `True`): + Parameter informing if focal loss should be used. + focal_loss_alpha (`float`, *optional*, defaults to 0.75): + Parameter alpha used to compute the focal loss. + focal_loss_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used to compute the focal loss. + weight_loss_vfl (`float`, *optional*, defaults to 1.0): + Relative weight of the varifocal loss in the object detection loss. + weight_loss_bbox (`float`, *optional*, defaults to 5.0): + Relative weight of the L1 bounding box loss in the object detection loss. + weight_loss_giou (`float`, *optional*, defaults to 2.0): + Relative weight of the generalized IoU loss in the object detection loss. + weight_loss_fgl (`float`, *optional*, defaults to 0.15): + Relative weight of the fine-grained localization loss in the object detection loss. + weight_loss_ddf (`float`, *optional*, defaults to 1.5): + Relative weight of the decoupled distillation focal loss in the object detection loss. + eval_idx (`int`, *optional*, defaults to -1): + Index of the decoder layer to use for evaluation. + layer_scale (`float`, *optional*, defaults to `1.0`): + Scaling factor for the hidden dimension in later decoder layers. + max_num_bins (`int`, *optional*, defaults to 32): + Maximum number of bins for the distribution-guided bounding box refinement. + reg_scale (`float`, *optional*, defaults to 4.0): + Scale factor for the regression distribution. + depth_mult (`float`, *optional*, defaults to 1.0): + Multiplier for the number of blocks in RepNCSPELAN5 layers. + top_prob_values (`int`, *optional*, defaults to 4): + Number of top probability values to consider from each corner's distribution. + lqe_hidden_dim (`int`, *optional*, defaults to 64): + Hidden dimension size for the Location Quality Estimator (LQE) network. + lqe_layers (`int`, *optional*, defaults to 2): + Number of layers in the Location Quality Estimator MLP. + decoder_offset_scale (`float`, *optional*, defaults to 0.5): + Offset scale used in deformable attention. + decoder_method (`str`, *optional*, defaults to `"default"`): + The method to use for the decoder: `"default"` or `"discrete"`. + up (`float`, *optional*, defaults to 0.5): + Controls the upper bounds of the Weighting Function. + weight_loss_mal (`float`, *optional*, defaults to 1.0): + Relative weight of the matching auxiliary loss in the object detection loss. + use_dense_one_to_one (`bool`, *optional*, defaults to `True`): + Whether to use dense one-to-one matching across decoder layers. + mal_alpha (`float`, *optional*): + Alpha parameter for the Matching Auxiliary Loss (MAL). If `None`, uses `focal_loss_alpha`. + encoder_fuse_op (`str`, *optional*, defaults to `"sum"`): + Fusion operation used in the encoder FPN. DEIMv2 uses `"sum"` instead of D-FINE's `"cat"`. + spatial_tuning_adapter_inplanes (`int`, *optional*, defaults to 16): + Number of input planes for the STA convolutional stem. + encoder_type (`str`, *optional*, defaults to `"hybrid"`): + Type of encoder to use. `"hybrid"` uses the full HybridEncoder with AIFI, FPN, and PAN. + `"lite"` uses the lightweight LiteEncoder with GAP fusion for smaller variants (Atto, Femto, Pico). + use_gateway (`bool`, *optional*, defaults to `True`): + Whether to use the gateway mechanism (cross-attention gating) in decoder layers. When `False`, + uses RMSNorm on the encoder attention output instead. + share_bbox_head (`bool`, *optional*, defaults to `False`): + Whether to share the bounding box prediction head across all decoder layers. + encoder_has_trailing_conv (`bool`, *optional*, defaults to `True`): + Whether the encoder's CSP blocks include a trailing 3x3 convolution after the bottleneck path. + `True` for RepNCSPELAN4 (used by HGNetV2 N and LiteEncoder variants). + `False` for RepNCSPELAN5 (used by DINOv3 variants). + """ + + model_type = "deimv2" + sub_configs = {"backbone_config": AutoConfig} + layer_types = ["basic", "bottleneck"] + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "encoder_attention_heads", + } + + initializer_range: float = 0.01 + initializer_bias_prior_prob: float | None = None + layer_norm_eps: float = 1e-5 + batch_norm_eps: float = 1e-5 + backbone_config: dict | PreTrainedConfig | None = None + freeze_backbone_batch_norms: bool = True + + # encoder HybridEncoder + encoder_hidden_dim: int = 256 + encoder_in_channels: list[int] | tuple[int, ...] = (512, 1024, 2048) + feat_strides: list[int] | tuple[int, ...] = (8, 16, 32) + encoder_layers: int = 1 + encoder_ffn_dim: int = 1024 + encoder_attention_heads: int = 8 + dropout: float | int = 0.0 + activation_dropout: float | int = 0.0 + encode_proj_layers: list[int] | tuple[int, ...] = (2,) + positional_encoding_temperature: int = 10000 + encoder_activation_function: str = "gelu" + activation_function: str = "silu" + + eval_size: list[int] | tuple[int, int] | None = None + normalize_before: bool = False + hidden_expansion: float = 1.0 + + # decoder Deimv2Transformer + d_model: int = 256 + num_queries: int = 300 + decoder_in_channels: list[int] | tuple[int, ...] = (256, 256, 256) + decoder_ffn_dim: int = 1024 + num_feature_levels: int = 3 + decoder_n_points: int | list[int] = 4 + decoder_layers: int = 6 + decoder_attention_heads: int = 8 + decoder_activation_function: str = "relu" + attention_dropout: float | int = 0.0 + num_denoising: int = 100 + label_noise_ratio: float = 0.5 + box_noise_scale: float = 1.0 + learn_initial_query: bool = False + anchor_image_size: int | list[int] | None = None + with_box_refine: bool = True + + # Loss + matcher_alpha: float = 0.25 + matcher_gamma: float = 2.0 + matcher_class_cost: float = 2.0 + matcher_bbox_cost: float = 5.0 + matcher_giou_cost: float = 2.0 + use_focal_loss: bool = True + auxiliary_loss: bool = True + focal_loss_alpha: float = 0.75 + focal_loss_gamma: float = 2.0 + weight_loss_vfl: float = 1.0 + weight_loss_bbox: float = 5.0 + weight_loss_giou: float = 2.0 + weight_loss_fgl: float = 0.15 + weight_loss_ddf: float = 1.5 + eos_coefficient: float = 1e-4 + eval_idx: int = -1 + layer_scale: int | float = 1.0 + max_num_bins: int = 32 + reg_scale: float = 4.0 + depth_mult: float = 1.0 + top_prob_values: int = 4 + lqe_hidden_dim: int = 64 + lqe_layers: int = 2 + decoder_offset_scale: float = 0.5 + decoder_method: str = "default" + up: float = 0.5 + tie_word_embeddings: bool = True + is_encoder_decoder: bool = True + weight_loss_mal: float = 1.0 + use_dense_one_to_one: bool = True + mal_alpha: float | None = None + encoder_fuse_op: str = "sum" + spatial_tuning_adapter_inplanes: int = 16 + encoder_type: str = "hybrid" + use_gateway: bool = True + share_bbox_head: bool = False + encoder_has_trailing_conv: bool = True + + def __post_init__(self, **kwargs): + self.backbone_config, kwargs = consolidate_backbone_kwargs_to_config( + backbone_config=self.backbone_config, + default_config_type="hgnet_v2", + default_config_kwargs={"out_indices": [2, 3, 4]}, + **kwargs, + ) + self.head_dim = self.d_model // self.decoder_attention_heads + super().__post_init__(**kwargs) + + def validate_architecture(self): + """Part of `@strict`-powered validation. Validates the architecture of the config.""" + if isinstance(self.decoder_n_points, list): + if len(self.decoder_n_points) != self.num_feature_levels: + raise ValueError( + f"Length of decoder_n_points list ({len(self.decoder_n_points)}) must match num_feature_levels ({self.num_feature_levels})." + ) + + if self.head_dim * self.decoder_attention_heads != self.d_model: + raise ValueError( + f"Embedded dimension {self.d_model} must be divisible by decoder_attention_heads {self.decoder_attention_heads}" + ) + + +__all__ = ["Deimv2Config"] diff --git a/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py b/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py new file mode 100644 index 000000000000..7207a95dffd0 --- /dev/null +++ b/src/transformers/models/deimv2/convert_deimv2_original_pytorch_checkpoint_to_hf.py @@ -0,0 +1,789 @@ +# Copyright 2026 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import re +from io import BytesIO +from pathlib import Path + +import httpx +import torch +from huggingface_hub import hf_hub_download +from PIL import Image +from safetensors.torch import load_file +from torchvision import transforms + +from transformers import Deimv2Config, Deimv2ForObjectDetection, RTDetrImageProcessor +from transformers.models.dinov3_vit.configuration_dinov3_vit import DINOv3ViTConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +HGNETV2_BACKBONE_CONFIGS = { + "B0": { + "hidden_sizes": [128, 256, 512, 1024], + "stem_channels": [3, 16, 16], + "stage_in_channels": [16, 64, 256, 512], + "stage_mid_channels": [16, 32, 64, 128], + "stage_out_channels": [64, 256, 512, 1024], + "stage_num_blocks": [1, 1, 2, 1], + "stage_downsample": [False, True, True, True], + "stage_light_block": [False, False, True, True], + "stage_kernel_size": [3, 3, 5, 5], + "stage_numb_of_layers": [3, 3, 3, 3], + }, + "Atto": { + "hidden_sizes": [64, 256, 256], + "stem_channels": [3, 16, 16], + "stage_in_channels": [16, 64, 256], + "stage_mid_channels": [16, 32, 64], + "stage_out_channels": [64, 256, 256], + "stage_num_blocks": [1, 1, 1], + "stage_downsample": [False, True, True], + "stage_light_block": [False, False, True], + "stage_kernel_size": [3, 3, 3], + "stage_numb_of_layers": [3, 3, 3], + }, + "Femto": { + "hidden_sizes": [64, 256, 512], + "stem_channels": [3, 16, 16], + "stage_in_channels": [16, 64, 256], + "stage_mid_channels": [16, 32, 64], + "stage_out_channels": [64, 256, 512], + "stage_num_blocks": [1, 1, 1], + "stage_downsample": [False, True, True], + "stage_light_block": [False, False, True], + "stage_kernel_size": [3, 3, 5], + "stage_numb_of_layers": [3, 3, 3], + }, + "Pico": { + "hidden_sizes": [64, 256, 512], + "stem_channels": [3, 16, 16], + "stage_in_channels": [16, 64, 256], + "stage_mid_channels": [16, 32, 64], + "stage_out_channels": [64, 256, 512], + "stage_num_blocks": [1, 1, 2], + "stage_downsample": [False, True, True], + "stage_light_block": [False, False, True], + "stage_kernel_size": [3, 3, 5], + "stage_numb_of_layers": [3, 3, 3], + }, +} +HGNETV2_BACKBONE_CONFIGS["B1"] = HGNETV2_BACKBONE_CONFIGS["B0"] +HGNETV2_BACKBONE_CONFIGS["B2"] = HGNETV2_BACKBONE_CONFIGS["B0"] + + +MODEL_NAME_TO_HUB_REPO = { + "deimv2_hgnetv2_n_coco": "Intellindust/DEIMv2_HGNetv2_N_COCO", + "deimv2_hgnetv2_pico_coco": "Intellindust/DEIMv2_HGNetv2_PICO_COCO", + "deimv2_hgnetv2_femto_coco": "Intellindust/DEIMv2_HGNetv2_FEMTO_COCO", + "deimv2_hgnetv2_atto_coco": "Intellindust/DEIMv2_HGNetv2_ATTO_COCO", + "deimv2_dinov3_s_coco": "Intellindust/DEIMv2_DINOv3_S_COCO", + "deimv2_dinov3_m_coco": "Intellindust/DEIMv2_DINOv3_M_COCO", + "deimv2_dinov3_l_coco": "Intellindust/DEIMv2_DINOv3_L_COCO", + "deimv2_dinov3_x_coco": "Intellindust/DEIMv2_DINOv3_X_COCO", +} + + +def get_deimv2_config(model_name: str) -> Deimv2Config: + repo_id = MODEL_NAME_TO_HUB_REPO[model_name] + config_path = hf_hub_download(repo_id=repo_id, filename="config.json") + with open(config_path) as f: + orig_config = json.load(f) + + # COCO labels + id2label = json.load( + open(hf_hub_download("huggingface/label-files", "coco-detection-mmdet-id2label.json", repo_type="dataset")) + ) + id2label = {int(k): v for k, v in id2label.items()} + + decoder_cfg = orig_config["DEIMTransformer"] + if "HybridEncoder" in orig_config: + encoder_cfg = orig_config["HybridEncoder"] + encoder_type = "hybrid" + elif "LiteEncoder" in orig_config: + encoder_cfg = orig_config["LiteEncoder"] + encoder_type = "lite" + else: + raise ValueError(f"No encoder config found. Available keys: {list(orig_config.keys())}") + + config = Deimv2Config() + config.num_labels = 80 + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + + # Encoder settings + config.encoder_type = encoder_type + config.encoder_hidden_dim = encoder_cfg["hidden_dim"] + config.encoder_in_channels = encoder_cfg["in_channels"] + config.feat_strides = encoder_cfg.get("feat_strides", [16] if encoder_type == "lite" else [8, 16, 32]) + config.activation_function = encoder_cfg.get("act", "silu") + config.depth_mult = float(encoder_cfg.get("depth_mult", 1.0)) + config.hidden_expansion = float(encoder_cfg.get("expansion", 1.0)) + config.encoder_fuse_op = encoder_cfg.get("fuse_op", "sum") + config.encoder_ffn_dim = encoder_cfg.get("dim_feedforward", 1024) + config.encoder_attention_heads = encoder_cfg.get("nhead", 8) + config.dropout = encoder_cfg.get("dropout", 0.0) + config.encode_proj_layers = encoder_cfg.get("use_encoder_idx", [2]) + config.encoder_activation_function = encoder_cfg.get("enc_act", "gelu") + if encoder_type == "lite": + config.encoder_layers = 0 + + # Decoder settings + config.d_model = decoder_cfg["hidden_dim"] + config.decoder_ffn_dim = decoder_cfg["dim_feedforward"] + config.decoder_layers = decoder_cfg["num_layers"] + config.num_feature_levels = decoder_cfg["num_levels"] + config.decoder_n_points = decoder_cfg["num_points"] + config.num_queries = decoder_cfg["num_queries"] + config.num_denoising = decoder_cfg.get("num_denoising", 100) + config.label_noise_ratio = float(decoder_cfg.get("label_noise_ratio", 0.5)) + config.box_noise_scale = float(decoder_cfg.get("box_noise_scale", 1.0)) + config.max_num_bins = decoder_cfg.get("reg_max", 32) + config.reg_scale = float(decoder_cfg.get("reg_scale", 4.0)) + config.eval_idx = decoder_cfg.get("eval_idx", -1) + config.layer_scale = decoder_cfg.get("layer_scale", 1) + config.decoder_in_channels = decoder_cfg["feat_channels"] + config.eval_size = list(decoder_cfg["eval_spatial_size"]) if "eval_spatial_size" in decoder_cfg else None + config.decoder_activation_function = decoder_cfg.get("activation", "silu") + config.share_bbox_head = decoder_cfg.get("share_bbox_head", False) + config.use_gateway = decoder_cfg.get("use_gateway", True) + + # Backbone settings + if "HGNetv2" in orig_config: + backbone_cfg = orig_config["HGNetv2"] + backbone_name = backbone_cfg.get("name", "B0") + return_idx = backbone_cfg.get("return_idx", [2, 3]) + config.backbone_config.out_indices = [i + 1 for i in return_idx] + config.backbone_config.use_learnable_affine_block = backbone_cfg.get("use_lab", True) + + if backbone_name not in HGNETV2_BACKBONE_CONFIGS: + raise ValueError(f"Unknown HGNetv2 variant: {backbone_name}") + for attr, value in HGNETV2_BACKBONE_CONFIGS[backbone_name].items(): + setattr(config.backbone_config, attr, value) + + num_stages = len(config.backbone_config.hidden_sizes) + config.backbone_config.depths = config.backbone_config.stage_numb_of_layers + config.backbone_config.stage_names = ["stem"] + [f"stage{i}" for i in range(1, num_stages + 1)] + elif "DINOv3STAs" in orig_config: + dinov3_cfg = orig_config["DINOv3STAs"] + name = dinov3_cfg["name"] + interaction_indexes = dinov3_cfg["interaction_indexes"] + config.spatial_tuning_adapter_inplanes = dinov3_cfg.get("conv_inplane", 16) + + is_dinov3 = "dinov3" in name + + DINOV3_PRESETS = { + "vit_tiny": { + "ffn_ratio": 4, + "use_gated_mlp": False, + "layerscale_value": 1.0, + "num_register_tokens": 0, + "pos_embed_rescale": None, + "key_bias": False, + }, + "vit_tinyplus": { + "ffn_ratio": 4, + "use_gated_mlp": False, + "layerscale_value": 1.0, + "num_register_tokens": 0, + "pos_embed_rescale": None, + "key_bias": False, + }, + "dinov3_vits16": { + "vit_hidden_size": 384, + "vit_num_heads": 6, + "ffn_ratio": 4, + "use_gated_mlp": False, + "layerscale_value": 1e-5, + "num_register_tokens": 4, + "pos_embed_rescale": 2.0, + "key_bias": True, + }, + "dinov3_vits16plus": { + "vit_hidden_size": 384, + "vit_num_heads": 6, + "ffn_ratio": 6, + "use_gated_mlp": True, + "layerscale_value": 1e-5, + "num_register_tokens": 4, + "pos_embed_rescale": 2.0, + "key_bias": True, + }, + } + preset = DINOV3_PRESETS[name] + + if is_dinov3: + vit_hidden_size = preset["vit_hidden_size"] + vit_num_heads = preset["vit_num_heads"] + else: + vit_hidden_size = dinov3_cfg.get("embed_dim", 192) + vit_num_heads = dinov3_cfg.get("num_heads", 3) + + ffn_ratio = preset["ffn_ratio"] + if preset["use_gated_mlp"]: + hidden_features = vit_hidden_size * ffn_ratio + d = int(hidden_features * 2 / 3) + intermediate_size = d + (-d % 8) + else: + intermediate_size = vit_hidden_size * ffn_ratio + + out_indices = [idx + 1 for idx in interaction_indexes] + config.backbone_config = DINOv3ViTConfig( + hidden_size=vit_hidden_size, + num_attention_heads=vit_num_heads, + num_hidden_layers=12, + intermediate_size=intermediate_size, + layerscale_value=preset["layerscale_value"], + use_gated_mlp=preset["use_gated_mlp"], + num_register_tokens=preset["num_register_tokens"], + pos_embed_rescale=preset["pos_embed_rescale"], + key_bias=preset["key_bias"], + rope_theta=100.0, + out_indices=out_indices, + apply_layernorm=is_dinov3, + reshape_hidden_states=True, + ) + config.encoder_has_trailing_conv = False + else: + raise ValueError(f"Unknown backbone in config: {list(orig_config.keys())}") + + config.head_dim = config.d_model // config.decoder_attention_heads + return config + + +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + # Backbone stem mappings + r"backbone\.stem\.(stem\w+)\.conv\.weight": r"model.conv_encoder.model.embedder.\1.convolution.weight", + # Stem normalization + r"backbone\.stem\.(stem\w+)\.bn\.(weight|bias|running_mean|running_var)": r"model.conv_encoder.model.embedder.\1.normalization.\2", + # Stem lab parameters + r"backbone\.stem\.(stem\w+)\.lab\.(scale|bias)": r"model.conv_encoder.model.embedder.\1.lab.\2", + # Backbone stages mappings + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv\.weight": r"model.conv_encoder.model.encoder.stages.\1.blocks.\2.layers.\3.convolution.weight", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.bn\.(weight|bias|running_mean|running_var)": r"model.conv_encoder.model.encoder.stages.\1.blocks.\2.layers.\3.normalization.\4", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.lab\.(scale|bias)": r"model.conv_encoder.model.encoder.stages.\1.blocks.\2.layers.\3.lab.\4", + # Conv1/Conv2 layers + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv1\.conv\.weight": r"model.conv_encoder.model.encoder.stages.\1.blocks.\2.layers.\3.conv1.convolution.weight", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv1\.bn\.(weight|bias|running_mean|running_var)": r"model.conv_encoder.model.encoder.stages.\1.blocks.\2.layers.\3.conv1.normalization.\4", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv1\.lab\.(scale|bias)": r"model.conv_encoder.model.encoder.stages.\1.blocks.\2.layers.\3.conv1.lab.\4", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv2\.conv\.weight": r"model.conv_encoder.model.encoder.stages.\1.blocks.\2.layers.\3.conv2.convolution.weight", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv2\.bn\.(weight|bias|running_mean|running_var)": r"model.conv_encoder.model.encoder.stages.\1.blocks.\2.layers.\3.conv2.normalization.\4", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.layers\.(\d+)\.conv2\.lab\.(scale|bias)": r"model.conv_encoder.model.encoder.stages.\1.blocks.\2.layers.\3.conv2.lab.\4", + # Backbone stages aggregation + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.aggregation\.(\d+)\.conv\.weight": r"model.conv_encoder.model.encoder.stages.\1.blocks.\2.aggregation.\3.convolution.weight", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.aggregation\.(\d+)\.bn\.(weight|bias|running_mean|running_var)": r"model.conv_encoder.model.encoder.stages.\1.blocks.\2.aggregation.\3.normalization.\4", + r"backbone\.stages\.(\d+)\.blocks\.(\d+)\.aggregation\.(\d+)\.lab\.(scale|bias)": r"model.conv_encoder.model.encoder.stages.\1.blocks.\2.aggregation.\3.lab.\4", + # Downsample + r"backbone\.stages\.(\d+)\.downsample\.conv\.weight": r"model.conv_encoder.model.encoder.stages.\1.downsample.convolution.weight", + r"backbone\.stages\.(\d+)\.downsample\.bn\.(weight|bias|running_mean|running_var)": r"model.conv_encoder.model.encoder.stages.\1.downsample.normalization.\2", + r"backbone\.stages\.(\d+)\.downsample\.lab\.(scale|bias)": r"model.conv_encoder.model.encoder.stages.\1.downsample.lab.\2", + # Encoder mappings + # Input projections + r"encoder\.input_proj\.(\d+)\.conv\.weight": r"model.conv_encoder.encoder_input_proj.\1.conv.weight", + r"encoder\.input_proj\.(\d+)\.norm\.(weight|bias|running_mean|running_var)": r"model.conv_encoder.encoder_input_proj.\1.norm.\2", + # AIFI transformer encoder layers + r"encoder\.encoder\.(\d+)\.layers\.0\.self_attn\.out_proj\.(weight|bias)": r"model.encoder.aifi.\1.layers.0.self_attn.o_proj.\2", + r"encoder\.encoder\.(\d+)\.layers\.0\.linear1\.(weight|bias)": r"model.encoder.aifi.\1.layers.0.mlp.layers.0.\2", + r"encoder\.encoder\.(\d+)\.layers\.0\.linear2\.(weight|bias)": r"model.encoder.aifi.\1.layers.0.mlp.layers.1.\2", + r"encoder\.encoder\.(\d+)\.layers\.0\.norm1\.(weight|bias)": r"model.encoder.aifi.\1.layers.0.self_attn_layer_norm.\2", + r"encoder\.encoder\.(\d+)\.layers\.0\.norm2\.(weight|bias)": r"model.encoder.aifi.\1.layers.0.final_layer_norm.\2", + # Encoder projections and convolutions + r"encoder\.lateral_convs\.(\d+)\.conv\.weight": r"model.encoder.lateral_convs.\1.conv.weight", + r"encoder\.lateral_convs\.(\d+)\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.lateral_convs.\1.norm.\2", + # FPN blocks - complete structure + # Basic convolutions + r"encoder\.fpn_blocks\.(\d+)\.cv1\.conv\.weight": r"model.encoder.fpn_blocks.\1.conv1.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.conv1.norm.\2", + r"encoder\.fpn_blocks\.(\d+)\.cv4\.conv\.weight": r"model.encoder.fpn_blocks.\1.conv2.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv4\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.conv2.norm.\2", + # CSP Rep1 path + r"encoder\.fpn_blocks\.(\d+)\.cv2\.0\.conv1\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep1.conv1.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv2\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep1.conv1.norm.\2", + r"encoder\.fpn_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep1.bottlenecks.\2.conv1.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep1.bottlenecks.\2.conv1.norm.\3", + r"encoder\.fpn_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep1.bottlenecks.\2.conv2.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep1.bottlenecks.\2.conv2.norm.\3", + # CSP Rep2 path + r"encoder\.fpn_blocks\.(\d+)\.cv3\.0\.conv1\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep2.conv1.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv3\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep2.conv1.norm.\2", + r"encoder\.fpn_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep2.bottlenecks.\2.conv1.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep2.bottlenecks.\2.conv1.norm.\3", + r"encoder\.fpn_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep2.bottlenecks.\2.conv2.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep2.bottlenecks.\2.conv2.norm.\3", + # FPN trailing convs + r"encoder\.fpn_blocks\.(\d+)\.cv2\.1\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep1.conv2.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv2\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep1.conv2.norm.\2", + r"encoder\.fpn_blocks\.(\d+)\.cv3\.1\.conv\.weight": r"model.encoder.fpn_blocks.\1.csp_rep2.conv2.conv.weight", + r"encoder\.fpn_blocks\.(\d+)\.cv3\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_blocks.\1.csp_rep2.conv2.norm.\2", + # PAN blocks - complete structure + r"encoder\.pan_blocks\.(\d+)\.cv1\.conv\.weight": r"model.encoder.pan_blocks.\1.conv1.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.conv1.norm.\2", + r"encoder\.pan_blocks\.(\d+)\.cv4\.conv\.weight": r"model.encoder.pan_blocks.\1.conv2.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv4\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.conv2.norm.\2", + # CSP Rep1 path + r"encoder\.pan_blocks\.(\d+)\.cv2\.0\.conv1\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep1.conv1.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv2\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep1.conv1.norm.\2", + r"encoder\.pan_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep1.bottlenecks.\2.conv1.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep1.bottlenecks.\2.conv1.norm.\3", + r"encoder\.pan_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep1.bottlenecks.\2.conv2.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep1.bottlenecks.\2.conv2.norm.\3", + # CSP Rep2 path + r"encoder\.pan_blocks\.(\d+)\.cv3\.0\.conv1\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep2.conv1.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv3\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep2.conv1.norm.\2", + r"encoder\.pan_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep2.bottlenecks.\2.conv1.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep2.bottlenecks.\2.conv1.norm.\3", + r"encoder\.pan_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep2.bottlenecks.\2.conv2.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep2.bottlenecks.\2.conv2.norm.\3", + # PAN trailing convs + r"encoder\.pan_blocks\.(\d+)\.cv2\.1\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep1.conv2.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv2\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep1.conv2.norm.\2", + r"encoder\.pan_blocks\.(\d+)\.cv3\.1\.conv\.weight": r"model.encoder.pan_blocks.\1.csp_rep2.conv2.conv.weight", + r"encoder\.pan_blocks\.(\d+)\.cv3\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_blocks.\1.csp_rep2.conv2.norm.\2", + # Downsample convolutions + r"encoder\.downsample_convs\.(\d+)\.0\.cv(\d+)\.conv\.weight": r"model.encoder.downsample_convs.\1.conv\2.conv.weight", + r"encoder\.downsample_convs\.(\d+)\.0\.cv(\d+)\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.downsample_convs.\1.conv\2.norm.\3", + # Decoder layers + r"decoder\.input_proj\.(\d+)\.0\.weight": r"model.decoder_input_proj.\1.conv.weight", + r"decoder\.input_proj\.(\d+)\.1\.(weight|bias|running_mean|running_var)": r"model.decoder_input_proj.\1.norm.\2", + r"decoder\.decoder\.layers\.(\d+)\.self_attn\.out_proj\.(weight|bias)": r"model.decoder.layers.\1.self_attn.o_proj.\2", + r"decoder\.decoder\.layers\.(\d+)\.cross_attn\.sampling_offsets\.(weight|bias)": r"model.decoder.layers.\1.encoder_attn.sampling_offsets.\2", + r"decoder\.decoder\.layers\.(\d+)\.cross_attn\.attention_weights\.(weight|bias)": r"model.decoder.layers.\1.encoder_attn.attention_weights.\2", + r"decoder\.decoder\.layers\.(\d+)\.cross_attn\.value_proj\.(weight|bias)": r"model.decoder.layers.\1.encoder_attn.value_proj.\2", + r"decoder\.decoder\.layers\.(\d+)\.cross_attn\.output_proj\.(weight|bias)": r"model.decoder.layers.\1.encoder_attn.output_proj.\2", + r"decoder\.decoder\.layers\.(\d+)\.cross_attn\.num_points_scale": r"model.decoder.layers.\1.encoder_attn.num_points_scale", + r"decoder\.decoder\.layers\.(\d+)\.norm1\.scale": r"model.decoder.layers.\1.self_attn_layer_norm.weight", + r"decoder\.decoder\.layers\.(\d+)\.norm3\.scale": r"model.decoder.layers.\1.final_layer_norm.weight", + r"decoder\.decoder\.layers\.(\d+)\.swish_ffn\.w3\.(weight|bias)": r"model.decoder.layers.\1.mlp.down_proj.\2", + r"decoder\.decoder\.layers\.(\d+)\.gateway\.gate\.(weight|bias)": r"model.decoder.layers.\1.gateway.gate.\2", + r"decoder\.decoder\.layers\.(\d+)\.gateway\.norm\.scale": r"model.decoder.layers.\1.gateway.norm.weight", + # LQE layers + r"decoder\.decoder\.lqe_layers\.(\d+)\.reg_conf\.layers\.(\d+)\.(weight|bias)": r"model.decoder.lqe_layers.\1.reg_conf.layers.\2.\3", + # Decoder heads and projections + r"decoder\.dec_score_head\.(\d+)\.(weight|bias)": r"model.decoder.class_embed.\1.\2", + r"decoder\.dec_bbox_head\.(\d+)\.layers\.(\d+)\.(weight|bias)": r"model.decoder.bbox_embed.\1.layers.\2.\3", + r"decoder\.pre_bbox_head\.layers\.(\d+)\.(weight|bias)": r"model.decoder.pre_bbox_head.layers.\1.\2", + r"decoder\.query_pos_head\.layers\.(\d+)\.(weight|bias)": r"model.decoder.query_pos_head.layers.\1.\2", + # Encoder output and score heads + r"decoder\.enc_output\.proj\.(weight|bias)": r"model.enc_output.0.\1", + r"decoder\.enc_output\.norm\.(weight|bias)": r"model.enc_output.1.\1", + r"decoder\.enc_score_head\.(weight|bias)": r"model.enc_score_head.\1", + r"decoder\.enc_bbox_head\.layers\.(\d+)\.(weight|bias)": r"model.enc_bbox_head.layers.\1.\2", + # Denoising class embed + r"decoder\.denoising_class_embed\.weight": r"model.denoising_class_embed.weight", + # Decoder parameters + r"decoder\.decoder\.up": r"model.decoder.up", + r"decoder\.decoder\.reg_scale": r"model.decoder.reg_scale", +} + +LITE_ENCODER_KEY_MAPPING = { + # LiteEncoder input_proj + r"encoder\.input_proj\.(\d+)\.conv\.weight": r"model.encoder.input_proj.\1.conv.weight", + r"encoder\.input_proj\.(\d+)\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.input_proj.\1.norm.\2", + # Downsamples + r"encoder\.down_sample1\.1\.weight": r"model.encoder.down_conv1.conv.weight", + r"encoder\.down_sample1\.2\.(weight|bias|running_mean|running_var)": r"model.encoder.down_conv1.norm.\1", + r"encoder\.down_sample2\.1\.weight": r"model.encoder.down_conv2.conv.weight", + r"encoder\.down_sample2\.2\.(weight|bias|running_mean|running_var)": r"model.encoder.down_conv2.norm.\1", + # GAP_Fusion + r"encoder\.bi_fusion\.cv\.conv\.weight": r"model.encoder.bi_fusion_conv.conv.weight", + r"encoder\.bi_fusion\.cv\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.bi_fusion_conv.norm.\1", + # FPN block (RepNCSPELAN5) + r"encoder\.fpn_block\.cv1\.conv\.weight": r"model.encoder.fpn_block.conv1.conv.weight", + r"encoder\.fpn_block\.cv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.conv1.norm.\1", + r"encoder\.fpn_block\.cv4\.conv\.weight": r"model.encoder.fpn_block.conv2.conv.weight", + r"encoder\.fpn_block\.cv4\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.conv2.norm.\1", + r"encoder\.fpn_block\.cv2\.0\.conv1\.conv\.weight": r"model.encoder.fpn_block.csp_rep1.conv1.conv.weight", + r"encoder\.fpn_block\.cv2\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep1.conv1.norm.\1", + r"encoder\.fpn_block\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.fpn_block.csp_rep1.bottlenecks.\1.conv1.conv.weight", + r"encoder\.fpn_block\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep1.bottlenecks.\1.conv1.norm.\2", + r"encoder\.fpn_block\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.fpn_block.csp_rep1.bottlenecks.\1.conv2.conv.weight", + r"encoder\.fpn_block\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep1.bottlenecks.\1.conv2.norm.\2", + r"encoder\.fpn_block\.cv2\.1\.conv\.weight": r"model.encoder.fpn_block.csp_rep1.conv2.conv.weight", + r"encoder\.fpn_block\.cv2\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep1.conv2.norm.\1", + r"encoder\.fpn_block\.cv3\.0\.conv1\.conv\.weight": r"model.encoder.fpn_block.csp_rep2.conv1.conv.weight", + r"encoder\.fpn_block\.cv3\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep2.conv1.norm.\1", + r"encoder\.fpn_block\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.fpn_block.csp_rep2.bottlenecks.\1.conv1.conv.weight", + r"encoder\.fpn_block\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep2.bottlenecks.\1.conv1.norm.\2", + r"encoder\.fpn_block\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.fpn_block.csp_rep2.bottlenecks.\1.conv2.conv.weight", + r"encoder\.fpn_block\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep2.bottlenecks.\1.conv2.norm.\2", + r"encoder\.fpn_block\.cv3\.1\.conv\.weight": r"model.encoder.fpn_block.csp_rep2.conv2.conv.weight", + r"encoder\.fpn_block\.cv3\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.fpn_block.csp_rep2.conv2.norm.\1", + # PAN block (same structure as FPN) + r"encoder\.pan_block\.cv1\.conv\.weight": r"model.encoder.pan_block.conv1.conv.weight", + r"encoder\.pan_block\.cv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.conv1.norm.\1", + r"encoder\.pan_block\.cv4\.conv\.weight": r"model.encoder.pan_block.conv2.conv.weight", + r"encoder\.pan_block\.cv4\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.conv2.norm.\1", + r"encoder\.pan_block\.cv2\.0\.conv1\.conv\.weight": r"model.encoder.pan_block.csp_rep1.conv1.conv.weight", + r"encoder\.pan_block\.cv2\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep1.conv1.norm.\1", + r"encoder\.pan_block\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.pan_block.csp_rep1.bottlenecks.\1.conv1.conv.weight", + r"encoder\.pan_block\.cv2\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep1.bottlenecks.\1.conv1.norm.\2", + r"encoder\.pan_block\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.pan_block.csp_rep1.bottlenecks.\1.conv2.conv.weight", + r"encoder\.pan_block\.cv2\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep1.bottlenecks.\1.conv2.norm.\2", + r"encoder\.pan_block\.cv2\.1\.conv\.weight": r"model.encoder.pan_block.csp_rep1.conv2.conv.weight", + r"encoder\.pan_block\.cv2\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep1.conv2.norm.\1", + r"encoder\.pan_block\.cv3\.0\.conv1\.conv\.weight": r"model.encoder.pan_block.csp_rep2.conv1.conv.weight", + r"encoder\.pan_block\.cv3\.0\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep2.conv1.norm.\1", + r"encoder\.pan_block\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.conv\.weight": r"model.encoder.pan_block.csp_rep2.bottlenecks.\1.conv1.conv.weight", + r"encoder\.pan_block\.cv3\.0\.bottlenecks\.(\d+)\.conv1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep2.bottlenecks.\1.conv1.norm.\2", + r"encoder\.pan_block\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.conv\.weight": r"model.encoder.pan_block.csp_rep2.bottlenecks.\1.conv2.conv.weight", + r"encoder\.pan_block\.cv3\.0\.bottlenecks\.(\d+)\.conv2\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep2.bottlenecks.\1.conv2.norm.\2", + r"encoder\.pan_block\.cv3\.1\.conv\.weight": r"model.encoder.pan_block.csp_rep2.conv2.conv.weight", + r"encoder\.pan_block\.cv3\.1\.norm\.(weight|bias|running_mean|running_var)": r"model.encoder.pan_block.csp_rep2.conv2.norm.\1", +} + +DECODER_NO_GATEWAY_KEY_MAPPING = { + r"decoder\.decoder\.layers\.(\d+)\.norm2\.scale": r"model.decoder.layers.\1.encoder_attn_layer_norm.weight", +} + +DINOV3_KEY_MAPPING = { + # ViT embeddings + r"backbone\.dinov3\.patch_embed\.proj\.(weight|bias)": r"model.conv_encoder.backbone.embeddings.patch_embeddings.\1", + r"backbone\.dinov3\.cls_token": r"model.conv_encoder.backbone.embeddings.cls_token", + r"backbone\.dinov3\.storage_tokens": r"model.conv_encoder.backbone.embeddings.register_tokens", + r"backbone\.dinov3\.mask_token": r"model.conv_encoder.backbone.embeddings.mask_token", + # ViT blocks + r"backbone\.dinov3\.blocks\.(\d+)\.norm1\.(weight|bias)": r"model.conv_encoder.backbone.model.layer.\1.norm1.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.norm2\.(weight|bias)": r"model.conv_encoder.backbone.model.layer.\1.norm2.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.attn\.qkv\.(weight|bias)": r"model.conv_encoder.backbone.model.layer.\1.attention.qkv.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.attn\.proj\.(weight|bias)": r"model.conv_encoder.backbone.model.layer.\1.attention.o_proj.\2", + # Standard MLP (S/M/L) + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.fc1\.(weight|bias)": r"model.conv_encoder.backbone.model.layer.\1.mlp.up_proj.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.fc2\.(weight|bias)": r"model.conv_encoder.backbone.model.layer.\1.mlp.down_proj.\2", + # SwiGLU MLP (X only) + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.w1\.(weight|bias)": r"model.conv_encoder.backbone.model.layer.\1.mlp.gate_proj.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.w2\.(weight|bias)": r"model.conv_encoder.backbone.model.layer.\1.mlp.up_proj.\2", + r"backbone\.dinov3\.blocks\.(\d+)\.mlp\.w3\.(weight|bias)": r"model.conv_encoder.backbone.model.layer.\1.mlp.down_proj.\2", + # LayerScale (L/X only) + r"backbone\.dinov3\.blocks\.(\d+)\.ls1\.gamma": r"model.conv_encoder.backbone.model.layer.\1.layer_scale1.lambda1", + r"backbone\.dinov3\.blocks\.(\d+)\.ls2\.gamma": r"model.conv_encoder.backbone.model.layer.\1.layer_scale2.lambda1", + # Norm (L/X only) + r"backbone\.dinov3\.norm\.(weight|bias)": r"model.conv_encoder.backbone.norm.\1", + # STA adapter + r"backbone\.sta\.stem\.0\.(weight)": r"model.conv_encoder.spatial_tuning_adapter.stem_conv.conv.\1", + r"backbone\.sta\.stem\.1\.(weight|bias|running_mean|running_var)": r"model.conv_encoder.spatial_tuning_adapter.stem_conv.norm.\1", + r"backbone\.sta\.conv2\.0\.(weight)": r"model.conv_encoder.spatial_tuning_adapter.conv2.conv.\1", + r"backbone\.sta\.conv2\.1\.(weight|bias|running_mean|running_var)": r"model.conv_encoder.spatial_tuning_adapter.conv2.norm.\1", + r"backbone\.sta\.conv3\.1\.(weight)": r"model.conv_encoder.spatial_tuning_adapter.conv3.conv.\1", + r"backbone\.sta\.conv3\.2\.(weight|bias|running_mean|running_var)": r"model.conv_encoder.spatial_tuning_adapter.conv3.norm.\1", + r"backbone\.sta\.conv4\.1\.(weight)": r"model.conv_encoder.spatial_tuning_adapter.conv4.conv.\1", + r"backbone\.sta\.conv4\.2\.(weight|bias|running_mean|running_var)": r"model.conv_encoder.spatial_tuning_adapter.conv4.norm.\1", + # Fusion projection convs/norms + r"backbone\.convs\.(\d+)\.weight": r"model.conv_encoder.fusion_proj.\1.conv.weight", + r"backbone\.norms\.(\d+)\.(weight|bias|running_mean|running_var)": r"model.conv_encoder.fusion_proj.\1.norm.\2", +} + + +def convert_old_keys_to_new_keys(state_dict, config=None): + mapping = dict(ORIGINAL_TO_CONVERTED_KEY_MAPPING) + + is_dinov3 = getattr(config.backbone_config, "model_type", None) == "dinov3_vit" if config else False + + if config is not None: + if config.encoder_type == "lite": + for k in list(mapping.keys()): + if ( + k.startswith(r"encoder\.input_proj") + or k.startswith(r"encoder\.lateral") + or k.startswith(r"encoder\.fpn_blocks") + or k.startswith(r"encoder\.pan_blocks") + or k.startswith(r"encoder\.downsample") + or k.startswith(r"encoder\.encoder") + ): + del mapping[k] + mapping.update(LITE_ENCODER_KEY_MAPPING) + + if not config.use_gateway: + mapping.update(DECODER_NO_GATEWAY_KEY_MAPPING) + for k in list(mapping.keys()): + if "gateway" in k: + del mapping[k] + + if is_dinov3: + for k in list(mapping.keys()): + if k.startswith(r"backbone\.") or k.startswith(r"encoder\.input_proj"): + del mapping[k] + mapping.update(DINOV3_KEY_MAPPING) + + for original_key, converted_key in mapping.items(): + for key in list(state_dict.keys()): + new_key = re.sub(f"^{original_key}$", converted_key, key) + if new_key != key: + state_dict[new_key] = state_dict.pop(key) + return state_dict + + +def read_in_q_k_v(state_dict, config): + encoder_hidden_dim = config.encoder_hidden_dim + d_model = config.d_model + + # first: transformer encoder + for i in range(config.encoder_layers): + # read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"encoder.encoder.{i}.layers.0.self_attn.in_proj_weight", None) + in_proj_bias = state_dict.pop(f"encoder.encoder.{i}.layers.0.self_attn.in_proj_bias", None) + if in_proj_weight is not None: + # next, add query, keys and values (in that order) to the state dict + state_dict[f"model.encoder.aifi.{i}.layers.0.self_attn.q_proj.weight"] = in_proj_weight[ + :encoder_hidden_dim + ] + state_dict[f"model.encoder.aifi.{i}.layers.0.self_attn.k_proj.weight"] = in_proj_weight[ + encoder_hidden_dim : 2 * encoder_hidden_dim + ] + state_dict[f"model.encoder.aifi.{i}.layers.0.self_attn.v_proj.weight"] = in_proj_weight[ + -encoder_hidden_dim: + ] + if in_proj_bias is not None: + state_dict[f"model.encoder.aifi.{i}.layers.0.self_attn.q_proj.bias"] = in_proj_bias[:encoder_hidden_dim] + state_dict[f"model.encoder.aifi.{i}.layers.0.self_attn.k_proj.bias"] = in_proj_bias[ + encoder_hidden_dim : 2 * encoder_hidden_dim + ] + state_dict[f"model.encoder.aifi.{i}.layers.0.self_attn.v_proj.bias"] = in_proj_bias[-encoder_hidden_dim:] + + # next: transformer decoder (which is a bit more complex because it also includes cross-attention) + for i in range(config.decoder_layers): + # read in weights + bias of input projection layer of self-attention + in_proj_weight = state_dict.pop(f"decoder.decoder.layers.{i}.self_attn.in_proj_weight", None) + in_proj_bias = state_dict.pop(f"decoder.decoder.layers.{i}.self_attn.in_proj_bias", None) + if in_proj_weight is not None: + # next, add query, keys and values (in that order) to the state dict + state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:d_model] + state_dict[f"model.decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:d_model] + state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[d_model : 2 * d_model] + state_dict[f"model.decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[d_model : 2 * d_model] + state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-d_model:] + state_dict[f"model.decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-d_model:] + + +def split_swiglu_fused_weights(state_dict, config): + for i in range( + 2 * config.decoder_layers + - (config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx) + - 1 + ): + for param in ["weight", "bias"]: + fused_key = f"decoder.decoder.layers.{i}.swish_ffn.w12.{param}" + if fused_key in state_dict: + fused = state_dict.pop(fused_key) + gate, up = fused.chunk(2, dim=0) + state_dict[f"model.decoder.layers.{i}.mlp.gate_proj.{param}"] = gate + state_dict[f"model.decoder.layers.{i}.mlp.up_proj.{param}"] = up + + +def strip_dinov3_model_prefix(state_dict): + for key in list(state_dict.keys()): + if "backbone.dinov3._model." in key: + new_key = key.replace("backbone.dinov3._model.", "backbone.dinov3.") + state_dict[new_key] = state_dict.pop(key) + return state_dict + + +def read_in_q_k_v_vit(state_dict, config): + has_key_bias = getattr(config.backbone_config, "key_bias", True) + prefix = "model.conv_encoder.backbone.model" + for i in range(config.backbone_config.num_hidden_layers): + qkv_key = f"{prefix}.layer.{i}.attention.qkv.weight" + if qkv_key in state_dict: + qkv_w = state_dict.pop(qkv_key) + q, k, v = qkv_w.chunk(3, dim=0) + state_dict[f"{prefix}.layer.{i}.attention.q_proj.weight"] = q + state_dict[f"{prefix}.layer.{i}.attention.k_proj.weight"] = k + state_dict[f"{prefix}.layer.{i}.attention.v_proj.weight"] = v + + qkv_bias_key = f"{prefix}.layer.{i}.attention.qkv.bias" + if qkv_bias_key in state_dict: + qkv_b = state_dict.pop(qkv_bias_key) + q_b, k_b, v_b = qkv_b.chunk(3, dim=0) + state_dict[f"{prefix}.layer.{i}.attention.q_proj.bias"] = q_b + if has_key_bias: + state_dict[f"{prefix}.layer.{i}.attention.k_proj.bias"] = k_b + state_dict[f"{prefix}.layer.{i}.attention.v_proj.bias"] = v_b + + +def load_original_state_dict(repo_id): + filepath = hf_hub_download(repo_id=repo_id, filename="model.safetensors") + return load_file(filepath) + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + with httpx.stream("GET", url) as response: + image = Image.open(BytesIO(response.read())) + return image + + +@torch.no_grad() +def convert_deimv2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, repo_id): + """ + Copy/paste/tweak model's weights to our Deimv2 structure. + """ + hub_repo = MODEL_NAME_TO_HUB_REPO[model_name] + config = get_deimv2_config(model_name) + state_dict = load_original_state_dict(hub_repo) + + logger.info(f"Converting model {model_name} from {hub_repo}...") + logger.info(f"Original state dict has {len(state_dict)} keys") + + state_dict.pop("decoder.valid_mask", None) + state_dict.pop("decoder.anchors", None) + + for key in list(state_dict.keys()): + if key.endswith(".num_batches_tracked"): + state_dict.pop(key) + + is_dinov3 = getattr(config.backbone_config, "model_type", None) == "dinov3_vit" + + if is_dinov3: + strip_dinov3_model_prefix(state_dict) + for key in list(state_dict.keys()): + if "rope_embed.periods" in key or "qkv.bias_mask" in key: + state_dict.pop(key) + + # query, key and value matrices need special treatment + read_in_q_k_v(state_dict, config) + + # split fused SwiGLU weights (w12) into separate gate_proj and up_proj + split_swiglu_fused_weights(state_dict, config) + + state_dict = convert_old_keys_to_new_keys(state_dict, config) + + if is_dinov3: + read_in_q_k_v_vit(state_dict, config) + mask_key = "model.conv_encoder.backbone.embeddings.mask_token" + if mask_key in state_dict and state_dict[mask_key].dim() == 2: + state_dict[mask_key] = state_dict[mask_key].unsqueeze(1) + + if "model.enc_output.0.weight" not in state_dict: + d_model = config.d_model + state_dict["model.enc_output.0.weight"] = torch.eye(d_model) + state_dict["model.enc_output.0.bias"] = torch.zeros(d_model) + state_dict["model.enc_output.1.weight"] = torch.ones(d_model) + state_dict["model.enc_output.1.bias"] = torch.zeros(d_model) + + if config.share_bbox_head: + num_decoder_layers = config.decoder_layers + for key in list(state_dict.keys()): + if "model.decoder.bbox_embed.0." in key: + for i in range(1, num_decoder_layers): + new_key = key.replace("bbox_embed.0.", f"bbox_embed.{i}.") + if new_key not in state_dict: + state_dict[new_key] = state_dict[key] + + # for two_stage + for key in list(state_dict.keys()): + if "bbox_embed" in key or ("class_embed" in key and "denoising_" not in key): + new_key = key.split("model.decoder.")[-1] + if new_key != key and new_key not in state_dict: + state_dict[new_key] = state_dict[key] + + model = Deimv2ForObjectDetection(config) + missing, unexpected = model.load_state_dict(state_dict, strict=False) + + expected_missing = {"mask_token", "register_tokens", "layer_scale1", "layer_scale2", "backbone.norm"} + unexpected_missing = [k for k in missing if not any(e in k for e in expected_missing)] + if unexpected_missing: + logger.warning(f"Missing keys ({len(unexpected_missing)}): {unexpected_missing[:10]}...") + elif missing: + logger.info( + f"All {len(missing)} missing keys are expected model-init defaults (mask_token, register_tokens, layer_scale)" + ) + if unexpected: + logger.warning(f"Unexpected keys ({len(unexpected)}): {unexpected[:10]}...") + + model.eval() + + if is_dinov3: + image_processor = RTDetrImageProcessor( + do_normalize=True, + image_mean=[0.485, 0.456, 0.406], + image_std=[0.229, 0.224, 0.225], + ) + else: + image_processor = RTDetrImageProcessor() + + img = prepare_img() + + if is_dinov3: + transformations = transforms.Compose( + [ + transforms.Resize([640, 640], interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + else: + transformations = transforms.Compose( + [ + transforms.Resize([640, 640], interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + ] + ) + original_pixel_values = transformations(img).unsqueeze(0) + encoding = image_processor(images=img, return_tensors="pt") + pixel_values = encoding["pixel_values"] + + if not torch.allclose(original_pixel_values, pixel_values, atol=1e-4): + max_diff = (original_pixel_values - pixel_values).abs().max().item() + logger.warning(f"Image preprocessing mismatch! Max diff: {max_diff:.6f}") + if max_diff > 1e-2: + raise ValueError(f"Image preprocessing mismatch too large: {max_diff}") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + pixel_values = pixel_values.to(device) + + outputs = model(pixel_values) + logger.info(f"Logits shape: {outputs.logits.shape}") + logger.info(f"Boxes shape: {outputs.pred_boxes.shape}") + logger.info(f"Logits sample: {outputs.logits[0, :3, :3]}") + logger.info(f"Boxes sample: {outputs.pred_boxes[0, :3, :3]}") + + if pytorch_dump_folder_path is not None: + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + logger.info(f"Saving model to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + logger.info(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + push_repo = repo_id or f"deimv2-{model_name}" + logger.info(f"Pushing to hub: {push_repo}") + config.push_to_hub(repo_id=push_repo) + model.push_to_hub(repo_id=push_repo) + image_processor.push_to_hub(repo_id=push_repo) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + default="deimv2_hgnetv2_n_coco", + type=str, + choices=list(MODEL_NAME_TO_HUB_REPO.keys()), + help="Model name to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + help="Path to the output directory.", + ) + parser.add_argument("--push_to_hub", action="store_true", help="Whether to push to the hub.") + parser.add_argument("--repo_id", type=str, default=None, help="Hub repo_id to push to.") + args = parser.parse_args() + convert_deimv2_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.repo_id) diff --git a/src/transformers/models/deimv2/modeling_deimv2.py b/src/transformers/models/deimv2/modeling_deimv2.py new file mode 100644 index 000000000000..fe0f002890c5 --- /dev/null +++ b/src/transformers/models/deimv2/modeling_deimv2.py @@ -0,0 +1,2178 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/deimv2/modular_deimv2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_deimv2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Callable +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from ... import initialization as init +from ...activations import ACT2CLS +from ...backbone_utils import load_backbone +from ...image_transforms import center_to_corners_format, corners_to_center_format +from ...integrations import use_kernel_forward_from_hub +from ...modeling_outputs import ModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...pytorch_utils import compile_compatible_method_lru_cache +from ...utils import TransformersKwargs, auto_docstring, torch_compilable_check, torch_int +from ...utils.generic import can_return_tuple, merge_with_config_defaults +from ...utils.output_capturing import OutputRecorder, capture_outputs +from .configuration_deimv2 import Deimv2Config + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the Deimv2Decoder. This class adds two attributes to + BaseModelOutputWithCrossAttentions, namely: + - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer) + - a stacked tensor of intermediate reference points. + """ +) +class Deimv2DecoderOutput(ModelOutput): + r""" + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate predicted corners (predicted corners of each layer of the decoder). + initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked initial reference points (initial reference points of each layer of the decoder). + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + """ + + last_hidden_state: torch.FloatTensor | None = None + intermediate_hidden_states: torch.FloatTensor | None = None + intermediate_logits: torch.FloatTensor | None = None + intermediate_reference_points: torch.FloatTensor | None = None + intermediate_predicted_corners: torch.FloatTensor | None = None + initial_reference_points: torch.FloatTensor | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + cross_attentions: tuple[torch.FloatTensor] | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for outputs of the RT-DETR encoder-decoder model. + """ +) +class Deimv2ModelOutput(ModelOutput): + r""" + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate predicted corners (predicted corners of each layer of the decoder). + initial_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points used for the first decoder layer. + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the encoder stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`): + Logits of predicted bounding boxes coordinates in the encoder stage. + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + denoising_meta_values (`dict`): + Extra dictionary for the denoising related values. + """ + + last_hidden_state: torch.FloatTensor | None = None + intermediate_hidden_states: torch.FloatTensor | None = None + intermediate_logits: torch.FloatTensor | None = None + intermediate_reference_points: torch.FloatTensor | None = None + intermediate_predicted_corners: torch.FloatTensor | None = None + initial_reference_points: torch.FloatTensor | None = None + decoder_hidden_states: tuple[torch.FloatTensor] | None = None + decoder_attentions: tuple[torch.FloatTensor] | None = None + cross_attentions: tuple[torch.FloatTensor] | None = None + encoder_last_hidden_state: torch.FloatTensor | None = None + encoder_hidden_states: tuple[torch.FloatTensor] | None = None + encoder_attentions: tuple[torch.FloatTensor] | None = None + init_reference_points: torch.FloatTensor | None = None + enc_topk_logits: torch.FloatTensor | None = None + enc_topk_bboxes: torch.FloatTensor | None = None + enc_outputs_class: torch.FloatTensor | None = None + enc_outputs_coord_logits: torch.FloatTensor | None = None + denoising_meta_values: dict | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type for DEIMv2 encoder modules (HybridEncoder and LiteEncoder). + Attentions are only available for HybridEncoder variants with AIFI layers. + """ +) +class Deimv2EncoderOutput(ModelOutput): + r""" + feature_maps (`list[torch.FloatTensor]`): + List of multi-scale feature maps from the encoder, one per feature level. + """ + + feature_maps: list[torch.FloatTensor] = None + hidden_states: tuple[torch.FloatTensor, ...] | None = None + attentions: tuple[torch.FloatTensor, ...] | None = None + + +@use_kernel_forward_from_hub("RMSNorm") +class Deimv2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + Deimv2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Deimv2SwiGLUFFN(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + hidden_features = config.decoder_ffn_dim // 2 + self.gate_proj = nn.Linear(config.d_model, hidden_features, bias=True) + self.up_proj = nn.Linear(config.d_model, hidden_features, bias=True) + self.down_proj = nn.Linear(hidden_features, config.d_model, bias=True) + self.act_fn = nn.SiLU() + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Deimv2Gate(nn.Module): + def __init__(self, d_model: int): + super().__init__() + self.gate = nn.Linear(2 * d_model, 2 * d_model) + self.norm = Deimv2RMSNorm(d_model) + + def forward(self, second_residual: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor: + gate_input = torch.cat([second_residual, hidden_states], dim=-1) + gates = torch.sigmoid(self.gate(gate_input)) + gate1, gate2 = gates.chunk(2, dim=-1) + hidden_states = self.norm(gate1 * second_residual + gate2 * hidden_states) + return hidden_states + + +class Deimv2MLP(nn.Module): + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act: str = "relu"): + super().__init__() + self.num_layers = num_layers + hidden_dims = [hidden_dim] * (num_layers - 1) + input_dims = [input_dim] + hidden_dims + output_dims = hidden_dims + [output_dim] + self.layers = nn.ModuleList(nn.Linear(in_dim, out_dim) for in_dim, out_dim in zip(input_dims, output_dims)) + self.act = ACT2CLS[act]() + + def forward(self, stat_features: torch.Tensor) -> torch.Tensor: + for i, layer in enumerate(self.layers): + stat_features = self.act(layer(stat_features)) if i < self.num_layers - 1 else layer(stat_features) + return stat_features + + +def multi_scale_deformable_attention_v2( + value: Tensor, + value_spatial_shapes: Tensor, + sampling_locations: Tensor, + attention_weights: Tensor, + num_points_list: list[int], + method="default", +) -> Tensor: + batch_size, _, num_heads, hidden_dim = value.shape + _, num_queries, num_heads, num_levels, num_points = sampling_locations.shape + value_list = ( + value.permute(0, 2, 3, 1) + .flatten(0, 1) + .split([height * width for height, width in value_spatial_shapes], dim=-1) + ) + # sampling_offsets [8, 480, 8, 12, 2] + if method == "default": + sampling_grids = 2 * sampling_locations - 1 + elif method == "discrete": + sampling_grids = sampling_locations + sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1) + sampling_grids = sampling_grids.split(num_points_list, dim=-2) + sampling_value_list = [] + for level_id, (height, width) in enumerate(value_spatial_shapes): + # batch_size, height*width, num_heads, hidden_dim + # -> batch_size, height*width, num_heads*hidden_dim + # -> batch_size, num_heads*hidden_dim, height*width + # -> batch_size*num_heads, hidden_dim, height, width + value_l_ = value_list[level_id].reshape(batch_size * num_heads, hidden_dim, height, width) + # batch_size, num_queries, num_heads, num_points, 2 + # -> batch_size, num_heads, num_queries, num_points, 2 + # -> batch_size*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[level_id] + # batch_size*num_heads, hidden_dim, num_queries, num_points + if method == "default": + sampling_value_l_ = nn.functional.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + elif method == "discrete": + sampling_coord = (sampling_grid_l_ * torch.tensor([[width, height]], device=value.device) + 0.5).to( + torch.int64 + ) + + # Separate clamping for x and y coordinates + sampling_coord_x = sampling_coord[..., 0].clamp(0, width - 1) + sampling_coord_y = sampling_coord[..., 1].clamp(0, height - 1) + + # Combine the clamped coordinates + sampling_coord = torch.stack([sampling_coord_x, sampling_coord_y], dim=-1) + sampling_coord = sampling_coord.reshape(batch_size * num_heads, num_queries * num_points_list[level_id], 2) + sampling_idx = ( + torch.arange(sampling_coord.shape[0], device=value.device) + .unsqueeze(-1) + .repeat(1, sampling_coord.shape[1]) + ) + sampling_value_l_ = value_l_[sampling_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]] + sampling_value_l_ = sampling_value_l_.permute(0, 2, 1).reshape( + batch_size * num_heads, hidden_dim, num_queries, num_points_list[level_id] + ) + sampling_value_list.append(sampling_value_l_) + # (batch_size, num_queries, num_heads, num_levels, num_points) + # -> (batch_size, num_heads, num_queries, num_levels, num_points) + # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.permute(0, 2, 1, 3).reshape( + batch_size * num_heads, 1, num_queries, sum(num_points_list) + ) + output = ( + (torch.concat(sampling_value_list, dim=-1) * attention_weights) + .sum(-1) + .view(batch_size, num_heads * hidden_dim, num_queries) + ) + return output.transpose(1, 2).contiguous() + + +class Deimv2MultiscaleDeformableAttention(nn.Module): + def __init__(self, config: Deimv2Config): + """ + D-Fine version of multiscale deformable attention + """ + super().__init__() + self.d_model = config.d_model + self.n_heads = config.decoder_attention_heads + self.n_levels = config.num_feature_levels + self.offset_scale = config.decoder_offset_scale + self.decoder_method = config.decoder_method + self.n_points = config.decoder_n_points + + if isinstance(self.n_points, list): + num_points_list = self.n_points + else: + num_points_list = [self.n_points for _ in range(self.n_levels)] + + self.num_points_list = num_points_list + num_points_scale = [1 / n for n in self.num_points_list for _ in range(n)] + self.register_buffer("num_points_scale", torch.tensor(num_points_scale, dtype=torch.float32)) + + self.total_points = self.n_heads * sum(self.num_points_list) + + self.sampling_offsets = nn.Linear(self.d_model, self.total_points * 2) + self.attention_weights = nn.Linear(self.d_model, self.total_points) + + self.ms_deformable_attn_core = multi_scale_deformable_attention_v2 + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + reference_points=None, + encoder_hidden_states=None, + spatial_shapes=None, + spatial_shapes_list=None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, num_queries, _ = hidden_states.shape + batch_size, sequence_length, _ = encoder_hidden_states.shape + + torch_compilable_check( + (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == sequence_length, + "Make sure to align the spatial shapes with the sequence length of the encoder hidden states", + ) + + # Reshape for multi-head attention + value = encoder_hidden_states.reshape(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) + if attention_mask is not None: + value = value.masked_fill(~attention_mask[..., None], float(0)) + + sampling_offsets: torch.Tensor = self.sampling_offsets(hidden_states) + sampling_offsets = sampling_offsets.reshape( + batch_size, num_queries, self.n_heads, sum(self.num_points_list), 2 + ) + + attention_weights = self.attention_weights(hidden_states).reshape( + batch_size, num_queries, self.n_heads, sum(self.num_points_list) + ) + attention_weights = F.softmax(attention_weights, dim=-1) + + if reference_points.shape[-1] == 2: + offset_normalizer = torch.tensor(spatial_shapes) + offset_normalizer = offset_normalizer.flip([1]).reshape(1, 1, 1, self.n_levels, 1, 2) + sampling_locations = ( + reference_points.reshape(batch_size, sequence_length, 1, self.n_levels, 1, 2) + + sampling_offsets / offset_normalizer + ) + elif reference_points.shape[-1] == 4: + # reference_points [8, 480, None, 1, 4] + # sampling_offsets [8, 480, 8, 12, 2] + num_points_scale = self.num_points_scale.to(dtype=hidden_states.dtype).unsqueeze(-1) + offset = sampling_offsets * num_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale + sampling_locations = reference_points[:, :, None, :, :2] + offset + else: + raise ValueError( + f"Last dim of reference_points must be 2 or 4, but get {reference_points.shape[-1]} instead." + ) + + output = self.ms_deformable_attn_core( + value, + spatial_shapes_list, + sampling_locations, + attention_weights, + self.num_points_list, + self.decoder_method, + ) + + return output, attention_weights + + +class Deimv2ConvNormLayer(nn.Module): + def __init__( + self, + config: Deimv2Config, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + groups: int = 1, + padding: int | None = None, + activation: str | None = None, + ): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + groups=groups, + padding=(kernel_size - 1) // 2 if padding is None else padding, + bias=False, + ) + self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps) + self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, hidden_state): + hidden_state = self.conv(hidden_state) + hidden_state = self.norm(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class Deimv2RepVggBlock(nn.Module): + """ + RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again". + """ + + def __init__(self, config: Deimv2Config, in_channels: int, out_channels: int): + super().__init__() + + activation = config.activation_function + hidden_channels = in_channels + self.conv1 = Deimv2ConvNormLayer(config, hidden_channels, out_channels, 3, 1, padding=1) + self.conv2 = Deimv2ConvNormLayer(config, hidden_channels, out_channels, 1, 1, padding=0) + self.activation = nn.Identity() if activation is None else ACT2CLS[activation]() + + def forward(self, x): + y = self.conv1(x) + self.conv2(x) + return self.activation(y) + + +class Deimv2CSPRepLayer(nn.Module): + """ + Cross Stage Partial (CSP) network layer with RepVGG blocks. + Differs from DFineCSPRepLayer: uses a single conv that splits into residual + processing path + (instead of two separate convs), and has an optional trailing conv controlled by `encoder_has_trailing_conv`. + """ + + def __init__( + self, config: Deimv2Config, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0 + ): + super().__init__() + activation = config.activation_function + hidden_channels = int(out_channels * expansion) + self.conv1 = Deimv2ConvNormLayer(config, in_channels, hidden_channels * 2, 1, 1, activation=activation) + self.bottlenecks = nn.ModuleList( + [Deimv2RepVggBlock(config, hidden_channels, hidden_channels) for _ in range(num_blocks)] + ) + self.conv2 = ( + Deimv2ConvNormLayer(config, hidden_channels, out_channels, 3, 1, activation=activation) + if config.encoder_has_trailing_conv + else nn.Identity() + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual, hidden_states = self.conv1(hidden_states).chunk(2, dim=1) + for bottleneck in self.bottlenecks: + hidden_states = bottleneck(hidden_states) + return self.conv2(residual + hidden_states) + + +class Deimv2RepNCSPELAN5(nn.Module): + """ + Rep(VGG) N(etwork) CSP (Cross Stage Partial) ELAN (Efficient Layer Aggregation Network) block. + Similar to DFineRepNCSPELAN4 but without intermediate convolutions between CSP branches, + resulting in a simpler 4-way concatenation (2 split halves + 2 CSP branches) instead of D-FINE's + 4-branch design with interleaved convolutions. + """ + + def __init__(self, config: Deimv2Config, numb_blocks: int = 3): + super().__init__() + activation = config.activation_function + in_channels = config.encoder_hidden_dim + out_channels = config.encoder_hidden_dim + split_channels = config.encoder_hidden_dim * 2 + csp_channels = round(config.hidden_expansion * config.encoder_hidden_dim // 2) + self.conv1 = Deimv2ConvNormLayer(config, in_channels, split_channels, 1, 1, activation=activation) + self.csp_rep1 = Deimv2CSPRepLayer(config, split_channels // 2, csp_channels, num_blocks=numb_blocks) + self.csp_rep2 = Deimv2CSPRepLayer(config, csp_channels, csp_channels, num_blocks=numb_blocks) + self.conv2 = Deimv2ConvNormLayer( + config, split_channels + (2 * csp_channels), out_channels, 1, 1, activation=activation + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states_1, hidden_states_2 = self.conv1(hidden_states).chunk(2, dim=1) + hidden_states_3 = self.csp_rep1(hidden_states_2) + hidden_states_4 = self.csp_rep2(hidden_states_3) + merged_hidden_states = torch.cat([hidden_states_1, hidden_states_2, hidden_states_3, hidden_states_4], dim=1) + return self.conv2(merged_hidden_states) + + +class Deimv2SCDown(nn.Module): + def __init__(self, config: Deimv2Config, kernel_size: int, stride: int): + super().__init__() + self.conv1 = Deimv2ConvNormLayer(config, config.encoder_hidden_dim, config.encoder_hidden_dim, 1, 1) + self.conv2 = Deimv2ConvNormLayer( + config, + config.encoder_hidden_dim, + config.encoder_hidden_dim, + kernel_size, + stride, + config.encoder_hidden_dim, + ) + + def forward(self, input_features: torch.Tensor) -> torch.Tensor: + input_features = self.conv1(input_features) + input_features = self.conv2(input_features) + return input_features + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float | None = None, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Deimv2SelfAttention(nn.Module): + """ + Multi-headed self-attention from 'Attention Is All You Need' paper. + + In DEIMV2, position embeddings are added to both queries and keys (but not values) in self-attention. + """ + + def __init__( + self, + config: Deimv2Config, + hidden_size: int, + num_attention_heads: int, + dropout: float = 0.0, + bias: bool = True, + ): + super().__init__() + self.config = config + self.head_dim = hidden_size // num_attention_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = dropout + self.is_causal = False + + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Position embeddings are added to both queries and keys (but not values). + """ + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states + + query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Deimv2EncoderLayer(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + self.normalize_before = config.normalize_before + self.hidden_size = config.encoder_hidden_dim + + # self-attention + self.self_attn = Deimv2SelfAttention( + config=config, + hidden_size=self.hidden_size, + num_attention_heads=config.num_attention_heads, + dropout=config.dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps) + self.dropout = config.dropout + self.mlp = Deimv2MLP( + self.hidden_size, config.encoder_ffn_dim, self.hidden_size, 2, config.encoder_activation_function + ) + self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + spatial_position_embeddings: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + spatial_position_embeddings (`torch.FloatTensor`, *optional*): + Spatial position embeddings (2D positional encodings of image locations), to be added to both + the queries and keys in self-attention (but not to values). + """ + residual = hidden_states + if self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=spatial_position_embeddings, + **kwargs, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + if self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + residual = hidden_states + + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + + if self.training: + if not torch.isfinite(hidden_states).all(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + return hidden_states + + +class Deimv2SinePositionEmbedding(nn.Module): + """ + 2D sinusoidal position embedding used in RT-DETR hybrid encoder. + """ + + def __init__(self, embed_dim: int = 256, temperature: int = 10000): + super().__init__() + self.embed_dim = embed_dim + self.temperature = temperature + + @compile_compatible_method_lru_cache(maxsize=32) + def forward( + self, + width: int, + height: int, + device: torch.device | str, + dtype: torch.dtype, + ) -> torch.Tensor: + """ + Generate 2D sinusoidal position embeddings. + + Returns: + Position embeddings of shape (1, height*width, embed_dim) + """ + grid_w = torch.arange(torch_int(width), device=device).to(dtype) + grid_h = torch.arange(torch_int(height), device=device).to(dtype) + grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="xy") + if self.embed_dim % 4 != 0: + raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding") + pos_dim = self.embed_dim // 4 + omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim + omega = 1.0 / (self.temperature**omega) + + out_w = grid_w.flatten()[..., None] @ omega[None] + out_h = grid_h.flatten()[..., None] @ omega[None] + + return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :] + + +class Deimv2AIFILayer(nn.Module): + """ + AIFI (Attention-based Intra-scale Feature Interaction) layer used in RT-DETR hybrid encoder. + """ + + def __init__(self, config: Deimv2Config): + super().__init__() + self.config = config + self.encoder_hidden_dim = config.encoder_hidden_dim + self.eval_size = config.eval_size + + self.position_embedding = Deimv2SinePositionEmbedding( + embed_dim=self.encoder_hidden_dim, + temperature=config.positional_encoding_temperature, + ) + self.layers = nn.ModuleList([Deimv2EncoderLayer(config) for _ in range(config.encoder_layers)]) + + def forward( + self, + hidden_states: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`): + Feature map to process. + """ + batch_size = hidden_states.shape[0] + height, width = hidden_states.shape[2:] + + hidden_states = hidden_states.flatten(2).permute(0, 2, 1) + + if self.training or self.eval_size is None: + pos_embed = self.position_embedding( + width=width, + height=height, + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + else: + pos_embed = None + + for layer in self.layers: + hidden_states = layer( + hidden_states, + attention_mask=None, + spatial_position_embeddings=pos_embed, + **kwargs, + ) + + hidden_states = ( + hidden_states.permute(0, 2, 1).reshape(batch_size, self.encoder_hidden_dim, height, width).contiguous() + ) + + return hidden_states + + +class Deimv2SpatialTuningAdapter(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + inplanes = config.spatial_tuning_adapter_inplanes + self.stem_conv = Deimv2ConvNormLayer(config, 3, inplanes, 3, 2, activation="gelu") + self.stem_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.conv2 = Deimv2ConvNormLayer(config, inplanes, 2 * inplanes, 3, 2) + self.conv3 = Deimv2ConvNormLayer(config, 2 * inplanes, 4 * inplanes, 3, 2) + self.conv4 = Deimv2ConvNormLayer(config, 4 * inplanes, 4 * inplanes, 3, 2) + self.act_fn = nn.GELU() + + def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states_1 = self.stem_pool(self.stem_conv(pixel_values)) + hidden_states_2 = self.conv2(hidden_states_1) + hidden_states_3 = self.conv3(self.act_fn(hidden_states_2)) + hidden_states_4 = self.conv4(self.act_fn(hidden_states_3)) + return hidden_states_2, hidden_states_3, hidden_states_4 + + +class Deimv2FrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than + torchvision.models.resnet[18,34,50,101] produce nans. + """ + + def __init__(self, n): + super().__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x): + # move reshapes to the beginning + # to make it user-friendly + weight = self.weight.reshape(1, -1, 1, 1) + bias = self.bias.reshape(1, -1, 1, 1) + running_var = self.running_var.reshape(1, -1, 1, 1) + running_mean = self.running_mean.reshape(1, -1, 1, 1) + epsilon = 1e-5 + scale = weight * (running_var + epsilon).rsqrt() + bias = bias - running_mean * scale + return x * scale + bias + + +def replace_batch_norm(model): + r""" + Recursively replace all `torch.nn.BatchNorm2d` with `Deimv2FrozenBatchNorm2d`. + + Args: + model (torch.nn.Module): + input model + """ + for name, module in model.named_children(): + if isinstance(module, nn.BatchNorm2d): + new_module = Deimv2FrozenBatchNorm2d(module.num_features) + + if module.weight.device != torch.device("meta"): + new_module.weight.copy_(module.weight) + new_module.bias.copy_(module.bias) + new_module.running_mean.copy_(module.running_mean) + new_module.running_var.copy_(module.running_var) + + model._modules[name] = new_module + + if len(list(module.children())) > 0: + replace_batch_norm(module) + + +class Deimv2ConvEncoder(nn.Module): + """ + Convolutional backbone using the modeling_deimv2_resnet.py. + + nn.BatchNorm2d layers are replaced by Deimv2FrozenBatchNorm2d as defined above. + https://github.com/lyuwenyu/RT-DETR/blob/main/Deimv2_pytorch/src/nn/backbone/presnet.py#L142 + """ + + def __init__(self, config): + super().__init__() + + backbone = load_backbone(config) + + if config.freeze_backbone_batch_norms: + # replace batch norm by frozen batch norm + with torch.no_grad(): + replace_batch_norm(backbone) + self.model = backbone + self.intermediate_channel_sizes = self.model.channels + self.encoder_input_proj = nn.ModuleList( + [ + Deimv2ConvNormLayer(config, in_channel, config.encoder_hidden_dim, 1, 1) + if config.encoder_type != "lite" + else nn.Identity() + for in_channel in self.intermediate_channel_sizes + ] + ) + + def forward(self, pixel_values: torch.Tensor, **kwargs: Unpack[TransformersKwargs]) -> list[torch.Tensor]: + features = self.model(pixel_values, **kwargs).feature_maps + return [proj(feat) for proj, feat in zip(self.encoder_input_proj, features)] + + +class Deimv2DINOv3ConvEncoder(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + self.backbone = load_backbone(config) + + self.spatial_tuning_adapter = Deimv2SpatialTuningAdapter(config) + + embed_dim = config.backbone_config.hidden_size + hidden_dim = config.encoder_hidden_dim + spatial_tuning_adapter_channels = config.spatial_tuning_adapter_inplanes + self.fusion_proj = nn.ModuleList( + [ + Deimv2ConvNormLayer(config, embed_dim + spatial_tuning_adapter_channels * 2, hidden_dim, 1, 1), + Deimv2ConvNormLayer(config, embed_dim + spatial_tuning_adapter_channels * 4, hidden_dim, 1, 1), + Deimv2ConvNormLayer(config, embed_dim + spatial_tuning_adapter_channels * 4, hidden_dim, 1, 1), + ] + ) + + def forward(self, pixel_values: torch.Tensor, **kwargs: Unpack[TransformersKwargs]) -> list[torch.Tensor]: + backbone_output = self.backbone(pixel_values, **kwargs) + feature_maps = backbone_output.feature_maps + + patch_size = self.backbone.config.patch_size + height_patches = pixel_values.shape[2] // patch_size + width_patches = pixel_values.shape[3] // patch_size + + semantic_features = [] + num_scales = len(feature_maps) + for i, feat in enumerate(feature_maps): + resize_height = int(height_patches * 2 ** (num_scales - 2 - i)) + resize_width = int(width_patches * 2 ** (num_scales - 2 - i)) + spatial = F.interpolate(feat, size=[resize_height, resize_width], mode="bilinear", align_corners=False) + semantic_features.append(spatial) + + detail_features = self.spatial_tuning_adapter(pixel_values) + + outputs = [] + for i, (semantic_feature, detail_feature) in enumerate(zip(semantic_features, detail_features)): + fused = torch.cat([semantic_feature, detail_feature], dim=1) + outputs.append(self.fusion_proj[i](fused)) + + return outputs + + +class Deimv2Integral(nn.Module): + """ + A static layer that calculates integral results from a distribution. + + This layer computes the target location using the formula: `sum{Pr(n) * W(n)}`, + where Pr(n) is the softmax probability vector representing the discrete + distribution, and W(n) is the non-uniform Weighting Function. + + Args: + max_num_bins (int): Max number of the discrete bins. Default is 32. + It can be adjusted based on the dataset or task requirements. + """ + + def __init__(self, config: Deimv2Config): + super().__init__() + self.max_num_bins = config.max_num_bins + + def forward(self, pred_corners: torch.Tensor, project: torch.Tensor) -> torch.Tensor: + batch_size, num_queries, _ = pred_corners.shape + pred_corners = F.softmax(pred_corners.reshape(-1, self.max_num_bins + 1), dim=1) + pred_corners = F.linear(pred_corners, project.to(pred_corners.device)).reshape(-1, 4) + pred_corners = pred_corners.reshape(batch_size, num_queries, -1) + return pred_corners + + +class Deimv2LQE(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + self.top_prob_values = config.top_prob_values + self.max_num_bins = config.max_num_bins + self.reg_conf = Deimv2MLP(4 * (self.top_prob_values + 1), config.lqe_hidden_dim, 1, config.lqe_layers) + + def forward(self, scores: torch.Tensor, pred_corners: torch.Tensor) -> torch.Tensor: + batch_size, length, _ = pred_corners.size() + prob = F.softmax(pred_corners.reshape(batch_size, length, 4, self.max_num_bins + 1), dim=-1) + prob_topk, _ = prob.topk(self.top_prob_values, dim=-1) + stat = torch.cat([prob_topk, prob_topk.mean(dim=-1, keepdim=True)], dim=-1) + quality_score = self.reg_conf(stat.reshape(batch_size, length, -1)) + scores = scores + quality_score + return scores + + +class Deimv2DecoderLayer(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + self.hidden_size = config.d_model + + # self-attention + self.self_attn = Deimv2SelfAttention( + config=config, + hidden_size=self.hidden_size, + num_attention_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + ) + self.dropout = config.dropout + self.self_attn_layer_norm = Deimv2RMSNorm(config.d_model) + self.encoder_attn = Deimv2MultiscaleDeformableAttention(config=config) + self.mlp = Deimv2SwiGLUFFN(config) + self.final_layer_norm = Deimv2RMSNorm(config.d_model) + self.gateway = Deimv2Gate(config.d_model) if config.use_gateway else None + self.use_gateway = config.use_gateway + self.encoder_attn_layer_norm = None if config.use_gateway else Deimv2RMSNorm(config.d_model) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor | None = None, + reference_points: torch.Tensor | None = None, + spatial_shapes: torch.Tensor | None = None, + spatial_shapes_list: list[tuple[int, int]] | None = None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, hidden_size)`. + object_queries_position_embeddings (`torch.FloatTensor`, *optional*): + Position embeddings for the object query slots. These are added to both queries and keys + in the self-attention layer (not values). + reference_points (`torch.FloatTensor`, *optional*): + Reference points. + spatial_shapes (`torch.LongTensor`, *optional*): + Spatial shapes. + level_start_index (`torch.LongTensor`, *optional*): + Level start index. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, hidden_size)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + """ + residual = hidden_states + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=encoder_attention_mask, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + + # Cross-Attention + hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings + hidden_states, _ = self.encoder_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gateway is not None: + hidden_states = self.gateway(residual, hidden_states) + else: + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + return hidden_states + + +@auto_docstring +class Deimv2PreTrainedModel(PreTrainedModel): + config: Deimv2Config + base_model_prefix = "deimv2" + main_input_name = "pixel_values" + input_modalities = ("image",) + _no_split_modules = [r"Deimv2HybridEncoder", r"Deimv2LiteEncoder", r"Deimv2DecoderLayer"] + _supports_sdpa = True + _supports_flash_attn = True + _supports_attention_backend = True + _supports_flex_attn = True + + @torch.no_grad() + def _init_weights(self, module): + """Initialize the weights""" + super()._init_weights(module) + # initialize linear layer bias value according to a given probability value. + if isinstance(module, (Deimv2ForObjectDetection, Deimv2Decoder)): + if module.class_embed is not None: + for layer in module.class_embed: + prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) + bias = float(-math.log((1 - prior_prob) / prior_prob)) + init.xavier_uniform_(layer.weight) + init.constant_(layer.bias, bias) + + if module.bbox_embed is not None: + for layer in module.bbox_embed: + init.constant_(layer.layers[-1].weight, 0) + init.constant_(layer.layers[-1].bias, 0) + + if hasattr(module, "reg_scale"): + init.constant_(module.reg_scale, self.config.reg_scale) + + if hasattr(module, "up"): + init.constant_(module.up, self.config.up) + + if isinstance(module, Deimv2MultiscaleDeformableAttention): + init.constant_(module.sampling_offsets.weight, 0.0) + default_dtype = torch.get_default_dtype() + thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( + 2.0 * math.pi / module.n_heads + ) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values + grid_init = grid_init.reshape(module.n_heads, 1, 2).tile([1, sum(module.num_points_list), 1]) + scaling = torch.concat([torch.arange(1, n + 1) for n in module.num_points_list]).reshape(1, -1, 1) + grid_init *= scaling + init.copy_(module.sampling_offsets.bias, grid_init.flatten()) + + init.constant_(module.attention_weights.weight, 0.0) + init.constant_(module.attention_weights.bias, 0.0) + + num_points_scale = [1 / n for n in module.num_points_list for _ in range(n)] + init.copy_(module.num_points_scale, torch.tensor(num_points_scale, dtype=torch.float32)) + + if isinstance(module, Deimv2Model): + prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) + bias = float(-math.log((1 - prior_prob) / prior_prob)) + init.xavier_uniform_(module.enc_score_head.weight) + init.constant_(module.enc_score_head.bias, bias) + + if isinstance(module, Deimv2Gate): + bias = float(-math.log((1 - 0.5) / 0.5)) + init.constant_(module.gate.bias, bias) + init.constant_(module.gate.weight, 0) + + if isinstance(module, Deimv2LQE): + init.constant_(module.reg_conf.layers[-1].bias, 0) + init.constant_(module.reg_conf.layers[-1].weight, 0) + + if hasattr(module, "weight_embedding") and self.config.learn_initial_query: + init.xavier_uniform_(module.weight_embedding.weight) + if hasattr(module, "denoising_class_embed") and self.config.num_denoising > 0: + init.xavier_uniform_(module.denoising_class_embed.weight) + + if isinstance(module, Deimv2SwiGLUFFN): + init.xavier_uniform_(module.gate_proj.weight) + init.constant_(module.gate_proj.bias, 0) + init.xavier_uniform_(module.up_proj.weight) + init.constant_(module.up_proj.bias, 0) + init.xavier_uniform_(module.down_proj.weight) + init.constant_(module.down_proj.bias, 0) + + +class Deimv2LiteEncoder(Deimv2PreTrainedModel): + # LiteEncoder has no transformer layers, so hidden_states are recorded from the conv projections. + _can_record_outputs = { + "hidden_states": [ + OutputRecorder(Deimv2ConvNormLayer, layer_name="input_proj"), + OutputRecorder(Deimv2ConvNormLayer, layer_name="bi_fusion_conv"), + ], + } + + def __init__(self, config: Deimv2Config): + super().__init__(config) + hidden_dim = config.encoder_hidden_dim + activation = config.activation_function + + self.input_proj = nn.ModuleList( + [Deimv2ConvNormLayer(config, in_channel, hidden_dim, 1, 1) for in_channel in config.encoder_in_channels] + ) + + self.down_pool1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + self.down_conv1 = Deimv2ConvNormLayer(config, hidden_dim, hidden_dim, 1, 1, activation=activation) + self.down_pool2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + self.down_conv2 = Deimv2ConvNormLayer(config, hidden_dim, hidden_dim, 1, 1, activation=activation) + + self.bi_fusion_conv = Deimv2ConvNormLayer(config, hidden_dim, hidden_dim, 1, 1, activation=activation) + + num_blocks = round(3 * config.depth_mult) + self.fpn_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks) + self.pan_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks) + + self.post_init() + + @merge_with_config_defaults + @capture_outputs + def forward(self, inputs_embeds: list[torch.Tensor], **kwargs: Unpack[TransformersKwargs]) -> Deimv2EncoderOutput: + projected_features = [self.input_proj[i](feature) for i, feature in enumerate(inputs_embeds)] + projected_features.append(self.down_conv1(self.down_pool1(projected_features[-1]))) + + projected_features[-1] = self.bi_fusion_conv( + projected_features[-1] + F.adaptive_avg_pool2d(projected_features[-1], 1) + ) + + outputs = [] + fused_feature = projected_features[0] + F.interpolate(projected_features[1], scale_factor=2.0, mode="nearest") + outputs.append(self.fpn_block(fused_feature)) + + fused_feature = projected_features[1] + self.down_conv2(self.down_pool2(outputs[-1])) + outputs.append(self.pan_block(fused_feature)) + + return Deimv2EncoderOutput(feature_maps=outputs) + + +def fuse_feature_maps(feature_map_1: torch.Tensor, feature_map_2: torch.Tensor, fuse_op: str = "sum") -> torch.Tensor: + """Fuses two feature maps via element-wise sum or channel-wise concatenation.""" + if fuse_op == "sum": + return feature_map_1 + feature_map_2 + return torch.cat([feature_map_1, feature_map_2], dim=1) + + +class Deimv2HybridEncoder(Deimv2PreTrainedModel): + """ + DEIMv2 variant of DFineHybridEncoder. Uses element-wise sum fusion (`fuse_feature_maps`) instead of + D-FINE's channel concatenation, Deimv2RepNCSPELAN5 (simplified 4-way concat) instead of DFineRepNCSPELAN4, + and returns Deimv2EncoderOutput with feature_maps instead of BaseModelOutput with last_hidden_state. + """ + + _can_record_outputs = { + "hidden_states": Deimv2AIFILayer, + "attentions": Deimv2SelfAttention, + } + + def __init__(self, config: Deimv2Config): + super().__init__(config) + self.config = config + self.in_channels = config.encoder_in_channels + self.num_fpn_stages = len(self.in_channels) - 1 + self.feat_strides = config.feat_strides + self.encoder_hidden_dim = config.encoder_hidden_dim + self.encode_proj_layers = config.encode_proj_layers + self.positional_encoding_temperature = config.positional_encoding_temperature + self.eval_size = config.eval_size + self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels] + self.out_strides = self.feat_strides + self.fuse_op = config.encoder_fuse_op + + self.aifi = nn.ModuleList([Deimv2AIFILayer(config) for _ in range(len(self.encode_proj_layers))]) + + self.lateral_convs = nn.ModuleList() + self.fpn_blocks = nn.ModuleList() + for _ in range(len(self.in_channels) - 1, 0, -1): + self.lateral_convs.append( + Deimv2ConvNormLayer(config, self.encoder_hidden_dim, self.encoder_hidden_dim, 1, 1) + ) + num_blocks = round(3 * config.depth_mult) + self.fpn_blocks.append(Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks)) + + self.downsample_convs = nn.ModuleList() + self.pan_blocks = nn.ModuleList() + for _ in range(len(self.in_channels) - 1): + self.downsample_convs.append(Deimv2SCDown(config, 3, 2)) + num_blocks = round(3 * config.depth_mult) + self.pan_blocks.append(Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks)) + + self.post_init() + + @merge_with_config_defaults + @capture_outputs(tie_last_hidden_states=False) + def forward( + self, + inputs_embeds: list[torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Deimv2EncoderOutput: + r""" + Args: + inputs_embeds (`list[torch.FloatTensor]`): + Multi-scale feature maps from the backbone (one tensor per feature level) passed to the encoder. + """ + feature_maps = inputs_embeds + + if self.config.encoder_layers > 0: + for i, enc_ind in enumerate(self.encode_proj_layers): + feature_maps[enc_ind] = self.aifi[i](feature_maps[enc_ind], **kwargs) + + # top-down FPN + fpn_feature_maps = [feature_maps[-1]] + for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)): + backbone_feature_map = feature_maps[self.num_fpn_stages - idx - 1] + top_fpn_feature_map = fpn_feature_maps[-1] + top_fpn_feature_map = lateral_conv(top_fpn_feature_map) + fpn_feature_maps[-1] = top_fpn_feature_map + top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest") + fused_feature_map = fuse_feature_maps(top_fpn_feature_map, backbone_feature_map, self.fuse_op) + new_fpn_feature_map = fpn_block(fused_feature_map) + fpn_feature_maps.append(new_fpn_feature_map) + + fpn_feature_maps.reverse() + + # bottom-up PAN + pan_feature_maps = [fpn_feature_maps[0]] + for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)): + top_pan_feature_map = pan_feature_maps[-1] + fpn_feature_map = fpn_feature_maps[idx + 1] + downsampled_feature_map = downsample_conv(top_pan_feature_map) + fused_feature_map = fuse_feature_maps(downsampled_feature_map, fpn_feature_map, self.fuse_op) + new_pan_feature_map = pan_block(fused_feature_map) + pan_feature_maps.append(new_pan_feature_map) + + return Deimv2EncoderOutput(feature_maps=pan_feature_maps) + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +def weighting_function(max_num_bins: int, up: torch.Tensor, reg_scale: int) -> torch.Tensor: + """ + Generates the non-uniform Weighting Function W(n) for bounding box regression. + + Args: + max_num_bins (int): Max number of the discrete bins. + up (Tensor): Controls upper bounds of the sequence, + where maximum offset is ±up * H / W. + reg_scale (float): Controls the curvature of the Weighting Function. + Larger values result in flatter weights near the central axis W(max_num_bins/2)=0 + and steeper weights at both ends. + Returns: + Tensor: Sequence of Weighting Function. + """ + upper_bound1 = abs(up[0]) * abs(reg_scale) + upper_bound2 = abs(up[0]) * abs(reg_scale) * 2 + step = (upper_bound1 + 1) ** (2 / (max_num_bins - 2)) + left_values = [-((step) ** i) + 1 for i in range(max_num_bins // 2 - 1, 0, -1)] + right_values = [(step) ** i - 1 for i in range(1, max_num_bins // 2)] + values = [-upper_bound2] + left_values + [torch.zeros_like(up[0][None])] + right_values + [upper_bound2] + values = torch.cat(values, 0) + return values + + +def distance2bbox(points, distance: torch.Tensor, reg_scale: float) -> torch.Tensor: + """ + Decodes edge-distances into bounding box coordinates. + + Args: + points (`torch.Tensor`): + (batch_size, num_boxes, 4) or (num_boxes, 4) format, representing [x_center, y_center, width, height] + distance (`torch.Tensor`): + (batch_size, num_boxes, 4) or (num_boxes, 4), representing distances from the point to the left, top, right, and bottom boundaries. + reg_scale (`float`): + Controls the curvature of the Weighting Function. + Returns: + `torch.Tensor`: Bounding boxes in (batch_size, num_boxes, 4) or (num_boxes, 4) format, representing [x_center, y_center, width, height] + """ + reg_scale = abs(reg_scale) + top_left_x = points[..., 0] - (0.5 * reg_scale + distance[..., 0]) * (points[..., 2] / reg_scale) + top_left_y = points[..., 1] - (0.5 * reg_scale + distance[..., 1]) * (points[..., 3] / reg_scale) + bottom_right_x = points[..., 0] + (0.5 * reg_scale + distance[..., 2]) * (points[..., 2] / reg_scale) + bottom_right_y = points[..., 1] + (0.5 * reg_scale + distance[..., 3]) * (points[..., 3] / reg_scale) + + bboxes = torch.stack([top_left_x, top_left_y, bottom_right_x, bottom_right_y], -1) + + return corners_to_center_format(bboxes) + + +class Deimv2Decoder(Deimv2PreTrainedModel): + """ + D-FINE Decoder implementing Fine-grained Distribution Refinement (FDR). + + This decoder refines object detection predictions through iterative updates across multiple layers, + utilizing attention mechanisms, location quality estimators, and distribution refinement techniques + to improve bounding box accuracy and robustness. + """ + + _can_record_outputs = { + "hidden_states": Deimv2DecoderLayer, + "attentions": Deimv2SelfAttention, + "cross_attentions": Deimv2MultiscaleDeformableAttention, + } + + def __init__(self, config: Deimv2Config): + super().__init__(config) + self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx + + self.dropout = config.dropout + self.layers = nn.ModuleList( + [Deimv2DecoderLayer(config) for _ in range(config.decoder_layers)] + + [Deimv2DecoderLayer(config) for _ in range(config.decoder_layers - self.eval_idx - 1)] + ) + self.query_pos_head = Deimv2MLP(4, config.d_model, config.d_model, 3, config.decoder_activation_function) + + # hack implementation for iterative bounding box refinement and two-stage Deformable DETR + self.bbox_embed = None + self.class_embed = None + self.reg_scale = nn.Parameter(torch.tensor([config.reg_scale]), requires_grad=False) + self.max_num_bins = config.max_num_bins + self.d_model = config.d_model + self.layer_scale = config.layer_scale + self.pre_bbox_head = Deimv2MLP(config.hidden_size, config.hidden_size, 4, 3) + self.integral = Deimv2Integral(config) + self.num_head = config.decoder_attention_heads + self.up = nn.Parameter(torch.tensor([config.up]), requires_grad=False) + self.lqe_layers = nn.ModuleList([Deimv2LQE(config) for _ in range(config.decoder_layers)]) + + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @capture_outputs + def forward( + self, + encoder_hidden_states: torch.Tensor, + reference_points: torch.Tensor, + inputs_embeds: torch.Tensor, + spatial_shapes, + level_start_index=None, + spatial_shapes_list=None, + encoder_attention_mask=None, + memory_mask=None, + **kwargs: Unpack[TransformersKwargs], + ) -> Deimv2DecoderOutput: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + The query embeddings that are passed into the decoder. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected + in `[0, 1]`: + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*): + Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area. + spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of the feature maps. + level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*): + Indexes for the start of each feature level. In range `[0, sequence_length]`. + """ + if inputs_embeds is not None: + hidden_states = inputs_embeds + + # decoder layers + intermediate = () + intermediate_reference_points = () + intermediate_logits = () + intermediate_predicted_corners = () + initial_reference_points = () + + output_detach = pred_corners_undetach = 0 + + project = weighting_function(self.max_num_bins, self.up, self.reg_scale) + ref_points_detach = F.sigmoid(reference_points) + + for i, decoder_layer in enumerate(self.layers): + ref_points_input = ref_points_detach.unsqueeze(2) + query_pos_embed = self.query_pos_head(ref_points_detach).clamp(min=-10, max=10) + + hidden_states = decoder_layer( + hidden_states, + position_embeddings=query_pos_embed, + reference_points=ref_points_input, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + **kwargs, + ) + + if i == 0: + # Initial bounding box predictions with inverse sigmoid refinement + new_reference_points = F.sigmoid( + self.pre_bbox_head(hidden_states) + inverse_sigmoid(ref_points_detach) + ) + ref_points_initial = new_reference_points.detach() + + # Refine bounding box corners using FDR, integrating previous layer's corrections + if self.bbox_embed is not None: + pred_corners = self.bbox_embed[i](hidden_states + output_detach) + pred_corners_undetach + inter_ref_bbox = distance2bbox( + ref_points_initial, self.integral(pred_corners, project), self.reg_scale + ) + pred_corners_undetach = pred_corners + ref_points_detach = inter_ref_bbox.detach() + + output_detach = hidden_states.detach() + + intermediate += (hidden_states,) + + if self.class_embed is not None and (self.training or i == self.eval_idx): + scores = self.class_embed[i](hidden_states) + # Add initial logits and reference points with pre-bbox head + if i == 0: + intermediate_logits += (scores,) + intermediate_reference_points += (new_reference_points,) + # Lqe does not affect the performance here. + scores = self.lqe_layers[i](scores, pred_corners) + intermediate_logits += (scores,) + intermediate_reference_points += (inter_ref_bbox,) + initial_reference_points += (ref_points_initial,) + intermediate_predicted_corners += (pred_corners,) + + # Keep batch_size as first dimension + intermediate = torch.stack(intermediate) + if self.class_embed is not None and self.bbox_embed is not None: + intermediate_logits = torch.stack(intermediate_logits, dim=1) + intermediate_predicted_corners = torch.stack(intermediate_predicted_corners, dim=1) + initial_reference_points = torch.stack(initial_reference_points, dim=1) + intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1) + + return Deimv2DecoderOutput( + last_hidden_state=hidden_states, + intermediate_hidden_states=intermediate, + intermediate_logits=intermediate_logits, + intermediate_reference_points=intermediate_reference_points, + intermediate_predicted_corners=intermediate_predicted_corners, + initial_reference_points=initial_reference_points, + ) + + +def get_contrastive_denoising_training_group( + targets, + num_classes, + num_queries, + class_embed, + num_denoising_queries=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, +): + """ + Creates a contrastive denoising training group using ground-truth samples. It adds noise to labels and boxes. + + Args: + targets (`list[dict]`): + The target objects, each containing 'class_labels' and 'boxes' for objects in an image. + num_classes (`int`): + Total number of classes in the dataset. + num_queries (`int`): + Number of query slots in the transformer. + class_embed (`callable`): + A function or a model layer to embed class labels. + num_denoising_queries (`int`, *optional*, defaults to 100): + Number of denoising queries. + label_noise_ratio (`float`, *optional*, defaults to 0.5): + Ratio of noise applied to labels. + box_noise_scale (`float`, *optional*, defaults to 1.0): + Scale of noise applied to bounding boxes. + Returns: + `tuple` comprising various elements: + - **input_query_class** (`torch.FloatTensor`) -- + Class queries with applied label noise. + - **input_query_bbox** (`torch.FloatTensor`) -- + Bounding box queries with applied box noise. + - **attn_mask** (`torch.FloatTensor`) -- + Attention mask for separating denoising and reconstruction queries. + - **denoising_meta_values** (`dict`) -- + Metadata including denoising positive indices, number of groups, and split sizes. + """ + + if num_denoising_queries <= 0: + return None, None, None, None + + num_ground_truths = [len(t["class_labels"]) for t in targets] + device = targets[0]["class_labels"].device + + max_gt_num = max(num_ground_truths) + if max_gt_num == 0: + return None, None, None, None + + num_groups_denoising_queries = num_denoising_queries // max_gt_num + num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries + # pad gt to max_num of a batch + batch_size = len(num_ground_truths) + + input_query_class = torch.full([batch_size, max_gt_num], num_classes, dtype=torch.int32, device=device) + input_query_bbox = torch.zeros([batch_size, max_gt_num, 4], device=device) + pad_gt_mask = torch.zeros([batch_size, max_gt_num], dtype=torch.bool, device=device) + + for i in range(batch_size): + num_gt = num_ground_truths[i] + if num_gt > 0: + input_query_class[i, :num_gt] = targets[i]["class_labels"] + input_query_bbox[i, :num_gt] = targets[i]["boxes"] + pad_gt_mask[i, :num_gt] = 1 + # each group has positive and negative queries. + input_query_class = input_query_class.tile([1, 2 * num_groups_denoising_queries]) + input_query_bbox = input_query_bbox.tile([1, 2 * num_groups_denoising_queries, 1]) + pad_gt_mask = pad_gt_mask.tile([1, 2 * num_groups_denoising_queries]) + # positive and negative mask + negative_gt_mask = torch.zeros([batch_size, max_gt_num * 2, 1], device=device) + negative_gt_mask[:, max_gt_num:] = 1 + negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1]) + positive_gt_mask = 1 - negative_gt_mask + # contrastive denoising training positive index + positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask + denoise_positive_idx = torch.nonzero(positive_gt_mask)[:, 1] + denoise_positive_idx = torch.split( + denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths] + ) + # total denoising queries + num_denoising_queries = torch_int(max_gt_num * 2 * num_groups_denoising_queries) + + if label_noise_ratio > 0: + mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5) + # randomly put a new one here + new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype) + input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class) + + if box_noise_scale > 0: + known_bbox = center_to_corners_format(input_query_bbox) + diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale + rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0 + rand_part = torch.rand_like(input_query_bbox) + rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask) + rand_part *= rand_sign + known_bbox += rand_part * diff + known_bbox.clip_(min=0.0, max=1.0) + input_query_bbox = corners_to_center_format(known_bbox) + input_query_bbox = inverse_sigmoid(input_query_bbox) + + input_query_class = class_embed(input_query_class) + + target_size = num_denoising_queries + num_queries + attn_mask = torch.full([target_size, target_size], 0, dtype=torch.float, device=device) + # match query cannot see the reconstruction + attn_mask[num_denoising_queries:, :num_denoising_queries] = -torch.inf + + # reconstructions cannot see each other + for i in range(num_groups_denoising_queries): + idx_block_start = max_gt_num * 2 * i + idx_block_end = max_gt_num * 2 * (i + 1) + attn_mask[idx_block_start:idx_block_end, :idx_block_start] = -torch.inf + attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = -torch.inf + + denoising_meta_values = { + "dn_positive_idx": denoise_positive_idx, + "dn_num_group": num_groups_denoising_queries, + "dn_num_split": [num_denoising_queries, num_queries], + } + + return input_query_class, input_query_bbox, attn_mask, denoising_meta_values + + +@auto_docstring( + custom_intro=""" + RT-DETR Model (consisting of a backbone and encoder-decoder) outputting raw hidden states without any head on top. + """ +) +class Deimv2Model(Deimv2PreTrainedModel): + def __init__(self, config: Deimv2Config): + super().__init__(config) + + is_dinov3 = getattr(config.backbone_config, "model_type", None) == "dinov3_vit" + self.conv_encoder = Deimv2DINOv3ConvEncoder(config) if is_dinov3 else Deimv2ConvEncoder(config) + self.encoder = ( + Deimv2LiteEncoder(config) if config.encoder_type == "lite" else Deimv2HybridEncoder(config=config) + ) + + if config.num_denoising > 0: + self.denoising_class_embed = nn.Embedding( + config.num_labels + 1, config.d_model, padding_idx=config.num_labels + ) + + if config.learn_initial_query: + self.weight_embedding = nn.Embedding(config.num_queries, config.d_model) + + self.enc_output = nn.Sequential( + nn.Linear(config.d_model, config.d_model), + nn.LayerNorm(config.d_model, eps=config.layer_norm_eps), + ) + self.enc_score_head = nn.Linear(config.d_model, config.num_labels) + self.enc_bbox_head = Deimv2MLP(config.d_model, config.d_model, 4, 3) + + if config.anchor_image_size: + self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype) + + num_backbone_outs = len(config.decoder_in_channels) + decoder_input_proj = [] + in_channels = config.decoder_in_channels[-1] + for _ in range(num_backbone_outs): + decoder_input_proj.append( + nn.Identity() + if config.hidden_size == config.decoder_in_channels[-1] + else Deimv2ConvNormLayer(config, in_channels, config.d_model, 1, 1) + ) + for _ in range(config.num_feature_levels - num_backbone_outs): + decoder_input_proj.append( + nn.Identity() + if config.hidden_size == config.decoder_in_channels[-1] + else Deimv2ConvNormLayer(config, in_channels, config.d_model, 3, 2) + ) + self.decoder_input_proj = nn.ModuleList(decoder_input_proj) + self.decoder = Deimv2Decoder(config) + + self.post_init() + + def freeze_backbone(self): + for param in self.backbone.parameters(): + param.requires_grad_(False) + + def unfreeze_backbone(self): + for param in self.backbone.parameters(): + param.requires_grad_(True) + + @compile_compatible_method_lru_cache(maxsize=32) + def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32): + if spatial_shapes is None: + spatial_shapes = [ + [int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)] + for s in self.config.feat_strides + ] + anchors = [] + for level, (height, width) in enumerate(spatial_shapes): + grid_y, grid_x = torch.meshgrid( + torch.arange(end=height, device=device).to(dtype), + torch.arange(end=width, device=device).to(dtype), + indexing="ij", + ) + grid_xy = torch.stack([grid_x, grid_y], -1) + grid_xy = grid_xy.unsqueeze(0) + 0.5 + grid_xy[..., 0] /= width + grid_xy[..., 1] /= height + wh = torch.ones_like(grid_xy) * grid_size * (2.0**level) + anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4)) + # define the valid range for anchor coordinates + eps = 1e-2 + anchors = torch.concat(anchors, 1) + valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) + anchors = torch.log(anchors / (1 - anchors)) + anchors = torch.where(valid_mask, anchors, torch.tensor(torch.finfo(dtype).max, dtype=dtype, device=device)) + + return anchors, valid_mask + + @auto_docstring + @can_return_tuple + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: torch.LongTensor | None = None, + encoder_outputs: torch.FloatTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: list[dict] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.FloatTensor] | Deimv2ModelOutput: + r""" + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + labels (`list[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes + in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`. + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, Deimv2Model + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("PekingU/Deimv2_r50vd") + >>> model = Deimv2Model.from_pretrained("PekingU/Deimv2_r50vd") + + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 300, 256] + ```""" + # Overrides DFineModel.forward: DEIMv2 uses a unified conv_encoder (backbone + projection) instead of + # D-FINE's separate backbone + encoder_input_proj, and returns feature_maps instead of last_hidden_state. + if pixel_values is None and inputs_embeds is None: + raise ValueError("You have to specify either pixel_values or inputs_embeds") + + # extract multi-scale features via conv_encoder (backbone + projection in one step) + if inputs_embeds is None: + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), device=device) + + # TODO: pass pixel_mask to backbone once DINOv3 supports it + proj_feats = self.conv_encoder(pixel_values) + else: + batch_size = inputs_embeds.shape[0] + device = inputs_embeds.device + proj_feats = inputs_embeds + + encoder_outputs = self.encoder( + proj_feats, + **kwargs, + ) + + # Equivalent to def _get_encoder_input + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412 + sources = [] + for level, source in enumerate(encoder_outputs.feature_maps): + sources.append(self.decoder_input_proj[level](source)) + + # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage + if self.config.num_feature_levels > len(sources): + sources.append(self.decoder_input_proj[len(sources)](encoder_outputs.feature_maps[-1])) + for i in range(len(sources), self.config.num_feature_levels): + sources.append(self.decoder_input_proj[i](encoder_outputs.feature_maps[-1])) + + # Prepare encoder inputs (by flattening) + source_flatten = [] + spatial_shapes_list = [] + spatial_shapes = torch.empty((len(sources), 2), device=device, dtype=torch.long) + for level, source in enumerate(sources): + height, width = source.shape[-2:] + spatial_shapes[level, 0] = height + spatial_shapes[level, 1] = width + spatial_shapes_list.append((height, width)) + source = source.flatten(2).transpose(1, 2) + source_flatten.append(source) + source_flatten = torch.cat(source_flatten, 1) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + + # prepare denoising training + if self.training and self.config.num_denoising > 0 and labels is not None: + ( + denoising_class, + denoising_bbox_unact, + attention_mask, + denoising_meta_values, + ) = get_contrastive_denoising_training_group( + targets=labels, + num_classes=self.config.num_labels, + num_queries=self.config.num_queries, + class_embed=self.denoising_class_embed, + num_denoising_queries=self.config.num_denoising, + label_noise_ratio=self.config.label_noise_ratio, + box_noise_scale=self.config.box_noise_scale, + ) + else: + denoising_class, denoising_bbox_unact, attention_mask, denoising_meta_values = None, None, None, None + + batch_size = len(source_flatten) + device = source_flatten.device + dtype = source_flatten.dtype + + # prepare input for decoder + if self.training or self.config.anchor_image_size is None: + # Pass spatial_shapes as tuple to make it hashable and make sure + # lru_cache is working for generate_anchors() + spatial_shapes_tuple = tuple(spatial_shapes_list) + anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype) + else: + anchors, valid_mask = self.anchors, self.valid_mask + anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype) + + # use the valid_mask to selectively retain values in the feature map where the mask is True + memory = valid_mask.to(source_flatten.dtype) * source_flatten + + output_memory = self.enc_output(memory) + + enc_outputs_class = self.enc_score_head(output_memory) + enc_outputs_coord_logits = self.enc_bbox_head(output_memory) + anchors + + _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.config.num_queries, dim=1) + + reference_points_unact = enc_outputs_coord_logits.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_logits.shape[-1]) + ) + + enc_topk_bboxes = F.sigmoid(reference_points_unact) + if denoising_bbox_unact is not None: + reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1) + + enc_topk_logits = enc_outputs_class.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]) + ) + + # extract region features + if self.config.learn_initial_query: + target = self.weight_embedding.tile([batch_size, 1, 1]) + else: + target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1])) + target = target.detach() + + if denoising_class is not None: + target = torch.concat([denoising_class, target], 1) + + init_reference_points = reference_points_unact.detach() + + # decoder + decoder_outputs = self.decoder( + inputs_embeds=target, + encoder_hidden_states=source_flatten, + encoder_attention_mask=attention_mask, + reference_points=init_reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + **kwargs, + ) + + return Deimv2ModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, + intermediate_logits=decoder_outputs.intermediate_logits, + intermediate_reference_points=decoder_outputs.intermediate_reference_points, + intermediate_predicted_corners=decoder_outputs.intermediate_predicted_corners, + initial_reference_points=decoder_outputs.initial_reference_points, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.feature_maps, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + init_reference_points=init_reference_points, + enc_topk_logits=enc_topk_logits, + enc_topk_bboxes=enc_topk_bboxes, + enc_outputs_class=enc_outputs_class, + enc_outputs_coord_logits=enc_outputs_coord_logits, + denoising_meta_values=denoising_meta_values, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type of [`Deimv2ForObjectDetection`]. + """ +) +class Deimv2ObjectDetectionOutput(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~Deimv2ImageProcessor.post_process_object_detection`] to retrieve the + unnormalized (absolute) bounding boxes. + auxiliary_outputs (`list[Dict]`, *optional*): + Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) + and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and + `pred_boxes`) for each decoder layer. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate predicted corners (predicted corners of each layer of the decoder). + initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked initial reference points (initial reference points of each layer of the decoder). + init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the encoder. + enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the encoder. + enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): + Logits of predicted bounding boxes coordinates in the first stage. + denoising_meta_values (`dict`): + Extra dictionary for the denoising related values + """ + + loss: torch.FloatTensor | None = None + loss_dict: dict | None = None + logits: torch.FloatTensor | None = None + pred_boxes: torch.FloatTensor | None = None + auxiliary_outputs: list[dict] | None = None + last_hidden_state: torch.FloatTensor | None = None + intermediate_hidden_states: torch.FloatTensor | None = None + intermediate_logits: torch.FloatTensor | None = None + intermediate_reference_points: torch.FloatTensor | None = None + intermediate_predicted_corners: torch.FloatTensor | None = None + initial_reference_points: torch.FloatTensor | None = None + decoder_hidden_states: tuple[torch.FloatTensor] | None = None + decoder_attentions: tuple[torch.FloatTensor] | None = None + cross_attentions: tuple[torch.FloatTensor] | None = None + encoder_last_hidden_state: torch.FloatTensor | None = None + encoder_hidden_states: tuple[torch.FloatTensor] | None = None + encoder_attentions: tuple[torch.FloatTensor] | None = None + init_reference_points: tuple[torch.FloatTensor] | None = None + enc_topk_logits: torch.FloatTensor | None = None + enc_topk_bboxes: torch.FloatTensor | None = None + enc_outputs_class: torch.FloatTensor | None = None + enc_outputs_coord_logits: torch.FloatTensor | None = None + denoising_meta_values: dict | None = None + + +@auto_docstring( + custom_intro=""" + RT-DETR Model (consisting of a backbone and encoder-decoder) outputting bounding boxes and logits to be further + decoded into scores and classes. + """ +) +class Deimv2ForObjectDetection(Deimv2PreTrainedModel): + _tied_weights_keys = { + r"bbox_embed.(?![0])\d+": r"bbox_embed.0", + r"class_embed.(?![0])\d+": r"^class_embed.0", + "class_embed": "model.decoder.class_embed", + "bbox_embed": "model.decoder.bbox_embed", + } + + def __init__(self, config: Deimv2Config): + super().__init__(config) + + self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx + self.model = Deimv2Model(config) + scaled_dim = round(config.layer_scale * config.hidden_size) + num_pred = config.decoder_layers + self.class_embed = nn.ModuleList([nn.Linear(config.d_model, config.num_labels) for _ in range(num_pred)]) + if config.share_bbox_head: + shared_bbox = Deimv2MLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3) + self.bbox_embed = nn.ModuleList([shared_bbox] * num_pred) + else: + self.bbox_embed = nn.ModuleList( + [ + Deimv2MLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3) + for _ in range(self.eval_idx + 1) + ] + + [ + Deimv2MLP(scaled_dim, scaled_dim, 4 * (config.max_num_bins + 1), 3) + for _ in range(config.decoder_layers - self.eval_idx - 1) + ] + ) + + self.model.decoder.class_embed = self.class_embed + self.model.decoder.bbox_embed = self.bbox_embed + self.post_init() + + def _set_aux_loss(self, outputs_class, outputs_coord): + return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)] + + @auto_docstring + @can_return_tuple + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: torch.LongTensor | None = None, + encoder_outputs: torch.FloatTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: list[dict] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.FloatTensor] | Deimv2ObjectDetectionOutput: + r""" + Example: + + ```python + >>> import torch + >>> from transformers.image_utils import load_image + >>> from transformers import AutoImageProcessor, Deimv2ForObjectDetection + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = load_image(url) + + >>> image_processor = AutoImageProcessor.from_pretrained("harshaljanjani/DEIMv2_HGNetv2_N_COCO_Transformers") + >>> model = Deimv2ForObjectDetection.from_pretrained("harshaljanjani/DEIMv2_HGNetv2_N_COCO_Transformers") + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> # forward pass + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 300, 80] + + >>> boxes = outputs.pred_boxes + >>> list(boxes.shape) + [1, 300, 4] + + >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> target_sizes = torch.tensor([image.size[::-1]]) + >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes) + >>> result = results[0] # first image in batch + + >>> for score, label, box in zip(result["scores"], result["labels"], result["boxes"]): + ... box = [round(i, 2) for i in box.tolist()] + ... print( + ... f"Detected {model.config.id2label[label.item()]} with confidence " + ... f"{round(score.item(), 3)} at location {box}" + ... ) + ``` + """ + outputs = self.model( + pixel_values, + pixel_mask=pixel_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + labels=labels, + **kwargs, + ) + + denoising_meta_values = outputs.denoising_meta_values if self.training else None + + outputs_class = outputs.intermediate_logits + outputs_coord = outputs.intermediate_reference_points + predicted_corners = outputs.intermediate_predicted_corners + initial_reference_points = outputs.initial_reference_points + + logits = outputs_class[:, -1] + pred_boxes = outputs_coord[:, -1] + + loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None + if labels is not None: + enc_topk_logits = outputs.enc_topk_logits + enc_topk_bboxes = outputs.enc_topk_bboxes + loss, loss_dict, auxiliary_outputs = self.loss_function( + logits, + labels, + self.device, + pred_boxes, + self.config, + outputs_class, + outputs_coord, + enc_topk_logits=enc_topk_logits, + enc_topk_bboxes=enc_topk_bboxes, + denoising_meta_values=denoising_meta_values, + predicted_corners=predicted_corners, + initial_reference_points=initial_reference_points, + **kwargs, + ) + + return Deimv2ObjectDetectionOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_boxes=pred_boxes, + auxiliary_outputs=auxiliary_outputs, + last_hidden_state=outputs.last_hidden_state, + intermediate_hidden_states=outputs.intermediate_hidden_states, + intermediate_logits=outputs.intermediate_logits, + intermediate_reference_points=outputs.intermediate_reference_points, + intermediate_predicted_corners=outputs.intermediate_predicted_corners, + initial_reference_points=outputs.initial_reference_points, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + init_reference_points=outputs.init_reference_points, + enc_topk_logits=outputs.enc_topk_logits, + enc_topk_bboxes=outputs.enc_topk_bboxes, + enc_outputs_class=outputs.enc_outputs_class, + enc_outputs_coord_logits=outputs.enc_outputs_coord_logits, + denoising_meta_values=outputs.denoising_meta_values, + ) + + @property + def _tied_weights_keys(self): + keys = { + r"class_embed.(?![0])\d+": r"^class_embed.0", + "class_embed": "model.decoder.class_embed", + "bbox_embed": "model.decoder.bbox_embed", + } + if self.config.share_bbox_head: + keys[r"model\.decoder\.bbox_embed\.(?![0])\d+"] = r"model.decoder.bbox_embed.0" + keys[r"bbox_embed.(?![0])\d+"] = r"bbox_embed.0" + return keys + + +__all__ = ["Deimv2Model", "Deimv2PreTrainedModel", "Deimv2ForObjectDetection"] diff --git a/src/transformers/models/deimv2/modular_deimv2.py b/src/transformers/models/deimv2/modular_deimv2.py new file mode 100644 index 000000000000..e675fd12114c --- /dev/null +++ b/src/transformers/models/deimv2/modular_deimv2.py @@ -0,0 +1,945 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from huggingface_hub.dataclasses import strict + +from ... import initialization as init +from ...backbone_utils import load_backbone +from ...modeling_outputs import ModelOutput +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, logging +from ...utils.generic import merge_with_config_defaults +from ...utils.output_capturing import OutputRecorder, capture_outputs +from ..auto import AutoConfig +from ..d_fine.configuration_d_fine import DFineConfig +from ..d_fine.modeling_d_fine import ( + DFineAIFILayer, + DFineConvEncoder, + DFineConvNormLayer, + DFineDecoder, + DFineDecoderLayer, + DFineDecoderOutput, + DFineEncoderLayer, + DFineForObjectDetection, + DFineGate, + DFineHybridEncoder, + DFineIntegral, + DFineLQE, + DFineMLP, + DFineModel, + DFineModelOutput, + DFineMultiscaleDeformableAttention, + DFinePreTrainedModel, + DFineRepVggBlock, + DFineSCDown, + get_contrastive_denoising_training_group, +) +from ..llama.modeling_llama import LlamaMLP, LlamaRMSNorm + + +logger = logging.get_logger(__name__) + + +@auto_docstring(checkpoint="Intellindust/DEIMv2_HGNetv2_N_COCO") +@strict +class Deimv2Config(DFineConfig): + r""" + initializer_bias_prior_prob (`float`, *optional*): + The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`. + If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights. + freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`): + Whether to freeze the batch normalization layers in the backbone. + encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`): + Multi level features input for encoder. + feat_strides (`list[int]`, *optional*, defaults to `[8, 16, 32]`): + Strides used in each feature map. + encode_proj_layers (`list[int]`, *optional*, defaults to `[2]`): + Indexes of the projected layers to be used in the encoder. + positional_encoding_temperature (`int`, *optional*, defaults to 10000): + The temperature parameter used to create the positional encodings. + encoder_activation_function (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. + eval_size (`list[int]` or `tuple[int, int]`, *optional*): + Height and width used to computes the effective height and width of the position embeddings after taking + into account the stride. + normalize_before (`bool`, *optional*, defaults to `False`): + Determine whether to apply layer normalization in the transformer encoder layer before self-attention and + feed-forward modules. + hidden_expansion (`float`, *optional*, defaults to 1.0): + Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer. + num_queries (`int`, *optional*, defaults to 300): + Number of object queries. + decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`): + Multi level features dimension for decoder. + num_feature_levels (`int`, *optional*, defaults to 3): + The number of input feature levels. + decoder_n_points (`int`, *optional*, defaults to 4): + The number of sampled keys in each feature level for each attention head in the decoder. + decoder_activation_function (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the decoder. + num_denoising (`int`, *optional*, defaults to 100): + The total number of denoising tasks or queries to be used for contrastive denoising. + label_noise_ratio (`float`, *optional*, defaults to 0.5): + The fraction of denoising labels to which random noise should be added. + box_noise_scale (`float`, *optional*, defaults to 1.0): + Scale or magnitude of noise to be added to the bounding boxes. + learn_initial_query (`bool`, *optional*, defaults to `False`): + Indicates whether the initial query embeddings for the decoder should be learned during training. + anchor_image_size (`tuple[int, int]`, *optional*): + Height and width of the input image used during evaluation to generate the bounding box anchors. + with_box_refine (`bool`, *optional*, defaults to `True`): + Whether to apply iterative bounding box refinement. + matcher_alpha (`float`, *optional*, defaults to 0.25): + Parameter alpha used by the Hungarian Matcher. + matcher_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used by the Hungarian Matcher. + matcher_class_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the class loss used by the Hungarian Matcher. + matcher_bbox_cost (`float`, *optional*, defaults to 5.0): + The relative weight of the bounding box loss used by the Hungarian Matcher. + matcher_giou_cost (`float`, *optional*, defaults to 2.0): + The relative weight of the giou loss of used by the Hungarian Matcher. + use_focal_loss (`bool`, *optional*, defaults to `True`): + Parameter informing if focal loss should be used. + focal_loss_alpha (`float`, *optional*, defaults to 0.75): + Parameter alpha used to compute the focal loss. + focal_loss_gamma (`float`, *optional*, defaults to 2.0): + Parameter gamma used to compute the focal loss. + weight_loss_vfl (`float`, *optional*, defaults to 1.0): + Relative weight of the varifocal loss in the object detection loss. + weight_loss_bbox (`float`, *optional*, defaults to 5.0): + Relative weight of the L1 bounding box loss in the object detection loss. + weight_loss_giou (`float`, *optional*, defaults to 2.0): + Relative weight of the generalized IoU loss in the object detection loss. + weight_loss_fgl (`float`, *optional*, defaults to 0.15): + Relative weight of the fine-grained localization loss in the object detection loss. + weight_loss_ddf (`float`, *optional*, defaults to 1.5): + Relative weight of the decoupled distillation focal loss in the object detection loss. + eval_idx (`int`, *optional*, defaults to -1): + Index of the decoder layer to use for evaluation. + layer_scale (`float`, *optional*, defaults to `1.0`): + Scaling factor for the hidden dimension in later decoder layers. + max_num_bins (`int`, *optional*, defaults to 32): + Maximum number of bins for the distribution-guided bounding box refinement. + reg_scale (`float`, *optional*, defaults to 4.0): + Scale factor for the regression distribution. + depth_mult (`float`, *optional*, defaults to 1.0): + Multiplier for the number of blocks in RepNCSPELAN5 layers. + top_prob_values (`int`, *optional*, defaults to 4): + Number of top probability values to consider from each corner's distribution. + lqe_hidden_dim (`int`, *optional*, defaults to 64): + Hidden dimension size for the Location Quality Estimator (LQE) network. + lqe_layers (`int`, *optional*, defaults to 2): + Number of layers in the Location Quality Estimator MLP. + decoder_offset_scale (`float`, *optional*, defaults to 0.5): + Offset scale used in deformable attention. + decoder_method (`str`, *optional*, defaults to `"default"`): + The method to use for the decoder: `"default"` or `"discrete"`. + up (`float`, *optional*, defaults to 0.5): + Controls the upper bounds of the Weighting Function. + weight_loss_mal (`float`, *optional*, defaults to 1.0): + Relative weight of the matching auxiliary loss in the object detection loss. + use_dense_one_to_one (`bool`, *optional*, defaults to `True`): + Whether to use dense one-to-one matching across decoder layers. + mal_alpha (`float`, *optional*): + Alpha parameter for the Matching Auxiliary Loss (MAL). If `None`, uses `focal_loss_alpha`. + encoder_fuse_op (`str`, *optional*, defaults to `"sum"`): + Fusion operation used in the encoder FPN. DEIMv2 uses `"sum"` instead of D-FINE's `"cat"`. + spatial_tuning_adapter_inplanes (`int`, *optional*, defaults to 16): + Number of input planes for the STA convolutional stem. + encoder_type (`str`, *optional*, defaults to `"hybrid"`): + Type of encoder to use. `"hybrid"` uses the full HybridEncoder with AIFI, FPN, and PAN. + `"lite"` uses the lightweight LiteEncoder with GAP fusion for smaller variants (Atto, Femto, Pico). + use_gateway (`bool`, *optional*, defaults to `True`): + Whether to use the gateway mechanism (cross-attention gating) in decoder layers. When `False`, + uses RMSNorm on the encoder attention output instead. + share_bbox_head (`bool`, *optional*, defaults to `False`): + Whether to share the bounding box prediction head across all decoder layers. + encoder_has_trailing_conv (`bool`, *optional*, defaults to `True`): + Whether the encoder's CSP blocks include a trailing 3x3 convolution after the bottleneck path. + `True` for RepNCSPELAN4 (used by HGNetV2 N and LiteEncoder variants). + `False` for RepNCSPELAN5 (used by DINOv3 variants). + """ + + model_type = "deimv2" + sub_configs = {"backbone_config": AutoConfig} + + eval_size: list[int] | tuple[int, int] | None = None + weight_loss_mal: float = 1.0 + use_dense_one_to_one: bool = True + mal_alpha: float | None = None + encoder_fuse_op: str = "sum" + spatial_tuning_adapter_inplanes: int = 16 + encoder_type: str = "hybrid" + use_gateway: bool = True + share_bbox_head: bool = False + encoder_has_trailing_conv: bool = True + + +class Deimv2DecoderOutput(DFineDecoderOutput): + pass + + +class Deimv2ModelOutput(DFineModelOutput): + pass + + +@dataclass +@auto_docstring( + custom_intro=""" + Output type for DEIMv2 encoder modules (HybridEncoder and LiteEncoder). + Attentions are only available for HybridEncoder variants with AIFI layers. + """ +) +class Deimv2EncoderOutput(ModelOutput): + r""" + feature_maps (`list[torch.FloatTensor]`): + List of multi-scale feature maps from the encoder, one per feature level. + """ + + feature_maps: list[torch.FloatTensor] = None + hidden_states: tuple[torch.FloatTensor, ...] | None = None + attentions: tuple[torch.FloatTensor, ...] | None = None + + +class Deimv2RMSNorm(LlamaRMSNorm): + pass + + +class Deimv2SwiGLUFFN(LlamaMLP): + def __init__(self, config: Deimv2Config): + nn.Module.__init__(self) + hidden_features = config.decoder_ffn_dim // 2 + self.gate_proj = nn.Linear(config.d_model, hidden_features, bias=True) + self.up_proj = nn.Linear(config.d_model, hidden_features, bias=True) + self.down_proj = nn.Linear(hidden_features, config.d_model, bias=True) + self.act_fn = nn.SiLU() + + +class Deimv2Gate(DFineGate): + def __init__(self, d_model: int): + super().__init__(d_model) + self.norm = Deimv2RMSNorm(d_model) + + +class Deimv2MLP(DFineMLP): + pass + + +class Deimv2MultiscaleDeformableAttention(DFineMultiscaleDeformableAttention): + pass + + +class Deimv2ConvNormLayer(DFineConvNormLayer): + pass + + +class Deimv2RepVggBlock(DFineRepVggBlock): + pass + + +class Deimv2CSPRepLayer(nn.Module): + """ + Cross Stage Partial (CSP) network layer with RepVGG blocks. + Differs from DFineCSPRepLayer: uses a single conv that splits into residual + processing path + (instead of two separate convs), and has an optional trailing conv controlled by `encoder_has_trailing_conv`. + """ + + def __init__( + self, config: Deimv2Config, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0 + ): + super().__init__() + activation = config.activation_function + hidden_channels = int(out_channels * expansion) + self.conv1 = Deimv2ConvNormLayer(config, in_channels, hidden_channels * 2, 1, 1, activation=activation) + self.bottlenecks = nn.ModuleList( + [Deimv2RepVggBlock(config, hidden_channels, hidden_channels) for _ in range(num_blocks)] + ) + self.conv2 = ( + Deimv2ConvNormLayer(config, hidden_channels, out_channels, 3, 1, activation=activation) + if config.encoder_has_trailing_conv + else nn.Identity() + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual, hidden_states = self.conv1(hidden_states).chunk(2, dim=1) + for bottleneck in self.bottlenecks: + hidden_states = bottleneck(hidden_states) + return self.conv2(residual + hidden_states) + + +class Deimv2RepNCSPELAN5(nn.Module): + """ + Rep(VGG) N(etwork) CSP (Cross Stage Partial) ELAN (Efficient Layer Aggregation Network) block. + Similar to DFineRepNCSPELAN4 but without intermediate convolutions between CSP branches, + resulting in a simpler 4-way concatenation (2 split halves + 2 CSP branches) instead of D-FINE's + 4-branch design with interleaved convolutions. + """ + + def __init__(self, config: Deimv2Config, numb_blocks: int = 3): + super().__init__() + activation = config.activation_function + in_channels = config.encoder_hidden_dim + out_channels = config.encoder_hidden_dim + split_channels = config.encoder_hidden_dim * 2 + csp_channels = round(config.hidden_expansion * config.encoder_hidden_dim // 2) + self.conv1 = Deimv2ConvNormLayer(config, in_channels, split_channels, 1, 1, activation=activation) + self.csp_rep1 = Deimv2CSPRepLayer(config, split_channels // 2, csp_channels, num_blocks=numb_blocks) + self.csp_rep2 = Deimv2CSPRepLayer(config, csp_channels, csp_channels, num_blocks=numb_blocks) + self.conv2 = Deimv2ConvNormLayer( + config, split_channels + (2 * csp_channels), out_channels, 1, 1, activation=activation + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states_1, hidden_states_2 = self.conv1(hidden_states).chunk(2, dim=1) + hidden_states_3 = self.csp_rep1(hidden_states_2) + hidden_states_4 = self.csp_rep2(hidden_states_3) + merged_hidden_states = torch.cat([hidden_states_1, hidden_states_2, hidden_states_3, hidden_states_4], dim=1) + return self.conv2(merged_hidden_states) + + +class Deimv2SCDown(DFineSCDown): + pass + + +class Deimv2EncoderLayer(DFineEncoderLayer): + pass + + +class Deimv2AIFILayer(DFineAIFILayer): + pass + + +class Deimv2SpatialTuningAdapter(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + inplanes = config.spatial_tuning_adapter_inplanes + self.stem_conv = Deimv2ConvNormLayer(config, 3, inplanes, 3, 2, activation="gelu") + self.stem_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.conv2 = Deimv2ConvNormLayer(config, inplanes, 2 * inplanes, 3, 2) + self.conv3 = Deimv2ConvNormLayer(config, 2 * inplanes, 4 * inplanes, 3, 2) + self.conv4 = Deimv2ConvNormLayer(config, 4 * inplanes, 4 * inplanes, 3, 2) + self.act_fn = nn.GELU() + + def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states_1 = self.stem_pool(self.stem_conv(pixel_values)) + hidden_states_2 = self.conv2(hidden_states_1) + hidden_states_3 = self.conv3(self.act_fn(hidden_states_2)) + hidden_states_4 = self.conv4(self.act_fn(hidden_states_3)) + return hidden_states_2, hidden_states_3, hidden_states_4 + + +def fuse_feature_maps(feature_map_1: torch.Tensor, feature_map_2: torch.Tensor, fuse_op: str = "sum") -> torch.Tensor: + """Fuses two feature maps via element-wise sum or channel-wise concatenation.""" + if fuse_op == "sum": + return feature_map_1 + feature_map_2 + return torch.cat([feature_map_1, feature_map_2], dim=1) + + +class Deimv2ConvEncoder(DFineConvEncoder): + def __init__(self, config): + super().__init__(config) + self.encoder_input_proj = nn.ModuleList( + [ + Deimv2ConvNormLayer(config, in_channel, config.encoder_hidden_dim, 1, 1) + if config.encoder_type != "lite" + else nn.Identity() + for in_channel in self.intermediate_channel_sizes + ] + ) + + def forward(self, pixel_values: torch.Tensor, **kwargs: Unpack[TransformersKwargs]) -> list[torch.Tensor]: + features = self.model(pixel_values, **kwargs).feature_maps + return [proj(feat) for proj, feat in zip(self.encoder_input_proj, features)] + + +class Deimv2DINOv3ConvEncoder(nn.Module): + def __init__(self, config: Deimv2Config): + super().__init__() + self.backbone = load_backbone(config) + + self.spatial_tuning_adapter = Deimv2SpatialTuningAdapter(config) + + embed_dim = config.backbone_config.hidden_size + hidden_dim = config.encoder_hidden_dim + spatial_tuning_adapter_channels = config.spatial_tuning_adapter_inplanes + self.fusion_proj = nn.ModuleList( + [ + Deimv2ConvNormLayer(config, embed_dim + spatial_tuning_adapter_channels * 2, hidden_dim, 1, 1), + Deimv2ConvNormLayer(config, embed_dim + spatial_tuning_adapter_channels * 4, hidden_dim, 1, 1), + Deimv2ConvNormLayer(config, embed_dim + spatial_tuning_adapter_channels * 4, hidden_dim, 1, 1), + ] + ) + + def forward(self, pixel_values: torch.Tensor, **kwargs: Unpack[TransformersKwargs]) -> list[torch.Tensor]: + backbone_output = self.backbone(pixel_values, **kwargs) + feature_maps = backbone_output.feature_maps + + patch_size = self.backbone.config.patch_size + height_patches = pixel_values.shape[2] // patch_size + width_patches = pixel_values.shape[3] // patch_size + + semantic_features = [] + num_scales = len(feature_maps) + for i, feat in enumerate(feature_maps): + resize_height = int(height_patches * 2 ** (num_scales - 2 - i)) + resize_width = int(width_patches * 2 ** (num_scales - 2 - i)) + spatial = F.interpolate(feat, size=[resize_height, resize_width], mode="bilinear", align_corners=False) + semantic_features.append(spatial) + + detail_features = self.spatial_tuning_adapter(pixel_values) + + outputs = [] + for i, (semantic_feature, detail_feature) in enumerate(zip(semantic_features, detail_features)): + fused = torch.cat([semantic_feature, detail_feature], dim=1) + outputs.append(self.fusion_proj[i](fused)) + + return outputs + + +class Deimv2Integral(DFineIntegral): + pass + + +class Deimv2LQE(DFineLQE): + pass + + +class Deimv2DecoderLayer(DFineDecoderLayer): + def __init__(self, config: Deimv2Config): + super().__init__(config) + self.encoder_attn = Deimv2MultiscaleDeformableAttention(config=config) + self.self_attn_layer_norm = Deimv2RMSNorm(config.d_model) + self.final_layer_norm = Deimv2RMSNorm(config.d_model) + self.mlp = Deimv2SwiGLUFFN(config) + self.use_gateway = config.use_gateway + self.gateway = Deimv2Gate(config.d_model) if config.use_gateway else None + self.encoder_attn_layer_norm = None if config.use_gateway else Deimv2RMSNorm(config.d_model) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor | None = None, + reference_points: torch.Tensor | None = None, + spatial_shapes: torch.Tensor | None = None, + spatial_shapes_list: list[tuple[int, int]] | None = None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=encoder_attention_mask, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + + # Cross-Attention + hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings + hidden_states, _ = self.encoder_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gateway is not None: + hidden_states = self.gateway(residual, hidden_states) + else: + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + return hidden_states + + +class Deimv2PreTrainedModel(DFinePreTrainedModel): + _no_split_modules = [r"Deimv2HybridEncoder", r"Deimv2LiteEncoder", r"Deimv2DecoderLayer"] + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + + if isinstance(module, Deimv2SwiGLUFFN): + init.xavier_uniform_(module.gate_proj.weight) + init.constant_(module.gate_proj.bias, 0) + init.xavier_uniform_(module.up_proj.weight) + init.constant_(module.up_proj.bias, 0) + init.xavier_uniform_(module.down_proj.weight) + init.constant_(module.down_proj.bias, 0) + + +class Deimv2LiteEncoder(Deimv2PreTrainedModel): + # LiteEncoder has no transformer layers, so hidden_states are recorded from the conv projections. + _can_record_outputs = { + "hidden_states": [ + OutputRecorder(Deimv2ConvNormLayer, layer_name="input_proj"), + OutputRecorder(Deimv2ConvNormLayer, layer_name="bi_fusion_conv"), + ], + } + + def __init__(self, config: Deimv2Config): + super().__init__(config) + hidden_dim = config.encoder_hidden_dim + activation = config.activation_function + + self.input_proj = nn.ModuleList( + [Deimv2ConvNormLayer(config, in_channel, hidden_dim, 1, 1) for in_channel in config.encoder_in_channels] + ) + + self.down_pool1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + self.down_conv1 = Deimv2ConvNormLayer(config, hidden_dim, hidden_dim, 1, 1, activation=activation) + self.down_pool2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + self.down_conv2 = Deimv2ConvNormLayer(config, hidden_dim, hidden_dim, 1, 1, activation=activation) + + self.bi_fusion_conv = Deimv2ConvNormLayer(config, hidden_dim, hidden_dim, 1, 1, activation=activation) + + num_blocks = round(3 * config.depth_mult) + self.fpn_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks) + self.pan_block = Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks) + + self.post_init() + + @merge_with_config_defaults + @capture_outputs + def forward(self, inputs_embeds: list[torch.Tensor], **kwargs: Unpack[TransformersKwargs]) -> Deimv2EncoderOutput: + projected_features = [self.input_proj[i](feature) for i, feature in enumerate(inputs_embeds)] + projected_features.append(self.down_conv1(self.down_pool1(projected_features[-1]))) + + projected_features[-1] = self.bi_fusion_conv( + projected_features[-1] + F.adaptive_avg_pool2d(projected_features[-1], 1) + ) + + outputs = [] + fused_feature = projected_features[0] + F.interpolate(projected_features[1], scale_factor=2.0, mode="nearest") + outputs.append(self.fpn_block(fused_feature)) + + fused_feature = projected_features[1] + self.down_conv2(self.down_pool2(outputs[-1])) + outputs.append(self.pan_block(fused_feature)) + + return Deimv2EncoderOutput(feature_maps=outputs) + + +class Deimv2HybridEncoder(DFineHybridEncoder): + """ + DEIMv2 variant of DFineHybridEncoder. Uses element-wise sum fusion (`fuse_feature_maps`) instead of + D-FINE's channel concatenation, Deimv2RepNCSPELAN5 (simplified 4-way concat) instead of DFineRepNCSPELAN4, + and returns Deimv2EncoderOutput with feature_maps instead of BaseModelOutput with last_hidden_state. + """ + + def __init__(self, config: Deimv2Config): + Deimv2PreTrainedModel.__init__(self, config) + self.config = config + self.in_channels = config.encoder_in_channels + self.num_fpn_stages = len(self.in_channels) - 1 + self.feat_strides = config.feat_strides + self.encoder_hidden_dim = config.encoder_hidden_dim + self.encode_proj_layers = config.encode_proj_layers + self.positional_encoding_temperature = config.positional_encoding_temperature + self.eval_size = config.eval_size + self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels] + self.out_strides = self.feat_strides + self.fuse_op = config.encoder_fuse_op + + self.aifi = nn.ModuleList([Deimv2AIFILayer(config) for _ in range(len(self.encode_proj_layers))]) + + self.lateral_convs = nn.ModuleList() + self.fpn_blocks = nn.ModuleList() + for _ in range(len(self.in_channels) - 1, 0, -1): + self.lateral_convs.append( + Deimv2ConvNormLayer(config, self.encoder_hidden_dim, self.encoder_hidden_dim, 1, 1) + ) + num_blocks = round(3 * config.depth_mult) + self.fpn_blocks.append(Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks)) + + self.downsample_convs = nn.ModuleList() + self.pan_blocks = nn.ModuleList() + for _ in range(len(self.in_channels) - 1): + self.downsample_convs.append(Deimv2SCDown(config, 3, 2)) + num_blocks = round(3 * config.depth_mult) + self.pan_blocks.append(Deimv2RepNCSPELAN5(config, numb_blocks=num_blocks)) + + self.post_init() + + def forward( + self, + inputs_embeds: list[torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Deimv2EncoderOutput: + r""" + Args: + inputs_embeds (`list[torch.FloatTensor]`): + Multi-scale feature maps from the backbone (one tensor per feature level) passed to the encoder. + """ + feature_maps = inputs_embeds + + if self.config.encoder_layers > 0: + for i, enc_ind in enumerate(self.encode_proj_layers): + feature_maps[enc_ind] = self.aifi[i](feature_maps[enc_ind], **kwargs) + + # top-down FPN + fpn_feature_maps = [feature_maps[-1]] + for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)): + backbone_feature_map = feature_maps[self.num_fpn_stages - idx - 1] + top_fpn_feature_map = fpn_feature_maps[-1] + top_fpn_feature_map = lateral_conv(top_fpn_feature_map) + fpn_feature_maps[-1] = top_fpn_feature_map + top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest") + fused_feature_map = fuse_feature_maps(top_fpn_feature_map, backbone_feature_map, self.fuse_op) + new_fpn_feature_map = fpn_block(fused_feature_map) + fpn_feature_maps.append(new_fpn_feature_map) + + fpn_feature_maps.reverse() + + # bottom-up PAN + pan_feature_maps = [fpn_feature_maps[0]] + for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)): + top_pan_feature_map = pan_feature_maps[-1] + fpn_feature_map = fpn_feature_maps[idx + 1] + downsampled_feature_map = downsample_conv(top_pan_feature_map) + fused_feature_map = fuse_feature_maps(downsampled_feature_map, fpn_feature_map, self.fuse_op) + new_pan_feature_map = pan_block(fused_feature_map) + pan_feature_maps.append(new_pan_feature_map) + + return Deimv2EncoderOutput(feature_maps=pan_feature_maps) + + +class Deimv2Decoder(DFineDecoder): + def __init__(self, config: Deimv2Config): + super().__init__(config=config) + self.query_pos_head = Deimv2MLP(4, config.d_model, config.d_model, 3, config.decoder_activation_function) + + +class Deimv2Model(DFineModel): + def __init__(self, config: Deimv2Config): + Deimv2PreTrainedModel.__init__(self, config) + + is_dinov3 = getattr(config.backbone_config, "model_type", None) == "dinov3_vit" + self.conv_encoder = Deimv2DINOv3ConvEncoder(config) if is_dinov3 else Deimv2ConvEncoder(config) + self.encoder = ( + Deimv2LiteEncoder(config) if config.encoder_type == "lite" else Deimv2HybridEncoder(config=config) + ) + + if config.num_denoising > 0: + self.denoising_class_embed = nn.Embedding( + config.num_labels + 1, config.d_model, padding_idx=config.num_labels + ) + + if config.learn_initial_query: + self.weight_embedding = nn.Embedding(config.num_queries, config.d_model) + + self.enc_output = nn.Sequential( + nn.Linear(config.d_model, config.d_model), + nn.LayerNorm(config.d_model, eps=config.layer_norm_eps), + ) + self.enc_score_head = nn.Linear(config.d_model, config.num_labels) + self.enc_bbox_head = Deimv2MLP(config.d_model, config.d_model, 4, 3) + + if config.anchor_image_size: + self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype) + + num_backbone_outs = len(config.decoder_in_channels) + decoder_input_proj = [] + in_channels = config.decoder_in_channels[-1] + for _ in range(num_backbone_outs): + decoder_input_proj.append( + nn.Identity() + if config.hidden_size == config.decoder_in_channels[-1] + else Deimv2ConvNormLayer(config, in_channels, config.d_model, 1, 1) + ) + for _ in range(config.num_feature_levels - num_backbone_outs): + decoder_input_proj.append( + nn.Identity() + if config.hidden_size == config.decoder_in_channels[-1] + else Deimv2ConvNormLayer(config, in_channels, config.d_model, 3, 2) + ) + self.decoder_input_proj = nn.ModuleList(decoder_input_proj) + self.decoder = Deimv2Decoder(config) + + self.post_init() + + def forward( + self, + pixel_values: torch.FloatTensor, + pixel_mask: torch.LongTensor | None = None, + encoder_outputs: torch.FloatTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: list[dict] | None = None, + **kwargs: Unpack[TransformersKwargs], + ): + # Overrides DFineModel.forward: DEIMv2 uses a unified conv_encoder (backbone + projection) instead of + # D-FINE's separate backbone + encoder_input_proj, and returns feature_maps instead of last_hidden_state. + if pixel_values is None and inputs_embeds is None: + raise ValueError("You have to specify either pixel_values or inputs_embeds") + + # extract multi-scale features via conv_encoder (backbone + projection in one step) + if inputs_embeds is None: + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), device=device) + + # TODO: pass pixel_mask to backbone once DINOv3 supports it + proj_feats = self.conv_encoder(pixel_values) + else: + batch_size = inputs_embeds.shape[0] + device = inputs_embeds.device + proj_feats = inputs_embeds + + encoder_outputs = self.encoder( + proj_feats, + **kwargs, + ) + + # Equivalent to def _get_encoder_input + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412 + sources = [] + for level, source in enumerate(encoder_outputs.feature_maps): + sources.append(self.decoder_input_proj[level](source)) + + # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage + if self.config.num_feature_levels > len(sources): + sources.append(self.decoder_input_proj[len(sources)](encoder_outputs.feature_maps[-1])) + for i in range(len(sources), self.config.num_feature_levels): + sources.append(self.decoder_input_proj[i](encoder_outputs.feature_maps[-1])) + + # Prepare encoder inputs (by flattening) + source_flatten = [] + spatial_shapes_list = [] + spatial_shapes = torch.empty((len(sources), 2), device=device, dtype=torch.long) + for level, source in enumerate(sources): + height, width = source.shape[-2:] + spatial_shapes[level, 0] = height + spatial_shapes[level, 1] = width + spatial_shapes_list.append((height, width)) + source = source.flatten(2).transpose(1, 2) + source_flatten.append(source) + source_flatten = torch.cat(source_flatten, 1) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + + # prepare denoising training + if self.training and self.config.num_denoising > 0 and labels is not None: + ( + denoising_class, + denoising_bbox_unact, + attention_mask, + denoising_meta_values, + ) = get_contrastive_denoising_training_group( + targets=labels, + num_classes=self.config.num_labels, + num_queries=self.config.num_queries, + class_embed=self.denoising_class_embed, + num_denoising_queries=self.config.num_denoising, + label_noise_ratio=self.config.label_noise_ratio, + box_noise_scale=self.config.box_noise_scale, + ) + else: + denoising_class, denoising_bbox_unact, attention_mask, denoising_meta_values = None, None, None, None + + batch_size = len(source_flatten) + device = source_flatten.device + dtype = source_flatten.dtype + + # prepare input for decoder + if self.training or self.config.anchor_image_size is None: + # Pass spatial_shapes as tuple to make it hashable and make sure + # lru_cache is working for generate_anchors() + spatial_shapes_tuple = tuple(spatial_shapes_list) + anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype) + else: + anchors, valid_mask = self.anchors, self.valid_mask + anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype) + + # use the valid_mask to selectively retain values in the feature map where the mask is True + memory = valid_mask.to(source_flatten.dtype) * source_flatten + + output_memory = self.enc_output(memory) + + enc_outputs_class = self.enc_score_head(output_memory) + enc_outputs_coord_logits = self.enc_bbox_head(output_memory) + anchors + + _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.config.num_queries, dim=1) + + reference_points_unact = enc_outputs_coord_logits.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_logits.shape[-1]) + ) + + enc_topk_bboxes = F.sigmoid(reference_points_unact) + if denoising_bbox_unact is not None: + reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1) + + enc_topk_logits = enc_outputs_class.gather( + dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1]) + ) + + # extract region features + if self.config.learn_initial_query: + target = self.weight_embedding.tile([batch_size, 1, 1]) + else: + target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1])) + target = target.detach() + + if denoising_class is not None: + target = torch.concat([denoising_class, target], 1) + + init_reference_points = reference_points_unact.detach() + + # decoder + decoder_outputs = self.decoder( + inputs_embeds=target, + encoder_hidden_states=source_flatten, + encoder_attention_mask=attention_mask, + reference_points=init_reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + **kwargs, + ) + + return Deimv2ModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, + intermediate_logits=decoder_outputs.intermediate_logits, + intermediate_reference_points=decoder_outputs.intermediate_reference_points, + intermediate_predicted_corners=decoder_outputs.intermediate_predicted_corners, + initial_reference_points=decoder_outputs.initial_reference_points, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.feature_maps, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + init_reference_points=init_reference_points, + enc_topk_logits=enc_topk_logits, + enc_topk_bboxes=enc_topk_bboxes, + enc_outputs_class=enc_outputs_class, + enc_outputs_coord_logits=enc_outputs_coord_logits, + denoising_meta_values=denoising_meta_values, + ) + + +class Deimv2ForObjectDetection(DFineForObjectDetection): + _no_split_modules = AttributeError() # Don't have the same restriction as DFine + + @property + def _tied_weights_keys(self): + keys = { + r"class_embed.(?![0])\d+": r"^class_embed.0", + "class_embed": "model.decoder.class_embed", + "bbox_embed": "model.decoder.bbox_embed", + } + if self.config.share_bbox_head: + keys[r"model\.decoder\.bbox_embed\.(?![0])\d+"] = r"model.decoder.bbox_embed.0" + keys[r"bbox_embed.(?![0])\d+"] = r"bbox_embed.0" + return keys + + def __init__(self, config: Deimv2Config): + Deimv2PreTrainedModel.__init__(self, config) + + self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx + self.model = Deimv2Model(config) + scaled_dim = round(config.layer_scale * config.hidden_size) + num_pred = config.decoder_layers + self.class_embed = nn.ModuleList([nn.Linear(config.d_model, config.num_labels) for _ in range(num_pred)]) + if config.share_bbox_head: + shared_bbox = Deimv2MLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3) + self.bbox_embed = nn.ModuleList([shared_bbox] * num_pred) + else: + self.bbox_embed = nn.ModuleList( + [ + Deimv2MLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3) + for _ in range(self.eval_idx + 1) + ] + + [ + Deimv2MLP(scaled_dim, scaled_dim, 4 * (config.max_num_bins + 1), 3) + for _ in range(config.decoder_layers - self.eval_idx - 1) + ] + ) + + self.model.decoder.class_embed = self.class_embed + self.model.decoder.bbox_embed = self.bbox_embed + self.post_init() + + def forward(**super_kwargs): + r""" + Example: + + ```python + >>> import torch + >>> from transformers.image_utils import load_image + >>> from transformers import AutoImageProcessor, Deimv2ForObjectDetection + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = load_image(url) + + >>> image_processor = AutoImageProcessor.from_pretrained("harshaljanjani/DEIMv2_HGNetv2_N_COCO_Transformers") + >>> model = Deimv2ForObjectDetection.from_pretrained("harshaljanjani/DEIMv2_HGNetv2_N_COCO_Transformers") + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> # forward pass + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 300, 80] + + >>> boxes = outputs.pred_boxes + >>> list(boxes.shape) + [1, 300, 4] + + >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> target_sizes = torch.tensor([image.size[::-1]]) + >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes) + >>> result = results[0] # first image in batch + + >>> for score, label, box in zip(result["scores"], result["labels"], result["boxes"]): + ... box = [round(i, 2) for i in box.tolist()] + ... print( + ... f"Detected {model.config.id2label[label.item()]} with confidence " + ... f"{round(score.item(), 3)} at location {box}" + ... ) + ``` + """ + super().forward(**super_kwargs) + + +__all__ = [ + "Deimv2Config", + "Deimv2Model", + "Deimv2PreTrainedModel", + "Deimv2ForObjectDetection", +] diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 00f9ac601b0e..55381a7e3c21 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -65,7 +65,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -@use_experts_implementation(is_transposed=True, has_bias=True) +@use_experts_implementation(is_concatenated=False, is_transposed=True, has_bias=True) class GptOssExperts(nn.Module): def __init__(self, config): super().__init__() diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index f7c89cab08e5..3354acef2196 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -62,7 +62,7 @@ def forward(self, hidden_states): return (self.weight * hidden_states).to(input_dtype) # main diff with Llama -@use_experts_implementation(is_transposed=True, has_bias=True) +@use_experts_implementation(is_concatenated=False, is_transposed=True, has_bias=True) class GptOssExperts(nn.Module): def __init__(self, config): super().__init__() diff --git a/src/transformers/models/openai_privacy_filter/modular_openai_privacy_filter.py b/src/transformers/models/openai_privacy_filter/modular_openai_privacy_filter.py index fc77aafbdcf5..422235d9da91 100644 --- a/src/transformers/models/openai_privacy_filter/modular_openai_privacy_filter.py +++ b/src/transformers/models/openai_privacy_filter/modular_openai_privacy_filter.py @@ -21,6 +21,7 @@ from torch.nn import functional as F from ...configuration_utils import PreTrainedConfig +from ...integrations import use_experts_implementation from ...masking_utils import create_bidirectional_sliding_window_mask from ...modeling_layers import GenericForTokenClassification from ...modeling_outputs import BaseModelOutput @@ -213,6 +214,7 @@ def forward( return attn_output, attn_weights +@use_experts_implementation(is_transposed=True, has_bias=True) class OpenAIPrivacyFilterExperts(GptOssExperts): def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor: # Concatenated layout instead of interleaving diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index bf5e0c431e42..bb1344a43dcf 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -1813,6 +1813,8 @@ def apply_chat_template( images, videos = [], [] for message in conversation: content = message.get("content") or [] + if isinstance(content, str): + continue visuals = [ content_block for content_block in content if content_block["type"] in ["image", "video"] ] diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py index a76f73aeb562..fd117b08023b 100644 --- a/src/transformers/quantizers/quantizer_torchao.py +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -184,6 +184,7 @@ def get_weight_conversions(self): source_patterns=[ "_weight_qdata", "_weight_scale_and_zero", + "_weight_per_tensor_scale", "_weight_scale", "_weight_zero_point", "_weight_act_pre_scale", diff --git a/tests/generation/test_continuous_batching.py b/tests/generation/test_continuous_batching.py index ff3e54be374f..cd7c95f7bf4e 100644 --- a/tests/generation/test_continuous_batching.py +++ b/tests/generation/test_continuous_batching.py @@ -1274,16 +1274,16 @@ def test_memory_prediction( max_blocks_per_request=max_bpr, return_logprobs=logprobs, use_async_batching=use_async_batching, + block_size=block_size, ) handler = PagedAttentionMemoryHandler( - block_size=block_size, + continuous_batching_config=cb_config, page_size=page_size, num_groups=num_groups, group_size=group_size, - peak_activation_per_token=peak_act, + activation_peaks=[(0, peak_act)], num_attention_masks=num_attn_masks, - continuous_batching_config=cb_config, ) N = self.NUM_BLOCKS * block_size # num_pages diff --git a/tests/models/deimv2/__init__.py b/tests/models/deimv2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/deimv2/test_modeling_deimv2.py b/tests/models/deimv2/test_modeling_deimv2.py new file mode 100644 index 000000000000..eb516c449fbc --- /dev/null +++ b/tests/models/deimv2/test_modeling_deimv2.py @@ -0,0 +1,1734 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch DEIMv2 model.""" + +import copy +import inspect +import math +import tempfile +import unittest +from functools import cached_property + +from parameterized import parameterized + +from transformers import ( + AutoImageProcessor, + Deimv2Config, + DINOv3ViTConfig, + HGNetV2Config, + is_torch_available, +) +from transformers.testing_utils import ( + require_torch, + require_torch_accelerator, + require_vision, + slow, + torch_device, +) + + +if is_torch_available(): + import torch + + from transformers import Deimv2ForObjectDetection, Deimv2Model + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ( + TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, + ModelTesterMixin, + _test_eager_matches_sdpa_inference, + floats_tensor, +) +from ...test_pipeline_mixin import PipelineTesterMixin +from ...test_processing_common import url_to_local_path + + +# TODO: Replace with the official Transformers ckpt once uploaded. +CHECKPOINT = "harshaljanjani/DEIMv2_HGNetv2_N_COCO_Transformers" +CHECKPOINT_LITE = "harshaljanjani/DEIMv2_HGNetv2_PICO_COCO_Transformers" +CHECKPOINT_DINOV3 = "harshaljanjani/DEIMv2_DINOv3_S_COCO_Transformers" + + +class Deimv2ModelTester: + def __init__( + self, + parent, + batch_size=3, + is_training=True, + use_labels=True, + n_targets=3, + num_labels=10, + initializer_range=0.02, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + backbone_config=None, + encoder_hidden_dim=32, + encoder_in_channels=[128, 256, 512], + feat_strides=[8, 16, 32], + encoder_layers=1, + encoder_ffn_dim=64, + encoder_attention_heads=2, + dropout=0.0, + activation_dropout=0.0, + encode_proj_layers=[2], + positional_encoding_temperature=10000, + encoder_activation_function="gelu", + activation_function="silu", + eval_size=None, + normalize_before=False, + d_model=32, + num_queries=30, + decoder_in_channels=[32, 32, 32], + decoder_ffn_dim=64, + num_feature_levels=3, + decoder_n_points=[3, 6, 3], + decoder_n_levels=3, + decoder_layers=2, + decoder_attention_heads=2, + decoder_activation_function="relu", + attention_dropout=0.0, + num_denoising=0, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_initial_query=False, + anchor_image_size=None, + image_size=64, + disable_custom_kernels=True, + with_box_refine=True, + decoder_offset_scale=0.5, + eval_idx=-1, + layer_scale=1, + reg_max=32, + reg_scale=4.0, + depth_mult=0.34, + hidden_expansion=0.5, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = 3 + self.is_training = is_training + self.use_labels = use_labels + self.n_targets = n_targets + self.num_labels = num_labels + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + self.backbone_config = backbone_config + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_in_channels = encoder_in_channels + self.feat_strides = feat_strides + self.encoder_layers = encoder_layers + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.dropout = dropout + self.activation_dropout = activation_dropout + self.encode_proj_layers = encode_proj_layers + self.positional_encoding_temperature = positional_encoding_temperature + self.encoder_activation_function = encoder_activation_function + self.activation_function = activation_function + self.eval_size = eval_size + self.normalize_before = normalize_before + self.d_model = d_model + self.num_queries = num_queries + self.decoder_in_channels = decoder_in_channels + self.decoder_ffn_dim = decoder_ffn_dim + self.num_feature_levels = num_feature_levels + self.decoder_n_points = decoder_n_points + self.decoder_n_levels = decoder_n_levels + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_activation_function = decoder_activation_function + self.attention_dropout = attention_dropout + self.decoder_offset_scale = decoder_offset_scale + self.eval_idx = eval_idx + self.layer_scale = layer_scale + self.reg_max = reg_max + self.reg_scale = reg_scale + self.depth_mult = depth_mult + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + self.learn_initial_query = learn_initial_query + self.anchor_image_size = anchor_image_size + self.image_size = image_size + self.disable_custom_kernels = disable_custom_kernels + self.with_box_refine = with_box_refine + self.hidden_expansion = hidden_expansion + + self.encoder_seq_length = math.ceil(self.image_size / 32) * math.ceil(self.image_size / 32) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device) + + labels = None + if self.use_labels: + labels = [] + for i in range(self.batch_size): + target = {} + target["class_labels"] = torch.randint( + high=self.num_labels, size=(self.n_targets,), device=torch_device + ) + target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device) + labels.append(target) + + config = self.get_config() + config.num_labels = self.num_labels + return config, pixel_values, pixel_mask, labels + + def get_config(self): + hidden_sizes = [64, 128, 256, 512] + backbone_config = HGNetV2Config( + stage_in_channels=[16, 64, 128, 256], + stage_mid_channels=[16, 32, 64, 128], + stage_out_channels=[64, 128, 256, 512], + stage_num_blocks=[1, 1, 2, 1], + stage_downsample=[False, True, True, True], + stage_light_block=[False, False, True, True], + stage_kernel_size=[3, 3, 5, 5], + stage_numb_of_layers=[3, 3, 3, 3], + embeddings_size=10, + hidden_sizes=hidden_sizes, + depths=[1, 1, 2, 1], + out_features=["stage2", "stage3", "stage4"], + out_indices=[2, 3, 4], + stem_channels=[3, 16, 16], + use_lab=True, + ) + return Deimv2Config( + backbone_config=backbone_config, + encoder_hidden_dim=self.encoder_hidden_dim, + encoder_in_channels=self.encoder_in_channels, + feat_strides=self.feat_strides, + encoder_layers=self.encoder_layers, + encoder_ffn_dim=self.encoder_ffn_dim, + encoder_attention_heads=self.encoder_attention_heads, + dropout=self.dropout, + activation_dropout=self.activation_dropout, + encode_proj_layers=self.encode_proj_layers, + positional_encoding_temperature=self.positional_encoding_temperature, + encoder_activation_function=self.encoder_activation_function, + activation_function=self.activation_function, + eval_size=self.eval_size, + normalize_before=self.normalize_before, + d_model=self.d_model, + num_queries=self.num_queries, + decoder_in_channels=self.decoder_in_channels, + decoder_ffn_dim=self.decoder_ffn_dim, + num_feature_levels=self.num_feature_levels, + decoder_n_points=self.decoder_n_points, + decoder_n_levels=self.decoder_n_levels, + decoder_layers=self.decoder_layers, + decoder_attention_heads=self.decoder_attention_heads, + decoder_activation_function=self.decoder_activation_function, + decoder_offset_scale=self.decoder_offset_scale, + eval_idx=self.eval_idx, + layer_scale=self.layer_scale, + reg_max=self.reg_max, + reg_scale=self.reg_scale, + depth_mult=self.depth_mult, + attention_dropout=self.attention_dropout, + num_denoising=self.num_denoising, + label_noise_ratio=self.label_noise_ratio, + box_noise_scale=self.box_noise_scale, + learn_initial_query=self.learn_initial_query, + anchor_image_size=self.anchor_image_size, + image_size=self.image_size, + disable_custom_kernels=self.disable_custom_kernels, + with_box_refine=self.with_box_refine, + ) + + def prepare_config_and_inputs_for_common(self): + config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs() + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + def create_and_check_deimv2_model(self, config, pixel_values, pixel_mask, labels): + model = Deimv2Model(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.d_model)) + + def create_and_check_deimv2_object_detection_head_model(self, config, pixel_values, pixel_mask, labels): + model = Deimv2ForObjectDetection(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels) + + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + +@require_torch +class Deimv2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (Deimv2Model, Deimv2ForObjectDetection) if is_torch_available() else () + pipeline_model_mapping = ( + {"image-feature-extraction": Deimv2Model, "object-detection": Deimv2ForObjectDetection} + if is_torch_available() + else {} + ) + is_encoder_decoder = True + test_resize_embeddings = False + + test_missing_keys = False + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class.__name__ == "Deimv2ForObjectDetection": + labels = [] + for i in range(self.model_tester.batch_size): + target = {} + target["class_labels"] = torch.ones( + size=(self.model_tester.n_targets,), device=torch_device, dtype=torch.long + ) + target["boxes"] = torch.ones( + self.model_tester.n_targets, 4, device=torch_device, dtype=torch.float + ) + labels.append(target) + inputs_dict["labels"] = labels + + return inputs_dict + + def setUp(self): + self.model_tester = Deimv2ModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=Deimv2Config, + has_text_modality=False, + common_properties=["hidden_size", "num_attention_heads"], + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_deimv2_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deimv2_model(*config_and_inputs) + + def test_deimv2_object_detection_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deimv2_object_detection_head_model(*config_and_inputs) + + @unittest.skip(reason="Multi-scale deformable attention is incompatible with nn.DataParallel") + def test_multi_gpu_data_parallel_forward(self): + pass + + @unittest.skip( + reason="Deimv2 is a vision model but inputs_embeds is in the forward signature (inherited from D-FINE)" + ) + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Forward signature has inputs_embeds but no input_ids") + def test_inputs_embeds_matches_input_ids(self): + pass + + @unittest.skip(reason="Base test asserts get_input_embeddings() returns nn.Embedding which vision models lack") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="Decoder heads are shared via reference assignment so untied saving is not applicable") + def test_load_save_without_tied_weights(self): + pass + + # Override: Multi-scale deformable attention outputs have different shapes than standard self-attention + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class._from_config(config, attn_implementation="eager") + config = model.config + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions + self.assertEqual(len(attentions), self.model_tester.encoder_layers) + + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions + self.assertEqual(len(attentions), self.model_tester.encoder_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [ + self.model_tester.encoder_attention_heads, + self.model_tester.encoder_seq_length, + self.model_tester.encoder_seq_length, + ], + ) + out_len = len(outputs) + + correct_outlen = 15 + + if "labels" in inputs_dict: + correct_outlen += 1 + if model_class.__name__ == "Deimv2ForObjectDetection": + correct_outlen += 2 + + self.assertEqual(out_len, correct_outlen) + + decoder_attentions = outputs.decoder_attentions + self.assertIsInstance(decoder_attentions, (list, tuple)) + self.assertEqual(len(decoder_attentions), self.model_tester.decoder_layers) + self.assertListEqual( + list(decoder_attentions[0].shape[-3:]), + [ + self.model_tester.decoder_attention_heads, + self.model_tester.num_queries, + self.model_tester.num_queries, + ], + ) + + cross_attentions = outputs.cross_attentions + self.assertIsInstance(cross_attentions, (list, tuple)) + self.assertEqual(len(cross_attentions), self.model_tester.decoder_layers) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [ + self.model_tester.num_queries, + self.model_tester.decoder_attention_heads, + self.model_tester.decoder_n_levels * self.model_tester.decoder_n_points + if isinstance(self.model_tester.decoder_n_points, int) + else sum(self.model_tester.decoder_n_points), + ], + ) + + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + else: + added_hidden_states = 2 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions + + self.assertEqual(len(self_attentions), self.model_tester.encoder_layers) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [ + self.model_tester.encoder_attention_heads, + self.model_tester.encoder_seq_length, + self.model_tester.encoder_seq_length, + ], + ) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", len(self.model_tester.encoder_in_channels) - 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + self.assertListEqual( + list(hidden_states[1].shape[-2:]), + [ + self.model_tester.image_size // self.model_tester.feat_strides[-1], + self.model_tester.image_size // self.model_tester.feat_strides[-1], + ], + ) + + if config.is_encoder_decoder: + hidden_states = outputs.decoder_hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.decoder_layers + 1 + ) + + self.assertIsInstance(hidden_states, (list, tuple)) + self.assertEqual(len(hidden_states), expected_num_layers) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [self.model_tester.num_queries, self.model_tester.d_model], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + # Override: Custom gradient retention check for multi-scale deformable attention outputs + def test_retain_grad_hidden_states_attentions(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = True + + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + inputs = self._prepare_for_class(inputs_dict, model_class) + + outputs = model(**inputs) + + output = outputs[0] + + encoder_hidden_states = outputs.encoder_hidden_states[0] + encoder_attentions = outputs.encoder_attentions[0] + encoder_hidden_states.retain_grad() + encoder_attentions.retain_grad() + + decoder_attentions = outputs.decoder_attentions[0] + decoder_attentions.retain_grad() + + cross_attentions = outputs.cross_attentions[0] + cross_attentions.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(encoder_hidden_states.grad) + self.assertIsNotNone(encoder_attentions.grad) + self.assertIsNotNone(decoder_attentions.grad) + self.assertIsNotNone(cross_attentions.grad) + + # Override: Deimv2 uses pixel_values as main input, not input_ids + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + arg_names = [*signature.parameters.keys()] + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_backbone_selection(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def _validate_backbone_init(config): + for model_class in self.all_model_classes: + model = model_class(copy.deepcopy(config)) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if model_class.__name__ == "Deimv2ForObjectDetection": + expected_shape = ( + self.model_tester.batch_size, + self.model_tester.num_queries, + self.model_tester.num_labels, + ) + self.assertEqual(outputs.logits.shape, expected_shape) + self.assertEqual(len(model.model.conv_encoder.intermediate_channel_sizes), 3) + else: + self.assertEqual(len(model.conv_encoder.intermediate_channel_sizes), 3) + + self.assertTrue(outputs) + + config_dict = config.to_dict() + config_dict["encoder_in_channels"] = [24, 40, 432] + config_dict["backbone"] = "tf_mobilenetv3_small_075" + config_dict["backbone_config"] = None + config_dict["use_timm_backbone"] = True + config_dict["backbone_kwargs"] = {"out_indices": [2, 3, 4]} + config = config.__class__(**config_dict) + _validate_backbone_init(config) + + config_dict = config.to_dict() + config_dict["backbone"] = "microsoft/resnet-18" + config_dict["backbone_config"] = None + config_dict["use_timm_backbone"] = False + config_dict["use_pretrained_backbone"] = True + config_dict["backbone_kwargs"] = {"out_indices": [2, 3, 4]} + config = config.__class__(**config_dict) + _validate_backbone_init(config) + + @parameterized.expand(["float32", "float16", "bfloat16"]) + @require_torch_accelerator + @slow + def test_inference_with_different_dtypes(self, dtype_str): + dtype = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[dtype_str] + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device).to(dtype) + model.eval() + for key, tensor in inputs_dict.items(): + if tensor.dtype == torch.float32: + inputs_dict[key] = tensor.to(dtype) + with torch.no_grad(): + _ = model(**self._prepare_for_class(inputs_dict, model_class)) + + @parameterized.expand(["float32", "float16", "bfloat16"]) + @require_torch_accelerator + @slow + def test_inference_equivalence_for_static_and_dynamic_anchors(self, dtype_str): + dtype = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[dtype_str] + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + h, w = inputs_dict["pixel_values"].shape[-2:] + + for key, tensor in inputs_dict.items(): + if tensor.dtype == torch.float32: + inputs_dict[key] = tensor.to(dtype) + + for model_class in self.all_model_classes: + with tempfile.TemporaryDirectory() as tmpdirname: + model_class(config).save_pretrained(tmpdirname) + model_static = model_class.from_pretrained( + tmpdirname, anchor_image_size=[h, w], device_map=torch_device, dtype=dtype + ).eval() + model_dynamic = model_class.from_pretrained( + tmpdirname, anchor_image_size=None, device_map=torch_device, dtype=dtype + ).eval() + + self.assertIsNotNone(model_static.config.anchor_image_size) + self.assertIsNone(model_dynamic.config.anchor_image_size) + + with torch.no_grad(): + outputs_static = model_static(**self._prepare_for_class(inputs_dict, model_class)) + outputs_dynamic = model_dynamic(**self._prepare_for_class(inputs_dict, model_class)) + + torch.testing.assert_close( + outputs_static.last_hidden_state, outputs_dynamic.last_hidden_state, rtol=1e-4, atol=1e-4 + ) + + +class Deimv2LiteEncoderModelTester: + def __init__( + self, + parent, + batch_size=3, + is_training=True, + use_labels=True, + n_targets=3, + num_labels=10, + initializer_range=0.02, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + encoder_hidden_dim=32, + encoder_in_channels=[256], + feat_strides=[16, 32], + dropout=0.0, + activation_dropout=0.0, + positional_encoding_temperature=10000, + encoder_activation_function="gelu", + activation_function="silu", + eval_size=None, + normalize_before=False, + d_model=32, + num_queries=10, + decoder_in_channels=[32, 32], + decoder_ffn_dim=64, + num_feature_levels=2, + decoder_n_points=[4, 2], + decoder_n_levels=2, + decoder_layers=2, + decoder_attention_heads=2, + decoder_activation_function="relu", + attention_dropout=0.0, + num_denoising=0, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_initial_query=False, + anchor_image_size=None, + image_size=64, + disable_custom_kernels=True, + with_box_refine=True, + decoder_offset_scale=0.5, + eval_idx=-1, + layer_scale=1, + reg_max=32, + reg_scale=4.0, + depth_mult=0.34, + hidden_expansion=0.5, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = 3 + self.is_training = is_training + self.use_labels = use_labels + self.n_targets = n_targets + self.num_labels = num_labels + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_in_channels = encoder_in_channels + self.feat_strides = feat_strides + self.encoder_layers = 0 + self.encoder_ffn_dim = 64 + self.encoder_attention_heads = 2 + self.dropout = dropout + self.activation_dropout = activation_dropout + self.encode_proj_layers = [] + self.positional_encoding_temperature = positional_encoding_temperature + self.encoder_activation_function = encoder_activation_function + self.activation_function = activation_function + self.eval_size = eval_size + self.normalize_before = normalize_before + self.d_model = d_model + self.num_queries = num_queries + self.decoder_in_channels = decoder_in_channels + self.decoder_ffn_dim = decoder_ffn_dim + self.num_feature_levels = num_feature_levels + self.decoder_n_points = decoder_n_points + self.decoder_n_levels = decoder_n_levels + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_activation_function = decoder_activation_function + self.attention_dropout = attention_dropout + self.decoder_offset_scale = decoder_offset_scale + self.eval_idx = eval_idx + self.layer_scale = layer_scale + self.reg_max = reg_max + self.reg_scale = reg_scale + self.depth_mult = depth_mult + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + self.learn_initial_query = learn_initial_query + self.anchor_image_size = anchor_image_size + self.image_size = image_size + self.disable_custom_kernels = disable_custom_kernels + self.with_box_refine = with_box_refine + self.hidden_expansion = hidden_expansion + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device) + + labels = None + if self.use_labels: + labels = [] + for i in range(self.batch_size): + target = {} + target["class_labels"] = torch.randint( + high=self.num_labels, size=(self.n_targets,), device=torch_device + ) + target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device) + labels.append(target) + + config = self.get_config() + config.num_labels = self.num_labels + return config, pixel_values, pixel_mask, labels + + def get_config(self): + backbone_config = HGNetV2Config( + stage_in_channels=[16, 64, 128], + stage_mid_channels=[16, 32, 64], + stage_out_channels=[64, 128, 256], + stage_num_blocks=[1, 1, 1], + stage_downsample=[False, True, True], + stage_light_block=[False, False, True], + stage_kernel_size=[3, 3, 3], + stage_numb_of_layers=[3, 3, 3], + embeddings_size=10, + hidden_sizes=[64, 128, 256], + depths=[1, 1, 1], + out_features=["stage3"], + out_indices=[3], + stem_channels=[3, 16, 16], + use_lab=True, + ) + return Deimv2Config( + backbone_config=backbone_config, + encoder_hidden_dim=self.encoder_hidden_dim, + encoder_in_channels=self.encoder_in_channels, + feat_strides=self.feat_strides, + encoder_layers=self.encoder_layers, + encoder_ffn_dim=self.encoder_ffn_dim, + encoder_attention_heads=self.encoder_attention_heads, + dropout=self.dropout, + activation_dropout=self.activation_dropout, + encode_proj_layers=self.encode_proj_layers, + positional_encoding_temperature=self.positional_encoding_temperature, + encoder_activation_function=self.encoder_activation_function, + activation_function=self.activation_function, + eval_size=self.eval_size, + normalize_before=self.normalize_before, + d_model=self.d_model, + num_queries=self.num_queries, + decoder_in_channels=self.decoder_in_channels, + decoder_ffn_dim=self.decoder_ffn_dim, + num_feature_levels=self.num_feature_levels, + decoder_n_points=self.decoder_n_points, + decoder_n_levels=self.decoder_n_levels, + decoder_layers=self.decoder_layers, + decoder_attention_heads=self.decoder_attention_heads, + decoder_activation_function=self.decoder_activation_function, + decoder_offset_scale=self.decoder_offset_scale, + eval_idx=self.eval_idx, + layer_scale=self.layer_scale, + reg_max=self.reg_max, + reg_scale=self.reg_scale, + depth_mult=self.depth_mult, + attention_dropout=self.attention_dropout, + num_denoising=self.num_denoising, + label_noise_ratio=self.label_noise_ratio, + box_noise_scale=self.box_noise_scale, + learn_initial_query=self.learn_initial_query, + anchor_image_size=self.anchor_image_size, + image_size=self.image_size, + disable_custom_kernels=self.disable_custom_kernels, + with_box_refine=self.with_box_refine, + encoder_type="lite", + use_gateway=False, + share_bbox_head=False, + ) + + def prepare_config_and_inputs_for_common(self): + config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs() + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + def create_and_check_deimv2_model(self, config, pixel_values, pixel_mask, labels): + model = Deimv2Model(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.d_model)) + + def create_and_check_deimv2_object_detection_head_model(self, config, pixel_values, pixel_mask, labels): + model = Deimv2ForObjectDetection(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels) + + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + +@require_torch +class Deimv2LiteEncoderModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (Deimv2Model, Deimv2ForObjectDetection) if is_torch_available() else () + pipeline_model_mapping = ( + {"image-feature-extraction": Deimv2Model, "object-detection": Deimv2ForObjectDetection} + if is_torch_available() + else {} + ) + is_encoder_decoder = True + test_resize_embeddings = False + has_attentions = False + + test_missing_keys = False + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class.__name__ == "Deimv2ForObjectDetection": + labels = [] + for i in range(self.model_tester.batch_size): + target = {} + target["class_labels"] = torch.ones( + size=(self.model_tester.n_targets,), device=torch_device, dtype=torch.long + ) + target["boxes"] = torch.ones( + self.model_tester.n_targets, 4, device=torch_device, dtype=torch.float + ) + labels.append(target) + inputs_dict["labels"] = labels + + return inputs_dict + + def setUp(self): + self.model_tester = Deimv2LiteEncoderModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=Deimv2Config, + has_text_modality=False, + common_properties=["hidden_size", "num_attention_heads"], + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_deimv2_lite_encoder_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deimv2_model(*config_and_inputs) + + def test_deimv2_lite_encoder_object_detection_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deimv2_object_detection_head_model(*config_and_inputs) + + @unittest.skip(reason="Multi-scale deformable attention is incompatible with nn.DataParallel") + def test_multi_gpu_data_parallel_forward(self): + pass + + @unittest.skip( + reason="Deimv2 is a vision model but inputs_embeds is in the forward signature (inherited from D-FINE)" + ) + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Forward signature has inputs_embeds but no input_ids") + def test_inputs_embeds_matches_input_ids(self): + pass + + @unittest.skip(reason="Base test asserts get_input_embeddings() returns nn.Embedding which vision models lack") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="Decoder heads are shared via reference assignment so untied saving is not applicable") + def test_load_save_without_tied_weights(self): + pass + + @unittest.skip( + reason="LiteEncoder has no encoder_hidden_states so the base test fails accessing encoder_hidden_states[0]" + ) + def test_retain_grad_hidden_states_attentions(self): + pass + + # Override: LiteEncoder has no encoder hidden states, only decoder hidden states + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if config.is_encoder_decoder: + hidden_states = outputs.decoder_hidden_states + + expected_num_layers = self.model_tester.decoder_layers + 1 + + self.assertIsInstance(hidden_states, (list, tuple)) + self.assertEqual(len(hidden_states), expected_num_layers) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [self.model_tester.num_queries, self.model_tester.d_model], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + # Override: Deimv2 uses pixel_values as main input, not input_ids + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + arg_names = [*signature.parameters.keys()] + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + +class Deimv2DINOv3ModelTester: + def __init__( + self, + parent, + batch_size=3, + is_training=True, + use_labels=True, + n_targets=3, + num_labels=10, + initializer_range=0.02, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + encoder_hidden_dim=32, + encoder_in_channels=[32, 32, 32], + feat_strides=[8, 16, 32], + encoder_layers=1, + encoder_ffn_dim=64, + encoder_attention_heads=2, + dropout=0.0, + activation_dropout=0.0, + encode_proj_layers=[2], + positional_encoding_temperature=10000, + encoder_activation_function="gelu", + activation_function="silu", + eval_size=None, + normalize_before=False, + d_model=32, + num_queries=30, + decoder_in_channels=[32, 32, 32], + decoder_ffn_dim=64, + num_feature_levels=3, + decoder_n_points=4, + decoder_n_levels=3, + decoder_layers=2, + decoder_attention_heads=2, + decoder_activation_function="relu", + attention_dropout=0.0, + num_denoising=0, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_initial_query=False, + anchor_image_size=None, + image_size=64, + disable_custom_kernels=True, + with_box_refine=True, + decoder_offset_scale=0.5, + eval_idx=-1, + layer_scale=1, + reg_max=32, + reg_scale=4.0, + depth_mult=0.34, + hidden_expansion=0.5, + sta_inplanes=8, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = 3 + self.is_training = is_training + self.use_labels = use_labels + self.n_targets = n_targets + self.num_labels = num_labels + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_in_channels = encoder_in_channels + self.feat_strides = feat_strides + self.encoder_layers = encoder_layers + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.dropout = dropout + self.activation_dropout = activation_dropout + self.encode_proj_layers = encode_proj_layers + self.positional_encoding_temperature = positional_encoding_temperature + self.encoder_activation_function = encoder_activation_function + self.activation_function = activation_function + self.eval_size = eval_size + self.normalize_before = normalize_before + self.d_model = d_model + self.num_queries = num_queries + self.decoder_in_channels = decoder_in_channels + self.decoder_ffn_dim = decoder_ffn_dim + self.num_feature_levels = num_feature_levels + self.decoder_n_points = decoder_n_points + self.decoder_n_levels = decoder_n_levels + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_activation_function = decoder_activation_function + self.attention_dropout = attention_dropout + self.decoder_offset_scale = decoder_offset_scale + self.eval_idx = eval_idx + self.layer_scale = layer_scale + self.reg_max = reg_max + self.reg_scale = reg_scale + self.depth_mult = depth_mult + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + self.learn_initial_query = learn_initial_query + self.anchor_image_size = anchor_image_size + self.image_size = image_size + self.disable_custom_kernels = disable_custom_kernels + self.with_box_refine = with_box_refine + self.hidden_expansion = hidden_expansion + self.sta_inplanes = sta_inplanes + + self.encoder_seq_length = math.ceil(self.image_size / 32) * math.ceil(self.image_size / 32) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + pixel_mask = torch.ones([self.batch_size, self.image_size, self.image_size], device=torch_device) + + labels = None + if self.use_labels: + labels = [] + for i in range(self.batch_size): + target = {} + target["class_labels"] = torch.randint( + high=self.num_labels, size=(self.n_targets,), device=torch_device + ) + target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device) + labels.append(target) + + config = self.get_config() + config.num_labels = self.num_labels + return config, pixel_values, pixel_mask, labels + + def get_config(self): + backbone_config = DINOv3ViTConfig( + hidden_size=32, + num_attention_heads=2, + num_hidden_layers=4, + intermediate_size=64, + num_register_tokens=1, + layerscale_value=1.0, + use_gated_mlp=False, + rope_theta=100.0, + patch_size=16, + image_size=self.image_size, + out_indices=[2, 3, 4], + apply_layernorm=False, + reshape_hidden_states=True, + ) + return Deimv2Config( + backbone_config=backbone_config, + encoder_hidden_dim=self.encoder_hidden_dim, + encoder_in_channels=self.encoder_in_channels, + feat_strides=self.feat_strides, + encoder_layers=self.encoder_layers, + encoder_ffn_dim=self.encoder_ffn_dim, + encoder_attention_heads=self.encoder_attention_heads, + dropout=self.dropout, + activation_dropout=self.activation_dropout, + encode_proj_layers=self.encode_proj_layers, + positional_encoding_temperature=self.positional_encoding_temperature, + encoder_activation_function=self.encoder_activation_function, + activation_function=self.activation_function, + eval_size=self.eval_size, + normalize_before=self.normalize_before, + d_model=self.d_model, + num_queries=self.num_queries, + decoder_in_channels=self.decoder_in_channels, + decoder_ffn_dim=self.decoder_ffn_dim, + num_feature_levels=self.num_feature_levels, + decoder_n_points=self.decoder_n_points, + decoder_n_levels=self.decoder_n_levels, + decoder_layers=self.decoder_layers, + decoder_attention_heads=self.decoder_attention_heads, + decoder_activation_function=self.decoder_activation_function, + decoder_offset_scale=self.decoder_offset_scale, + eval_idx=self.eval_idx, + layer_scale=self.layer_scale, + reg_max=self.reg_max, + reg_scale=self.reg_scale, + depth_mult=self.depth_mult, + attention_dropout=self.attention_dropout, + num_denoising=self.num_denoising, + label_noise_ratio=self.label_noise_ratio, + box_noise_scale=self.box_noise_scale, + learn_initial_query=self.learn_initial_query, + anchor_image_size=self.anchor_image_size, + image_size=self.image_size, + disable_custom_kernels=self.disable_custom_kernels, + with_box_refine=self.with_box_refine, + sta_inplanes=self.sta_inplanes, + encoder_has_trailing_conv=False, + ) + + def prepare_config_and_inputs_for_common(self): + config, pixel_values, pixel_mask, labels = self.prepare_config_and_inputs() + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + def create_and_check_deimv2_model(self, config, pixel_values, pixel_mask, labels): + model = Deimv2Model(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.num_queries, self.d_model)) + + def create_and_check_deimv2_object_detection_head_model(self, config, pixel_values, pixel_mask, labels): + model = Deimv2ForObjectDetection(config=config) + model.to(torch_device) + model.eval() + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask) + result = model(pixel_values) + + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels) + + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) + self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) + + +@require_torch +class Deimv2DINOv3ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (Deimv2Model, Deimv2ForObjectDetection) if is_torch_available() else () + pipeline_model_mapping = ( + {"image-feature-extraction": Deimv2Model, "object-detection": Deimv2ForObjectDetection} + if is_torch_available() + else {} + ) + is_encoder_decoder = True + test_resize_embeddings = False + + test_missing_keys = False + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class.__name__ == "Deimv2ForObjectDetection": + labels = [] + for i in range(self.model_tester.batch_size): + target = {} + target["class_labels"] = torch.ones( + size=(self.model_tester.n_targets,), device=torch_device, dtype=torch.long + ) + target["boxes"] = torch.ones( + self.model_tester.n_targets, 4, device=torch_device, dtype=torch.float + ) + labels.append(target) + inputs_dict["labels"] = labels + + return inputs_dict + + def setUp(self): + self.model_tester = Deimv2DINOv3ModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=Deimv2Config, + has_text_modality=False, + common_properties=["hidden_size", "num_attention_heads"], + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_deimv2_dinov3_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deimv2_model(*config_and_inputs) + + def test_deimv2_dinov3_object_detection_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deimv2_object_detection_head_model(*config_and_inputs) + + @unittest.skip(reason="Multi-scale deformable attention is incompatible with nn.DataParallel") + def test_multi_gpu_data_parallel_forward(self): + pass + + @unittest.skip( + reason="Deimv2 is a vision model but inputs_embeds is in the forward signature (inherited from D-FINE)" + ) + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Forward signature has inputs_embeds but no input_ids") + def test_inputs_embeds_matches_input_ids(self): + pass + + @unittest.skip(reason="Base test asserts get_input_embeddings() returns nn.Embedding which vision models lack") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="Decoder heads are shared via reference assignment so untied saving is not applicable") + def test_load_save_without_tied_weights(self): + pass + + @unittest.skip(reason="DINOv3 RoPE with dynamic interpolation causes torch.compile inductor overflow") + def test_sdpa_can_compile_dynamic(self): + pass + + # Override: DINOv3 backbone requires wider tolerances for SDPA vs eager comparison + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) + def test_eager_matches_sdpa_inference( + self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels + ): + atols = { + ("cpu", False, torch.float32): 1e-4, + ("cpu", False, torch.float16): 5e-3, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-4, + ("cpu", True, torch.float16): 5e-3, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-4, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-4, + ("cuda", True, torch.bfloat16): 1e-2, + ("cuda", True, torch.float16): 5e-3, + } + rtols = { + ("cpu", False, torch.float32): 1e-3, + ("cpu", False, torch.float16): 5e-3, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-3, + ("cpu", True, torch.float16): 5e-3, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-3, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-3, + ("cuda", True, torch.bfloat16): 3e-2, + ("cuda", True, torch.float16): 5e-3, + } + _test_eager_matches_sdpa_inference( + self, + name, + torch_dtype, + padding_side, + use_attention_mask, + output_attentions, + enable_kernels, + atols=atols, + rtols=rtols, + ) + + # Override: DINOv3 backbone numerical precision requires wider tolerances + def test_batching_equivalence(self): + super().test_batching_equivalence(atol=1e-4, rtol=1e-4) + + @unittest.skip(reason="Flex attention test requires decoder_input_ids which detection models don't have") + def test_flex_attention_with_grads(self): + pass + + # Override: Multi-scale deformable attention outputs have different shapes than standard self-attention + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class._from_config(config, attn_implementation="eager") + config = model.config + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions + self.assertEqual(len(attentions), self.model_tester.encoder_layers) + + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions + self.assertEqual(len(attentions), self.model_tester.encoder_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [ + self.model_tester.encoder_attention_heads, + self.model_tester.encoder_seq_length, + self.model_tester.encoder_seq_length, + ], + ) + out_len = len(outputs) + + correct_outlen = 15 + + if "labels" in inputs_dict: + correct_outlen += 1 + if model_class.__name__ == "Deimv2ForObjectDetection": + correct_outlen += 2 + + self.assertEqual(out_len, correct_outlen) + + decoder_attentions = outputs.decoder_attentions + self.assertIsInstance(decoder_attentions, (list, tuple)) + self.assertEqual(len(decoder_attentions), self.model_tester.decoder_layers) + self.assertListEqual( + list(decoder_attentions[0].shape[-3:]), + [ + self.model_tester.decoder_attention_heads, + self.model_tester.num_queries, + self.model_tester.num_queries, + ], + ) + + cross_attentions = outputs.cross_attentions + self.assertIsInstance(cross_attentions, (list, tuple)) + self.assertEqual(len(cross_attentions), self.model_tester.decoder_layers) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [ + self.model_tester.num_queries, + self.model_tester.decoder_attention_heads, + self.model_tester.decoder_n_levels * self.model_tester.decoder_n_points + if isinstance(self.model_tester.decoder_n_points, int) + else sum(self.model_tester.decoder_n_points), + ], + ) + + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + else: + added_hidden_states = 2 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions + + self.assertEqual(len(self_attentions), self.model_tester.encoder_layers) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [ + self.model_tester.encoder_attention_heads, + self.model_tester.encoder_seq_length, + self.model_tester.encoder_seq_length, + ], + ) + + # Override: Encoder hidden states are multi-scale feature maps, not a standard sequence of layer outputs + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", len(self.model_tester.encoder_in_channels) - 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + self.assertListEqual( + list(hidden_states[1].shape[-2:]), + [ + self.model_tester.image_size // self.model_tester.feat_strides[-1], + self.model_tester.image_size // self.model_tester.feat_strides[-1], + ], + ) + + if config.is_encoder_decoder: + hidden_states = outputs.decoder_hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.decoder_layers + 1 + ) + + self.assertIsInstance(hidden_states, (list, tuple)) + self.assertEqual(len(hidden_states), expected_num_layers) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [self.model_tester.num_queries, self.model_tester.d_model], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + # Override: Custom gradient retention check for multi-scale deformable attention outputs + def test_retain_grad_hidden_states_attentions(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = True + + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + inputs = self._prepare_for_class(inputs_dict, model_class) + + outputs = model(**inputs) + + output = outputs[0] + + encoder_hidden_states = outputs.encoder_hidden_states[0] + encoder_attentions = outputs.encoder_attentions[0] + encoder_hidden_states.retain_grad() + encoder_attentions.retain_grad() + + decoder_attentions = outputs.decoder_attentions[0] + decoder_attentions.retain_grad() + + cross_attentions = outputs.cross_attentions[0] + cross_attentions.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(encoder_hidden_states.grad) + self.assertIsNotNone(encoder_attentions.grad) + self.assertIsNotNone(decoder_attentions.grad) + self.assertIsNotNone(cross_attentions.grad) + + # Override: Deimv2 uses pixel_values as main input, not input_ids + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + arg_names = [*signature.parameters.keys()] + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + @parameterized.expand(["float32", "float16", "bfloat16"]) + @require_torch_accelerator + @slow + def test_inference_with_different_dtypes(self, dtype_str): + dtype = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[dtype_str] + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device).to(dtype) + model.eval() + for key, tensor in inputs_dict.items(): + if tensor.dtype == torch.float32: + inputs_dict[key] = tensor.to(dtype) + with torch.no_grad(): + _ = model(**self._prepare_for_class(inputs_dict, model_class)) + + @parameterized.expand(["float32", "float16", "bfloat16"]) + @require_torch_accelerator + @slow + def test_inference_equivalence_for_static_and_dynamic_anchors(self, dtype_str): + dtype = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[dtype_str] + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + h, w = inputs_dict["pixel_values"].shape[-2:] + + for key, tensor in inputs_dict.items(): + if tensor.dtype == torch.float32: + inputs_dict[key] = tensor.to(dtype) + + for model_class in self.all_model_classes: + with tempfile.TemporaryDirectory() as tmpdirname: + model_class(config).save_pretrained(tmpdirname) + model_static = model_class.from_pretrained( + tmpdirname, anchor_image_size=[h, w], device_map=torch_device, dtype=dtype + ).eval() + model_dynamic = model_class.from_pretrained( + tmpdirname, anchor_image_size=None, device_map=torch_device, dtype=dtype + ).eval() + + self.assertIsNotNone(model_static.config.anchor_image_size) + self.assertIsNone(model_dynamic.config.anchor_image_size) + + with torch.no_grad(): + outputs_static = model_static(**self._prepare_for_class(inputs_dict, model_class)) + outputs_dynamic = model_dynamic(**self._prepare_for_class(inputs_dict, model_class)) + + torch.testing.assert_close( + outputs_static.last_hidden_state, outputs_dynamic.last_hidden_state, rtol=5e-3, atol=5e-3 + ) + + +def prepare_img(): + from transformers.image_utils import load_image + + url = url_to_local_path("http://images.cocodataset.org/val2017/000000039769.jpg") + return load_image(url) + + +@require_torch +@require_vision +@slow +class Deimv2ModelIntegrationTest(unittest.TestCase): + @cached_property + def default_image_processor(self): + return AutoImageProcessor.from_pretrained(CHECKPOINT, use_fast=False) + + def test_inference_object_detection_head(self): + model = Deimv2ForObjectDetection.from_pretrained(CHECKPOINT).to(torch_device) + image_processor = self.default_image_processor + image = prepare_img() + inputs = image_processor(images=image, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + + expected_shape_logits = torch.Size((1, 300, model.config.num_labels)) + self.assertEqual(outputs.logits.shape, expected_shape_logits) + + expected_logits = torch.tensor( + [[-4.0859, -6.9373, -5.4723], [-5.5887, -6.0078, -6.4360], [-6.1448, -6.8509, -6.8703]] + ).to(torch_device) + expected_boxes = torch.tensor( + [[0.1886, 0.1662, 0.2875], [0.0690, 0.1814, 0.9368], [0.2510, 0.2141, 0.9115]] + ).to(torch_device) + + torch.testing.assert_close(outputs.logits[0, :3, :3], expected_logits, atol=2e-4, rtol=2e-4) + + expected_shape_boxes = torch.Size((1, 300, 4)) + self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes) + torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_boxes, atol=2e-4, rtol=2e-4) + + results = image_processor.post_process_object_detection( + outputs, threshold=0.0, target_sizes=[image.size[::-1]] + )[0] + + expected_scores = torch.tensor([0.7606, 0.3165, 0.2726, 0.2488], device=torch_device) + expected_labels = [65, 65, 15, 59] + expected_slice_boxes = torch.tensor( + [ + [4.0781e01, 6.8216e01, 1.7560e02, 1.1085e02], + [4.8195e01, 7.5405e01, 2.1123e02, 9.1451e01], + [1.1296e01, 6.8089e01, 6.1285e02, 4.0393e02], + [1.9821e01, -9.0347e01, 7.0787e02, 3.7968e02], + ], + device=torch_device, + ) + + torch.testing.assert_close(results["scores"][:4], expected_scores, atol=1e-3, rtol=1e-4) + self.assertSequenceEqual(results["labels"][:4].tolist(), expected_labels) + torch.testing.assert_close(results["boxes"][:4], expected_slice_boxes[:4], atol=5e-3, rtol=5e-4) + + +@require_torch +@require_vision +@slow +class Deimv2LiteEncoderIntegrationTest(unittest.TestCase): + @cached_property + def default_image_processor(self): + return AutoImageProcessor.from_pretrained(CHECKPOINT_LITE, use_fast=False) + + def test_inference_object_detection_head(self): + model = Deimv2ForObjectDetection.from_pretrained(CHECKPOINT_LITE).to(torch_device) + image_processor = self.default_image_processor + image = prepare_img() + inputs = image_processor(images=image, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + + expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels)) + self.assertEqual(outputs.logits.shape, expected_shape_logits) + + expected_logits = torch.tensor( + [[-2.6151, -6.4701, -6.3505], [-3.8592, -6.2610, -7.2720], [-2.3801, -4.3216, -3.5101]] + ).to(torch_device) + expected_boxes = torch.tensor( + [[0.7994, 0.2984, 0.3822], [0.5536, 0.5362, 0.0392], [0.3501, 0.4577, 0.7440]] + ).to(torch_device) + + torch.testing.assert_close(outputs.logits[0, :3, :3], expected_logits, atol=2e-4, rtol=2e-4) + + expected_shape_boxes = torch.Size((1, model.config.num_queries, 4)) + self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes) + torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_boxes, atol=2e-4, rtol=2e-4) + + +@require_torch +@require_vision +@slow +class Deimv2DINOv3IntegrationTest(unittest.TestCase): + @cached_property + def default_image_processor(self): + return AutoImageProcessor.from_pretrained(CHECKPOINT_DINOV3, use_fast=False) + + def test_inference_object_detection_head(self): + model = Deimv2ForObjectDetection.from_pretrained(CHECKPOINT_DINOV3).to(torch_device) + image_processor = self.default_image_processor + image = prepare_img() + inputs = image_processor(images=image, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + + expected_shape_logits = torch.Size((1, 300, model.config.num_labels)) + self.assertEqual(outputs.logits.shape, expected_shape_logits) + + expected_logits = torch.tensor( + [[-2.1404, -2.8207, -3.2710], [-2.3058, -2.7178, -3.2924], [-3.2780, -4.0269, -4.6266]] + ).to(torch_device) + expected_boxes = torch.tensor( + [[0.5258, 0.7694, 0.7997], [0.3734, 0.1949, 0.7989], [0.5082, 0.5847, 0.8590]] + ).to(torch_device) + + torch.testing.assert_close(outputs.logits[0, :3, :3], expected_logits, atol=2e-4, rtol=2e-4) + + expected_shape_boxes = torch.Size((1, 300, 4)) + self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes) + torch.testing.assert_close(outputs.pred_boxes[0, :3, :3], expected_boxes, atol=2e-4, rtol=2e-4) diff --git a/tests/models/gemma3n/test_modeling_gemma3n.py b/tests/models/gemma3n/test_modeling_gemma3n.py index 0d6d7e0446d0..65a622163c88 100644 --- a/tests/models/gemma3n/test_modeling_gemma3n.py +++ b/tests/models/gemma3n/test_modeling_gemma3n.py @@ -993,7 +993,7 @@ def test_model_4b_bf16(self): output_text = self.processor.batch_decode(output, skip_special_tokens=True) EXPECTED_TEXTS = Expectations({ ("cuda", None): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a clear blue ocean. The cow is facing the viewer with its head slightly'], - ("rocm", (9, 4)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean under a clear blue sky. The cow is facing the viewer'], + ("rocm", (9, 4)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a clear blue ocean. The cow is facing the viewer with its head slightly'], }).get_expectation() # fmt: skip self.assertEqual(output_text, EXPECTED_TEXTS) @@ -1077,7 +1077,7 @@ def test_model_4b_batch(self): output_text = self.processor.batch_decode(output, skip_special_tokens=True) EXPECTED_TEXTS = Expectations({ ("cuda", None): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a clear blue ocean. The cow is facing the viewer with its head slightly', "user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Subject:** The first image features a cow"], - ("rocm", (9, 4)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean under a clear blue sky. The cow is facing the viewer', "user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Subject Matter:** The first image shows a"], + ("rocm", (9, 4)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean. The sky is blue with a few white clouds. The', "user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Subject:** The first image features a cow"], ("xpu", None): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean. The cow is facing the viewer with its head slightly turned', "user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Subject:** The first image features a cow"], }).get_expectation() # fmt: skip self.assertEqual(output_text, EXPECTED_TEXTS) @@ -1104,7 +1104,7 @@ def test_model_4b_image(self): EXPECTED_TEXTS = Expectations({ ("cuda", None): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a clear blue ocean. The cow is facing the viewer with its head slightly'], ("xpu", None): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a clear blue ocean. The cow is facing the viewer with its head slightly'], - ("rocm", (9, 4)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean under a clear blue sky. The cow is facing the viewer'], + ("rocm", (9, 4)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach next to a clear blue ocean. The cow is facing the viewer with its head slightly'], }).get_expectation() # fmt: skip self.assertEqual(len(inputs["pixel_values"]), EXPECTED_NUM_IMAGES) self.assertEqual(output_text, EXPECTED_TEXTS) @@ -1146,7 +1146,7 @@ def test_model_4b_multiimage(self): EXPECTED_TEXTS = Expectations({ ("cuda", None): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nIn the image, I see a street scene in what appears to be a Chinatown district. Here are some of the key elements:\n\n* **A'], ("xpu", None): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nIn the image, I see a street scene in what appears to be a Chinatown district. Here are the key elements:\n\n* **A prominent red'], - ("rocm", (9, 4)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nIn the image, I see a street scene in what appears to be a Chinatown district. \n\nHere are some key elements:\n\n* **A'], + ("rocm", (9, 4)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nIn the image, I see a street scene in what appears to be a Chinatown district. Here are some of the key elements:\n\n* **A'], }).get_expectation() # fmt: skip self.assertEqual(output_text, EXPECTED_TEXTS) @@ -1191,7 +1191,7 @@ def test_generation_beyond_sliding_window(self): EXPECTED_COMPLETIONS = Expectations({ ("cuda", None): [" and the people are so friendly. I'm so glad I came here. I'm so", ", green, yellow, orange, purple, pink, brown, black, white.\n\nHere'"], - ("rocm", (9, 4)): [" and the food is delicious. I'm so glad I came here. I'm so glad", ", green, yellow, orange, purple, pink, brown, black, white.\n\nHere'"], + ("rocm", (9, 4)): [' and the food is delicious. The staff is friendly and helpful. The atmosphere is relaxed and welcoming.', ", green, yellow, orange, purple, pink, brown, black, white.\n\nHere'"], }).get_expectation() # fmt: skip self.assertEqual(output_text, EXPECTED_COMPLETIONS) diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py index ebcc08816d95..b188b4f9a0c3 100644 --- a/tests/quantization/torchao_integration/test_torchao.py +++ b/tests/quantization/torchao_integration/test_torchao.py @@ -39,6 +39,7 @@ from torchao.dtypes import ( AffineQuantizedTensor, ) + from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, Float8Tensor, @@ -587,13 +588,14 @@ class TorchAoSerializationTest(unittest.TestCase): test_params = ( [ - (Int8WeightOnlyConfig(version=2), ALL_DEVICES_COMMON), - (Int8DynamicActivationInt8WeightConfig(version=2), ALL_DEVICES_COMMON), - (Float8DynamicActivationFloat8WeightConfig(), Expectations({("cuda", None): "What are we having for dinner?\n\nJess: (smiling) I", ("xpu", None): "What are we having for dinner?\n\nJess: (smiling) I"})), - (Float8WeightOnlyConfig(), Expectations({("cuda", None): COMMON_OUTPUT, ("xpu", None): COMMON_OUTPUT})), - (Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"), Expectations({("cuda", None): "What are we having for dinner?\nRed, white, and green beans,", ("xpu", None): COMMON_OUTPUT})), - (Int8DynamicActivationIntxWeightConfig(), Expectations({("cpu", None): COMMON_OUTPUT, ("cuda", 9): COMMON_OUTPUT, ("cuda", 8): "What are we having for dinner?\n\nJEN: (smiling) I", ("xpu", None): COMMON_OUTPUT})), - (IntxWeightOnlyConfig(), ALL_DEVICES_COMMON), + ("Int8WeightOnlyConfig", Int8WeightOnlyConfig(version=2), ALL_DEVICES_COMMON), + ("Int8DynamicActivationInt8WeightConfig", Int8DynamicActivationInt8WeightConfig(version=2), ALL_DEVICES_COMMON), + ("Float8DynamicActivationFloat8WeightConfig", Float8DynamicActivationFloat8WeightConfig(), Expectations({("cuda", None): COMMON_OUTPUT, ("xpu", None): "What are we having for dinner?\n\nJess: (smiling) I"})), + ("Float8WeightOnlyConfig", Float8WeightOnlyConfig(), Expectations({("cuda", None): COMMON_OUTPUT, ("xpu", None): COMMON_OUTPUT})), + ("Int4WeightOnlyConfig", Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"), Expectations({("cuda", None): "What are we having for dinner?\nRed, white, and green beans,", ("xpu", None): COMMON_OUTPUT})), + ("Int8DynamicActivationIntxWeightConfig", Int8DynamicActivationIntxWeightConfig(), Expectations({("cpu", None): COMMON_OUTPUT, ("cuda", 9): COMMON_OUTPUT, ("cuda", 8): "What are we having for dinner?\n\nJEN: (smiling) I", ("xpu", None): COMMON_OUTPUT})), + ("IntxWeightOnlyConfig", IntxWeightOnlyConfig(), ALL_DEVICES_COMMON), + ("NVFP4DynamicActivationNVFP4WeightConfig", NVFP4DynamicActivationNVFP4WeightConfig(), Expectations({("cuda", None): "What are we having for dinner?\n\n10. Avoid using \"I"})), ] if is_torchao_available() else [] @@ -609,8 +611,12 @@ def _check_serialization(self, device, config, expected_output): if isinstance(config, (Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig)): if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 9): self.skipTest(f"{type(config).__name__} requires CUDA capability >= (8, 9)") + if isinstance(config, NVFP4DynamicActivationNVFP4WeightConfig): + if torch.cuda.is_available() and torch.cuda.get_device_capability() < (10, 0): + self.skipTest(f"{type(config).__name__} requires CUDA capability >= (10, 0) (SM100)") quant_config = TorchAoConfig(config) - dtype = torch.bfloat16 if isinstance(config, Int4WeightOnlyConfig) else "auto" + needs_bfloat16 = isinstance(config, Int4WeightOnlyConfig | NVFP4DynamicActivationNVFP4WeightConfig) + dtype = torch.bfloat16 if needs_bfloat16 else "auto" quantized_model = AutoModelForCausalLM.from_pretrained( self.model_name, dtype=dtype, @@ -629,7 +635,7 @@ def _check_serialization(self, device, config, expected_output): self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), expected_output) @parameterized.expand(test_params, skip_on_empty=True) - def test_serialization_cpu(self, config, expected_outputs): + def test_serialization_cpu(self, _name, config, expected_outputs): try: expected = expected_outputs.find_expectation(("cpu", None, None)) except ValueError: @@ -638,7 +644,7 @@ def test_serialization_cpu(self, config, expected_outputs): @parameterized.expand(test_params, skip_on_empty=True) @require_torch_accelerator - def test_serialization_accelerator(self, config, expected_outputs): + def test_serialization_accelerator(self, _name, config, expected_outputs): try: expected = expected_outputs.get_expectation() except ValueError: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index b909212b62cd..bc8f65891445 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -51,7 +51,11 @@ is_deepspeed_zero3_enabled, unset_hf_deepspeed_config, ) -from transformers.integrations.moe import batched_mm_experts_forward, grouped_mm_experts_forward +from transformers.integrations.moe import ( + batched_mm_experts_forward, + grouped_mm_experts_forward, + sonicmoe_experts_forward, +) from transformers.modeling_layers import GradientCheckpointingLayer from transformers.modeling_utils import FLASH_ATTN_KERNEL_FALLBACK, _get_tied_weight_keys from transformers.models.auto import get_values @@ -110,6 +114,7 @@ GENERATION_CONFIG_NAME, SAFE_WEIGHTS_NAME, ModelOutput, + is_kernels_available, is_torch_bf16_available_on_device, is_torch_fp16_available_on_device, ) @@ -576,59 +581,49 @@ def _test_eager_matches_batched_and_grouped_inference(self, name, dtype): model.save_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname).eval().to(torch_device).to(dtype) - with torch.no_grad(): - inputs_dict = {k: v.to(dtype) if torch.is_floating_point(v) else v for k, v in inputs_dict.items()} - prepared_inputs = self._prepare_for_class(inputs_dict, model_class) - - mock_batched_mm_forward = Mock(wraps=batched_mm_experts_forward) - mock_grouped_mm_forward = Mock(wraps=grouped_mm_experts_forward) - with ( - # This is needed because we call the functions through the interface's global mapping - patch.dict( - "transformers.integrations.moe.ALL_EXPERTS_FUNCTIONS._global_mapping", - {"batched_mm": mock_batched_mm_forward, "grouped_mm": mock_grouped_mm_forward}, - ), - ): - model.set_experts_implementation("eager") - self.assertEqual(model.config._experts_implementation, "eager") - outputs_eager = model(**prepared_inputs) - mock_batched_mm_forward.assert_not_called() - mock_grouped_mm_forward.assert_not_called() + inputs_dict = {k: v.to(dtype) if torch.is_floating_point(v) else v for k, v in inputs_dict.items()} + prepared_inputs = self._prepare_for_class(inputs_dict, model_class) - mock_batched_mm_forward.reset_mock() - mock_grouped_mm_forward.reset_mock() + implementations = ["eager", "batched_mm", "grouped_mm"] + mocks = { + "batched_mm": Mock(wraps=batched_mm_experts_forward), + "grouped_mm": Mock(wraps=grouped_mm_experts_forward), + } - model.set_experts_implementation("batched_mm") - self.assertEqual(model.config._experts_implementation, "batched_mm") - outputs_batched_mm = model(**prepared_inputs) - mock_grouped_mm_forward.assert_not_called() - mock_batched_mm_forward.assert_called() - - mock_batched_mm_forward.reset_mock() - mock_grouped_mm_forward.reset_mock() - - model.set_experts_implementation("grouped_mm") - self.assertEqual(model.config._experts_implementation, "grouped_mm") - outputs_grouped_mm = model(**prepared_inputs) - mock_batched_mm_forward.assert_not_called() - mock_grouped_mm_forward.assert_called() - - mock_batched_mm_forward.reset_mock() - mock_grouped_mm_forward.reset_mock() - - # extract output tensors for comparison - outputs_eager = _get_output_tensors(outputs_eager) - outputs_batched_mm = _get_output_tensors(outputs_batched_mm) - outputs_grouped_mm = _get_output_tensors(outputs_grouped_mm) - - # make sure we have collected some tensors from the outputs - self.assertTrue(outputs_eager, "No outputs from eager implementation") - self.assertTrue(outputs_batched_mm, "No outputs from batched_mm implementation") - self.assertTrue(outputs_grouped_mm, "No outputs from grouped_mm implementation") - - # make sure all implementations give numerically close outputs - torch.testing.assert_close(outputs_eager, outputs_batched_mm, rtol=1e-4, atol=1e-4) - torch.testing.assert_close(outputs_eager, outputs_grouped_mm, rtol=1e-4, atol=1e-4) + if ( + dtype != torch.float32 + and is_kernels_available() + and torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (9, 0) + ): + # we also need nvidia-cutlass-dsl and apache-tvm-ffi + mocks["sonicmoe"] = Mock(wraps=sonicmoe_experts_forward) + implementations.append("sonicmoe") + + outputs = {} + # This is needed because we call the functions through the interface's global mapping + with patch.dict("transformers.integrations.moe.ALL_EXPERTS_FUNCTIONS._global_mapping", mocks): + for impl in implementations: + model.set_experts_implementation(impl) + self.assertEqual(model.config._experts_implementation, impl) + + with torch.no_grad(): + outputs[impl] = _get_output_tensors(model(**prepared_inputs)) + + self.assertTrue(outputs[impl], f"No outputs from {impl} implementation") + + for name, mock in mocks.items(): + if name == impl: + mock.assert_called() + else: + mock.assert_not_called() + + mock.reset_mock() + + # all non-eager implementations must numerically match eager + eager_outputs = outputs.pop("eager") + for impl, impl_outputs in outputs.items(): + torch.testing.assert_close(eager_outputs, impl_outputs, rtol=1e-4, atol=1e-4) def _config_zero_init(config): diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 6a27b6b5e0fb..fab48f9ddb8a 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -2823,6 +2823,20 @@ def test_error_wrong_attn_implementation(self): self.assertTrue('The only possible arguments are `attn_implementation="eager"' in str(cm.exception)) + def test_registered_experts_implementation_is_valid(self): + from transformers.integrations.moe import ALL_EXPERTS_FUNCTIONS + + def custom_experts_forward(*args, **kwargs): + pass + + experts_implementation = "custom_experts" + model = BaseModel(PreTrainedConfig()) + + with patch.dict(ALL_EXPERTS_FUNCTIONS._global_mapping, {}, clear=False): + ALL_EXPERTS_FUNCTIONS.register(experts_implementation, custom_experts_forward) + + self.assertEqual(model.get_correct_experts_implementation(experts_implementation), experts_implementation) + def test_not_available_flash(self): if is_flash_attn_2_available(): self.skipTest(reason="Please uninstall flash-attn package to run test_not_available_flash") diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 502cbf461e45..5368ad0c61bd 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -115,6 +115,7 @@ "MaskFormerDetrConfig": True, "DetrConfig": True, "DFineConfig": True, + "Deimv2Config": True, # Mixed encoder variants (hybrid/lite) + DFine inheritance "GroundingDinoConfig": True, "MMGroundingDinoConfig": True, "RTDetrConfig": True, diff --git a/utils/check_modeling_rules_doc.py b/utils/check_modeling_rules_doc.py index 24e7b17fd925..8eaf8e57012d 100644 --- a/utils/check_modeling_rules_doc.py +++ b/utils/check_modeling_rules_doc.py @@ -13,7 +13,7 @@ # limitations under the License. """ Keep `## Rules reference` section of docs/source/en/modeling_rules.md in sync -with the rules defined in the mlinter package. +with the rules defined in utils/rules.toml via the installed mlinter package. Usage (from the root of the repo): @@ -31,21 +31,22 @@ """ import argparse -import os +from pathlib import Path CHECKER_CONFIG = { "name": "modeling_rules_doc", "label": "Modeling rules documentation", - # Depends on the installed `mlinter` package output, which cannot be expressed - # as repo file globs for the checker cache. + # Depends on utils/rules.toml plus the installed `mlinter` package output, + # which cannot be fully expressed as repo file globs for the checker cache. "file_globs": None, - "check_args": [], - "fix_args": ["--fix_and_overwrite"], + "check_args": ["--rules-toml", "utils/rules.toml"], + "fix_args": ["--rules-toml", "utils/rules.toml", "--fix_and_overwrite"], } -ROOT = os.path.dirname(os.path.dirname(__file__)) -DOC_PATH = os.path.join(ROOT, "docs", "source", "en", "modeling_rules.md") +ROOT = Path(__file__).resolve().parent.parent +DOC_PATH = ROOT / "docs" / "source" / "en" / "modeling_rules.md" +RULES_TOML_PATH = ROOT / "utils" / "rules.toml" BEGIN_MARKER = "" END_MARKER = "" @@ -54,21 +55,29 @@ def _require_mlinter(): try: import mlinter + from mlinter import mlinter as mlinter_impl except ModuleNotFoundError as error: raise ModuleNotFoundError( "This script requires the standalone `transformers-mlinter` package. " 'Install the repo quality dependencies with `pip install -e ".[quality]"` and retry.' ) from error - return mlinter + return mlinter, mlinter_impl -def generate_rules_reference() -> str: - return _require_mlinter().render_rules_reference() +def _resolve_path(path: Path) -> Path: + return path if path.is_absolute() else ROOT / path -def check_modeling_rules_doc(overwrite: bool = False): - with open(DOC_PATH, encoding="utf-8") as f: +def generate_rules_reference(rule_specs_path: Path = RULES_TOML_PATH) -> str: + mlinter, mlinter_impl = _require_mlinter() + # Reuse mlinter's registry-switching helper so docs rendering reflects the repo-local rule file. + with mlinter_impl._using_rule_specs(_resolve_path(rule_specs_path)): + return mlinter.render_rules_reference() + + +def check_modeling_rules_doc(overwrite: bool = False, rule_specs_path: Path = RULES_TOML_PATH): + with DOC_PATH.open(encoding="utf-8") as f: content = f.read() begin_idx = content.find(BEGIN_MARKER) @@ -80,7 +89,7 @@ def check_modeling_rules_doc(overwrite: bool = False): ) after_begin = begin_idx + len(BEGIN_MARKER) - expected = "\n\n" + generate_rules_reference() + "\n" + expected = "\n\n" + generate_rules_reference(rule_specs_path) + "\n" current = content[after_begin:end_idx] if current == expected: @@ -88,22 +97,28 @@ def check_modeling_rules_doc(overwrite: bool = False): if overwrite: new_content = content[:after_begin] + expected + content[end_idx:] - with open(DOC_PATH, "w", encoding="utf-8") as f: + with DOC_PATH.open("w", encoding="utf-8") as f: f.write(new_content) print(f"Updated rules reference in {DOC_PATH}") else: raise ValueError( "The rules reference section in docs/source/en/modeling_rules.md is out of sync " - "with the mlinter package's rules. Run `make fix-repo` to regenerate it." + "with utils/rules.toml. Run `make fix-repo` to regenerate it." ) if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument( + "--rules-toml", + type=Path, + default=RULES_TOML_PATH, + help="Path to a rules TOML file. Defaults to utils/rules.toml.", + ) parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.") args = parser.parse_args() try: - check_modeling_rules_doc(args.fix_and_overwrite) + check_modeling_rules_doc(args.fix_and_overwrite, args.rules_toml) except ModuleNotFoundError as error: raise SystemExit(str(error)) from error diff --git a/utils/check_modeling_structure.py b/utils/check_modeling_structure.py index 447eabf8b8a6..6078672d7349 100644 --- a/utils/check_modeling_structure.py +++ b/utils/check_modeling_structure.py @@ -1,6 +1,23 @@ #!/usr/bin/env python +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Thin local entrypoint for the external mlinter package.""" +import sys +from pathlib import Path + + CHECKER_CONFIG = { "name": "modeling_structure", "label": "Modeling file structure", @@ -9,10 +26,12 @@ "src/transformers/models/**/modular_*.py", "src/transformers/models/**/configuration_*.py", ], - "check_args": [], + "check_args": ["--rules-toml", "utils/rules.toml"], "fix_args": None, } +RULES_TOML_PATH = Path(__file__).resolve().with_name("rules.toml") + def _require_mlinter(): try: @@ -26,8 +45,16 @@ def _require_mlinter(): return mlinter +def _add_default_rules_toml(argv: list[str]) -> list[str]: + if any(arg == "--rules-toml" or arg.startswith("--rules-toml=") for arg in argv[1:]): + return argv + + return [argv[0], "--rules-toml", str(RULES_TOML_PATH), *argv[1:]] + + if __name__ == "__main__": try: + sys.argv = _add_default_rules_toml(sys.argv) raise SystemExit(_require_mlinter().main()) except ModuleNotFoundError as error: raise SystemExit(str(error)) from error diff --git a/utils/rules.toml b/utils/rules.toml new file mode 100644 index 000000000000..1c7de0e729b0 --- /dev/null +++ b/utils/rules.toml @@ -0,0 +1,251 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file can carry repo-local rule overrides for faster iteration between +# `transformers-mlinter` releases. +# Keep it synced with the upstream package's rules.toml when possible so local +# behavior does not drift from the published checker longer than necessary. + +version = 1 + +[rules.TRF001] +description = "Class-level config_class on PreTrainedModel should match Config naming." +default_enabled = true +allowlist_models = ["qwen3_omni_moe"] + +[rules.TRF001.explanation] +what_it_does = "Checks naming consistency between PreTrainedModel and config_class." +why_bad = "Mismatched config_class can break loading, auto classes, and developer expectations." +diff = ''' + class AcmePreTrainedModel(PreTrainedModel): +- config_class = WileConfig ++ config_class = AcmeConfig +''' + +[rules.TRF002] +description = "base_model_prefix should be a non-empty canonical string when defined on PreTrainedModel classes." +default_enabled = true +allowlist_models = ["lighton_ocr"] + +[rules.TRF002.explanation] +what_it_does = "Checks that base_model_prefix, when set, is a non-empty, whitespace-free string literal." +why_bad = "Invalid prefixes can break weight loading key mapping and base model access patterns." +diff = ''' + class AcmePreTrainedModel(PreTrainedModel): +- base_model_prefix = "" ++ base_model_prefix = "model" +''' + +[rules.TRF003] +description = "forward() should use capture_output/can_return_tuple decorators instead of manual return_dict branching." +default_enabled = false +allowlist_models = [] + +[rules.TRF003.explanation] +what_it_does = "Detects forward methods that use the old 'if not return_dict: return (x,)' pattern." +why_bad = "The old return_dict branching pattern is error-prone and verbose. Use the capture_output or can_return_tuple decorators instead." +diff = ''' +-def forward(self, x, return_dict=None): +- if not return_dict: +- return (x,) +- return AcmeModelOutput(last_hidden_state=x) ++@can_return_tuple ++def forward(self, x): ++ return AcmeModelOutput(last_hidden_state=x) +''' + +[rules.TRF004] +description = "Models must never override tie_weights. Use _tied_weights_keys instead." +default_enabled = true +allowlist_models = ["data2vec", "hubert", "sew", "sew_d", "unispeech", "unispeech_sat", "wav2vec2", "wav2vec2_conformer", "wavlm"] + +[rules.TRF004.explanation] +what_it_does = "Checks that no model class defines a tie_weights method." +why_bad = "Overriding tie_weights leads to bad consequences for loading, device_map computation, and saving. Use _tied_weights_keys class attribute to declare tied weights instead." +diff = ''' +-def tie_weights(self): +- self.lm_head.weight = self.emb.weight ++class AcmeForCausalLM(AcmePreTrainedModel): ++ _tied_weights_keys = ["lm_head.weight"] +''' + +[rules.TRF005] +description = "_no_split_modules, when defined, should be a list/tuple of non-empty strings." +default_enabled = true +allowlist_models = ["d_fine", "deformable_detr", "glm46v", "lw_detr", "pp_doclayout_v3", "rt_detr", "rt_detr_v2", "voxtral", "voxtral_realtime"] + +[rules.TRF005.explanation] +what_it_does = "Checks the shape of _no_split_modules when present." +why_bad = "Malformed values can break device-map partitioning and sharding behavior." +diff = ''' +-_no_split_modules = [SomeLayerClass, ""] ++_no_split_modules = ["AcmeDecoderLayer", "AcmeAttention"] +''' + +[rules.TRF006] +description = "forward with cache arguments should reference cache control/state variables consistently." +default_enabled = true +allowlist_models = ["chinese_clip", "evolla", "idefics2", "llama4"] + +[rules.TRF006.explanation] +what_it_does = "Checks forward signatures that expose cache arguments for usage of those arguments in method body." +why_bad = "Unused cache arguments can indicate incomplete caching support and inconsistent API behavior." +diff = ''' + def forward(self, x, past_key_values=None, use_cache=False): ++ if use_cache: ++ ... + return x +''' + +[rules.TRF007] +description = "self.post_init() in __init__ should remain at the end of initialization for PreTrainedModel classes." +default_enabled = true +allowlist_models = ["distilbert", "lxmert", "mt5", "pix2struct", "pop2piano", "switch_transformers", "t5"] + +[rules.TRF007.explanation] +what_it_does = "Checks for self attribute assignments after self.post_init() in __init__." +why_bad = "Mutating model structure after post_init can bypass intended initialization/finalization logic." +diff = ''' + def __init__(self, config): + ... +- self.post_init() +- self.proj = nn.Linear(...) ++ self.proj = nn.Linear(...) ++ self.post_init() +''' + +[rules.TRF008] +description = "Doc decorators on PreTrainedModel classes should avoid empty add_start_docstrings usage." +default_enabled = true + +[rules.TRF008.explanation] +what_it_does = "Checks add_start_docstrings usage on model classes for non-empty docstring arguments." +why_bad = "Empty decorator usage produces unclear docs and weakens generated API documentation quality." +diff = ''' +-@add_start_docstrings("") ++@add_start_docstrings("The Acme model.") + class AcmeModel(AcmePreTrainedModel): + ... +''' + +[rules.TRF009] +description = "modeling_.py should avoid importing implementation code from another model package." +default_enabled = true +allowlist_models = ["dpr", "maskformer", "sam3_video", "vision_text_dual_encoder"] + +[rules.TRF009.explanation] +what_it_does = "Checks modeling files for cross-model imports such as transformers.models.other_model.* or from ..other_model.* imports." +why_bad = "Cross-model implementation imports violate the single-file policy and make model behavior harder to inspect and maintain." +diff = ''' +-from transformers.models.llama.modeling_llama import LlamaAttention ++# Keep implementation local to this file. ++# If reusing code, copy it with a # Copied from comment. +''' + +[rules.TRF010] +description = "Direct config definitions must use @strict(accept_kwargs=True)." +default_enabled = true +allowlist_models = ["nemotron_h", "vibevoice_asr"] + +[rules.TRF010.explanation] +what_it_does = "Checks direct PreTrainedConfig/PretrainedConfig subclasses in configuration_*.py and modular_*.py for an explicit @strict(accept_kwargs=True) decorator." +why_bad = "Without strict, new config classes miss the repo's runtime type-validation contract and drift from the dataclass-based config standard." +diff = ''' ++@strict(accept_kwargs=True) + class AcmeConfig(PreTrainedConfig): + ... +''' + +[rules.TRF011] +description = "forward() must not access non-nn.Module attributes on submodules (breaks pipeline parallelism with Identity replacement)." +default_enabled = true +allowlist_models = [] + +[rules.TRF011.explanation] +what_it_does = "In forward() methods of PreTrainedModel subclasses, checks for attribute accesses on submodules that would not exist on torch.nn.Identity. This includes attribute accesses on loop variables iterating over self.layers, and self.. chains where is not a standard nn.Module attribute." +why_bad = "Pipeline parallelism may replace any submodule with torch.nn.Identity. Accessing custom attributes (e.g. decoder_layer.attention_type) on a replaced module raises AttributeError at runtime. Per-layer metadata should be read from self.config instead." +diff = ''' + def forward(self, ...): +- for decoder_layer in self.layers: ++ for i, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + hidden_states, +- attention_mask=causal_mask_mapping[decoder_layer.attention_type], ++ attention_mask=causal_mask_mapping[self.config.layer_types[i]], + ) +''' + +[rules.TRF012] +description = "_init_weights must use init primitives, not in-place operations on module weights." +default_enabled = true +allowlist_models = [] + +[rules.TRF012.explanation] +what_it_does = "Checks that _init_weights(self, module) does not use in-place operations (e.g. .normal_(), .zero_()) directly on module weights." +why_bad = "We rely on internal flags set on parameters to track whether they need re-initialization. In-place ops bypass this mechanism. Use the `init` primitives instead." +diff = ''' ++from transformers import initialization as init ++ + def _init_weights(self, module): +- module.weight.normal_(mean=0.0, std=0.02) ++ init.normal_(module.weight, mean=0.0, std=0.02) +''' + +[rules.TRF013] +description = "PreTrainedModel __init__ must call self.post_init()." +default_enabled = true +allowlist_models = [] + +[rules.TRF013.explanation] +what_it_does = "Checks that every PreTrainedModel subclass with an __init__ method calls self.post_init(). In modular files, calling super().__init__() is also accepted since it propagates post_init from the parent." +why_bad = "post_init performs essential finalization (weight initialization, gradient checkpointing setup, etc.). Omitting it causes subtle runtime bugs." +diff = ''' + class AcmeModel(AcmePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList(...) ++ self.post_init() +''' + +[rules.TRF014] +description = "`trust_remote_code` should never be used in native model integrations." +default_enabled = true +allowlist_models = [] + +[rules.TRF014.explanation] +what_it_does = "Checks whether `trust_remote_code` is passed or used in code (e.g. as kwarg) within native model integration files." +why_bad = "`trust_remote_code` allows arbitrary loading, including binaries, which should only be a power feature for users, not a standard use-case. Native integrations must not depend on it, as remote code cannot be reviewed or maintained within transformers." +diff = ''' + class AcmeModel(AcmePreTrainedModel): + def __init__(self, config): + super().__init__(config) +- self.model = AutoModel.from_pretrained(..., trust_remote_code=True) ++ self.model = AutoModel.from_pretrained(...) +''' + +[rules.TRF015] +description = "Models with non-empty _tied_weights_keys must have tie_word_embeddings in their Config." +default_enabled = true +allowlist_models = [] + +[rules.TRF015.explanation] +what_it_does = "When a PreTrainedModel subclass defines _tied_weights_keys as a non-empty collection, checks that the corresponding configuration file declares a tie_word_embeddings field." +why_bad = "Without tie_word_embeddings in the config, users cannot control weight tying behavior. The model ties weights unconditionally, breaking serialization round-trips and preventing fine-tuning with untied heads." +diff = ''' + # configuration_foo.py + @strict(accept_kwargs=True) + class FooConfig(PreTrainedConfig): + hidden_size: int = 768 ++ tie_word_embeddings: bool = True +'''