diff --git a/.circleci/create_circleci_config.py b/.circleci/create_circleci_config.py index 70bf7e93fb96..3e50b2cf0e91 100644 --- a/.circleci/create_circleci_config.py +++ b/.circleci/create_circleci_config.py @@ -328,6 +328,15 @@ def job_name(self): parallelism=6, ) +tensor_parallel_ci_job = CircleCIJob( + "tensor_parallel_ci", + additional_env={"RUN_TENSOR_PARALLEL_TESTS": True}, + docker_image=[{"image": "huggingface/transformers-torch-light"}], + install_steps=["uv pip install .", "uv pip install torchao"], + marker="is_tensor_parallel_test", + parallelism=6, +) + # We also include a `dummy.py` file in the files to be doc-tested to prevent edge case failure. Otherwise, the pytest # hangs forever during test collection while showing `collecting 0 items / 21 errors`. (To see this, we have to remove # the bash output redirection.) @@ -358,7 +367,8 @@ def job_name(self): REPO_UTIL_TESTS = [repo_utils_job] DOC_TESTS = [doc_test_job] TRAINING_CI_TESTS = [training_ci_job] -ALL_TESTS = REGULAR_TESTS + EXAMPLES_TESTS + PIPELINE_TESTS + REPO_UTIL_TESTS + DOC_TESTS + [custom_tokenizers_job] + [exotic_models_job] + TRAINING_CI_TESTS # fmt: skip +TENSOR_PARALLEL_CI_TESTS = [tensor_parallel_ci_job] +ALL_TESTS = REGULAR_TESTS + EXAMPLES_TESTS + PIPELINE_TESTS + REPO_UTIL_TESTS + DOC_TESTS + [custom_tokenizers_job] + [exotic_models_job] + TRAINING_CI_TESTS + TENSOR_PARALLEL_CI_TESTS # fmt: skip def create_circleci_config(folder=None): diff --git a/conftest.py b/conftest.py index 4137d0fe7e3d..c194a058b1c4 100644 --- a/conftest.py +++ b/conftest.py @@ -91,6 +91,7 @@ def pytest_configure(config): config.addinivalue_line("markers", "flash_attn_test: mark test which tests flash attention functionality") config.addinivalue_line("markers", "flash_attn_3_test: mark test which tests flash attention 3 functionality") config.addinivalue_line("markers", "training_ci: mark test for training CI validation") + config.addinivalue_line("markers", "tensor_parallel_ci: mark test for tensor parallel CI validation") os.environ["DISABLE_SAFETENSORS_CONVERSION"] = "true" diff --git a/docs/source/en/weightconverter.md b/docs/source/en/weightconverter.md index 4312f1277688..a3aaa443c292 100644 --- a/docs/source/en/weightconverter.md +++ b/docs/source/en/weightconverter.md @@ -16,21 +16,163 @@ rendered properly in your Markdown viewer. # Dynamic weight loading -Checkpoints are often serialized in a format that does not match what a model expects at runtime. Quantization and parallelism frequently require reshaping, splitting, or merging tensors into the expected model format instead of loading weights as-is. +Checkpoints are often serialized in a format that does not match what a model expects at runtime. Common scenarios include: + +1. **Fused weights**: Checkpoints store separate `gate_proj` and `up_proj` weights, but the model uses a fused `gate_up_proj` for efficiency. +2. **MoE expert consolidation**: Individual expert weights (`experts.0.weight`, `experts.1.weight`, ...) need to be stacked into a single 3D tensor. +3. **Legacy naming**: Old checkpoints use different naming conventions (e.g., `LayerNorm.gamma` vs `LayerNorm.weight`). +4. **Quantization**: Weights may be stored in quantized formats that need deserialization. Dynamic weight loading addresses this by applying scheduled, reversible operations to checkpoint tensors as they are loaded. Transformers makes this available through [`WeightConverter`], which maps one or more source keys to target keys by running a list of composable conversion operations. This approach adapts to new weight layouts, and supports loading quantized mixture-of-experts (MoEs) or enabling tensor parallelism and MoEs. This guide demonstrates how to use the [`WeightConverter`] to convert tensors. Your [`WeightConverter`] should be added inside [_build_checkpoint_conversion_mapping()](https://github.com/huggingface/transformers/blob/4c9fde2a2a3aece0bcf1be93f696e88297da9397/src/transformers/conversion_mapping.py#L34) in the [conversion_mapping.py](https://github.com/huggingface/transformers/blob/main/src/transformers/conversion_mapping.py) file. +## Full loading pipeline + +All models go through the dynamic weight loading system. Conversion mapping is an **optional step within that system** that only activates when the model has entries in `_MODEL_TO_CONVERSION_PATTERN`. + +``` +Checkpoint File → from_pretrained() → convert_and_load_state_dict_in_model() + ↓ + ┌───────────────────────────────────────────────────────────┐ + │ For each weight in checkpoint: │ + │ 1. Match renamed/processed source key to model parameter │ + │ 2. Shard the weight and send to device (async) │ + │ 3. Collect tensors with the same source_pattern together │ + │ (e.g. MoE experts, gate_up_proj) │ + │ 4. Apply dequantization/deserialization (if pre-quant) │ + │ 5. Apply conversion (if defined) │ + │ 6. Apply quantization (if enabled and step 4 not used) │ + │ 7. Set parameter on model │ + └───────────────────────────────────────────────────────────┘ +``` + +| Step | When it activates | +|------|-------------------| +| Dynamic loading | Always, for all models | +| Conversion mapping | Only when `model_type` is in `_MODEL_TO_CONVERSION_PATTERN` | +| TP sharding | Only when `tp_plan="auto"` and model has `base_model_tp_plan` | +| Dequantization/deserialization | Only when loading a pre-quantized checkpoint | +| Quantization | Only when a quantization config is provided and weights are not pre-quantized | + +### Dense models (e.g., Llama) + +For most dense models, the checkpoint format matches the model format directly, so no conversion mapping is needed. Some models may still require renaming (e.g., legacy naming conventions). TP sharding still applies when enabled. + +``` +Checkpoint: Model: +model.layers.0.self_attn.q_proj.weight → model.layers.0.self_attn.q_proj.weight +model.layers.0.self_attn.k_proj.weight → model.layers.0.self_attn.k_proj.weight +model.layers.0.mlp.gate_proj.weight → model.layers.0.mlp.gate_proj.weight +model.layers.0.mlp.up_proj.weight → model.layers.0.mlp.up_proj.weight +model.layers.0.mlp.down_proj.weight → model.layers.0.mlp.down_proj.weight +```x + +Legacy checkpoints may use older naming conventions that are handled by built-in renamings applied to all models: + +``` +Checkpoint: Model: +LayerNorm.gamma → LayerNorm.weight +LayerNorm.beta → LayerNorm.bias +``` + +### MoE models (e.g., Mixtral) + +For MoE models, the checkpoint format differs from the model format. Conversion mapping transforms separate expert weights into fused 3D tensors, and TP sharding applies after conversion. + +``` +Checkpoint: Model: +experts.0.w1.weight ─┐ +experts.1.w1.weight │ MergeModulelist +... ├───────────────→ experts.gate_up_proj (8, hidden, 2*intermediate) +experts.0.w3.weight │ + Concatenate +experts.1.w3.weight ─┘ +``` + +## Architecture + +The system is built around several key components defined in `src/transformers/core_model_loading.py`: + +**Phase 1 — Per-key processing** (iterates over checkpoint keys): + +1. **Rename key** via `WeightRenaming` (e.g. `block_sparse_moe` -> `mlp`) +2. **Match pattern** via `WeightConverter` (e.g. `experts.*.w1.weight`) +3. **Shard (TP) and send to device** asynchronously via `ThreadPoolExecutor` +4. **Collect** tensors with the same `source_pattern` together (e.g. all MoE expert weights, gate + up projections) + +**Phase 2 — Per-mapping processing** (iterates over collected mappings): + +1. **Dequantize/deserialize** (pre-quantized checkpoints only) +2. **Apply `ConversionOps` chain**: `Chunk`, `Concatenate`, `MergeModulelist`, `Transpose`, etc. +3. **Quantize** on-the-fly (if not pre-quantized) +4. **Set parameter** on model + +### WeightTransform + +The base class that handles pattern matching and tensor collection: + +- **Pattern compilation**: Converts glob-style patterns (`*.weight`) to regex. +- **Key renaming**: `rename_source_key()` transforms checkpoint keys to model keys. +- **Tensor collection**: `add_tensor()` gathers related tensors for batch processing. +- **Reversibility**: `reverse_transform()` creates the inverse operation for saving. + +```python +@dataclass(slots=True) +class WeightTransform: + source_patterns: str | list[str] # Checkpoint key patterns + target_patterns: str | list[str] # Model key patterns + compiled_sources: re.Pattern # Compiled regex for matching + distributed_operation: TensorParallelLayer | None + quantization_operation: ConversionOps | None + collected_tensors: dict[str, list[Future]] # Gathered tensors + layer_targets: dict[str, set[str]] # Target key tracking +``` + +### WeightRenaming + +[`WeightRenaming`] is a specialized [`WeightTransform`] for simple 1:1 key renaming without tensor operations: + +```py +# Legacy checkpoint compatibility +WeightRenaming("LayerNorm.gamma", "LayerNorm.weight") + +# Module path changes +WeightRenaming(".block_sparse_moe.", ".mlp.") + +# Adding prefixes +WeightRenaming("(.+)", "timm_model.\\1") +``` + +### WeightConverter + +[`WeightConverter`] extends [`WeightTransform`] with a list of [`ConversionOps`]: + +```python +@dataclass(slots=True) +class WeightConverter(WeightTransform): + operations: list[ConversionOps] # Chain of operations +``` + +It supports many-to-one (e.g., concatenating `gate` + `up` → `gate_up`), one-to-many (e.g., splitting `qkv` → `q`, `k`, `v`), and chained operations applied sequentially. + ## Conversion operations The [`WeightConverter`] class has several operations that are executed when [`~PreTrainedModel.from_pretrained`] is called for transforming checkpoint source tensors into model target tensors. Operations are fully reversible. Saving reverses the conversions and returns the original checkpoint so you can easily work across different frameworks. +| Operation | Reverse | +|-----------|---------| +| [`Chunk(dim)`] | [`Concatenate(dim)`] | +| [`Concatenate(dim)`] | [`Chunk(dim)`] | +| [`MergeModulelist(dim)`] | [`SplitModulelist(dim)`] | +| [`SplitModulelist(dim)`] | [`MergeModulelist(dim)`] | +| [`Transpose(d0, d1)`] | [`Transpose(d1, d0)`] | +| [`Force16BytesAlignment`] | [`Force16BytesAlignment`] (idempotent) | + ### Chunk -The [`Chunk`] operation is used to split a tensor. For example, if a model expects Q, K, and V as three separate tensors instead of a single tensor. +The [`Chunk`] operation splits a tensor into equal parts along a dimension. For example, if a model expects Q, K, and V as three separate tensors instead of a single tensor. ```py WeightConverter( @@ -42,7 +184,7 @@ WeightConverter( ### Concatenate -The [`Concatenate`] operation allows you to fuse separate tensors into a single tensor. For example, if a model expects Q, K, and V as a single tensor instead of separate tensors. +The [`Concatenate`] operation fuses separate tensors into a single tensor. For example, if a model expects Q, K, and V as a single tensor instead of separate tensors. ```py WeightConverter( @@ -54,7 +196,7 @@ WeightConverter( ### MergeModulelist -[`MergeModulelist`] merges a list of tensors into a single tensor. For example, you can compose [`MergeModulelist`] with [`Concatenate`] to stack the experts in a MoE and pack them into one tensor. +[`MergeModulelist`] merges a list of 2D tensors into a single 3D tensor. For example, you can compose [`MergeModulelist`] with [`Concatenate`] to stack the experts in a MoE and pack them into one tensor. ```py WeightConverter( @@ -69,7 +211,7 @@ WeightConverter( ### SplitModulelist -[`SplitModulelist`] splits a tensor back into a list of tensors. For example, you can split a stack of experts back into individual experts. +[`SplitModulelist`] splits a 3D tensor back into a list of 2D tensors. For example, you can split a stack of experts back into individual experts. ```py WeightConverter( @@ -94,6 +236,124 @@ WeightConverter( ) ``` +### Transpose + +[`Transpose`] swaps dimensions of a tensor. Useful for converting weight layouts between different conventions. + +```py +WeightConverter( + source_patterns="mlp.gate.weight", + target_patterns="mlp.text_moe.gate.weight", + operations=[Transpose(dim0=0, dim1=1)], +) +``` + +### Force16BytesAlignment + +[`Force16BytesAlignment`] clones a tensor if it is not 16-byte aligned. This is required for `torch._grouped_mm` and TMA/SIMD operations. It is idempotent: applying it more than once has no additional effect. + +## Operation chaining + +Operations can be chained to perform complex transformations. The operations execute in order, with each operation's output becoming the next operation's input. + +### Example: Mixtral MoE conversion + +```python +WeightConverter( + source_patterns=[ + ".experts.*.w1.weight", # gate_proj per expert + ".experts.*.w3.weight", # up_proj per expert + ], + target_patterns=".experts.gate_up_proj", + operations=[ + MergeModulelist(dim=0), # Stack all experts: (n_experts, in, out) + Concatenate(dim=1), # Fuse gate+up: (n_experts, in, 2*out) + ], +) +``` + +**Data flow:** +``` +Input: + ".experts.*.w1.weight": [tensor_0, tensor_1, ..., tensor_7] # 8 experts + ".experts.*.w3.weight": [tensor_0, tensor_1, ..., tensor_7] # 8 experts + +After MergeModulelist(dim=0): + ".experts.*.w1.weight": (8, 4096, 14336) # stacked gate + ".experts.*.w3.weight": (8, 4096, 14336) # stacked up + +After Concatenate(dim=1): + ".experts.gate_up_proj": (8, 4096, 28672) # fused gate_up +``` + +### Pattern matching + +The `*` in patterns acts as a wildcard: +- During loading, it matches any numeric index (`experts.0.`, `experts.1.`, etc.). +- Tensors with the same pattern (differing only in index) are grouped together. +- The order of collection is preserved for correct concatenation. + +## Tensor parallelism integration + +The dynamic loading system integrates with tensor parallelism (TP) through the `TensorParallelLayer` hierarchy defined in `src/transformers/integrations/tensor_parallel.py`. + +When TP is enabled, tensors are sharded **during** materialization, not after. This means each rank only loads the portion of the tensor it needs. + +```python +def spawn_tp_materialize(thread_pool, tensor, sharding_method, tensor_idx, device, dtype): + def _job(): + return sharding_method.shard_tensor(tensor, tensor_idx=tensor_idx, device=device, dtype=dtype) + return thread_pool.submit(_job) +``` + +### Available parallel styles + +| Style | Weight Shard Dim | Description | +|-------|------------------|-------------| +| `colwise` | -2 | Column-wise: output features sharded | +| `rowwise` | -1 | Row-wise: input features sharded | +| `packed_colwise` | -2 | For fused weights (gate_up_proj) | +| `packed_rowwise` | -1 | For fused weights | +| `embedding_rowwise` | 0 | Vocabulary parallelism | +| `grouped_gemm` | 0 | Expert parallelism for MoE | +| `sequence_parallel` | None | No weight sharding | + +### Packed weight handling + +For fused weights like `gate_up_proj`, special care is needed to shard correctly: + +```python +def get_packed_weights(param, empty_param, device_mesh, rank, dim): + """ + Interleaves gate and up shards correctly. + + Packed tensor: [G0 G1 G2 G3 | U0 U1 U2 U3] + + With TP=2: + - Rank 0 gets: [G0 G1 | U0 U1] + - Rank 1 gets: [G2 G3 | U2 U3] + """ +``` + +The TP operation is stored in the [`WeightTransform`] and applied after conversion operations: + +```python +if matched_tp_pattern := tp_plan_alt.search(renamed_key): + tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]] + mapping.distributed_operation = tp_layer( + device_mesh=device_mesh, + rank=device_mesh.get_local_rank(), + empty_param=empty_param.clone() + ) +``` + +## Quantization integration + +Quantization hooks into the loading pipeline in two ways, depending on whether the checkpoint is already quantized: + +- **Pre-quantized checkpoints**: The quantizer provides [`WeightConverter`] instances (via `get_weight_conversions()`) that deserialize quantized tensors. Checkpoint dtypes are preserved to avoid unwanted casts. +- **On-the-fly quantization**: The quantizer provides a quantization operation that is applied after conversion ops, quantizing weights as they are loaded. + ## Fast and efficient model loading Loading a model is faster and uses less memory because the loader knows which tensors are required for operations and schedules their materialization lazily. @@ -105,6 +365,44 @@ If your system runs other heavy processes, multiple threads may slow down loadin > [!NOTE] > The default is 4 threads for asynchronous parameter loading. This provides the best trade-off across loading scenarios and hardware. The work is mostly I/O bound, but depending on accelerator hardware and the `dtype` required at loading, it can become CPU/GPU-bound if the `dtype` differs from the serialized one (this requires an additional copy operation). +### Async vs sync loading + +```python +def spawn_materialize(thread_pool, tensor, device, dtype) -> Future | Callable: + def _job(): + return _materialize_copy(tensor, device, dtype) + + if thread_pool is not None: + return thread_pool.submit(_job) # Async: returns Future + else: + return _job # Sync: returns Callable (deferred execution) +``` + +Sync loading is used when: +- `HF_DEACTIVATE_ASYNC_LOAD=1` environment variable is set. +- Disk offloading is enabled (memory constraints require sequential loading). + +### Materialization flow + +``` +1. Checkpoint iteration (Phase 1): + - For each key, submit materialization job to ThreadPoolExecutor + - Job returns Future (async) or Callable (sync) + - Collect into the matching WeightConverter/WeightRenaming + +2. Per-mapping processing (Phase 2, one mapping at a time): + - materialize_tensors() waits for this mapping's Futures only + - Apply conversion operations chain (self.operations) + - Apply quantization operation (if on-the-fly) + - Set parameters on model + - Delete realized tensors immediately + +3. Cleanup: + - Thread pool shutdown (with cancel_futures=True for interrupts) +``` + +### Memory efficiency + When converting a weight, the converter waits for all required tensors to materialize if they haven't loaded yet. For example, the [`MergeModulelist`] operation requires all weights in `ModuleList` to be loaded before merging. Concatenating tensors requires a temporary copy, so operations like [`MergeModulelist`] and [`Concatenate`] need 2x the memory of the underlying tensors during conversion. Once merged, only the resulting tensor stays in memory. The theoretical worst-case memory peak is the model size plus the tensors required for the largest [`MergeModulelist`] or [`Concatenate`] operation. @@ -118,6 +416,109 @@ For example, a MoE model using [`MergeModulelist`] for experts on each layer, th These worst-case scenarios are uncommon. The actual memory peak tends to stay close to the model size. +## Reversibility + +The system supports saving models with the inverse transformations, enabling round-trip save/load: + +```python +def revert_weight_conversion(model, state_dict): + """Applies reverse conversions for saving.""" + weight_conversions = getattr(model, "_weight_conversions", None) + + # Reverse all transforms + reverse_weight_conversion = [ + conversion.reverse_transform() for conversion in weight_conversions + ] + + # Apply in reverse + for first_param_name, reversed_converter in conversion_mapping.items(): + realized_value = reversed_converter.convert(first_param_name, model=model) +``` + +Target patterns may contain regex elements that need processing for the reverse direction: + +```python +def process_target_pattern(pattern: str) -> tuple[str, str | None]: + """ + - Removes `^` and `$` anchors + - Removes negative lookahead/lookbehind + - Detects capturing groups, replaces with \1 + """ +``` + +## Real examples + +### Mixtral-style MoE + +**Checkpoint format:** +``` +model.layers.0.block_sparse_moe.experts.0.w1.weight # gate per expert +model.layers.0.block_sparse_moe.experts.0.w2.weight # down per expert +model.layers.0.block_sparse_moe.experts.0.w3.weight # up per expert +... +model.layers.0.block_sparse_moe.experts.7.w1.weight +``` + +**Model format:** +``` +model.layers.0.mlp.experts.gate_up_proj # (8, 4096, 28672) +model.layers.0.mlp.experts.down_proj # (8, 14336, 4096) +``` + +**Conversion mapping** (from `conversion_mapping.py`): +```python +"mixtral": [ + WeightRenaming(".block_sparse_moe.", ".mlp."), + WeightConverter( + source_patterns=[".experts.*.w1.weight", ".experts.*.w3.weight"], + target_patterns=".experts.gate_up_proj", + operations=[MergeModulelist(dim=0), Concatenate(dim=1)], + ), + WeightConverter( + source_patterns=[".experts.*.w2.weight"], + target_patterns=".experts.down_proj", + operations=[MergeModulelist(dim=0)], + ), +], +``` + +### Custom operations (ERNIE 4.5 VL MoE) + +When the built-in operations aren't sufficient, you can create a custom [`ConversionOps`] subclass. For example, ERNIE 4.5 VL MoE needs to split a shared expert list between text and vision modalities — something no single built-in op handles. The custom `ErnieFuseAndSplitTextVisionExperts` operation splits and re-stacks experts across two target keys: + +```python +"ernie4_5_vl_moe": [ + WeightRenaming("vision_model", "vision_tower"), + WeightConverter( + source_patterns=["experts.*.down_proj.weight"], + target_patterns=[ + "text_moe.experts.down_proj", + "vision_moe.experts.down_proj", + ], + operations=[ErnieFuseAndSplitTextVisionExperts(stack_dim=0, concat_dim=1)], + ), +], +``` + +Custom ops must implement `convert()` and the `reverse_op` property to support round-trip save/load. + +### Model type aliases + +Many models share conversion patterns: + +```python +_MODEL_TO_CONVERSION_PATTERN = { + "mixtral": "mixtral", + "minimax": "mixtral", + "qwen2_moe": "qwen2_moe", + "deepseek_v2": "qwen2_moe", + "deepseek_v3": "qwen2_moe", + "qwen3_moe": "qwen2_moe", + "olmoe": "qwen2_moe", + ... +} +``` + ## Reusing the dynamic loading building blocks Dynamic weight loading is not limited to full model checkpoints. The same building blocks let you load *any* set of @@ -141,4 +542,13 @@ At a high level, the contract looks like this: - `_finalize_load_state_dict(...)` to move any missing/mismatched tensors off `meta`, initialize them, and tie weights. - `log_state_dict_report(...)` to report missing/unexpected/mismatched keys (and conversion errors). -These APIs are expose to allow you to handle custom code, custom weight format, but also make sure you benefit from the highest and most efficient weight loading, sharding and good quality of life of `transformers` API! \ No newline at end of file +These APIs are exposed to allow you to handle custom code, custom weight formats, but also make sure you benefit from the highest and most efficient weight loading, sharding and good quality of life of `transformers` API! + +## Key files reference + +| File | Purpose | +|------|---------| +| `src/transformers/core_model_loading.py` | Core loading logic, WeightConverter, ConversionOps | +| `src/transformers/conversion_mapping.py` | Built-in conversion patterns for all models | +| `src/transformers/integrations/tensor_parallel.py` | TP sharding classes and utilities | +| `src/transformers/quantizers/base.py` | Quantization hooks and base class | diff --git a/pyproject.toml b/pyproject.toml index c138b905cd21..710f64032aa9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ markers = [ "bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests", "generate: marks tests that use the GenerationTesterMixin", "is_training_test: marks tests that use the TrainingTesterMixin (deselect with '-m \"not is_training_test\"')", + "is_tensor_parallel_test: marks tests that use the TensorParallelTesterMixin (deselect with '-m \"not is_tensor_parallel_test\"')", ] log_cli = 1 log_cli_level = "WARNING" diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index b8af3b0e7c8c..57d715dca697 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -352,9 +352,21 @@ def _build_checkpoint_conversion_mapping(): ), ] - mapping["ernie4_5_moe"] = mapping["qwen2_moe"].copy() - mapping["ernie4_5_moe"] += [ - WeightRenaming("mlp.moe_statics.e_score_correction_bias", "mlp.gate.moe_statics.e_score_correction_bias") + mapping["ernie4_5_moe"] = [ + WeightRenaming("mlp.moe_statics.e_score_correction_bias", "mlp.gate.moe_statics.e_score_correction_bias"), + WeightConverter( + source_patterns=[ + "mlp.experts.*.gate_proj.weight", + "mlp.experts.*.up_proj.weight", + ], + target_patterns="mlp.experts.gate_up_proj", + operations=[MergeModulelist(dim=0), Concatenate(dim=1), Force16BytesAlignment()], + ), + WeightConverter( + source_patterns="mlp.experts.*.down_proj.weight", + target_patterns="mlp.experts.down_proj", + operations=[MergeModulelist(dim=0), Force16BytesAlignment()], + ), ] mapping["minimax_m2"] = mapping["mixtral"].copy() mapping["minimax_m2"] += [ @@ -363,6 +375,22 @@ def _build_checkpoint_conversion_mapping(): mapping["exaone_moe"] = mapping["qwen2_moe"].copy() mapping["exaone_moe"] += [WeightRenaming("mlp.e_score_correction_bias", "mlp.gate.e_score_correction_bias")] + mapping["solar_open"] = [ + WeightConverter( + source_patterns=[ + "mlp.experts.*.gate_proj.weight", + "mlp.experts.*.up_proj.weight", + ], + target_patterns="mlp.experts.gate_up_proj", + operations=[MergeModulelist(dim=0), Concatenate(dim=1), Force16BytesAlignment()], + ), + WeightConverter( + source_patterns="mlp.experts.*.down_proj.weight", + target_patterns="mlp.experts.down_proj", + operations=[MergeModulelist(dim=0), Force16BytesAlignment()], + ), + ] + mapping["qwen3_5_moe_text"] = mapping["qwen3_5_text"].copy() mapping["qwen3_5_moe_text"] += mapping["qwen2_moe"].copy() diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 6d6c0f4de52c..debe3d69b262 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -460,6 +460,7 @@ def backward(ctx, grad_output): device_mesh = ctx.device_mesh if device_mesh.size() == 1: return grad_output, None + grad_output = grad_output.contiguous() dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=device_mesh.get_group()) return grad_output, None @@ -658,7 +659,7 @@ def shard_tensor( ) -> torch.Tensor: raise NotImplementedError - def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: + def prepare_module_tp(self, module: nn.Module, device_mesh, **kwargs) -> nn.Module: distribute_module( module, device_mesh, @@ -724,6 +725,86 @@ def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) - return tuple(shape) +class ReplicatedWithGradAllReduce(TensorParallelLayer): + """ + Replicated parameter with gradient all-reduce. + + For parameters like q_norm/k_norm that sit between colwise and rowwise + layers. The parameter is replicated (not sharded), but its gradient + accumulates from local heads only in TP mode. This class registers a + backward hook to all-reduce the parameter gradient. + """ + + @staticmethod + def _prepare_input_fn(mod, inputs, device_mesh): + return inputs + + @staticmethod + def _prepare_output_fn(mod, outputs, device_mesh): + return outputs + + def shard_tensor(self, param, tensor_idx=None, device=None, dtype=None): + return param[...].to(device=device, dtype=dtype) + + def prepare_module_tp(self, module, device_mesh, **kwargs): + # Use a module-level backward hook (not param.register_hook) because parameters are replaced during weight loading after this method runs. + # Module hooks survive parameter replacement. + def _backward_hook(mod, grad_input, grad_output, mesh=device_mesh): + for param in mod.parameters(): + if param.grad is not None: + all_reduce_forward(param.grad, mesh) + + module.register_full_backward_hook(_backward_hook) + + +class MlaKvAProjParallel(TensorParallelLayer): + """ + For MLA attention used in DeepSeek-V2 style models (deepseek_v2, longcat_flash, glm_moe_dsa, glm4_moe_lite): + kv_a_proj_with_mqa output is [kv_lora_rank + qk_rope_head_dim] (can have different naming but important thing + to understand is that it is split) + Example below (from modeling_longcat_flash.py): + + kv_a_proj_with_mqa + | + split + / \ + k_pass k_rot <-- "bypasses kv_b_proj" + | | (goes straight to attention, + kv_a_layernorm | never touches kv_b_proj) + | | + kv_b_proj | + (colwise) | + | | + k_pass k_rot + \\ / + cat + | + key_states + + k_pass is passed to kv_b_proj (colwise) which has built-in all_reduce_backward so we don't have a partial gradient for it. + However, k_rot goes straight to attention, never touches kv_b_proj. So we need to average gradient across all ranks otherwise we only get gradient for one rank (partial gradient). + """ + + def _prepare_output_fn(self, mod, output, device_mesh): + if not hasattr(mod.config, "qk_rope_head_dim"): + raise AttributeError( + f"Config for {type(mod).__name__} does not have `qk_rope_head_dim`. " + "MlaKvAProjParallel requires `qk_rope_head_dim` to be defined in the model config. " + "Please add it to the model's config or update the TP plan mapping." + ) + rope_dim = mod.config.qk_rope_head_dim + pass_output, rope_output = output.split([output.shape[-1] - rope_dim, rope_dim], dim=-1) + rope_output = all_reduce_backward(rope_output, device_mesh) + return torch.cat([pass_output, rope_output], dim=-1) + + def shard_tensor(self, param, tensor_idx=None, device=None, dtype=None): + return param[...].to(device=device, dtype=dtype) + + def prepare_module_tp(self, module, device_mesh, config=None, **kwargs): + module.config = config + distribute_module(module, device_mesh, output_fn=self._prepare_output_fn) + + class RowwiseParallel(TensorParallelLayer): """ Row-wise parallel: weight is sharded on dim -1 (input features). @@ -1087,6 +1168,29 @@ def shard_tensor( return param[...].to(device=device, dtype=dtype) +class MoeIdentityExpertParallel(TensorParallelLayer): + """ + TP class for zero/identity experts in MoE layers. + + Under TP, the parent MoeTensorParalellExperts does all_reduce_forward (sum) + on the expert module output. Identity experts produce the same output on + every rank, so the sum gives world_size * output. This class divides the + input by world_size to compensate. + """ + + @staticmethod + def _prepare_input_fn(mod, inputs, device_mesh): + input_tensor = inputs[0] if inputs else inputs + # TODO(fmom): when 2D-device mesh, need to select a //-ism axis to divide the input tensor by. + return input_tensor / device_mesh.size() + + def shard_tensor(self, param, tensor_idx=None, device=None, dtype=None): + return param[...].to(device=device, dtype=dtype) + + def prepare_module_tp(self, module, device_mesh, **kwargs): + distribute_module(module, device_mesh, input_fn=self._prepare_input_fn) + + class ParallelInterface(GeneralInterface): # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if # a new instance is created (in order to locally override a given entry) @@ -1103,6 +1207,9 @@ class ParallelInterface(GeneralInterface): "grouped_gemm": GroupedGemmParallel(), "ep_router": RouterParallel(), "moe_tp_experts": MoeTensorParalellExperts(), + "moe_identity_expert": MoeIdentityExpertParallel(), + "replicated_with_grad_allreduce": ReplicatedWithGradAllReduce(), + "mla_kv_a_proj": MlaKvAProjParallel(), } if is_torch_available() and _torch_distributed_available else {} @@ -1120,6 +1227,8 @@ class ParallelInterface(GeneralInterface): "packed_rowwise": -1, "embedding_rowwise": 0, "sequence_parallel": None, + "replicated_with_grad_allreduce": None, + "mla_kv_a_proj": None, } # Bias sharding: colwise shards bias, rowwise doesn't (bias is replicated and all-reduced) @@ -1132,6 +1241,8 @@ class ParallelInterface(GeneralInterface): "packed_rowwise": None, "embedding_rowwise": None, "sequence_parallel": None, + "replicated_with_grad_allreduce": None, + "mla_kv_a_proj": None, } @@ -1258,13 +1369,14 @@ def add_tensor_parallel_hooks_to_module( if current_module_plan is not None: tp_layer = ALL_PARALLEL_STYLES[current_module_plan] try: - tp_layer.prepare_module_tp(module, device_mesh) + tp_layer.prepare_module_tp(module, device_mesh, config=model.config) except NotImplementedError as e: print( f"Trying to prepare {layer_name}, but it's not supported. Corresponding module: {module} Fix it's TP plan: {e}" ) module._hf_tp_plan = current_module_plan + module._hf_device_mesh = device_mesh module.__repr__ = lambda: f"{module.__repr__()}\nTP Plan: {current_module_plan}" diff --git a/src/transformers/integrations/torchao.py b/src/transformers/integrations/torchao.py index 63cc7a324275..6db371a4f599 100644 --- a/src/transformers/integrations/torchao.py +++ b/src/transformers/integrations/torchao.py @@ -148,6 +148,12 @@ def convert( quantize_(module, c, (lambda x, fqn: True)) missing_keys.discard(full_layer_name) module._is_hf_initialized = True + # torchao quantizes weights into a module but some models access the weight directly + # (e.g. module.o_proj.weight). The _is_hf_initialized flag is set at the module + # level only, so we also set it on each parameter to prevent _init_weights from + # calling normal_() on already-quantized Float8Tensors. + for param in module.parameters(recurse=False): + param._is_hf_initialized = True return {"lm_head.weight": lm_head} if is_embedding_param and untie_embedding_weights else {} else: # need to apply to custom param name @@ -155,6 +161,8 @@ def convert( quantize_(module, custom_param_fqn_config, filter_fn=None) missing_keys.discard(full_layer_name) module._is_hf_initialized = True + for param in module.parameters(recurse=False): + param._is_hf_initialized = True return {} return {full_layer_name: value} @@ -189,6 +197,8 @@ def convert( quantize_(module, c, filter_fn=lambda x, fqn: True) missing_keys.discard(full_layer_name) module._is_hf_initialized = True + for param in module.parameters(recurse=False): + param._is_hf_initialized = True return {"lm_head.weight": lm_head} if is_embedding_param and untie_embedding_weights else {} return {full_layer_name: value} @@ -198,6 +208,8 @@ def convert( quantize_(module, self.hf_quantizer.quantization_config.get_apply_tensor_subclass()) missing_keys.discard(full_layer_name) module._is_hf_initialized = True + for param in module.parameters(recurse=False): + param._is_hf_initialized = True return {"lm_head.weight": lm_head} if is_embedding_param and untie_embedding_weights else {} diff --git a/src/transformers/models/apertus/configuration_apertus.py b/src/transformers/models/apertus/configuration_apertus.py index 1271a8e9af00..5b35fea84953 100644 --- a/src/transformers/models/apertus/configuration_apertus.py +++ b/src/transformers/models/apertus/configuration_apertus.py @@ -101,6 +101,8 @@ class ApertusConfig(PreTrainedConfig): "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", diff --git a/src/transformers/models/apertus/modular_apertus.py b/src/transformers/models/apertus/modular_apertus.py index 6caebcd27666..e27531502ea2 100644 --- a/src/transformers/models/apertus/modular_apertus.py +++ b/src/transformers/models/apertus/modular_apertus.py @@ -121,6 +121,8 @@ class ApertusConfig(PreTrainedConfig): "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", diff --git a/src/transformers/models/deepseek_v2/configuration_deepseek_v2.py b/src/transformers/models/deepseek_v2/configuration_deepseek_v2.py index 211b7e322708..a831587ca332 100644 --- a/src/transformers/models/deepseek_v2/configuration_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/configuration_deepseek_v2.py @@ -123,12 +123,19 @@ class DeepseekV2Config(PreTrainedConfig): base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", - "layers.*.self_attn.q_a_proj": "colwise", "layers.*.self_attn.q_b_proj": "colwise", + "layers.*.self_attn.kv_a_proj_with_mqa": "mla_kv_a_proj", "layers.*.self_attn.kv_b_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", + "layers.*.mlp.shared_experts.gate_proj": "colwise", + "layers.*.mlp.shared_experts.up_proj": "colwise", + "layers.*.mlp.shared_experts.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), diff --git a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py index ea37bb77a1e9..989a954811ff 100644 --- a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -360,6 +360,7 @@ def forward( k_nope, value_states = torch.split(k_nope, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_pe = k_pe.view(batch_size, 1, seq_length, self.qk_rope_head_dim) + q_pe, k_pe = apply_rotary_emb(q_pe, k_pe, position_embeddings.to(q_pe.device)) k_pe = k_pe.expand(*k_nope.shape[:-1], -1) diff --git a/src/transformers/models/deepseek_v2/modular_deepseek_v2.py b/src/transformers/models/deepseek_v2/modular_deepseek_v2.py index 30e9c91116e2..e256c1e4d14a 100644 --- a/src/transformers/models/deepseek_v2/modular_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modular_deepseek_v2.py @@ -140,12 +140,19 @@ class DeepseekV2Config(LlamaConfig): base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", - "layers.*.self_attn.q_a_proj": "colwise", "layers.*.self_attn.q_b_proj": "colwise", + "layers.*.self_attn.kv_a_proj_with_mqa": "mla_kv_a_proj", "layers.*.self_attn.kv_b_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", + "layers.*.mlp.shared_experts.gate_proj": "colwise", + "layers.*.mlp.shared_experts.up_proj": "colwise", + "layers.*.mlp.shared_experts.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", } model_type = "deepseek_v2" @@ -384,6 +391,7 @@ def forward( k_nope, value_states = torch.split(k_nope, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_pe = k_pe.view(batch_size, 1, seq_length, self.qk_rope_head_dim) + q_pe, k_pe = apply_rotary_emb(q_pe, k_pe, position_embeddings.to(q_pe.device)) k_pe = k_pe.expand(*k_nope.shape[:-1], -1) diff --git a/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py b/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py index 1baa1956b2c8..ea843f047e6a 100644 --- a/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py @@ -129,8 +129,9 @@ class DeepseekV3Config(PreTrainedConfig): model_type = "deepseek_v3" keys_to_ignore_at_inference = ["past_key_values"] base_model_tp_plan = { - "layers.*.mlp.experts.gate_up_proj": "rowwise", + "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", "layers.*.mlp.shared_experts.gate_proj": "colwise", "layers.*.mlp.shared_experts.up_proj": "colwise", "layers.*.mlp.shared_experts.down_proj": "rowwise", diff --git a/src/transformers/models/dots1/configuration_dots1.py b/src/transformers/models/dots1/configuration_dots1.py index aa4da940220a..7cfc5c0c82cf 100644 --- a/src/transformers/models/dots1/configuration_dots1.py +++ b/src/transformers/models/dots1/configuration_dots1.py @@ -118,8 +118,11 @@ class Dots1Config(PreTrainedConfig): "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.experts.gate_up_proj": "rowwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", + "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", "layers.*.mlp.shared_experts.gate_proj": "colwise", "layers.*.mlp.shared_experts.up_proj": "colwise", "layers.*.mlp.shared_experts.down_proj": "rowwise", diff --git a/src/transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py index fc70e9632e8e..78103f55740c 100644 --- a/src/transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py @@ -121,8 +121,9 @@ class Ernie4_5_MoeConfig(PreTrainedConfig): "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.experts.gate_up_proj": "rowwise", + "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", "layers.*.mlp.shared_experts.gate_proj": "colwise", "layers.*.mlp.shared_experts.up_proj": "colwise", "layers.*.mlp.shared_experts.down_proj": "rowwise", diff --git a/src/transformers/models/exaone4/configuration_exaone4.py b/src/transformers/models/exaone4/configuration_exaone4.py index bdff9525d671..14dee5b9bf17 100644 --- a/src/transformers/models/exaone4/configuration_exaone4.py +++ b/src/transformers/models/exaone4/configuration_exaone4.py @@ -115,6 +115,8 @@ class Exaone4Config(PreTrainedConfig): "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", diff --git a/src/transformers/models/exaone4/modular_exaone4.py b/src/transformers/models/exaone4/modular_exaone4.py index 6c8f98a5cb57..10f125139e75 100644 --- a/src/transformers/models/exaone4/modular_exaone4.py +++ b/src/transformers/models/exaone4/modular_exaone4.py @@ -149,6 +149,8 @@ class Exaone4Config(PreTrainedConfig): "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", diff --git a/src/transformers/models/exaone_moe/configuration_exaone_moe.py b/src/transformers/models/exaone_moe/configuration_exaone_moe.py index 41c7c8fb86ae..7b19868ef58a 100644 --- a/src/transformers/models/exaone_moe/configuration_exaone_moe.py +++ b/src/transformers/models/exaone_moe/configuration_exaone_moe.py @@ -135,6 +135,8 @@ class ExaoneMoeConfig(PreTrainedConfig): "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", diff --git a/src/transformers/models/flex_olmo/configuration_flex_olmo.py b/src/transformers/models/flex_olmo/configuration_flex_olmo.py index 1314b540041c..f5c8b878d207 100644 --- a/src/transformers/models/flex_olmo/configuration_flex_olmo.py +++ b/src/transformers/models/flex_olmo/configuration_flex_olmo.py @@ -114,8 +114,9 @@ class FlexOlmoConfig(PreTrainedConfig): "layers.*.self_attn.k_proj": "colwise_gather_output", # we need to replicate here due to the added norm on q and k "layers.*.self_attn.v_proj": "colwise_gather_output", # we need to replicate here due to the added norm on q and k "layers.*.self_attn.o_proj": "rowwise_split_input", # input is replicated due to the added norm on q and k - "layers.*.mlp.experts.gate_up_proj": "rowwise", + "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), diff --git a/src/transformers/models/flex_olmo/modular_flex_olmo.py b/src/transformers/models/flex_olmo/modular_flex_olmo.py index 54e59a2785fa..8da6b3e61ccf 100644 --- a/src/transformers/models/flex_olmo/modular_flex_olmo.py +++ b/src/transformers/models/flex_olmo/modular_flex_olmo.py @@ -126,8 +126,9 @@ class FlexOlmoConfig(PreTrainedConfig): "layers.*.self_attn.k_proj": "colwise_gather_output", # we need to replicate here due to the added norm on q and k "layers.*.self_attn.v_proj": "colwise_gather_output", # we need to replicate here due to the added norm on q and k "layers.*.self_attn.o_proj": "rowwise_split_input", # input is replicated due to the added norm on q and k - "layers.*.mlp.experts.gate_up_proj": "rowwise", + "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), diff --git a/src/transformers/models/gemma3/configuration_gemma3.py b/src/transformers/models/gemma3/configuration_gemma3.py index 9985653fb66e..3c4908b30d22 100644 --- a/src/transformers/models/gemma3/configuration_gemma3.py +++ b/src/transformers/models/gemma3/configuration_gemma3.py @@ -118,6 +118,8 @@ class Gemma3TextConfig(PreTrainedConfig): "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 756356d85ea4..653004478d70 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -143,6 +143,17 @@ class Gemma3TextConfig(Gemma2Config, PreTrainedConfig): """ model_type = "gemma3_text" + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } default_theta = {"global": 1_000_000.0, "local": 10_000.0} def __init__( diff --git a/src/transformers/models/gemma3n/configuration_gemma3n.py b/src/transformers/models/gemma3n/configuration_gemma3n.py index 37e076f5861e..780d7279e390 100644 --- a/src/transformers/models/gemma3n/configuration_gemma3n.py +++ b/src/transformers/models/gemma3n/configuration_gemma3n.py @@ -147,6 +147,9 @@ class Gemma3nTextConfig(PreTrainedConfig): "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.v_norm": "replicated_with_grad_allreduce", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index a97cc2823c7b..edde17a95955 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -171,6 +171,18 @@ class Gemma3nTextConfig(Gemma2Config, PreTrainedConfig): """ model_type = "gemma3n_text" + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.v_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } default_theta = {"global": 1_000_000.0, "local": 10_000.0} def __init__( diff --git a/src/transformers/models/glm4_moe/configuration_glm4_moe.py b/src/transformers/models/glm4_moe/configuration_glm4_moe.py index 2269275cd5ef..7d3565f7cf10 100644 --- a/src/transformers/models/glm4_moe/configuration_glm4_moe.py +++ b/src/transformers/models/glm4_moe/configuration_glm4_moe.py @@ -123,8 +123,12 @@ class Glm4MoeConfig(PreTrainedConfig): "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.experts.gate_up_proj": "rowwise", + "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", # NOTE(3outeille): This needs to be right after down_proj in the dict. Otherwise, the pattern model.layers.*.mlp.experts will have priority over model.layers.*.mlp.experts.down_proj which will assign a wrong TP plan. + "layers.*.mlp.shared_experts.gate_proj": "colwise", + "layers.*.mlp.shared_experts.up_proj": "colwise", + "layers.*.mlp.shared_experts.down_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", diff --git a/src/transformers/models/glm4_moe/modular_glm4_moe.py b/src/transformers/models/glm4_moe/modular_glm4_moe.py index a4236905e05c..e14e386b69a3 100644 --- a/src/transformers/models/glm4_moe/modular_glm4_moe.py +++ b/src/transformers/models/glm4_moe/modular_glm4_moe.py @@ -137,8 +137,12 @@ class Glm4MoeConfig(PreTrainedConfig): "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.experts.gate_up_proj": "rowwise", + "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", # NOTE(3outeille): This needs to be right after down_proj in the dict. Otherwise, the pattern model.layers.*.mlp.experts will have priority over model.layers.*.mlp.experts.down_proj which will assign a wrong TP plan. + "layers.*.mlp.shared_experts.gate_proj": "colwise", + "layers.*.mlp.shared_experts.up_proj": "colwise", + "layers.*.mlp.shared_experts.down_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", diff --git a/src/transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py b/src/transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py index 63f83a68d94d..afaeb116e893 100644 --- a/src/transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +++ b/src/transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py @@ -129,6 +129,9 @@ class Glm4MoeLiteConfig(PreTrainedConfig): model_type = "glm4_moe_lite" keys_to_ignore_at_inference = ["past_key_values"] base_model_tp_plan = { + "layers.*.self_attn.q_b_proj": "colwise", + "layers.*.self_attn.kv_a_proj_with_mqa": "mla_kv_a_proj", + "layers.*.self_attn.kv_b_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", diff --git a/src/transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py b/src/transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py index f4bb3f364e30..0131975fc9da 100644 --- a/src/transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +++ b/src/transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py @@ -138,6 +138,9 @@ class Glm4MoeLiteConfig(PreTrainedConfig): model_type = "glm4_moe_lite" keys_to_ignore_at_inference = ["past_key_values"] base_model_tp_plan = { + "layers.*.self_attn.q_b_proj": "colwise", + "layers.*.self_attn.kv_a_proj_with_mqa": "mla_kv_a_proj", + "layers.*.self_attn.kv_b_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", diff --git a/src/transformers/models/glm_moe_dsa/configuration_glm_moe_dsa.py b/src/transformers/models/glm_moe_dsa/configuration_glm_moe_dsa.py index 1cc7ff40312b..22b944c3e1ec 100644 --- a/src/transformers/models/glm_moe_dsa/configuration_glm_moe_dsa.py +++ b/src/transformers/models/glm_moe_dsa/configuration_glm_moe_dsa.py @@ -126,11 +126,18 @@ class GlmMoeDsaConfig(PreTrainedConfig): model_type = "glm_moe_dsa" keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_b_proj": "colwise", + "layers.*.self_attn.kv_a_proj_with_mqa": "mla_kv_a_proj", + "layers.*.self_attn.kv_b_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", "layers.*.mlp.experts": "moe_tp_experts", + "layers.*.mlp.shared_experts.gate_proj": "colwise", + "layers.*.mlp.shared_experts.up_proj": "colwise", + "layers.*.mlp.shared_experts.down_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", diff --git a/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py b/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py index a29dee5ff8c4..6632bb38f8df 100644 --- a/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py +++ b/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py @@ -355,7 +355,7 @@ def forward( else: q_resid = self.q_a_layernorm(self.q_a_proj(hidden_states)) # [B, S, q_lora_rank] query_states = self.q_b_proj(q_resid) - query_states = query_states.view(batch_size, seq_length, self.num_heads, self.qk_head_dim).transpose(1, 2) + query_states = query_states.view(batch_size, seq_length, -1, self.qk_head_dim).transpose(1, 2) # Split nope/rope, apply RoPE, recombine — layout: [B, H, S, D] q_nope, q_pe = torch.split(query_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) q_pe = apply_rotary_pos_emb(q_pe, cos, sin, unsqueeze_dim=1) # BHSD format @@ -367,7 +367,7 @@ def forward( # Expand KV through kv_b_proj kv_expanded = self.kv_b_proj(k_compressed) # [B, S, H * (nope_D + v_D)] - kv_expanded = kv_expanded.view(batch_size, seq_length, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + kv_expanded = kv_expanded.view(batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) k_nope, value_states = torch.split(kv_expanded, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_nope = k_nope.transpose(1, 2) # [B, H, S, nope_D] value_states = value_states.transpose(1, 2) # [B, H, S, v_D] @@ -375,7 +375,7 @@ def forward( # RoPE on k_pe (single-head rope stream) k_pe = k_pe.view(batch_size, 1, seq_length, self.qk_rope_head_dim) # [B, 1, S, rope_D] k_pe = apply_rotary_pos_emb(k_pe, cos, sin, unsqueeze_dim=1) # BHSD format - k_pe = k_pe.expand(-1, self.num_heads, -1, -1) # [B, H, S, rope_D] + k_pe = k_pe.expand(-1, k_nope.shape[1], -1, -1) # [B, H, S, rope_D] # Assemble full Q and K query_states = torch.cat([q_nope, q_pe], dim=-1) # [B, H, S, qk_head_dim] diff --git a/src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py b/src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py index 65a986013669..f6cdc5e6b305 100644 --- a/src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py +++ b/src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py @@ -177,11 +177,18 @@ class GlmMoeDsaConfig(PreTrainedConfig): model_type = "glm_moe_dsa" keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_b_proj": "colwise", + "layers.*.self_attn.kv_a_proj_with_mqa": "mla_kv_a_proj", + "layers.*.self_attn.kv_b_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", "layers.*.mlp.experts": "moe_tp_experts", + "layers.*.mlp.shared_experts.gate_proj": "colwise", + "layers.*.mlp.shared_experts.up_proj": "colwise", + "layers.*.mlp.shared_experts.down_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", @@ -510,7 +517,7 @@ def forward( else: q_resid = self.q_a_layernorm(self.q_a_proj(hidden_states)) # [B, S, q_lora_rank] query_states = self.q_b_proj(q_resid) - query_states = query_states.view(batch_size, seq_length, self.num_heads, self.qk_head_dim).transpose(1, 2) + query_states = query_states.view(batch_size, seq_length, -1, self.qk_head_dim).transpose(1, 2) # Split nope/rope, apply RoPE, recombine — layout: [B, H, S, D] q_nope, q_pe = torch.split(query_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) q_pe = apply_rotary_pos_emb(q_pe, cos, sin, unsqueeze_dim=1) # BHSD format @@ -522,7 +529,7 @@ def forward( # Expand KV through kv_b_proj kv_expanded = self.kv_b_proj(k_compressed) # [B, S, H * (nope_D + v_D)] - kv_expanded = kv_expanded.view(batch_size, seq_length, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + kv_expanded = kv_expanded.view(batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) k_nope, value_states = torch.split(kv_expanded, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_nope = k_nope.transpose(1, 2) # [B, H, S, nope_D] value_states = value_states.transpose(1, 2) # [B, H, S, v_D] @@ -530,7 +537,7 @@ def forward( # RoPE on k_pe (single-head rope stream) k_pe = k_pe.view(batch_size, 1, seq_length, self.qk_rope_head_dim) # [B, 1, S, rope_D] k_pe = apply_rotary_pos_emb(k_pe, cos, sin, unsqueeze_dim=1) # BHSD format - k_pe = k_pe.expand(-1, self.num_heads, -1, -1) # [B, H, S, rope_D] + k_pe = k_pe.expand(-1, k_nope.shape[1], -1, -1) # [B, H, S, rope_D] # Assemble full Q and K query_states = torch.cat([q_nope, q_pe], dim=-1) # [B, H, S, qk_head_dim] diff --git a/src/transformers/models/gpt_oss/configuration_gpt_oss.py b/src/transformers/models/gpt_oss/configuration_gpt_oss.py index a93481fc7d26..f352dab55edb 100644 --- a/src/transformers/models/gpt_oss/configuration_gpt_oss.py +++ b/src/transformers/models/gpt_oss/configuration_gpt_oss.py @@ -31,12 +31,7 @@ class GptOssConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } - base_model_tp_plan = { - "layers.*.self_attn.q_proj": "colwise", - "layers.*.self_attn.k_proj": "colwise", - "layers.*.self_attn.v_proj": "colwise", - "layers.*.self_attn.o_proj": "rowwise", - "layers.*.self_attn.sinks": "colwise", + base_model_ep_plan = { "layers.*.mlp.router": "ep_router", "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", "layers.*.mlp.experts.gate_up_proj_bias": "grouped_gemm", diff --git a/src/transformers/models/longcat_flash/configuration_longcat_flash.py b/src/transformers/models/longcat_flash/configuration_longcat_flash.py index 05e0c98ea72a..7c7ea3ee6af9 100644 --- a/src/transformers/models/longcat_flash/configuration_longcat_flash.py +++ b/src/transformers/models/longcat_flash/configuration_longcat_flash.py @@ -122,13 +122,16 @@ class LongcatFlashConfig(PreTrainedConfig): default_theta = 10000000.0 base_model_tp_plan = { "layers.*.self_attn.*.q_b_proj": "colwise", + "layers.*.self_attn.*.kv_a_proj_with_mqa": "mla_kv_a_proj", "layers.*.self_attn.*.kv_b_proj": "colwise", "layers.*.self_attn.*.o_proj": "rowwise", + "layers.*.mlp.experts.gate_up_proj": "packed_colwise", + "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts.identity_expert": "moe_identity_expert", + "layers.*.mlp.experts": "moe_tp_experts", "layers.*.mlps.*.gate_proj": "colwise", "layers.*.mlps.*.up_proj": "colwise", "layers.*.mlps.*.down_proj": "rowwise", - "layers.*.mlp.experts.gate_up_proj": "rowwise", - "layers.*.mlp.experts.down_proj": "rowwise", } base_model_pp_plan = { diff --git a/src/transformers/models/longcat_flash/modeling_longcat_flash.py b/src/transformers/models/longcat_flash/modeling_longcat_flash.py index 77e550503131..8ede9f347280 100644 --- a/src/transformers/models/longcat_flash/modeling_longcat_flash.py +++ b/src/transformers/models/longcat_flash/modeling_longcat_flash.py @@ -183,6 +183,7 @@ def __init__(self, config): self.zero_expert_num = config.zero_expert_num or 0 self.total_experts = self.num_routed_experts + self.zero_expert_num self.act_fn = ACT2FN[config.hidden_act] + self.identity_expert = nn.Identity() if self.num_routed_experts > 0: self.gate_up_proj = nn.Parameter( @@ -211,7 +212,7 @@ def forward(self, hidden_states, top_k_index, top_k_weights): current_state = hidden_states[token_idx] if expert_idx >= self.num_routed_experts or self.gate_up_proj is None: - current_hidden_states = current_state + current_hidden_states = self.identity_expert(current_state) else: gate, up = F.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) current_hidden_states = self.act_fn(gate) * up @@ -557,6 +558,9 @@ class LongcatFlashPreTrainedModel(PreTrainedModel): "attentions": LongcatFlashMLA, } _keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"] + _keep_in_fp32_modules = [ + "classifier.weight" + ] # TODO let's make sure orignal code base has this, for now it fixes quantization @torch.no_grad() def _init_weights(self, module): diff --git a/src/transformers/models/longcat_flash/modular_longcat_flash.py b/src/transformers/models/longcat_flash/modular_longcat_flash.py index 08c35591a78f..4f2af0841f31 100644 --- a/src/transformers/models/longcat_flash/modular_longcat_flash.py +++ b/src/transformers/models/longcat_flash/modular_longcat_flash.py @@ -103,6 +103,7 @@ def __init__(self, config): self.zero_expert_num = config.zero_expert_num or 0 self.total_experts = self.num_routed_experts + self.zero_expert_num self.act_fn = ACT2FN[config.hidden_act] + self.identity_expert = nn.Identity() if self.num_routed_experts > 0: self.gate_up_proj = nn.Parameter( @@ -131,7 +132,7 @@ def forward(self, hidden_states, top_k_index, top_k_weights): current_state = hidden_states[token_idx] if expert_idx >= self.num_routed_experts or self.gate_up_proj is None: - current_hidden_states = current_state + current_hidden_states = self.identity_expert(current_state) else: gate, up = F.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) current_hidden_states = self.act_fn(gate) * up @@ -341,6 +342,9 @@ class LongcatFlashPreTrainedModel(PreTrainedModel): "attentions": LongcatFlashMLA, } _keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"] + _keep_in_fp32_modules = [ + "classifier.weight" + ] # TODO let's make sure orignal code base has this, for now it fixes quantization @torch.no_grad() def _init_weights(self, module): diff --git a/src/transformers/models/minimax/configuration_minimax.py b/src/transformers/models/minimax/configuration_minimax.py index 1802066c56f2..8e049489f81b 100644 --- a/src/transformers/models/minimax/configuration_minimax.py +++ b/src/transformers/models/minimax/configuration_minimax.py @@ -134,9 +134,9 @@ class MiniMaxConfig(PreTrainedConfig): "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.gate": "colwise_gather_output", # we need to replicate here to correctly route experts "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), diff --git a/src/transformers/models/minimax/modular_minimax.py b/src/transformers/models/minimax/modular_minimax.py index 15de27f09dd3..176669e4b301 100644 --- a/src/transformers/models/minimax/modular_minimax.py +++ b/src/transformers/models/minimax/modular_minimax.py @@ -162,9 +162,9 @@ class MiniMaxConfig(PreTrainedConfig): "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.gate": "colwise_gather_output", # we need to replicate here to correctly route experts "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), diff --git a/src/transformers/models/minimax_m2/configuration_minimax_m2.py b/src/transformers/models/minimax_m2/configuration_minimax_m2.py index 6a2805cb3534..46644f182bab 100644 --- a/src/transformers/models/minimax_m2/configuration_minimax_m2.py +++ b/src/transformers/models/minimax_m2/configuration_minimax_m2.py @@ -111,11 +111,10 @@ class MiniMaxM2Config(PreTrainedConfig): model_type = "minimax_m2" keys_to_ignore_at_inference = ["past_key_values"] base_model_tp_plan = { - "layers.*.self_attn.q_proj": "colwise_rep", - "layers.*.self_attn.k_proj": "colwise_rep", - "layers.*.self_attn.v_proj": "colwise_rep", - "layers.*.self_attn.o_proj": "rowwise_rep", - "layers.*.mlp.gate": "colwise_rep", # we need to replicate here to correctly route experts + "layers.*.self_attn.q_proj": "colwise_gather_output", + "layers.*.self_attn.k_proj": "colwise_gather_output", + "layers.*.self_attn.v_proj": "colwise_gather_output", + "layers.*.self_attn.o_proj": "rowwise_split_input", "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", "layers.*.mlp.experts": "moe_tp_experts", diff --git a/src/transformers/models/minimax_m2/modular_minimax_m2.py b/src/transformers/models/minimax_m2/modular_minimax_m2.py index 13a656c7f218..de997c94b01d 100644 --- a/src/transformers/models/minimax_m2/modular_minimax_m2.py +++ b/src/transformers/models/minimax_m2/modular_minimax_m2.py @@ -132,11 +132,10 @@ class MiniMaxM2Config(PreTrainedConfig): model_type = "minimax_m2" keys_to_ignore_at_inference = ["past_key_values"] base_model_tp_plan = { - "layers.*.self_attn.q_proj": "colwise_rep", - "layers.*.self_attn.k_proj": "colwise_rep", - "layers.*.self_attn.v_proj": "colwise_rep", - "layers.*.self_attn.o_proj": "rowwise_rep", - "layers.*.mlp.gate": "colwise_rep", # we need to replicate here to correctly route experts + "layers.*.self_attn.q_proj": "colwise_gather_output", + "layers.*.self_attn.k_proj": "colwise_gather_output", + "layers.*.self_attn.v_proj": "colwise_gather_output", + "layers.*.self_attn.o_proj": "rowwise_split_input", "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", "layers.*.mlp.experts": "moe_tp_experts", diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index d53d9ae6ff32..f588a81e8c73 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -372,7 +372,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output - expert_output += shared_expert_output + expert_output = expert_output + shared_expert_output expert_output = expert_output.reshape(batch_size, sequence_length, hidden_dim) return expert_output diff --git a/src/transformers/models/qwen2_moe/modular_qwen2_moe.py b/src/transformers/models/qwen2_moe/modular_qwen2_moe.py index 473ca3fdd0a5..bbefd9ea8d77 100644 --- a/src/transformers/models/qwen2_moe/modular_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modular_qwen2_moe.py @@ -125,7 +125,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output - expert_output += shared_expert_output + expert_output = expert_output + shared_expert_output expert_output = expert_output.reshape(batch_size, sequence_length, hidden_dim) return expert_output diff --git a/src/transformers/models/qwen3/configuration_qwen3.py b/src/transformers/models/qwen3/configuration_qwen3.py index bf503a4e55d7..c95537a5b27a 100644 --- a/src/transformers/models/qwen3/configuration_qwen3.py +++ b/src/transformers/models/qwen3/configuration_qwen3.py @@ -111,6 +111,8 @@ class Qwen3Config(PreTrainedConfig): "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", diff --git a/src/transformers/models/qwen3_5/configuration_qwen3_5.py b/src/transformers/models/qwen3_5/configuration_qwen3_5.py index c4a0518d393c..9237759388de 100644 --- a/src/transformers/models/qwen3_5/configuration_qwen3_5.py +++ b/src/transformers/models/qwen3_5/configuration_qwen3_5.py @@ -115,6 +115,8 @@ class Qwen3_5TextConfig(PreTrainedConfig): "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index 12d6ed010554..0726c770ab5f 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -150,6 +150,8 @@ class Qwen3_5TextConfig(Qwen3NextConfig): "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", diff --git a/src/transformers/models/qwen3_5_moe/configuration_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/configuration_qwen3_5_moe.py index 13ee7ba77a42..0fddf3845855 100644 --- a/src/transformers/models/qwen3_5_moe/configuration_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/configuration_qwen3_5_moe.py @@ -126,8 +126,11 @@ class Qwen3_5MoeTextConfig(PreTrainedConfig): "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", "layers.*.mlp.shared_expert.gate_proj": "colwise", "layers.*.mlp.shared_expert.up_proj": "colwise", "layers.*.mlp.shared_expert.down_proj": "rowwise", diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 20e4cc92f512..b9183911408d 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -884,7 +884,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output - expert_output += shared_expert_output + expert_output = expert_output + shared_expert_output expert_output = expert_output.reshape(batch_size, sequence_length, hidden_dim) return expert_output diff --git a/src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py index 3369cf363ee9..af0a567ff27b 100644 --- a/src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modular_qwen3_5_moe.py @@ -159,8 +159,11 @@ class Qwen3_5MoeTextConfig(Qwen3NextConfig): "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", "layers.*.mlp.shared_expert.gate_proj": "colwise", "layers.*.mlp.shared_expert.up_proj": "colwise", "layers.*.mlp.shared_expert.down_proj": "rowwise", diff --git a/src/transformers/models/qwen3_moe/configuration_qwen3_moe.py b/src/transformers/models/qwen3_moe/configuration_qwen3_moe.py index 67f029ed5199..cd2a275bb575 100644 --- a/src/transformers/models/qwen3_moe/configuration_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/configuration_qwen3_moe.py @@ -127,9 +127,12 @@ class Qwen3MoeConfig(PreTrainedConfig): "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", diff --git a/src/transformers/models/qwen3_next/configuration_qwen3_next.py b/src/transformers/models/qwen3_next/configuration_qwen3_next.py index e73f5cf4acf3..0f63bad9ed39 100644 --- a/src/transformers/models/qwen3_next/configuration_qwen3_next.py +++ b/src/transformers/models/qwen3_next/configuration_qwen3_next.py @@ -135,12 +135,15 @@ class Qwen3NextConfig(PreTrainedConfig): "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", "layers.*.mlp.shared_expert.gate_proj": "colwise", "layers.*.mlp.shared_expert.up_proj": "colwise", "layers.*.mlp.shared_expert.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 2190ec03f83b..b078d6ff9a77 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -894,7 +894,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output - expert_output += shared_expert_output + expert_output = expert_output + shared_expert_output expert_output = expert_output.reshape(batch_size, sequence_length, hidden_dim) return expert_output diff --git a/src/transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py index 128a5622dade..e86850df778f 100644 --- a/src/transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py @@ -543,6 +543,8 @@ class Qwen3OmniMoeTalkerCodePredictorConfig(PreTrainedConfig): "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", @@ -729,9 +731,12 @@ class Qwen3OmniMoeTalkerTextConfig(PreTrainedConfig): "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 4a4d8a5029be..346b69bc989f 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -2868,7 +2868,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output - expert_output += shared_expert_output + expert_output = expert_output + shared_expert_output expert_output = expert_output.reshape(batch_size, sequence_length, hidden_dim) return expert_output diff --git a/src/transformers/models/t5gemma2/configuration_t5gemma2.py b/src/transformers/models/t5gemma2/configuration_t5gemma2.py index 87dd6cdc9aa7..154fd237b5de 100644 --- a/src/transformers/models/t5gemma2/configuration_t5gemma2.py +++ b/src/transformers/models/t5gemma2/configuration_t5gemma2.py @@ -104,6 +104,8 @@ class T5Gemma2TextConfig(PreTrainedConfig): "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", @@ -383,6 +385,8 @@ class T5Gemma2DecoderConfig(PreTrainedConfig): "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", + "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index ac2ce80bd8cc..10679b745601 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -272,6 +272,7 @@ def parse_int_from_env(key, default=None): _run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True) _run_agent_tests = parse_flag_from_env("RUN_AGENT_TESTS", default=False) _run_training_tests = parse_flag_from_env("RUN_TRAINING_TESTS", default=True) +_run_tensor_parallel_tests = parse_flag_from_env("RUN_TENSOR_PARALLEL_TESTS", default=True) def is_staging_test(test_case): @@ -338,6 +339,22 @@ def is_training_test(test_case): return pytest.mark.is_training_test()(test_case) +def is_tensor_parallel_test(test_case): + """ + Decorator marking a test as a tensor parallel test. If RUN_TENSOR_PARALLEL_TESTS is set to a falsy value, those + tests will be skipped. + """ + if not _run_tensor_parallel_tests: + return unittest.skip(reason="test is tensor parallel test")(test_case) + else: + try: + import pytest # We don't need a hard dependency on pytest in the main library + except ImportError: + return test_case + else: + return pytest.mark.is_tensor_parallel_test()(test_case) + + def slow(test_case): """ Decorator marking a test as slow. diff --git a/tests/causal_lm_tester.py b/tests/causal_lm_tester.py index de6d9dd3bb51..b3398f13c393 100644 --- a/tests/causal_lm_tester.py +++ b/tests/causal_lm_tester.py @@ -38,6 +38,7 @@ torch_device, ) from .test_pipeline_mixin import PipelineTesterMixin +from .test_tensor_parallel_mixin import TensorParallelTesterMixin from .test_training_mixin import TrainingTesterMixin @@ -305,7 +306,9 @@ def prepare_config_and_inputs_for_common(self): @require_torch -class CausalLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, TrainingTesterMixin): +class CausalLMModelTest( + ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, TrainingTesterMixin, TensorParallelTesterMixin +): model_tester_class = None all_model_classes = None pipeline_model_mapping = None diff --git a/tests/models/apertus/test_modeling_apertus.py b/tests/models/apertus/test_modeling_apertus.py index b89ccae80010..9bffa19e3cfc 100644 --- a/tests/models/apertus/test_modeling_apertus.py +++ b/tests/models/apertus/test_modeling_apertus.py @@ -40,6 +40,13 @@ class ApertusModelTester(CausalLMModelTester): if is_torch_available(): base_model_class = ApertusModel + def __init__(self, parent): + super().__init__(parent=parent) + # NOTE(3outeille): must be 0.0 for TP backward tests. In train mode, non-zero dropout causes + # different RNG states between the non-TP and TP model forward passes (they run sequentially), + # leading to different dropout masks and mismatched losses. + self.attention_probs_dropout_prob = 0.0 + @require_torch class ApertusModelTest(CausalLMModelTest, unittest.TestCase): diff --git a/tests/models/deepseek_v3/test_modeling_deepseek_v3.py b/tests/models/deepseek_v3/test_modeling_deepseek_v3.py index 796a1b1c51f6..5cbaef4b57ae 100644 --- a/tests/models/deepseek_v3/test_modeling_deepseek_v3.py +++ b/tests/models/deepseek_v3/test_modeling_deepseek_v3.py @@ -32,6 +32,7 @@ from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, ids_tensor from ...test_pipeline_mixin import PipelineTesterMixin +from ...test_tensor_parallel_mixin import TensorParallelTesterMixin if is_torch_available(): @@ -47,6 +48,9 @@ class DeepseekV3ModelTester: + if is_torch_available(): + causal_lm_class = DeepseekV3ForCausalLM + def __init__( self, parent, @@ -80,7 +84,10 @@ def __init__( hidden_act="silu", max_position_embeddings=512, initializer_range=0.02, - attention_probs_dropout_prob=0.1, + # NOTE(3outeille): must be 0.0 for TP backward tests. In train mode, non-zero dropout causes + # different RNG states between the non-TP and TP model forward passes (they run sequentially), + # leading to different dropout masks and mismatched losses. + attention_probs_dropout_prob=0.0, type_vocab_size=16, type_sequence_label_size=2, num_labels=3, @@ -207,7 +214,9 @@ def prepare_config_and_inputs_for_common(self): @require_torch -class DeepseekV3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class DeepseekV3ModelTest( + ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase, TensorParallelTesterMixin +): all_model_classes = ( ( DeepseekV3Model, diff --git a/tests/models/exaone4/test_modeling_exaone4.py b/tests/models/exaone4/test_modeling_exaone4.py index b2326321cc0f..3ad081ca529d 100644 --- a/tests/models/exaone4/test_modeling_exaone4.py +++ b/tests/models/exaone4/test_modeling_exaone4.py @@ -47,12 +47,23 @@ class Exaone4ModelTester(CausalLMModelTester): if is_torch_available(): base_model_class = Exaone4Model + def __init__(self, parent): + super().__init__(parent=parent) + # NOTE(3outeille): must be 0.0 for TP backward tests. In train mode, non-zero dropout causes + # different RNG states between the non-TP and TP model forward passes (they run sequentially), + # leading to different dropout masks and mismatched losses. + self.attention_probs_dropout_prob = 0.0 + @require_torch class Exaone4ModelTest(CausalLMModelTest, unittest.TestCase): model_tester_class = Exaone4ModelTester model_split_percents = [0.5, 0.6] + @unittest.skip("Exaone4 TP + quantized generation test needs fixing") + def test_tp_generation_quantized(self): + pass + @require_torch class Exaone4IntegrationTest(unittest.TestCase): diff --git a/tests/models/exaone_moe/test_modeling_exaone_moe.py b/tests/models/exaone_moe/test_modeling_exaone_moe.py index 95c7ccb50d51..f410637ba806 100644 --- a/tests/models/exaone_moe/test_modeling_exaone_moe.py +++ b/tests/models/exaone_moe/test_modeling_exaone_moe.py @@ -52,6 +52,10 @@ class ExaoneMoeModelTest(CausalLMModelTest, unittest.TestCase): model_tester_class = ExaoneMoeModelTester model_split_percents = [0.5, 0.8, 0.9] + @unittest.skip("ExaoneMoe TP + quantized generation test needs fixing") + def test_tp_generation_quantized(self): + pass + @slow @require_torch diff --git a/tests/models/flex_olmo/test_modeling_flex_olmo.py b/tests/models/flex_olmo/test_modeling_flex_olmo.py index 222d010e2ead..6d9a84dab6d1 100644 --- a/tests/models/flex_olmo/test_modeling_flex_olmo.py +++ b/tests/models/flex_olmo/test_modeling_flex_olmo.py @@ -41,6 +41,13 @@ class FlexOlmoModelTester(CausalLMModelTester): if is_torch_available(): base_model_class = FlexOlmoModel + def __init__(self, parent): + super().__init__(parent=parent) + # NOTE(3outeille): must be 0.0 for TP backward tests. In train mode, non-zero dropout causes + # different RNG states between the non-TP and TP model forward passes (they run sequentially), + # leading to different dropout masks and mismatched losses. + self.attention_probs_dropout_prob = 0.0 + @require_torch class FlexOlmoModelTest(CausalLMModelTest, unittest.TestCase): diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 26d431f650d1..0fbbe93f159c 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -61,6 +61,10 @@ class Gemma2ModelTest(CausalLMModelTest, unittest.TestCase): model_split_percents = [0.5, 0.6] model_tester_class = Gemma2ModelTester + @unittest.skip("Gemma2 tanh soft-capping amplifies TP numerical noise beyond 80% match threshold") + def test_tp_generation_quantized(self): + pass + @slow @require_torch_accelerator diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index eabee28c4ff9..42953c6052d5 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -70,6 +70,13 @@ class Gemma3TextModelTester(CausalLMModelTester): causal_lm_class = Gemma3ForCausalLM sequence_classification_class = Gemma3TextForSequenceClassification + def __init__(self, parent): + super().__init__(parent=parent) + # NOTE(3outeille): must be 0.0 for TP backward tests. In train mode, non-zero dropout causes + # different RNG states between the non-TP and TP model forward passes (they run sequentially), + # leading to different dropout masks and mismatched losses. + self.attention_probs_dropout_prob = 0.0 + @require_torch class Gemma3TextModelTest(CausalLMModelTest, unittest.TestCase): @@ -77,6 +84,10 @@ class Gemma3TextModelTest(CausalLMModelTest, unittest.TestCase): _is_stateful = True model_split_percents = [0.5, 0.6] + @unittest.skip("Gemma3 tanh soft-capping amplifies TP numerical noise beyond 80% match threshold") + def test_tp_generation_quantized(self): + pass + @unittest.skip("Gemma3 applies key/query norm which doesn't work with packing") def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): pass diff --git a/tests/models/gemma3n/test_modeling_gemma3n.py b/tests/models/gemma3n/test_modeling_gemma3n.py index 298e22bfb81c..a4da945939f3 100644 --- a/tests/models/gemma3n/test_modeling_gemma3n.py +++ b/tests/models/gemma3n/test_modeling_gemma3n.py @@ -270,7 +270,7 @@ def __init__( num_attention_heads=2, num_key_value_heads=2, altup_num_inputs=2, - intermediate_size=21, + intermediate_size=22, hidden_activation="gelu_pytorch_tanh", max_position_embeddings=512, type_vocab_size=16, @@ -314,6 +314,10 @@ def __init__( self.eos_token_id = eos_token_id self.head_dim = self.hidden_size // self.num_attention_heads self.is_decoder = is_decoder + # NOTE(3outeille): must be 0.0 for TP backward tests. In train mode, non-zero dropout causes + # different RNG states between the non-TP and TP model forward passes (they run sequentially), + # leading to different dropout masks and mismatched losses. + self.attention_probs_dropout_prob = 0.0 @require_torch @@ -321,6 +325,7 @@ class Gemma3nTextModelTest(CausalLMModelTest, unittest.TestCase): model_tester_class = Gemma3nTextModelTester _is_stateful = True model_split_percents = [0.5, 0.6] + training_overfit_steps = 400 def _check_hidden_states_for_generate( self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False diff --git a/tests/models/glm/test_modeling_glm.py b/tests/models/glm/test_modeling_glm.py index f38cd44dd9dc..60fbbcd1dd9e 100644 --- a/tests/models/glm/test_modeling_glm.py +++ b/tests/models/glm/test_modeling_glm.py @@ -43,6 +43,13 @@ class GlmModelTester(CausalLMModelTester): if is_torch_available(): base_model_class = GlmModel + def __init__(self, parent): + super().__init__(parent=parent) + # NOTE(3outeille): must be 0.0 for TP backward tests. In train mode, non-zero dropout causes + # different RNG states between the non-TP and TP model forward passes (they run sequentially), + # leading to different dropout masks and mismatched losses. + self.attention_dropout = 0.0 + @require_torch class GlmModelTest(CausalLMModelTest, unittest.TestCase): diff --git a/tests/models/glm4/test_modeling_glm4.py b/tests/models/glm4/test_modeling_glm4.py index ed4611f9cbde..a8f8057ee420 100644 --- a/tests/models/glm4/test_modeling_glm4.py +++ b/tests/models/glm4/test_modeling_glm4.py @@ -43,6 +43,13 @@ class Glm4ModelTester(CausalLMModelTester): if is_torch_available(): base_model_class = Glm4Model + def __init__(self, parent): + super().__init__(parent=parent) + # NOTE(3outeille): must be 0.0 for TP backward tests. In train mode, non-zero dropout causes + # different RNG states between the non-TP and TP model forward passes (they run sequentially), + # leading to different dropout masks and mismatched losses. + self.attention_dropout = 0.0 + @require_torch class Glm4ModelTest(CausalLMModelTest, unittest.TestCase): diff --git a/tests/models/glm4_moe_lite/test_modeling_glm4_moe_lite.py b/tests/models/glm4_moe_lite/test_modeling_glm4_moe_lite.py index 471d47002554..648557c6f299 100644 --- a/tests/models/glm4_moe_lite/test_modeling_glm4_moe_lite.py +++ b/tests/models/glm4_moe_lite/test_modeling_glm4_moe_lite.py @@ -63,6 +63,10 @@ class Glm4MoeModelTest(CausalLMModelTest, unittest.TestCase): test_all_params_have_gradient = False model_split_percents = [0.5, 0.7, 0.8] + @unittest.skip("MoE topk routing is too sensitive to Float8 quantization numerical noise") + def test_tp_generation_quantized(self): + pass + def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): """Needs to be overridden as GLM-4.7-Flash has special MLA cache format (though we don't really use the MLA)""" self.assertIsInstance(past_key_values, Cache) diff --git a/tests/models/glm_moe_dsa/test_modeling_glm_moe_dsa.py b/tests/models/glm_moe_dsa/test_modeling_glm_moe_dsa.py index f3857895ff8e..9a1e26800eec 100644 --- a/tests/models/glm_moe_dsa/test_modeling_glm_moe_dsa.py +++ b/tests/models/glm_moe_dsa/test_modeling_glm_moe_dsa.py @@ -76,6 +76,10 @@ class GlmMoeDsaModelTest(CausalLMModelTest, unittest.TestCase): test_all_params_have_gradient = False model_split_percents = [0.5, 0.7, 0.8] + @unittest.skip("Float8 quantization + TP numerical noise exceeds match threshold") + def test_tp_generation_quantized(self): + pass + def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): """Needs to be overridden as GLM-4.7-Flash has special MLA cache format (though we don't really use the MLA)""" self.assertIsInstance(past_key_values, Cache) diff --git a/tests/models/jais2/test_modeling_jais2.py b/tests/models/jais2/test_modeling_jais2.py index b1b3bbc72e4c..a224e702b006 100644 --- a/tests/models/jais2/test_modeling_jais2.py +++ b/tests/models/jais2/test_modeling_jais2.py @@ -53,6 +53,11 @@ class Jais2ModelTester(CausalLMModelTester): @require_torch class Jais2ModelTest(CausalLMModelTest, unittest.TestCase): model_tester_class = Jais2ModelTester + + @unittest.skip("Float8 quantization + TP numerical noise exceeds match threshold") + def test_tp_generation_quantized(self): + pass + all_model_classes = ( ( Jais2Model, diff --git a/tests/models/olmo_hybrid/test_modeling_olmo_hybrid.py b/tests/models/olmo_hybrid/test_modeling_olmo_hybrid.py index dec32a4d6813..37b2750a632d 100644 --- a/tests/models/olmo_hybrid/test_modeling_olmo_hybrid.py +++ b/tests/models/olmo_hybrid/test_modeling_olmo_hybrid.py @@ -64,6 +64,10 @@ class OlmoHybridModelTest(CausalLMModelTest, unittest.TestCase): model_tester_class = OlmoHybridModelTester rotary_embedding_layer = OlmoHybridRotaryEmbedding if is_torch_available() else None + @unittest.skip("Float8 quantization + TP numerical noise exceeds match threshold") + def test_tp_generation_quantized(self): + pass + # === Cache helper methods (same pattern as Qwen3Next) === def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): """OlmoHybrid has a special Cache as it alternates with gated deltanet layers""" diff --git a/tests/models/phi3/test_modeling_phi3.py b/tests/models/phi3/test_modeling_phi3.py index 6a1b337cc10b..efcb22dc3137 100644 --- a/tests/models/phi3/test_modeling_phi3.py +++ b/tests/models/phi3/test_modeling_phi3.py @@ -87,6 +87,13 @@ class Phi3ModelTester(CausalLMModelTester): if is_torch_available(): base_model_class = Phi3Model + def __init__(self, parent): + super().__init__(parent=parent) + # NOTE(3outeille): must be 0.0 for TP backward tests. In train mode, non-zero dropout causes + # different RNG states between the non-TP and TP model forward passes (they run sequentially), + # leading to different dropout masks and mismatched losses. + self.attention_dropout = 0.0 + @require_torch class Phi3ModelTest(CausalLMModelTest, unittest.TestCase): diff --git a/tests/models/qwen3/test_modeling_qwen3.py b/tests/models/qwen3/test_modeling_qwen3.py index aa9c1efa90df..6fa304662caf 100644 --- a/tests/models/qwen3/test_modeling_qwen3.py +++ b/tests/models/qwen3/test_modeling_qwen3.py @@ -46,6 +46,13 @@ class Qwen3ModelTester(CausalLMModelTester): if is_torch_available(): base_model_class = Qwen3Model + def __init__(self, parent): + super().__init__(parent=parent) + # NOTE(3outeille): must be 0.0 for TP backward tests. In train mode, non-zero dropout causes + # different RNG states between the non-TP and TP model forward passes (they run sequentially), + # leading to different dropout masks and mismatched losses. + self.attention_probs_dropout_prob = 0.0 + @require_torch class Qwen3ModelTest(CausalLMModelTest, unittest.TestCase): diff --git a/tests/models/qwen3_next/test_modeling_qwen3_next.py b/tests/models/qwen3_next/test_modeling_qwen3_next.py index 0f90b24d073f..29e5f51705de 100644 --- a/tests/models/qwen3_next/test_modeling_qwen3_next.py +++ b/tests/models/qwen3_next/test_modeling_qwen3_next.py @@ -43,6 +43,10 @@ class Qwen3NextModelTester(CausalLMModelTester): def __init__(self, parent): super().__init__(parent=parent) + # NOTE(3outeille): must be 0.0 for TP backward tests. In train mode, non-zero dropout causes + # different RNG states between the non-TP and TP model forward passes (they run sequentially), + # leading to different dropout masks and mismatched losses. + self.attention_probs_dropout_prob = 0.0 self.layer_types = ["linear_attention", "full_attention"] self.linear_conv_kernel_dim = 2 self.linear_key_head_dim = 16 diff --git a/tests/models/seed_oss/test_modeling_seed_oss.py b/tests/models/seed_oss/test_modeling_seed_oss.py index 1884e3c03b16..83aa6d013150 100644 --- a/tests/models/seed_oss/test_modeling_seed_oss.py +++ b/tests/models/seed_oss/test_modeling_seed_oss.py @@ -42,6 +42,15 @@ class SeedOssModelTester(CausalLMModelTester): if is_torch_available(): base_model_class = SeedOssModel + def __init__(self, parent): + super().__init__(parent=parent) + # NOTE(3outeille): must be 0.0 for TP backward tests. In train mode, non-zero dropout causes + # different RNG states between the non-TP and TP model forward passes (they run sequentially), + # leading to different dropout masks and mismatched losses. + self.attention_probs_dropout_prob = 0.0 + self.attention_dropout = 0.0 + self.residual_dropout = 0.0 + @require_torch class SeedOssModelTest(CausalLMModelTest, unittest.TestCase): diff --git a/tests/models/starcoder2/test_modeling_starcoder2.py b/tests/models/starcoder2/test_modeling_starcoder2.py index 2b799bfd9046..665c3d3b8f96 100644 --- a/tests/models/starcoder2/test_modeling_starcoder2.py +++ b/tests/models/starcoder2/test_modeling_starcoder2.py @@ -50,6 +50,10 @@ class Starcoder2ModelTester(CausalLMModelTester): class Starcoder2ModelTest(CausalLMModelTest, unittest.TestCase): model_tester_class = Starcoder2ModelTester + @unittest.skip("Float8 quantization + TP numerical noise exceeds match threshold") + def test_tp_generation_quantized(self): + pass + @slow @require_torch_accelerator diff --git a/tests/tensor_parallel/test_tensor_parallel.py b/tests/tensor_parallel/test_tensor_parallel.py index c52422f0a7ae..e173edf18e87 100644 --- a/tests/tensor_parallel/test_tensor_parallel.py +++ b/tests/tensor_parallel/test_tensor_parallel.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,108 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Run all tests: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -# Run dense tests: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "dense" -# Run MoE tests: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "moe" -# Collect tests: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py --collect-only -import os -import tempfile import warnings -import pytest -from safetensors import safe_open +import torch -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, is_torch_available, set_seed +from transformers import AutoModelForCausalLM from transformers.integrations.tensor_parallel import get_packed_weights, repack_weights -from transformers.testing_utils import ( - TestCasePlus, - backend_device_count, - get_torch_dist_unique_port, - require_huggingface_hub_greater_or_equal, - require_torch_multi_accelerator, - torch_device, -) -from transformers.utils import is_torch_greater_or_equal - - -# Tensor parallel tests require torch >= 2.9 for proper torch.compile support with distributed collectives -# Newer versions of PyTorch has torch.library.register_autograd in https://github.com/pytorch/pytorch/blob/8bcedd6e6029cce5f3a3731dd59be4941414c731/torch/distributed/_functional_collectives.py#L630 -# that fix the warning "autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it" -# NOTE(3outeille): need to double check if it works with older version of torch -pytestmark = pytest.mark.skipif( - not is_torch_greater_or_equal("2.9"), - reason="Tensor parallel tests require torch >= 2.9 for torch.compile support with distributed collectives", -) - - -if is_torch_available(): - import torch - import torch.distributed as dist - import torch.multiprocessing as mp - - -def get_packed_grad_shard(grad, world_size, rank, dim): - """Get the correct shard of a packed gradient (matching get_packed_weights interleaved logic). - - Packed weights like gate_up_proj are sharded with interleaving: - Original: [G0 G1 G2 G3 | U0 U1 U2 U3] (gate | up) - Rank 0: [G0 G1 | U0 U1] - Rank 1: [G2 G3 | U2 U3] - """ - total_size = grad.shape[dim] - # Packed weights have 2 blocks (gate and up) - block_size = total_size // 2 - shard_block_size = block_size // world_size - - # Build interleaved indices - indices = [] - for block_idx in range(2): # gate block, then up block - block_offset = block_idx * block_size - start = block_offset + rank * shard_block_size - stop = block_offset + (rank + 1) * shard_block_size - indices.extend(range(start, stop)) - - # Select along the sharded dimension - return grad.index_select(dim, torch.tensor(indices, device=grad.device)) - - -def global_wrapper(rank, func, tp, port, func_args, func_kwargs): - def setup_dist_env(rank, world_size, port): - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["RANK"] = str(rank) - os.environ["LOCAL_RANK"] = str(rank) - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = str(port) - - world_size = tp - setup_dist_env(rank, world_size, port) - - dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) - - func(rank, *func_args, **func_kwargs) - - dist.barrier() - dist.destroy_process_group() - - -def init_distributed(tp: int): - def _init_distributed(func): - def wrapper(*args, **kwargs): - world_size = tp - port = get_torch_dist_unique_port() - spawn_args = (func, tp, port, args, kwargs) - mp.spawn(global_wrapper, args=spawn_args, nprocs=world_size) - - return wrapper - - return _init_distributed - - -def skip_if_insufficient_devices(nproc_per_node): - """Skip test if there aren't enough devices available.""" - if backend_device_count(torch_device) < nproc_per_node: - pytest.skip(f"Need at least {nproc_per_node} devices, have {backend_device_count(torch_device)}") +from transformers.testing_utils import TestCasePlus class TestTensorParallelUtils(TestCasePlus): @@ -247,642 +152,3 @@ def test_tp_plan_none_handling(self): # Test setting a plan after None model.tp_plan = {"model.layers.*.self_attn.q_proj": "colwise"} self.assertEqual(model.tp_plan, {"model.layers.*.self_attn.q_proj": "colwise"}) - - -# ====== TEST FUNCTIONS ====== -def _test_model_dense_forward_impl(rank, mode, dtype=torch.float32): - """Implementation for comparing TP and non-TP model outputs.""" - model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM" - - set_seed(42) - - atol, rtol = (1e-5, 1e-5) - - # Load tokenizer and prepare inputs - same for both models - tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) - prompt = "Can I help" - inputs = tokenizer(prompt, return_tensors="pt") - - # Load TP model first to determine device - model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype, tp_plan="auto") - dist.barrier() - if mode == "eval": - model_tp.eval() - else: - model_tp.train() - - # Load non-TP model and move to same device as TP model - device = model_tp.device - model = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype) - model = model.to(device) - - if mode == "eval": - model.eval() - else: - model.train() - - # Prepare inputs on the same device - input_ids = inputs.input_ids.to(device) - - with torch.no_grad(): - outputs = model(input_ids) - logits = outputs.logits - - outputs_tp = model_tp(input_ids) - logits_tp = outputs_tp.logits - - diff = (logits - logits_tp).abs() - assert torch.allclose(logits, logits_tp, atol=atol, rtol=rtol), ( - f"TP and non-TP model outputs differ (dtype={dtype}). " - f"Max diff: {diff.max().item()} | Min diff: {diff.min().item()}" - ) - - dist.barrier() - - -def _test_model_dense_backward_pass_impl(rank, dtype=torch.float32): - """Implementation for comparing TP and non-TP model backward passes.""" - model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM" - - set_seed(42) - - # Set tolerance based on dtype - atol, rtol = (1e-5, 1e-5) - - model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype, tp_plan="auto") - dist.barrier() - model_tp.train() - - device = model_tp.device - model = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype) - model = model.to(device) - model.train() - - batch_size, seq_length = 2, 1024 - set_seed(42) - input_ids = torch.randint(0, model.config.vocab_size, (batch_size, seq_length)).to(device) - labels = torch.randint(0, model.config.vocab_size, (batch_size, seq_length)).to(device) - - outputs = model(input_ids, labels=labels) - loss = outputs.loss - loss.backward() - - outputs_tp = model_tp(input_ids, labels=labels) - loss_tp = outputs_tp.loss - loss_tp.backward() - - assert torch.allclose(loss, loss_tp, atol=atol, rtol=rtol), ( - f"TP and non-TP model losses differ (dtype={dtype}). Non-TP loss: {loss.item()}, TP loss: {loss_tp.item()}, Diff: {(loss - loss_tp).abs().item()}" - ) - - # Compare gradients for matching parameters - # Note: TP model may have sharded parameters, so we slice the reference gradient to match - world_size = dist.get_world_size() - for (name, param), (name_tp, param_tp) in zip(model.named_parameters(), model_tp.named_parameters()): - if param.grad is not None and param_tp.grad is not None: - grad = param.grad - grad_tp = param_tp.grad - - # Slice reference gradient to match local shard if parameter is sharded - if grad.shape != grad_tp.shape: - # Find the dimension that differs and slice accordingly - for dim in range(grad.ndim): - if grad.size(dim) != grad_tp.size(dim): - # Packed weights (gate_up_proj) use interleaved sharding - if "gate_up_proj" in name: - grad = get_packed_grad_shard(grad, world_size, rank, dim) - else: - # Regular weights use simple chunking - shard_size = grad_tp.size(dim) - start = rank * shard_size - grad = grad.narrow(dim, start, shard_size) - break - - assert torch.allclose(grad.cpu(), grad_tp.cpu(), atol=atol, rtol=rtol), ( - f"Gradients differ for parameter {name} (dtype={dtype}). Max diff: {(grad.cpu() - grad_tp.cpu()).abs().max().item()} | Min diff: {(grad.cpu() - grad_tp.cpu()).abs().min().item()}" - ) - - dist.barrier() - - -def _test_model_dense_forward_compile_impl(rank, mode, dtype=torch.float32): - """Implementation for comparing TP and non-TP model outputs with torch.compile.""" - model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM" - - set_seed(42) - - # Set tolerance based on dtype - atol, rtol = (1e-5, 1e-5) - - tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) - prompt = "Can I help" - inputs = tokenizer(prompt, return_tensors="pt") - - model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype, tp_plan="auto") - dist.barrier() - if mode == "eval": - model_tp.eval() - else: - model_tp.train() - - device = model_tp.device - model = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype) - model = model.to(device) - - if mode == "eval": - model.eval() - else: - model.train() - - # Compile both models - model.forward = torch.compile(model.forward) - model_tp.forward = torch.compile(model_tp.forward) - - input_ids = inputs.input_ids.to(device) - - with torch.no_grad(): - outputs = model(input_ids) - logits = outputs.logits - - outputs_tp = model_tp(input_ids) - logits_tp = outputs_tp.logits - - assert torch.allclose(logits, logits_tp, atol=atol, rtol=rtol), ( - f"TP and non-TP model outputs differ (dtype={dtype}). Max diff: {(logits - logits_tp).abs().max().item()} | Min diff: {(logits - logits_tp).abs().min().item()}" - ) - - dist.barrier() - - -def _test_model_dense_backward_compile_impl(rank, dtype=torch.float32): - """Implementation for comparing TP and non-TP model backward passes with torch.compile.""" - model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM" - - set_seed(42) - - # Set tolerance based on dtype - atol, rtol = (1e-5, 1e-5) - - model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype, tp_plan="auto") - dist.barrier() - model_tp.train() - - device = model_tp.device - model = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype) - model = model.to(device) - model.train() - - # Compile both models - model.forward = torch.compile(model.forward) - model_tp.forward = torch.compile(model_tp.forward) - - batch_size, seq_length = 2, 1024 - set_seed(42) - input_ids = torch.randint(0, model.config.vocab_size, (batch_size, seq_length)).to(device) - labels = torch.randint(0, model.config.vocab_size, (batch_size, seq_length)).to(device) - - outputs = model(input_ids, labels=labels) - loss = outputs.loss - loss.backward() - - outputs_tp = model_tp(input_ids, labels=labels) - loss_tp = outputs_tp.loss - loss_tp.backward() - - assert torch.allclose(loss, loss_tp, atol=atol, rtol=rtol), ( - f"TP and non-TP model losses differ (dtype={dtype}). Non-TP loss: {loss.item()}, TP loss: {loss_tp.item()}, Diff: {(loss - loss_tp).abs().item()}" - ) - - # Compare gradients for matching parameters - world_size = dist.get_world_size() - for (name, param), (name_tp, param_tp) in zip(model.named_parameters(), model_tp.named_parameters()): - if param.grad is not None and param_tp.grad is not None: - grad = param.grad - grad_tp = param_tp.grad - - # Slice reference gradient to match local shard if parameter is sharded - if grad.shape != grad_tp.shape: - for dim in range(grad.ndim): - if grad.size(dim) != grad_tp.size(dim): - # Packed weights (gate_up_proj) use interleaved sharding - if "gate_up_proj" in name: - grad = get_packed_grad_shard(grad, world_size, rank, dim) - else: - # Regular weights use simple chunking - shard_size = grad_tp.size(dim) - start = rank * shard_size - grad = grad.narrow(dim, start, shard_size) - break - - assert torch.allclose(grad.cpu(), grad_tp.cpu(), atol=atol, rtol=rtol), ( - f"Gradients differ for parameter {name} (dtype={dtype}). Max diff: {(grad.cpu() - grad_tp.cpu()).abs().max().item()}" - ) - - dist.barrier() - - -def _test_model_dense_save_impl(rank, tmp_dir): - """Implementation of test_model_save for distributed execution.""" - model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM" - - if dist.is_initialized(): - kwargs = {"tp_plan": "auto"} - result_dir = f"{tmp_dir}/tp" - else: - kwargs = {} - result_dir = f"{tmp_dir}/nontp" - - model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs) - model.save_pretrained(result_dir) - - -# ====== DENSE MODEL TESTS ====== -@pytest.mark.parametrize("nproc_per_node", [2]) -@pytest.mark.parametrize("mode", ["train", "eval"]) -@require_torch_multi_accelerator -def test_model_dense_forward(nproc_per_node, mode): - """Test that TP and non-TP models produce the same outputs.""" - skip_if_insufficient_devices(nproc_per_node) - init_distributed(tp=nproc_per_node)(_test_model_dense_forward_impl)(mode, torch.float32) - - -@pytest.mark.parametrize("nproc_per_node", [2]) -@require_torch_multi_accelerator -def test_model_dense_backward_pass(nproc_per_node): - """Test that TP and non-TP models produce the same gradients.""" - skip_if_insufficient_devices(nproc_per_node) - init_distributed(tp=nproc_per_node)(_test_model_dense_backward_pass_impl)(torch.float32) - - -@pytest.mark.parametrize("nproc_per_node", [2]) -@pytest.mark.parametrize("mode", ["train", "eval"]) -@require_torch_multi_accelerator -def test_model_dense_forward_compile(nproc_per_node, mode): - """Test that TP and non-TP models produce the same outputs with torch.compile.""" - skip_if_insufficient_devices(nproc_per_node) - init_distributed(tp=nproc_per_node)(_test_model_dense_forward_compile_impl)(mode, torch.float32) - - -@pytest.mark.parametrize("nproc_per_node", [2]) -@require_torch_multi_accelerator -def test_model_dense_backward_compile(nproc_per_node): - """Test that TP and non-TP models produce the same gradients with torch.compile.""" - skip_if_insufficient_devices(nproc_per_node) - init_distributed(tp=nproc_per_node)(_test_model_dense_backward_compile_impl)(torch.float32) - - -@pytest.mark.parametrize("nproc_per_node", [2]) -@require_huggingface_hub_greater_or_equal("0.31.4") -@require_torch_multi_accelerator -def test_model_dense_save(nproc_per_node): - """Test that TP model can be saved and matches non-TP version.""" - skip_if_insufficient_devices(nproc_per_node) - - with tempfile.TemporaryDirectory() as tmp_dir: - # First run with TP (distributed) - init_distributed(tp=nproc_per_node)(_test_model_dense_save_impl)(tmp_dir) - - # Then run without TP (non-distributed) - _test_model_dense_save_impl(0, tmp_dir) - - non_tp_model_path = os.path.join(tmp_dir, "nontp") - tp_model_path = os.path.join(tmp_dir, "tp") - - for filename in os.listdir(non_tp_model_path): - if not filename.endswith(".safetensors"): - continue - - non_tp_model = safe_open(os.path.join(non_tp_model_path, filename), device="cpu", framework="pt") - tp_model = safe_open(os.path.join(tp_model_path, filename), device="cpu", framework="pt") - for non_tp_key in non_tp_model.keys(): - non_tp_tensor = non_tp_model.get_tensor(non_tp_key) - tp_tensor = tp_model.get_tensor(non_tp_key) - assert torch.allclose(non_tp_tensor, tp_tensor), f"Tensor with key: {non_tp_key} does not match" - del non_tp_tensor, tp_tensor - - -def _test_model_moe_forward_impl(rank, mode, dtype=torch.float32): - """Implementation for comparing TP and non-TP MoE model outputs.""" - model_id = "hf-internal-testing/tiny-random-MixtralForCausalLM" - - set_seed(42) - - # Set tolerance based on dtype - atol, rtol = (1e-5, 1e-5) - - tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) - prompt = "Can I help" - inputs = tokenizer(prompt, return_tensors="pt") - - model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype, tp_plan="auto") - dist.barrier() - if mode == "eval": - model_tp.eval() - else: - model_tp.train() - - device = model_tp.device - model = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype) - model = model.to(device) - - if mode == "eval": - model.eval() - else: - model.train() - - input_ids = inputs.input_ids.to(device) - - with torch.no_grad(): - outputs = model(input_ids) - logits = outputs.logits - - outputs_tp = model_tp(input_ids) - logits_tp = outputs_tp.logits - - diff = (logits - logits_tp).abs() - assert torch.allclose(logits, logits_tp, atol=atol, rtol=rtol), ( - f"TP and non-TP MoE model outputs differ (dtype={dtype}). " - f"Max diff: {diff.max().item()} | Min diff: {diff.min().item()}" - ) - - dist.barrier() - - -def _test_model_moe_backward_pass_impl(rank, dtype=torch.float32): - """Implementation for comparing TP and non-TP MoE model backward passes.""" - model_id = "hf-internal-testing/tiny-random-MixtralForCausalLM" - - set_seed(42) - - atol, rtol = (1e-5, 1e-5) - - config = AutoConfig.from_pretrained(model_id) - - model_tp = AutoModelForCausalLM.from_pretrained(model_id, config=config, dtype=dtype, tp_plan="auto") - dist.barrier() - model_tp.train() - - device = model_tp.device - model = AutoModelForCausalLM.from_pretrained(model_id, config=config, dtype=dtype) - model = model.to(device) - model.train() - - batch_size, seq_length = 2, 1024 - set_seed(42) - input_ids = torch.randint(0, model.config.vocab_size, (batch_size, seq_length), device=device) - labels = torch.randint(0, model.config.vocab_size, (batch_size, seq_length), device=device) - - outputs = model(input_ids, labels=labels) - loss = outputs.loss - loss.backward() - - outputs_tp = model_tp(input_ids, labels=labels) - loss_tp = outputs_tp.loss - loss_tp.backward() - - assert torch.allclose(loss, loss_tp, atol=atol, rtol=rtol), ( - f"TP and non-TP MoE model losses differ (dtype={dtype}). Non-TP loss: {loss.item()}, TP loss: {loss_tp.item()}, Diff: {(loss - loss_tp).abs().item()}" - ) - - # Compare gradients for matching parameters - world_size = dist.get_world_size() - - for (name, param), (name_tp, param_tp) in zip(model.named_parameters(), model_tp.named_parameters()): - if param.grad is not None and param_tp.grad is not None: - grad = param.grad - grad_tp = param_tp.grad - - # Slice reference gradient to match local shard if parameter is sharded - if grad.shape != grad_tp.shape: - for dim in range(grad.ndim): - if grad.size(dim) != grad_tp.size(dim): - if "gate_up_proj" in name: - grad = get_packed_grad_shard(grad, world_size, rank, dim) - else: - shard_size = grad_tp.size(dim) - start = rank * shard_size - grad = grad.narrow(dim, start, shard_size) - break - - assert torch.allclose(grad.cpu(), grad_tp.cpu(), atol=atol, rtol=rtol), ( - f"Gradients differ for parameter {name} (dtype={dtype}). Max diff: {(grad.cpu() - grad_tp.cpu()).abs().max().item()}" - ) - - dist.barrier() - - -def _test_model_moe_forward_compile_impl(rank, mode, dtype=torch.float32, experts_implementation=None): - """Implementation for comparing TP and non-TP MoE model outputs with torch.compile.""" - model_id = "hf-internal-testing/tiny-random-MixtralForCausalLM" - - set_seed(42) - - if dtype == torch.bfloat16: - atol, rtol = (5e-3, 5e-3) - else: - atol, rtol = (1e-5, 1e-5) - - tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) - prompt = "Can I help" - inputs = tokenizer(prompt, return_tensors="pt") - - model_tp = AutoModelForCausalLM.from_pretrained( - model_id, dtype=dtype, tp_plan="auto", experts_implementation=experts_implementation - ) - dist.barrier() - if mode == "eval": - model_tp.eval() - else: - model_tp.train() - - device = model_tp.device - model = AutoModelForCausalLM.from_pretrained(model_id, dtype=dtype, experts_implementation=experts_implementation) - model = model.to(device) - - if mode == "eval": - model.eval() - else: - model.train() - - # Compile both models - model.forward = torch.compile(model.forward) - model_tp.forward = torch.compile(model_tp.forward) - - input_ids = inputs.input_ids.to(device) - - with torch.no_grad(): - outputs = model(input_ids) - logits = outputs.logits - - outputs_tp = model_tp(input_ids) - logits_tp = outputs_tp.logits - - assert torch.allclose(logits, logits_tp, atol=atol, rtol=rtol), ( - f"TP and non-TP MoE model outputs differ (dtype={dtype}). Max diff: {(logits - logits_tp).abs().max().item()} | Min diff: {(logits - logits_tp).abs().min().item()}" - ) - - dist.barrier() - - -def _test_model_moe_backward_compile_impl(rank, dtype=torch.float32, experts_implementation=None): - """Implementation for comparing TP and non-TP MoE model backward passes with torch.compile.""" - model_id = "hf-internal-testing/tiny-random-MixtralForCausalLM" - - set_seed(42) - - # bfloat16 has lower precision - if dtype == torch.bfloat16: - atol, rtol = (1e-3, 1e-3) - else: - atol, rtol = (1e-5, 1e-5) - - config = AutoConfig.from_pretrained(model_id) - - model_tp = AutoModelForCausalLM.from_pretrained( - model_id, config=config, dtype=dtype, tp_plan="auto", experts_implementation=experts_implementation - ) - dist.barrier() - model_tp.train() - - device = model_tp.device - model = AutoModelForCausalLM.from_pretrained( - model_id, config=config, dtype=dtype, experts_implementation=experts_implementation - ) - model = model.to(device) - model.train() - - model.forward = torch.compile(model.forward) - model_tp.forward = torch.compile(model_tp.forward) - - batch_size, seq_length = 2, 1024 - set_seed(42) - input_ids = torch.randint(0, model.config.vocab_size, (batch_size, seq_length)).to(device) - labels = torch.randint(0, model.config.vocab_size, (batch_size, seq_length)).to(device) - - outputs = model(input_ids, labels=labels) - loss = outputs.loss - loss.backward() - - outputs_tp = model_tp(input_ids, labels=labels) - loss_tp = outputs_tp.loss - loss_tp.backward() - - assert torch.allclose(loss, loss_tp, atol=atol, rtol=rtol), ( - f"TP and non-TP MoE model losses differ (dtype={dtype}). Non-TP loss: {loss.item()}, TP loss: {loss_tp.item()}, Diff: {(loss - loss_tp).abs().item()}" - ) - - # Compare gradients for matching parameters - world_size = dist.get_world_size() - - for (name, param), (name_tp, param_tp) in zip(model.named_parameters(), model_tp.named_parameters()): - if param.grad is not None and param_tp.grad is not None: - grad = param.grad - grad_tp = param_tp.grad - - # Slice reference gradient to match local shard if parameter is sharded - if grad.shape != grad_tp.shape: - for dim in range(grad.ndim): - if grad.size(dim) != grad_tp.size(dim): - if "gate_up_proj" in name: - grad = get_packed_grad_shard(grad, world_size, rank, dim) - else: - shard_size = grad_tp.size(dim) - start = rank * shard_size - grad = grad.narrow(dim, start, shard_size) - break - - assert torch.allclose(grad.cpu(), grad_tp.cpu(), atol=atol, rtol=rtol), ( - f"Gradients differ for parameter {name} (dtype={dtype}). Max diff: {(grad.cpu() - grad_tp.cpu()).abs().max().item()}" - ) - - dist.barrier() - - -def _test_model_moe_save_impl(rank, tmp_dir): - """Implementation of test_model_save for MoE model distributed execution.""" - model_id = "hf-internal-testing/tiny-random-MixtralForCausalLM" - - if dist.is_initialized(): - kwargs = {"tp_plan": "auto"} - result_dir = f"{tmp_dir}/tp" - else: - kwargs = {} - result_dir = f"{tmp_dir}/nontp" - - model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", **kwargs) - model.save_pretrained(result_dir) - - -# ====== MOE MODEL TESTS ====== -@pytest.mark.parametrize("nproc_per_node", [2]) -@pytest.mark.parametrize("mode", ["train", "eval"]) -@require_torch_multi_accelerator -def test_model_moe_forward(nproc_per_node, mode): - """Test that TP and non-TP MoE models produce the same outputs.""" - skip_if_insufficient_devices(nproc_per_node) - init_distributed(tp=nproc_per_node)(_test_model_moe_forward_impl)(mode, torch.float32) - - -@pytest.mark.parametrize("nproc_per_node", [2]) -@require_torch_multi_accelerator -def test_model_moe_backward_pass(nproc_per_node): - """Test that TP and non-TP MoE models produce the same gradients.""" - skip_if_insufficient_devices(nproc_per_node) - init_distributed(tp=nproc_per_node)(_test_model_moe_backward_pass_impl)(torch.float32) - - -@pytest.mark.parametrize("nproc_per_node", [2]) -@pytest.mark.parametrize("mode", ["train", "eval"]) -@pytest.mark.parametrize("dtype", ["float32", "bfloat16"]) -@pytest.mark.parametrize("experts_implementation", ["batched_mm", "grouped_mm"]) -@require_torch_multi_accelerator -def test_model_moe_forward_compile(nproc_per_node, mode, dtype, experts_implementation): - """Test that TP and non-TP MoE models produce the same outputs with torch.compile.""" - skip_if_insufficient_devices(nproc_per_node) - dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32 - init_distributed(tp=nproc_per_node)(_test_model_moe_forward_compile_impl)( - mode, dtype, experts_implementation=experts_implementation - ) - - -@pytest.mark.parametrize("nproc_per_node", [2]) -@pytest.mark.parametrize("dtype", ["float32", "bfloat16"]) -@pytest.mark.parametrize("experts_implementation", ["batched_mm", "grouped_mm"]) -@require_torch_multi_accelerator -def test_model_moe_backward_compile(nproc_per_node, dtype, experts_implementation): - """Test that TP and non-TP MoE models produce the same gradients with torch.compile.""" - skip_if_insufficient_devices(nproc_per_node) - dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32 - init_distributed(tp=nproc_per_node)(_test_model_moe_backward_compile_impl)( - dtype, experts_implementation=experts_implementation - ) - - -@pytest.mark.parametrize("nproc_per_node", [2]) -@require_huggingface_hub_greater_or_equal("0.31.4") -@require_torch_multi_accelerator -def test_model_moe_save(nproc_per_node): - """Test that TP MoE model can be saved and matches non-TP version.""" - skip_if_insufficient_devices(nproc_per_node) - - with tempfile.TemporaryDirectory() as tmp_dir: - # First run with TP (distributed) - init_distributed(tp=nproc_per_node)(_test_model_moe_save_impl)(tmp_dir) - - # Then run without TP (non-distributed) - _test_model_moe_save_impl(0, tmp_dir) - - non_tp_model_path = os.path.join(tmp_dir, "nontp") - tp_model_path = os.path.join(tmp_dir, "tp") - - for filename in os.listdir(non_tp_model_path): - if not filename.endswith(".safetensors"): - continue - - non_tp_model = safe_open(os.path.join(non_tp_model_path, filename), device="cpu", framework="pt") - tp_model = safe_open(os.path.join(tp_model_path, filename), device="cpu", framework="pt") - for non_tp_key in non_tp_model.keys(): - non_tp_tensor = non_tp_model.get_tensor(non_tp_key) - tp_tensor = tp_model.get_tensor(non_tp_key) - assert torch.allclose(non_tp_tensor, tp_tensor), f"Tensor with key: {non_tp_key} does not match" - del non_tp_tensor, tp_tensor diff --git a/tests/test_tensor_parallel_mixin.py b/tests/test_tensor_parallel_mixin.py new file mode 100644 index 000000000000..ce4413c264fc --- /dev/null +++ b/tests/test_tensor_parallel_mixin.py @@ -0,0 +1,484 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import socket +import tempfile +from abc import ABC, abstractmethod + +from transformers import TorchAoConfig, set_seed +from transformers.integrations.tensor_parallel import _get_parameter_tp_plan +from transformers.testing_utils import ( + is_tensor_parallel_test, + is_torch_available, +) +from transformers.utils import is_torch_greater_or_equal, is_torchao_available + + +if is_torchao_available(): + from torchao.quantization import Float8WeightOnlyConfig + + +if is_torch_available(): + import torch + import torch.distributed as dist + import torch.multiprocessing as mp + from torch.multiprocessing.spawn import ProcessRaisedException + + +def _find_free_port(): + """Find a free port by binding a socket and releasing it.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("localhost", 0)) + return s.getsockname()[1] + + +def get_packed_grad_shard(grad, world_size, rank, dim): + """Get the correct shard of a packed gradient (matching get_packed_weights interleaved logic). + + Packed weights like gate_up_proj are sharded with interleaving: + Original: [G0 G1 G2 G3 | U0 U1 U2 U3] (gate | up) + Rank 0: [G0 G1 | U0 U1] + Rank 1: [G2 G3 | U2 U3] + """ + total_size = grad.shape[dim] + # Packed weights have 2 blocks (gate and up) + block_size = total_size // 2 + shard_block_size = block_size // world_size + + # Build interleaved indices + indices = [] + for block_idx in range(2): # gate block, then up block + block_offset = block_idx * block_size + start = block_offset + rank * shard_block_size + stop = block_offset + (rank + 1) * shard_block_size + indices.extend(range(start, stop)) + + # Select along the sharded dimension + return grad.index_select(dim, torch.tensor(indices, device=grad.device)) + + +def _global_wrapper(rank, func, tp, port, func_args, func_kwargs): + """Wrapper to set up distributed environment and run the test function.""" + + def setup_dist_env(rank, world_size, port): + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + + world_size = tp + setup_dist_env(rank, world_size, port) + + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + + func(rank, *func_args, **func_kwargs) + + dist.barrier() + dist.destroy_process_group() + + +def _init_distributed(tp: int, max_retries: int = 5): + """Decorator to initialize distributed environment and spawn processes.""" + + def _init_distributed_inner(func): + def wrapper(*args, **kwargs): + world_size = tp + for attempt in range(max_retries): + port = _find_free_port() + spawn_args = (func, tp, port, args, kwargs) + try: + mp.spawn(_global_wrapper, args=spawn_args, nprocs=world_size) + return + except ProcessRaisedException as e: + if "EADDRINUSE" in str(e) and attempt < max_retries - 1: + continue + raise + + return wrapper + + return _init_distributed_inner + + +def _load_tp_and_reference_models(model_path, model_class): + """Load TP model and non-TP reference model for comparison. + + Returns: + tuple: (model_tp, model_ref, device) + """ + model_tp = model_class.from_pretrained(model_path, tp_plan="auto") + dist.barrier() + + device = model_tp.device + model_ref = model_class.from_pretrained(model_path) + model_ref = model_ref.to(device) + + return model_tp, model_ref, device + + +def _verify_tp_sharding(rank, model_tp, model_ref): + """Verify TP sharding by comparing parameter shapes between TP and reference models. + + Returns: + list: Names of sharded parameters + """ + world_size = dist.get_world_size() + sharded_params = [] + + for (name, param), (_, param_full) in zip(model_tp.named_parameters(), model_ref.named_parameters()): + if param.shape != param_full.shape: + sharded_params.append(name) + if rank == 0: + print(f"[TP Test Debug] TP sharded: {name} - full: {param_full.shape} -> sharded: {param.shape}") + + # Verify sharding is correct + for dim in range(param.ndim): + if param.size(dim) != param_full.size(dim): + param_plan = _get_parameter_tp_plan(name, model_tp.tp_plan, is_weight=True) + if param_plan in ("packed_colwise",): + expected_size = param_full.size(dim) // world_size + assert param.size(dim) == expected_size, ( + f"Packed weight {name} sharding incorrect: expected {expected_size}, got {param.size(dim)}" + ) + else: + expected_size = (param_full.size(dim) + world_size - 1) // world_size + assert param.size(dim) <= expected_size, ( + f"Weight {name} sharding incorrect: expected <= {expected_size}, got {param.size(dim)}" + ) + break + + return sharded_params + + +def _test_tp_forward_impl(_rank, model_path, model_class, atol, rtol): + """Implementation for comparing TP and non-TP model outputs.""" + set_seed(0) + + model_tp, model, device = _load_tp_and_reference_models(model_path, model_class) + + _verify_tp_sharding(_rank, model_tp, model) + + model_tp.eval() + model.eval() + + vocab_size = model.config.vocab_size + set_seed(0) + input_ids = torch.randint(0, vocab_size, (2, 64)).to(device) + + with torch.no_grad(): + logits = model(input_ids).logits + logits_tp = model_tp(input_ids).logits + + diff = (logits - logits_tp).abs() + assert torch.allclose(logits, logits_tp, atol=atol, rtol=rtol), ( + f"TP and non-TP model outputs differ. Max diff: {diff.max().item()} | Min diff: {diff.min().item()}" + ) + + dist.barrier() + + +def _test_tp_backward_impl(rank, model_path, model_class, atol, rtol): + """Implementation for comparing TP and non-TP model backward passes.""" + set_seed(0) + + model_tp, model, device = _load_tp_and_reference_models(model_path, model_class) + model_tp.train() + model.train() + + vocab_size = model.config.vocab_size + set_seed(0) + input_ids = torch.randint(0, vocab_size, (2, 64)).to(device) + set_seed(0) + labels = torch.randint(0, vocab_size, (2, 64)).to(device) + + loss = model(input_ids, labels=labels).loss + loss.backward() + + loss_tp = model_tp(input_ids, labels=labels).loss + loss_tp.backward() + + assert torch.allclose(loss, loss_tp, atol=atol, rtol=rtol), ( + f"TP and non-TP model losses differ. " + f"Non-TP loss: {loss.item()}, TP loss: {loss_tp.item()}, " + f"Diff: {(loss - loss_tp).abs().item()}" + ) + + # Compare gradients for matching parameters + world_size = dist.get_world_size() + failed_grads = {} + for (name, param), (_, param_tp) in zip(model.named_parameters(), model_tp.named_parameters()): + if param.grad is not None and param_tp.grad is not None: + grad = param.grad + grad_tp = param_tp.grad + + # Slice reference gradient to match local shard if parameter is sharded + if grad.shape != grad_tp.shape: + for dim in range(grad.ndim): + if grad.size(dim) != grad_tp.size(dim): + param_plan = _get_parameter_tp_plan(name, model_tp.tp_plan, is_weight=True) + if param_plan in ("packed_colwise",): + # interleaved slicing + grad = get_packed_grad_shard(grad, world_size, rank, dim) + else: + # regular slicing + shard_size = grad_tp.size(dim) + start = rank * shard_size + grad = grad.narrow(dim, start, shard_size) + break + + if not torch.allclose(grad.cpu(), grad_tp.cpu(), atol=atol, rtol=rtol): + failed_grads[name] = (grad.cpu() - grad_tp.cpu()).abs().max().item() + + assert not failed_grads, f"Gradients differ for {len(failed_grads)} parameter(s):\n" + "\n".join( + f" {name}: max diff = {diff}" for name, diff in failed_grads.items() + ) + + dist.barrier() + + +def _test_tp_generation_impl(_rank, model_path, model_class, atol, rtol, max_new_tokens): + """Implementation for comparing TP and non-TP model generation outputs (direct load path).""" + set_seed(0) + + model_tp, model, device = _load_tp_and_reference_models(model_path, model_class) + model_tp.eval() + model.eval() + + set_seed(0) + vocab_size = model.config.vocab_size + input_ids = torch.randint(0, vocab_size, (1, 10)).to(device) + generation_kwargs = { + "max_new_tokens": max_new_tokens, + "do_sample": False, + "num_beams": 1, + "output_scores": True, + "return_dict_in_generate": True, + "use_cache": True, + } + + with torch.no_grad(): + output = model.generate(input_ids, **generation_kwargs) + output_tp = model_tp.generate(input_ids, **generation_kwargs) + + # Compare logits/scores at each generation step + scores = torch.stack(output.scores) + scores_tp = torch.stack(output_tp.scores) + + diff = (scores - scores_tp).abs() + assert torch.allclose(scores, scores_tp, atol=atol, rtol=rtol), ( + f"TP and non-TP model generation logits differ (direct load path). " + f"Max diff: {diff.max().item()} | Mean diff: {diff.mean().item()}" + ) + + # Compare generated token sequences + assert torch.equal(output.sequences, output_tp.sequences), ( + f"TP and non-TP model generated different token sequences (direct load path). " + f"Non-TP: {output.sequences.tolist()} | TP: {output_tp.sequences.tolist()}" + ) + + dist.barrier() + + +def _test_tp_generation_quantized_impl(_rank, model_path, model_class, max_new_tokens): + """Implementation for comparing TP+quantized and non-TP quantized generation (sequence equality).""" + set_seed(0) + + quantization_config = TorchAoConfig(Float8WeightOnlyConfig()) + + model_tp = model_class.from_pretrained(model_path, tp_plan="auto", quantization_config=quantization_config) + dist.barrier() + + device = model_tp.device + model = model_class.from_pretrained(model_path, quantization_config=quantization_config) + model = model.to(device) + + model_tp.eval() + model.eval() + + vocab_size = model.config.vocab_size + set_seed(0) + input_ids = torch.randint(0, vocab_size, (1, 10)).to(device) + + generation_kwargs = { + "max_new_tokens": max_new_tokens, + "do_sample": False, + "num_beams": 1, + "output_scores": True, + "return_dict_in_generate": True, + "use_cache": True, + } + + with torch.no_grad(): + output = model.generate(input_ids, **generation_kwargs) + output_tp = model_tp.generate(input_ids, **generation_kwargs) + + print(f"[Rank {_rank}] Non-TP-quantized model tokens: {output.sequences[0].tolist()}") + print(f"[Rank {_rank}] TP-quantized tokens: {output_tp.sequences[0].tolist()}") + print(f"[Rank {_rank}] Sequences match: {torch.equal(output.sequences, output_tp.sequences)}") + + # Compare generated token sequences (allow up to 25% mismatch due to Float8 quantization + # scale differences between full-weight and sharded-weight quantization) + # NOTE(3outeille): Some models have no perfect match. Investigate better the discrepancy but for now low priority. + seq = output.sequences[0] + seq_tp = output_tp.sequences[0] + min_len = min(len(seq), len(seq_tp)) + match_count = (seq[:min_len] == seq_tp[:min_len]).sum().item() + match_ratio = match_count / max(len(seq), len(seq_tp)) + assert match_ratio >= 0.75, ( + f"non-TP-quantized + TP-quantized model generated too many different tokens " + f"(match ratio: {match_ratio:.2%}, threshold: 75%).\n" + f"Non-TP+quantized: {output.sequences.tolist()} \n TP+quantized: {output_tp.sequences.tolist()}" + ) + + dist.barrier() + + +class TensorParallelTesterMixin(ABC): + """ + Mixin for tensor parallel tests. Add to model test classes alongside ModelTesterMixin. + + The model_tester (e.g., CausalLMModelTester) already provides: + - get_config() -> tiny model config + - causal_lm_class, base_model_class, etc. + + This mixin adds tensor parallel-specific tests using that infrastructure. + """ + + # ============================================================ + # Configuration (can be overridden per model) + # ============================================================ + tensor_parallel_size: int = 2 + tensor_parallel_atol: float = 1e-5 + tensor_parallel_rtol: float = 1e-5 + + @property + @abstractmethod + def model_tester(self): + """The model tester instance (e.g., CausalLMModelTester).""" + ... + + # ============================================================ + # Helper methods + # ============================================================ + def _has_tp_plan(self) -> bool: + """Check if model has a tensor parallel plan defined.""" + config = self.model_tester.get_config() + return hasattr(config, "base_model_tp_plan") and config.base_model_tp_plan is not None + + def _get_tp_model_class(self): + """Get the model class to use for TP tests (prefers *ForCausalLM).""" + if hasattr(self.model_tester, "causal_lm_class") and self.model_tester.causal_lm_class is not None: + return self.model_tester.causal_lm_class + return self.all_model_classes[0] + + def _skip_if_not_supported(self): + """Check and skip test if TP is not supported for this model/environment.""" + if not is_torch_greater_or_equal("2.9"): + self.skipTest("Tensor parallel tests require torch >= 2.9") + + if torch.cuda.is_available(): + self.skipTest("Tensor parallel mixin tests are CPU-only and should not run on GPU machines") + + if os.cpu_count() < self.tensor_parallel_size: + self.skipTest( + f"Tensor parallel tests require at least {self.tensor_parallel_size} CPUs, " + f"but only {os.cpu_count()} available" + ) + + if not hasattr(self.model_tester, "causal_lm_class") or self.model_tester.causal_lm_class is None: + self.skipTest("Model tester does not have causal_lm_class (not using CausalLMModelTester)") + + if not self._has_tp_plan(): + self.skipTest("Model does not have a tensor parallel plan (base_model_tp_plan)") + + # # Skip encoder-decoder models (TP not supported) + # if getattr(self, "is_encoder_decoder", False): + # self.skipTest("TP tests not supported for encoder-decoder models") + + # # Skip VLM models for now + # config = self.model_tester.get_config() + # if hasattr(config, "vision_config") and config.vision_config is not None: + # self.skipTest("VLM models are not yet supported in TP tests") + + @is_tensor_parallel_test + def test_tp_forward(self): + self._skip_if_not_supported() + + config = self.model_tester.get_config() + model_class = self._get_tp_model_class() + atol = self.tensor_parallel_atol + rtol = self.tensor_parallel_rtol + + with tempfile.TemporaryDirectory() as tmp_dir: + set_seed(42) + model = model_class(config) + model.save_pretrained(tmp_dir, save_original_format=True) + + _init_distributed(tp=self.tensor_parallel_size)(_test_tp_forward_impl)(tmp_dir, model_class, atol, rtol) + + @is_tensor_parallel_test + def test_tp_backward(self): + self._skip_if_not_supported() + + config = self.model_tester.get_config() + model_class = self._get_tp_model_class() + atol = self.tensor_parallel_atol + rtol = self.tensor_parallel_rtol + + with tempfile.TemporaryDirectory() as tmp_dir: + set_seed(42) + model = model_class(config) + model.save_pretrained(tmp_dir, save_original_format=True) + + _init_distributed(tp=self.tensor_parallel_size)(_test_tp_backward_impl)(tmp_dir, model_class, atol, rtol) + + @is_tensor_parallel_test + def test_tp_generation(self): + # Test TP generation: unfused checkpoint → conversion mapping (if needed) → TP sharding → model → generate + self._skip_if_not_supported() + + config = self.model_tester.get_config() + + model_class = self._get_tp_model_class() + atol = self.tensor_parallel_atol + rtol = self.tensor_parallel_rtol + max_new_tokens = 25 + + with tempfile.TemporaryDirectory() as tmp_dir: + set_seed(42) + model = model_class(config) + model.save_pretrained(tmp_dir, save_original_format=True) + _init_distributed(tp=self.tensor_parallel_size)(_test_tp_generation_impl)( + tmp_dir, model_class, atol, rtol, max_new_tokens + ) + + @is_tensor_parallel_test + def test_tp_generation_quantized(self): + self._skip_if_not_supported() + + if not is_torchao_available(): + self.skipTest("Test requires torchao") + + config = self.model_tester.get_config() + model_class = self._get_tp_model_class() + max_new_tokens = 25 + + with tempfile.TemporaryDirectory() as tmp_dir: + set_seed(42) + model = model_class(config) + model.save_pretrained(tmp_dir, save_original_format=True) + + _init_distributed(tp=self.tensor_parallel_size)(_test_tp_generation_quantized_impl)( + tmp_dir, model_class, max_new_tokens + ) diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index 5909e44b5c3f..cf2aebf244d2 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -1069,6 +1069,7 @@ def parse_commit_message(commit_message: str) -> dict[str, bool]: "tests_hub": r"tests/.*", "tests_non_model": r"tests/[^/]*?/test_.*\.py", "tests_training_ci": r"tests/models/.*/test_modeling_.*", + "tests_tensor_parallel_ci": r"tests/models/.*/test_modeling_.*", }