From 446a519cb3417cb80aea3f83bb51b5745bdcce38 Mon Sep 17 00:00:00 2001 From: 3outeille Date: Mon, 13 Apr 2026 14:27:08 +0000 Subject: [PATCH 1/7] from_pretrained orchestration + save/load MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add gather_full_state_dict() for DTensor→full tensor saving - Add convert_strided_to_shard() / restore_strided_from_shard() for DCP - Add _redistribute_dtensor() helper - Full distributed_config integration in from_pretrained/save_pretrained - Rename apply_fsdp2 → apply_fully_shard_data_parallel - save_optimizer() / load_optimizer() in distributed/utils - Trainer integration with distributed_config - Updated FSDP and TP tests for new orchestration API - DTensor shard-on-read test updates --- src/transformers/distributed/utils.py | 87 +++++- src/transformers/integrations/__init__.py | 12 +- src/transformers/integrations/fsdp.py | 57 +--- .../integrations/tensor_parallel.py | 78 ++++++ src/transformers/modeling_utils.py | 262 +++++++++--------- src/transformers/trainer.py | 7 +- tests/test_fsdp_mixin.py | 24 +- tests/utils/test_core_model_loading.py | 246 +++++++++++++--- 8 files changed, 520 insertions(+), 253 deletions(-) diff --git a/src/transformers/distributed/utils.py b/src/transformers/distributed/utils.py index 40278c8fcf2a..408bf33d65fb 100644 --- a/src/transformers/distributed/utils.py +++ b/src/transformers/distributed/utils.py @@ -16,14 +16,19 @@ import os from typing import TYPE_CHECKING -from ..utils import is_torch_available, strtobool +from ..utils import is_torch_available, is_torch_greater_or_equal, strtobool if TYPE_CHECKING: import torch.nn as nn + from .configuration_utils import DistributedConfig + if is_torch_available(): import torch + import torch.distributed.checkpoint as dcp + + from ..integrations.tensor_parallel import convert_strided_to_shard, restore_strided_from_shard def is_fsdp_enabled() -> bool: @@ -48,3 +53,83 @@ def is_fsdp_managed_module(module: nn.Module) -> bool: except ImportError: return False return isinstance(module, FullyShardedDataParallel) or getattr(module, "_is_fsdp_managed_module", False) + + +def _ensure_torch_distributed(device_type: str): + """Initialize torch.distributed if not already initialized.""" + if not torch.distributed.is_initialized(): + try: + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "xccl", "hpu": "hccl"} + backend = backend_map.get(device_type) + + torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size) + current_device = getattr(torch, device_type) + if device_type != "cpu": + current_device.set_device(local_rank) + except Exception as e: + raise OSError( + "We tried to initialize torch.distributed for you, but it failed. Make " + "sure you init torch distributed in your script to use distributed training." + ) from e + + +def init_device_mesh(distributed_config: DistributedConfig) -> torch.distributed.device_mesh.DeviceMesh: + if not is_torch_greater_or_equal("2.5"): + raise OSError("Distributed training with DistributedConfig requires `torch>=2.5`.") + + device_type = torch._C._get_accelerator().type + _ensure_torch_distributed(device_type) + + world_size = torch.distributed.get_world_size() + if device_type != "cpu": + getattr(torch, device_type).set_device(int(os.environ.get("LOCAL_RANK", 0))) + + tp_size = distributed_config.tp_size + fsdp_size = distributed_config.fsdp_size + + assert world_size == tp_size * fsdp_size, f"world_size ({world_size}) must be equal to tp_size ({tp_size}) * fsdp_size ({fsdp_size})" + + dims, names = [], [] + if fsdp_size > 1: + dims.append(fsdp_size) + names.append("fsdp") + if tp_size > 1: + dims.append(tp_size) + names.append("tp") + + # Build from a 1D world mesh via _unflatten so that PyTorch can flatten + # sub-dimensions back when needed (e.g. for single all_reduce across + # [fsdp, tp] during grad norm computation instead of 2 sequential ones). + world_mesh = torch.distributed.init_device_mesh(device_type, (world_size,), mesh_dim_names=("world",)) + mesh = world_mesh._unflatten(0, tuple(dims), tuple(names)) + + # Pre-create flattened sub-mesh for multi-dimensional meshes so DTensor + # can use a single collective instead of sequential per-dimension ones. + if len(dims) > 1: + mesh._flatten("_".join(names)) + + return mesh + + +def save_optimizer(optimizer, checkpoint_dir: str) -> None: + # Save optimizer state via DCP, handling _StridedShard placements transparently. + osd = optimizer.state_dict() + placement_map = convert_strided_to_shard(osd) + dcp.save({"optimizer": osd}, checkpoint_id=checkpoint_dir) + if placement_map and torch.distributed.get_rank() == 0: + torch.save(placement_map, os.path.join(checkpoint_dir, "placement_map.pt")) + + +def load_optimizer(optimizer, checkpoint_dir: str) -> None: + # Load optimizer state via DCP, restoring _StridedShard placements transparently. + osd = optimizer.state_dict() + dcp.load({"optimizer": osd}, checkpoint_id=checkpoint_dir) + pmap_path = os.path.join(checkpoint_dir, "placement_map.pt") + if os.path.exists(pmap_path): + placement_map = torch.load(pmap_path, weights_only=False) + restore_strided_from_shard(osd, placement_map) + optimizer.load_state_dict(osd) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index e3515eab24b1..5e7518dea08b 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -160,11 +160,7 @@ "convert_and_export_with_cache", ] -_import_structure["tensor_parallel"] = [ - "shard_and_distribute_module", - "ALL_PARALLEL_STYLES", - "translate_to_torch_parallel_style", -] +_import_structure["tensor_parallel"] = [] try: if not is_torch_greater_or_equal("2.5"): raise OptionalDependencyNotAvailable() @@ -305,12 +301,6 @@ else: from .executorch import TorchExportableModuleWithStaticCache, convert_and_export_with_cache - from .tensor_parallel import ( - ALL_PARALLEL_STYLES, - shard_and_distribute_module, - translate_to_torch_parallel_style, - ) - try: if not is_torch_greater_or_equal("2.5"): raise OptionalDependencyNotAvailable() diff --git a/src/transformers/integrations/fsdp.py b/src/transformers/integrations/fsdp.py index 10936490ee03..368bbfc02854 100644 --- a/src/transformers/integrations/fsdp.py +++ b/src/transformers/integrations/fsdp.py @@ -24,12 +24,8 @@ if is_torch_available() and is_torch_greater_or_equal("2.5"): import torch import torch.distributed as dist - import torch.distributed.checkpoint as dcp from torch.distributed._composable.fsdp import fully_shard - from torch.distributed.checkpoint.hf_storage import HuggingFaceStorageWriter - from torch.distributed.checkpoint.state_dict import get_model_state_dict from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy - from torch.distributed.tensor import DTensor logger = logging.get_logger(__name__) @@ -359,9 +355,9 @@ def _get_manual_plan_modules(fsdp_plan: dict[str, Any]) -> dict[str, list[str]]: return modules -def apply_fsdp2( +def apply_fully_shard_data_parallel( model, - device_mesh, + fsdp_mesh, fsdp_plan: dict[str, Any] | str | None, ): """ @@ -397,10 +393,10 @@ def apply_fsdp2( if not is_torch_greater_or_equal("2.5"): raise OSError("FSDP2 requires torch>=2.5") - if device_mesh is None: - raise ValueError("device_mesh is required for FSDP2") + if fsdp_plan is None: + return model - if isinstance(fsdp_plan, str): + if fsdp_plan == "auto": fsdp_plan = {"mode": fsdp_plan} input_embed = getattr(model, "get_input_embeddings", lambda: None)() @@ -427,17 +423,17 @@ def apply_fsdp2( "Could not auto-detect transformer block classes for FSDP. Applying FSDP only to root module." ) else: - _auto_shard_input_embedding(input_embed, is_weights_tied, device_mesh, auto_policy_kwargs) + _auto_shard_input_embedding(input_embed, is_weights_tied, fsdp_mesh, auto_policy_kwargs) - _auto_shard_transformer_blocks(model, block_classes, device_mesh, auto_policy_kwargs) + _auto_shard_transformer_blocks(model, block_classes, fsdp_mesh, auto_policy_kwargs) tail_modules = _auto_get_tail_modules( model, decoder_layer_names, input_embed, output_embed, is_weights_tied ) - _auto_shard_tail_modules(tail_modules, device_mesh, auto_policy_kwargs) + _auto_shard_tail_modules(tail_modules, fsdp_mesh, auto_policy_kwargs) # Shard root model - fully_shard(model, mesh=device_mesh, **auto_policy_kwargs) + fully_shard(model, mesh=fsdp_mesh, **auto_policy_kwargs) logger.info( f"FSDP2 applied to model: {len(block_classes)} block type(s), {len(decoder_layer_names)} decoder layers" @@ -468,7 +464,7 @@ def apply_fsdp2( for name, module in _iter_manual_plan_targets(model, pattern, name_to_module, already_sharded_names): if name in already_sharded_names: continue - shard_kwargs = {"mesh": device_mesh, "reshard_after_forward": reshard} + shard_kwargs = {"mesh": fsdp_mesh, "reshard_after_forward": reshard} if mp_policy is not None: shard_kwargs["mp_policy"] = mp_policy if offload_policy is not None: @@ -480,7 +476,7 @@ def apply_fsdp2( # Shard root model with the same policies as sub-modules. # MixedPrecisionPolicy.output_dtype casting happens in post_forward # for every fully_shard-wrapped module, even with no direct parameters. - fully_shard(model, mesh=device_mesh, mp_policy=root_mp_policy, offload_policy=root_offload_policy) + fully_shard(model, mesh=fsdp_mesh, mp_policy=root_mp_policy, offload_policy=root_offload_policy) # Used by generation code to detect FSDP and enable synced_gpus. model._is_fsdp_managed_module = True @@ -495,37 +491,6 @@ def apply_fsdp2( return model - -# TODO(3outeille): probably remove this function. Will be handled when someone tackle PEFT + FSDP. -def save_fsdp_model(model, save_directory): - """Save FSDP2 model weights as HF safetensors via DCP distributed save + consolidation. - - Each rank saves its DTensor shard in parallel, then rank 0 consolidates - into standard HF-compatible safetensors files. - """ - model_sd = get_model_state_dict(model) - - # Clone tensors sharing storage (tied weights) — safetensors refuses aliased tensors - seen_data_ptrs = {} - for key in list(model_sd.keys()): - tensor = model_sd[key] - t = tensor._local_tensor if isinstance(tensor, DTensor) else tensor - ptr = t.data_ptr() - if ptr in seen_data_ptrs: - model_sd[key] = tensor.clone() - else: - seen_data_ptrs[ptr] = key - - dcp.save( - model_sd, - storage_writer=HuggingFaceStorageWriter( - path=save_directory, - save_distributed=True, - enable_consolidation=True, - ), - ) - - # ========================= PEFT compatibility ========================= # TODO(3outeille): make sure new FSDP works with PEFT def get_fsdp_ckpt_kwargs(): diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 9dd0906c439e..31a67e1595f1 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -94,6 +94,84 @@ def _to_cpu_fresh(tensor: torch.Tensor) -> torch.Tensor: return out.contiguous() +def gather_full_state_dict(model) -> dict[str, torch.Tensor]: + """Gather all sharded params to full plain tensors for saving. + + Handles FSDP unshard and TP DTensor gather. + Streams one parameter at a time to avoid holding all full tensors on GPU. + Only rank 0 accumulates the result; other ranks return ``{}``. + """ + tp_size = model.tp_size + is_rank0 = dist.get_rank() == 0 + + # Get state dict — FSDP unshard if needed (returns DTensors, not full tensors) + if getattr(model, "_is_fsdp_managed_module", False): + from torch.distributed.checkpoint.state_dict import get_model_state_dict + + state_dict = get_model_state_dict(model) + else: + state_dict = model.state_dict() + + # No TP — materialize on rank 0 only + if tp_size is None: + if is_rank0: + return {k: _to_cpu_fresh(v) for k, v in state_dict.items()} + return {} + + # Stream: gather one param at a time, only rank 0 keeps the CPU copy + result = {} + for key, tensor in state_dict.items(): + if isinstance(tensor, DTensor): + # All ranks participate in the collective, only rank 0 keeps the result + with torch.no_grad(): + full = tensor.redistribute(placements=[Replicate()] * tensor.device_mesh.ndim, async_op=False).to_local() + if is_rank0: + result[key] = _to_cpu_fresh(full) + del full + elif is_rank0: + result[key] = _to_cpu_fresh(tensor) + + return result + + +def _redistribute_dtensor(tensor: DTensor, target_placements: tuple) -> DTensor: + """Redistribute a DTensor via Replicate as an intermediate step. + + PyTorch doesn't implement all placement conversions (e.g. _StridedShard↔Shard). + Going through Replicate first is always supported. + """ + with torch.no_grad(): + replicated = tensor.redistribute(placements=[Replicate()] * tensor.device_mesh.ndim) + return replicated.redistribute(placements=target_placements) + +def convert_strided_to_shard(state_dict: dict) -> dict[str, tuple]: + # Convert _StridedShard DTensors in a state dict to plain Shard for DCP compatibility. + placement_map: dict[str, tuple] = {} + for key, value in state_dict.items(): + if isinstance(value, dict): + nested = convert_strided_to_shard(value) + for nk, nv in nested.items(): + placement_map[f"{key}.{nk}"] = nv + elif isinstance(value, DTensor) and any(isinstance(p, _StridedShard) for p in value.placements): + placement_map[key] = tuple(value.placements) + shard_placements = tuple(Shard(p.dim) if isinstance(p, _StridedShard) else p for p in value.placements) + state_dict[key] = _redistribute_dtensor(value, shard_placements) + return placement_map + + +def restore_strided_from_shard(state_dict: dict, placement_map: dict[str, tuple]) -> None: + # Restore _StridedShard placements after dcp.load. + def _resolve(d, dotted_key): + parts = dotted_key.split(".", 1) + if len(parts) == 2 and parts[0] in d and isinstance(d[parts[0]], dict): + return _resolve(d[parts[0]], parts[1]) + return d, dotted_key + + for key, original_placements in placement_map.items(): + container, leaf_key = _resolve(state_dict, key) + if leaf_key in container and isinstance(container[leaf_key], DTensor): + container[leaf_key] = _redistribute_dtensor(container[leaf_key], original_placements) + def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str | TPStyle] | None): """ Verify the TP plan of the model, log a warning if the layers that were not sharded and the rules that were not applied. diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b6feec6f7ba6..0ebd98b2836a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -51,7 +51,7 @@ revert_weight_conversion, ) from .distributed import DistributedConfig -from .distributed.utils import is_fsdp_enabled +from .distributed.utils import init_device_mesh, is_fsdp_enabled from .dynamic_module_utils import custom_object_save from .generation import CompileConfig, GenerationConfig from .integrations import PeftAdapterMixin, deepspeed_config, hub_kernels, is_deepspeed_zero3_enabled @@ -69,18 +69,15 @@ from .integrations.flash_attention import flash_attention_forward from .integrations.flash_paged import paged_attention_forward from .integrations.flex_attention import flex_attention_forward -from .integrations.fsdp import initialize_fsdp +from .integrations.fsdp import apply_fully_shard_data_parallel from .integrations.hub_kernels import allow_all_hub_kernels, is_kernel from .integrations.peft import maybe_load_adapters from .integrations.sdpa_attention import sdpa_attention_forward from .integrations.sdpa_paged import sdpa_attention_paged_forward from .integrations.tensor_parallel import ( - ALL_PARALLEL_STYLES, _get_parameter_tp_plan, - distribute_model, - gather_state_dict_for_save, - initialize_tensor_parallelism, - shard_and_distribute_module, + apply_tensor_parallel, + gather_full_state_dict, verify_tp_plan, ) from .loss.loss_utils import LOSS_MAPPING @@ -139,8 +136,6 @@ from accelerate.utils import extract_model_from_parallel -_torch_distributed_available = torch.distributed.is_available() - if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp from smdistributed.modelparallel import __version__ as SMP_VERSION @@ -1352,14 +1347,6 @@ def tp_plan(self, plan: dict[str, str] | None): if not isinstance(plan, dict): raise ValueError("Can only set a dictionary as `tp_plan`") - # Ensure the styles are all valid - for layer_pattern, parallel_style in plan.items(): - if parallel_style not in ALL_PARALLEL_STYLES: - raise ValueError( - f"Unsupported tensor parallel style '{parallel_style}' for layer '{layer_pattern}'. " - f"Supported styles are {list(ALL_PARALLEL_STYLES.keys())}" - ) - # Validate that the layer patterns match existing model structure. We check this by getting all parameter # names and seeing if any match the patterns model_param_names = [name for name, _ in self.named_parameters()] @@ -1928,11 +1915,10 @@ def get_correct_attn_implementation(self, requested_attention: str | None, is_in def get_correct_experts_implementation(self, requested_experts: str | None) -> str: applicable_experts = "grouped_mm" if requested_experts is None else requested_experts - if applicable_experts not in ["eager", "grouped_mm", "batched_mm", "deepgemm"]: + if applicable_experts not in ["eager", "grouped_mm", "batched_mm"]: message = ( f'Specified `experts_implementation="{applicable_experts}"` is not supported. The only possible arguments are ' - '`experts_implementation="eager"`, `"experts_implementation=grouped_mm"`, `"experts_implementation=batched_mm"` ' - 'and `"experts_implementation=deepgemm"`.' + '`experts_implementation="eager"`, `"experts_implementation=grouped_mm"` and `"experts_implementation=batched_mm"`.' ) raise ValueError(message) @@ -2987,7 +2973,6 @@ def _get_resized_lm_head( new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias ) - new_lm_head._is_hf_initialized = True return new_lm_head def _init_added_embeddings_weights_with_mean( @@ -3246,7 +3231,7 @@ def save_pretrained( ) # we need to check against tp_size, not tp_plan, as tp_plan is substituted to the class one - if self._tp_size is not None and not is_huggingface_hub_greater_or_equal("0.31.4"): + if self.tp_size is not None and not is_huggingface_hub_greater_or_equal("0.31.4"): raise ImportError( "Saving a model with tensor parallelism requires `huggingface_hub` version 0.31.4 or higher." ) @@ -3285,64 +3270,58 @@ def save_pretrained( if self._auto_class is not None: custom_object_save(self, save_directory, config=self.config) - # Save the config - if is_main_process: - if not _hf_peft_config_loaded: - model_to_save.config.save_pretrained(save_directory) - if self.can_generate(): - model_to_save.generation_config.save_pretrained(save_directory) + # Don't persist distributed_config in saved config — it's runtime-only + # (otherwise AutoConfig absorbs it on reload, preventing from_pretrained from seeing it as a kwarg). + # Keep a runtime copy around because TP/FSDP save helpers still rely on it after config serialization. + distributed_config = getattr(model_to_save.config, "distributed_config", None) + if distributed_config is not None: + del model_to_save.config.distributed_config - if _hf_peft_config_loaded: - logger.info( - "Detected adapters on the model, saving the model in the PEFT format, only adapter weights will be saved." - ) - state_dict = model_to_save.get_adapter_state_dict(state_dict=state_dict) + # Save the config + try: + if is_main_process: + if not _hf_peft_config_loaded: + model_to_save.config.save_pretrained(save_directory) + if self.can_generate(): + model_to_save.generation_config.save_pretrained(save_directory) - if save_peft_format: + if _hf_peft_config_loaded: logger.info( - "To match the expected format of the PEFT library, all keys of the state dict of adapters will be prepended with `base_model.model`." + "Detected adapters on the model, saving the model in the PEFT format, only adapter weights will be saved." ) - peft_state_dict = {} - for key, value in state_dict.items(): - peft_state_dict[f"base_model.model.{key}"] = value - state_dict = peft_state_dict + state_dict = model_to_save.get_adapter_state_dict(state_dict=state_dict) - active_adapter = self.active_adapters() + if save_peft_format: + logger.info( + "To match the expected format of the PEFT library, all keys of the state dict of adapters will be prepended with `base_model.model`." + ) + peft_state_dict = {} + for key, value in state_dict.items(): + peft_state_dict[f"base_model.model.{key}"] = value + state_dict = peft_state_dict - if len(active_adapter) > 1: - raise ValueError( - "Multiple active adapters detected, saving multiple active adapters is not supported yet. You can save adapters separately one by one " - "by iteratively calling `model.set_adapter(adapter_name)` then `model.save_pretrained(...)`" - ) - active_adapter = active_adapter[0] - - current_peft_config = self.peft_config[active_adapter] - current_peft_config.save_pretrained(save_directory) - - # FSDP2 models: use DCP distributed save + consolidation for safetensors. - # All ranks must call this collectively. Config/generation_config are - # already saved above (guarded by is_main_process). - if getattr(self, "_is_fsdp_managed_module", False): - from .integrations.fsdp import save_fsdp_model - - save_fsdp_model(model_to_save, save_directory) - - if push_to_hub: - model_card = create_and_tag_model_card(repo_id, self.model_tags, token=token) - model_card.save(os.path.join(save_directory, "README.md")) - self._upload_modified_files( - save_directory, - repo_id, - files_timestamps, - commit_message=commit_message, - token=token, - create_pr=create_pr, - ) - return + active_adapter = self.active_adapters() + + if len(active_adapter) > 1: + raise ValueError( + "Multiple active adapters detected, saving multiple active adapters is not supported yet. You can save adapters separately one by one " + "by iteratively calling `model.set_adapter(adapter_name)` then `model.save_pretrained(...)`" + ) + active_adapter = active_adapter[0] + + current_peft_config = self.peft_config[active_adapter] + current_peft_config.save_pretrained(save_directory) + finally: + if distributed_config is not None: + model_to_save.config.distributed_config = distributed_config - # Get the model state_dict + # Get the model state_dict (handles FSDP unshard + TP gather in one call) if state_dict is None: - state_dict = model_to_save.state_dict() + if getattr(self, "device_mesh", None) is not None: + # Pass self (not model_to_save) so device_mesh/tp_size/tp_plan are available + state_dict = gather_full_state_dict(self) + else: + state_dict = model_to_save.state_dict() # if any model parameters are offloaded, we need to know it for later is_offloaded = False @@ -3368,10 +3347,6 @@ def save_pretrained( if ignore_key in state_dict: del state_dict[ignore_key] - # If model was sharded with TP, gather full tensors for saving - if self._tp_size is not None: - state_dict = gather_state_dict_for_save(state_dict, self._tp_plan, self._device_mesh, self._tp_size) - # Remove tied weights as safetensors do not handle them state_dict = remove_tied_weights_from_state_dict(state_dict, model_to_save) @@ -3642,10 +3617,7 @@ def get_init_context( elif is_quantized: init_contexts.extend([torch.device("meta"), set_quantized_state()]) else: - # meta_device_safe_creation_ops patches torch.linspace to default to CPU - # so that custom models calling .item() during __init__ (e.g. drop-path - # schedules) don't crash on meta tensors. - init_contexts.extend([torch.device("meta"), init.meta_device_safe_creation_ops()]) + init_contexts.append(torch.device("meta")) return init_contexts @@ -3867,13 +3839,22 @@ def from_pretrained( max_memory (`Dict`, *optional*): A dictionary device identifier to maximum memory if using `device_map`. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. - tp_plan (`Optional[Union[dict, str]]`, *optional*): - A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Use `tp_plan="auto"` to - use the predefined plan based on the model. If it's a dict, then it should match between module names and desired layout. - Note that if you use it, you should launch your script accordingly with `torchrun [args] script.py`. This will be much - faster than using a `device_map`, but has limitations. - tp_size (`str`, *optional*): - A torch tensor parallel degree. If not provided would default to world size. + distributed_config ([`DistributedConfig`], *optional*): + Configuration for native distributed training (FSDP2 + TP) via `torch.distributed`. Mutually + exclusive with `quantization_config` (for now) and `device_map`. When set, accelerate is not used for + device placement or dispatch. Launch with `torchrun --nproc_per_node=N script.py`. + + Accepts `tp_size`, `tp_plan`, `fsdp_size`, `fsdp_plan`. When a size is specified without a + plan, the plan defaults to `"auto"`. `tp_plan="auto"` uses the model's predefined tensor + parallel sharding plan. `fsdp_plan="auto"` wraps each transformer layer individually with + FSDP2 (`fully_shard`). Both plans also accept a `dict` for manual control: `tp_plan` maps + parameter names to parallel styles (e.g. `{"model.layers.*.self_attn.q_proj": "colwise"}`), + `fsdp_plan` maps module names to wrap (e.g. `{"model.layers.0": {}, "model.layers.1": {}}`). + + Examples: + - TP-only: `DistributedConfig(tp_size=4)` + - FSDP-only: `DistributedConfig(fsdp_size=4)` + - 2D parallel: `DistributedConfig(tp_size=2, fsdp_size=2)` on 4 GPUs device_mesh (`torch.distributed.DeviceMesh`, *optional*): A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now. If provided, it has to contain dimension named `"tp"` in case it's > 1 dimensional, this dimension will be used for tensor parallelism @@ -3954,9 +3935,6 @@ def from_pretrained( adapter_name = kwargs.pop("adapter_name", "default") generation_config = kwargs.pop("generation_config", None) gguf_file = kwargs.pop("gguf_file", None) - tp_plan = kwargs.pop("tp_plan", None) - tp_size = kwargs.pop("tp_size", None) - fsdp_plan = kwargs.pop("fsdp_plan", None) distributed_config: DistributedConfig = kwargs.pop("distributed_config", None) device_mesh = kwargs.pop("device_mesh", None) trust_remote_code = kwargs.pop("trust_remote_code", None) @@ -3965,8 +3943,17 @@ def from_pretrained( kernel_config = kwargs.pop("kernel_config", None) key_mapping = kwargs.pop("key_mapping", None) - if distributed_config is not None and tp_plan is None: - tp_plan = "auto" + if distributed_config is not None: + if device_map is not None: + raise ValueError( + "`distributed_config` and `device_map` are mutually exclusive. " + "`distributed_config` handles device placement natively via torch.distributed." + ) + # NOTE(3outeille): support quantization (fp4/fp8) with distributed training later + if quantization_config is not None: + raise ValueError( + "Quantization is not currently supported with distributed training. Please disable quantization or distributed_config." + ) # Not used anymore -- remove them from the kwargs for name in ["mirror", "_fast_init", "low_cpu_mem_usage", "from_tf", "from_flax", "offload_state_dict"]: @@ -3997,27 +3984,18 @@ def from_pretrained( "`state_dict` cannot be passed together with a model name or a `gguf_file`. Use one of the two loading strategies." ) - if device_map == "auto" and int(os.environ.get("WORLD_SIZE", "0")): - logger.info( - "You've set device_map=`auto` while triggering a distributed run with torchrun. This might lead to unexpected behavior. " - "If your plan is to load the model on each device, you should set device_map={" - ": PartialState().process_index} where PartialState comes from accelerate library" - ) - - if fsdp_plan is not None and (tp_plan is not None or tp_size is not None): - raise ValueError("Combining `fsdp_plan` with tensor parallel loading is not supported yet.") + if distributed_config is not None: + device_mesh = init_device_mesh(distributed_config) + else: + # Accelerate path + if device_map == "auto" and int(os.environ.get("WORLD_SIZE", "0")): + logger.info( + "You've set device_map=`auto` while triggering a distributed run with torchrun. This might lead to unexpected behavior. " + "If your plan is to load the model on each device, you should set device_map={" + ": PartialState().process_index} where PartialState comes from accelerate library" + ) - if tp_plan is not None or tp_size is not None: # TP warnings, and setup - device_map, device_mesh, tp_size = initialize_tensor_parallelism( - tp_plan, tp_size=tp_size, device_mesh=device_mesh, device_map=device_map - ) - - if fsdp_plan is not None: - device_map, device_mesh, _ = initialize_fsdp( - fsdp_plan=fsdp_plan, - device_mesh=device_mesh, - device_map=device_map, - ) + device_map = check_and_set_device_map(device_map) # validate & normalize (requires accelerate) if gguf_file is not None and not is_accelerate_available(): raise ValueError("accelerate is required when loading a GGUF file `pip install accelerate`.") @@ -4030,7 +4008,6 @@ def from_pretrained( download_kwargs_with_commit, **adapter_kwargs, ) - device_map = check_and_set_device_map(device_map) # warn, error and fix the device map user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} if from_pipeline is not None: @@ -4140,20 +4117,25 @@ def from_pretrained( # instantiated model, as the flags can be modified by instances sometimes) dtype_plan = model._get_dtype_plan(dtype) - # Obtain the weight conversion mapping for this model if any are registered and apply to all submodels recursively + # Obtain the weight conversion mapping for this model if any are registered weight_conversions = get_model_conversion_mapping(model, key_mapping, hf_quantizer) - if _torch_distributed_available and device_mesh is not None and (tp_plan is not None or fsdp_plan is not None): - model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size, fsdp_plan=fsdp_plan) - - # Prepare the full device map - if isinstance(device_map, dict): - device_map = _get_device_map(model, device_map, max_memory, hf_quantizer) - elif device_map is not None: - device_map = {"": device_map} + if distributed_config is not None: + model.config.distributed_config = distributed_config + model.device_mesh = device_mesh + sub_mesh = lambda name: device_mesh[name] if device_mesh.ndim > 1 else device_mesh + mesh_dim_names = device_mesh.mesh_dim_names or () + if "tp" in mesh_dim_names: + model = apply_tensor_parallel(model, sub_mesh("tp"), distributed_config.tp_plan) + if "fsdp" in mesh_dim_names: + model = apply_fully_shard_data_parallel(model, sub_mesh("fsdp"), distributed_config.fsdp_plan) + else: + # Accelerate path: auto device mapping + if device_map is not None: + device_map = _get_device_map(model, device_map, max_memory, hf_quantizer) # Finalize model weight initialization - active_tp_plan = getattr(model, "_tp_plan", None) if tp_size is not None else None + active_tp_plan = getattr(model, "_tp_plan", None) if getattr(distributed_config, "tp_plan", None) else None load_config = LoadStateDictConfig( pretrained_model_name_or_path=pretrained_model_name_or_path, ignore_mismatched_sizes=ignore_mismatched_sizes, @@ -4447,8 +4429,8 @@ def tp_size(self): """ Returns the model's tensor parallelism degree. """ - # if None, the model didn't undergo tensor parallel sharding - return self._tp_size + dc = getattr(self.config, "distributed_config", None) + return dc.tp_size if dc is not None else None @property def supports_pp_plan(self): @@ -4555,10 +4537,10 @@ def _move_missing_keys_from_meta_to_device( # In this case we need to move everything back if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized: for key, param in self.named_parameters(): - value = torch.zeros_like(param, device="cpu") + value = torch.empty_like(param, device="cpu") _load_parameter_into_model(self, key, value) for key, buffer in self.named_buffers(): - value = torch.zeros_like(buffer, device="cpu") + value = torch.empty_like(buffer, device="cpu") _load_parameter_into_model(self, key, value) return @@ -4567,15 +4549,25 @@ def _move_missing_keys_from_meta_to_device( # will be re-initialized for nothing (which can be quite long) for key in missing_keys - self.all_tied_weights_keys.keys(): param = self.get_parameter_or_buffer(key) - param_device = get_device(device_map, key, valid_torch_device=True) - value = torch.empty_like(param, device=param_device) - # For TP, we may need to shard the param - if device_mesh is not None: - shard_and_distribute_module( - self, value, param, key, None, False, device_mesh.get_local_rank(), device_mesh + from torch.distributed.tensor import DTensor + + if isinstance(param, DTensor): + # DTensor from parallelize_module on meta — materialize on actual device + local_value = torch.empty( + param._local_tensor.shape, + dtype=param.dtype, + device=torch.device(param.device_mesh.device_type, torch.cuda.current_device()), + ) + new_dtensor = DTensor.from_local( + local_value, param.device_mesh, param.placements, + run_check=False, shape=param.shape, stride=tuple(param.stride()), ) - # Otherwise, just move it to device + with torch.no_grad(): + new_param = torch.nn.Parameter(new_dtensor, requires_grad=param.requires_grad) + torch.utils.swap_tensors(param, new_param) else: + param_device = get_device(device_map, key, valid_torch_device=True) + value = torch.empty_like(param, device=param_device) _load_parameter_into_model(self, key, value) # We need to move back non-persistent buffers as well, as they are not part of loaded weights anyway for key, buffer in self.named_non_persistent_buffers(): @@ -4654,7 +4646,7 @@ def mark_tied_weights_as_initialized(self, loading_info): later as they will be tied (overwritten) anyway. This is very important as most embeddings are tied, and they are huge params (vocabularies are often 256k), so running inits on them is very costly.""" - for tied_param in getattr(self, "all_tied_weights_keys", {}).keys(): + for tied_param in self.all_tied_weights_keys.keys(): param = self.get_parameter(tied_param) param._is_hf_initialized = True diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 235189fe8320..380ed28a9e3a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2391,9 +2391,10 @@ def get_cp_size(self) -> int: def get_tp_size(self) -> int: """Get the tensor parallel size from either the model or DeepSpeed config.""" - # 1. Check model.tp_size first - if (model_tp := getattr(self.model, "_tp_size", None)) is not None: - return model_tp + # TODO: adapt it cleaner with distributed config once distributed api is stable + dc = getattr(getattr(self.model, "config", None), "distributed_config", None) + if dc is not None and dc.tp_size is not None: + return dc.tp_size # 2. Fall back to DeepSpeed config if enabled if self.is_deepspeed_enabled and (deepspeed_config := getattr(self.args, "hf_deepspeed_config", None)): diff --git a/tests/test_fsdp_mixin.py b/tests/test_fsdp_mixin.py index d7a0a4ca3340..f6f7f7e6e2ab 100644 --- a/tests/test_fsdp_mixin.py +++ b/tests/test_fsdp_mixin.py @@ -49,9 +49,10 @@ from torch.distributed.tensor import DTensor from torch.nn.parallel import DistributedDataParallel as DDP + from transformers.distributed import DistributedConfig from transformers.integrations.fsdp import ( _find_final_norm, - apply_fsdp2, + apply_fully_shard_data_parallel, get_transformer_block_classes, initialize_fsdp, ) @@ -373,12 +374,11 @@ def train_fsdp2( ): # -- Phase 1: Pre-checkpoint run -- train only the first `checkpoint_step` steps, then save _set_determinism(SEED) - _, device_mesh, _ = initialize_fsdp(fsdp_plan=fsdp_plan) + distributed_config = DistributedConfig(fsdp_plan=fsdp_plan) pre_ckpt_model = AutoModelForCausalLM.from_pretrained( init_model_dir, torch_dtype=dtype, - fsdp_plan=fsdp_plan, - device_mesh=device_mesh, + distributed_config=distributed_config, attn_implementation="eager", ) pre_ckpt_model.train() @@ -415,8 +415,7 @@ def train_fsdp2( resumed_model = AutoModelForCausalLM.from_pretrained( model_dir, torch_dtype=dtype, - fsdp_plan=fsdp_plan, - device_mesh=device_mesh, + distributed_config=distributed_config, attn_implementation="eager", ) resumed_model.train() @@ -461,16 +460,14 @@ def _test_fsdp2_save_load_impl(rank, config_class, config_dict): batches = _build_repeated_training_batches(config, device, 3) - auto_plan = {"mode": "auto"} + distributed_config = DistributedConfig(fsdp_plan="auto") init_tmpdir, init_tmpdir_obj = _save_init_pretrained(rank, config, torch.float32) try: - _, device_mesh, _ = initialize_fsdp(fsdp_plan=auto_plan) _set_determinism(SEED) model = AutoModelForCausalLM.from_pretrained( init_tmpdir, - fsdp_plan=auto_plan, - device_mesh=device_mesh, + distributed_config=distributed_config, attn_implementation="eager", ) dist.barrier() @@ -495,8 +492,7 @@ def _test_fsdp2_save_load_impl(rank, config_class, config_dict): new_model = AutoModelForCausalLM.from_pretrained( tmpdir, - fsdp_plan=auto_plan, - device_mesh=device_mesh, + distributed_config=distributed_config, attn_implementation="eager", ) dist.barrier() @@ -522,7 +518,7 @@ def _test_fsdp2_save_load_impl(rank, config_class, config_dict): def _test_fsdp2_sharding_structure_impl(rank, config_class, config_dict, tie_word_embeddings): """ - Verify that apply_fsdp2(fsdp_plan={"mode": "auto"}) wraps exactly the right modules. + Verify that apply_fully_shard_data_parallel(fsdp_plan={"mode": "auto"}) wraps exactly the right modules. Expected FSDP targets: UNTIED TIED @@ -570,7 +566,7 @@ def _test_fsdp2_sharding_structure_impl(rank, config_class, config_dict, tie_wor if not weights_tied: expected_targets |= {output_name} - model = apply_fsdp2(model, device_mesh, fsdp_plan=auto_plan) + model = apply_fully_shard_data_parallel(model, device_mesh, fsdp_plan=auto_plan) actual_targets = {name for name, module in model.named_modules() if type(module).__name__.startswith("FSDP")} diff --git a/tests/utils/test_core_model_loading.py b/tests/utils/test_core_model_loading.py index 942dcdc99b11..7a0b7c1b8ddc 100644 --- a/tests/utils/test_core_model_loading.py +++ b/tests/utils/test_core_model_loading.py @@ -13,17 +13,19 @@ # limitations under the License. import unittest from types import SimpleNamespace +from unittest.mock import patch import torch import torch.nn as nn +from torch.distributed.tensor.placement_types import Replicate, Shard, _StridedShard from transformers import PretrainedConfig from transformers.conversion_mapping import get_checkpoint_conversion_mapping, register_checkpoint_conversion_mapping from transformers.core_model_loading import ( Chunk, Concatenate, + DtensorShardOperation, ErnieFuseAndSplitTextVisionExperts, - FSDPShardOperation, MergeModulelist, PermuteForRope, WeightConverter, @@ -217,24 +219,67 @@ def __init__(self, add_extra_moe=False): class FakeMesh: - def __init__(self, world_size: int, rank: int): - self.shape = (world_size,) - self._rank = rank + """Fake multi-dimensional device mesh for testing DtensorShardOperation.""" + + def __init__(self, shape, rank, dim_names=None): + if isinstance(shape, int): + shape = (shape,) + self.shape = tuple(shape) + self.ndim = len(self.shape) + self.mesh_dim_names = dim_names or tuple(f"dim{i}" for i in range(self.ndim)) + # Compute nD coordinate (row-major: last dim changes fastest) + self._coord = [] + r = rank + for s in reversed(self.shape): + self._coord.insert(0, r % s) + r //= s def get_local_rank(self): - return self._rank + return self._coord[0] def get_coordinate(self): - return (self._rank,) + return tuple(self._coord) + + def size(self): + result = 1 + for s in self.shape: + result *= s + return result + + def _is_current_rank_part_of_mesh(self): + return True + + def _sym_get_coordinate(self, dim): + return self._coord[dim] + + def __getitem__(self, name): + idx = self.mesh_dim_names.index(name) + return FakeMesh( + shape=(self.shape[idx],), + rank=self._coord[idx], + dim_names=(name,), + ) + + +def _make_dtensor_shard_op(mesh, placements, param_shape, local_shape): + """Build a DtensorShardOperation without requiring a real DTensor / distributed init.""" + op = object.__new__(DtensorShardOperation) + op.device_mesh = mesh + op.placements = tuple(placements) + ns = SimpleNamespace(shape=torch.Size(param_shape), ndim=len(param_shape)) + ns.dim = lambda: len(param_shape) + op.param = ns + op.local_shape = tuple(local_shape) + return op class TestConvertAndLoadStateDict(unittest.TestCase): - def test_fsdp_shard_aware_mixtral_conversion_uses_only_local_experts(self): - shard_op = FSDPShardOperation( - device_mesh=FakeMesh(world_size=2, rank=0), - rank=0, - empty_param=torch.empty((2, 4, 2)), - placements=(torch.distributed.tensor.placement_types.Shard(0),), + def test_dtensor_shard_aware_mixtral_conversion_uses_only_local_experts(self): + shard_op = _make_dtensor_shard_op( + FakeMesh(shape=(2,), rank=0), + [Shard(0)], + param_shape=(2, 4, 2), + local_shape=(1, 4, 2), ) converter = WeightConverter( ["experts.*.w1.weight", "experts.*.w3.weight"], @@ -242,40 +287,44 @@ def test_fsdp_shard_aware_mixtral_conversion_uses_only_local_experts(self): operations=[MergeModulelist(dim=0), Concatenate(dim=1)], ) - for idx, tensor in enumerate( - [ - torch.tensor([[0.0, 1.0], [2.0, 3.0]]), - torch.tensor([[10.0, 11.0], [12.0, 13.0]]), - ] - ): - converter.add_tensor( - "model.layers.0.experts.gate_up_proj.weight", - f"model.layers.0.experts.{idx}.w1.weight", - "experts.*.w1.weight", - spawn_parallel_materialize(None, tensor, shard_op, idx, device="cpu", dtype=None), - ) - - for idx, tensor in enumerate( - [ - torch.tensor([[4.0, 5.0], [6.0, 7.0]]), - torch.tensor([[14.0, 15.0], [16.0, 17.0]]), - ] + with patch( + "transformers.core_model_loading.compute_local_shape_and_global_offset", + return_value=(torch.Size([1, 4, 2]), torch.Size([0, 0, 0])), ): - converter.add_tensor( - "model.layers.0.experts.gate_up_proj.weight", - f"model.layers.0.experts.{idx}.w3.weight", - "experts.*.w3.weight", - spawn_parallel_materialize(None, tensor, shard_op, idx, device="cpu", dtype=None), + for idx, tensor in enumerate( + [ + torch.tensor([[0.0, 1.0], [2.0, 3.0]]), + torch.tensor([[10.0, 11.0], [12.0, 13.0]]), + ] + ): + converter.add_tensor( + "model.layers.0.experts.gate_up_proj.weight", + f"model.layers.0.experts.{idx}.w1.weight", + "experts.*.w1.weight", + spawn_parallel_materialize(None, tensor, shard_op, idx, device="cpu", dtype=None), + ) + + for idx, tensor in enumerate( + [ + torch.tensor([[4.0, 5.0], [6.0, 7.0]]), + torch.tensor([[14.0, 15.0], [16.0, 17.0]]), + ] + ): + converter.add_tensor( + "model.layers.0.experts.gate_up_proj.weight", + f"model.layers.0.experts.{idx}.w3.weight", + "experts.*.w3.weight", + spawn_parallel_materialize(None, tensor, shard_op, idx, device="cpu", dtype=None), + ) + + converted = converter.convert("model.layers.0.experts.gate_up_proj.weight") + + self.assertEqual(list(converted), ["model.layers.0.experts.gate_up_proj.weight"]) + torch.testing.assert_close( + converted["model.layers.0.experts.gate_up_proj.weight"], + torch.tensor([[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0], [6.0, 7.0]]]), ) - converted = converter.convert("model.layers.0.experts.gate_up_proj.weight") - - self.assertEqual(list(converted), ["model.layers.0.experts.gate_up_proj.weight"]) - torch.testing.assert_close( - converted["model.layers.0.experts.gate_up_proj.weight"], - torch.tensor([[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0], [6.0, 7.0]]]), - ) - def test_moe_and_qkv_conversion(self): model = DummyRoot() model.config = PretrainedConfig() @@ -785,6 +834,117 @@ def test_ernie4_5_vl_moe_conversion_reversed(self): self.assertTrue(compare_state_dicts(reversed_state_dict, state_dict)) +class TestDtensorShardOperation(unittest.TestCase): + """One test per code path in DtensorShardOperation.shard_tensor.""" + + def test_no_shard_returns_full_tensor(self): + """Replicate-only → full copy.""" + mesh = FakeMesh(shape=(2,), rank=0) + op = _make_dtensor_shard_op(mesh, [Replicate()], param_shape=(4, 4), local_shape=(4, 4)) + tensor = torch.arange(16).reshape(4, 4).float() + torch.testing.assert_close(op.shard_tensor(tensor), tensor) + + def test_1d_shard_fast_path(self): + #TODO(3outeille): double check fast path + tensor = torch.arange(16).reshape(4, 4).float() + for rank, expected in [(0, tensor[:2]), (1, tensor[2:])]: + mesh = FakeMesh(shape=(2,), rank=rank) + op = _make_dtensor_shard_op(mesh, [Shard(0)], param_shape=(4, 4), local_shape=(2, 4)) + torch.testing.assert_close(op.shard_tensor(tensor), expected, msg=f"rank {rank}") + + def test_nd_contiguous_single_slice(self): + """nD Shard on different dims → single slice read per rank.""" + tensor = torch.arange(64).reshape(8, 8).float() + expected = {0: tensor[:4, :4], 1: tensor[:4, 4:], 2: tensor[4:, :4], 3: tensor[4:, 4:]} + for rank in range(4): + mesh = FakeMesh(shape=(2, 2), rank=rank) + op = _make_dtensor_shard_op(mesh, [Shard(0), Shard(1)], param_shape=(8, 8), local_shape=(4, 4)) + torch.testing.assert_close(op.shard_tensor(tensor), expected[rank], msg=f"rank {rank}") + + def test_nd_strided_shard_disjoint_ranges(self): + """_StridedShard on its own dim → multiple slice reads + cat.""" + tensor = torch.arange(64).reshape(8, 8).float() + # Shard(0) splits rows; _StridedShard(1, split_factor=2) produces disjoint col ranges + expected = { + 0: torch.cat([tensor[:4, :2], tensor[:4, 4:6]], dim=1), + 1: torch.cat([tensor[:4, 2:4], tensor[:4, 6:8]], dim=1), + 2: torch.cat([tensor[4:, :2], tensor[4:, 4:6]], dim=1), + 3: torch.cat([tensor[4:, 2:4], tensor[4:, 6:8]], dim=1), + } + for rank in range(4): + mesh = FakeMesh(shape=(2, 2), rank=rank) + op = _make_dtensor_shard_op( + mesh, [Shard(0), _StridedShard(dim=1, split_factor=2)], + param_shape=(8, 8), local_shape=(4, 4), + ) + torch.testing.assert_close(op.shard_tensor(tensor), expected[rank], msg=f"rank {rank}") + + def test_nd_strided_plus_shard_same_dim_fallback(self): + """_StridedShard + Shard on same dim → materialize-then-split fallback.""" + tensor = torch.arange(16).reshape(4, 4).float() + expected = {0: tensor[[0]], 1: tensor[[2]], 2: tensor[[1]], 3: tensor[[3]]} + for rank in range(4): + mesh = FakeMesh(shape=(2, 2), rank=rank) + op = _make_dtensor_shard_op( + mesh, [_StridedShard(dim=0, split_factor=2), Shard(0)], + param_shape=(4, 4), local_shape=(1, 4), + ) + torch.testing.assert_close(op.shard_tensor(tensor), expected[rank], msg=f"rank {rank}") + + def test_prepacked_strided_shard_uses_contiguous_source_slice(self): + """Pre-concat w1/w3 tensors should shard contiguously before gate/up packing.""" + tensor = torch.arange(8).reshape(4, 2).float() + for rank, expected in [(0, tensor[:2]), (1, tensor[2:])]: + mesh = FakeMesh(shape=(2,), rank=rank) + op = _make_dtensor_shard_op( + mesh, + [_StridedShard(dim=1, split_factor=2)], + param_shape=(8, 8, 2), + local_shape=(8, 4, 2), + ) + torch.testing.assert_close(op.shard_tensor(tensor, tensor_idx=0), expected, msg=f"rank {rank}") + + def test_expert_filtering(self): + """Mixtral-style experts: skip non-owned, return owned.""" + mesh = FakeMesh(shape=(2,), rank=1) + op = _make_dtensor_shard_op(mesh, [Shard(0)], param_shape=(4, 2, 2), local_shape=(2, 2, 2)) + expert_tensor = torch.ones(2, 2) + with patch( + "transformers.core_model_loading.compute_local_shape_and_global_offset", + return_value=(torch.Size([2, 2, 2]), torch.Size([2, 0, 0])), + ): + # rank 1 owns experts 2,3 (offset=2) + self.assertIsNone(op.shard_tensor(expert_tensor, tensor_idx=0)) + torch.testing.assert_close(op.shard_tensor(expert_tensor, tensor_idx=2), expert_tensor) + + def test_expert_filtering_preserves_inner_sharding(self): + """MoE expert ownership checks should still apply TP sharding on inner dims.""" + tensor = torch.arange(8).reshape(4, 2).float() + expected = { + 0: tensor[:2], + 1: tensor[2:], + 2: None, + 3: None, + } + for rank in range(4): + mesh = FakeMesh(shape=(2, 2), rank=rank) + op = _make_dtensor_shard_op(mesh, [Shard(0), Shard(1)], param_shape=(4, 4, 2), local_shape=(2, 2, 2)) + + def fake_local_shape_and_offset(*args, **kwargs): + expert_rank, tp_rank = mesh.get_coordinate() + return torch.Size([2, 2, 2]), torch.Size([2 * expert_rank, 2 * tp_rank, 0]) + + with patch( + "transformers.core_model_loading.compute_local_shape_and_global_offset", + side_effect=fake_local_shape_and_offset, + ): + shard = op.shard_tensor(tensor, tensor_idx=1) + if expected[rank] is None: + self.assertIsNone(shard) + else: + torch.testing.assert_close(shard, expected[rank], msg=f"rank {rank}") + + class TestConversionMapping(unittest.TestCase): def test_register_checkpoint_conversion_mapping(self): register_checkpoint_conversion_mapping( From b1e9179862fd5089d64c5718676f1a5cfd78c33b Mon Sep 17 00:00:00 2001 From: 3outeille Date: Tue, 14 Apr 2026 14:45:04 +0000 Subject: [PATCH 2/7] revert distributed utils --- src/transformers/distributed/utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/distributed/utils.py b/src/transformers/distributed/utils.py index 408bf33d65fb..95aa17d7f015 100644 --- a/src/transformers/distributed/utils.py +++ b/src/transformers/distributed/utils.py @@ -44,15 +44,21 @@ def is_fsdp_enabled() -> bool: def is_fsdp_managed_module(module: nn.Module) -> bool: + """Check if a module is managed by FSDP (1 or 2).""" if not is_torch_available(): return False if not torch.distributed.is_available(): return False + + # FSDP2: attribute set by apply_fsdp2() + if getattr(module, "_is_fsdp_managed_module", False): + return True + # FSDP1: wrapped by FullyShardedDataParallel try: from torch.distributed.fsdp import FullyShardedDataParallel except ImportError: return False - return isinstance(module, FullyShardedDataParallel) or getattr(module, "_is_fsdp_managed_module", False) + return isinstance(module, FullyShardedDataParallel) def _ensure_torch_distributed(device_type: str): From 11b4d67841e6f11c24f94957cea23c2fef5e048f Mon Sep 17 00:00:00 2001 From: 3outeille Date: Tue, 14 Apr 2026 14:50:25 +0000 Subject: [PATCH 3/7] eaaea --- src/transformers/distributed/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/distributed/utils.py b/src/transformers/distributed/utils.py index 95aa17d7f015..b909b6de3683 100644 --- a/src/transformers/distributed/utils.py +++ b/src/transformers/distributed/utils.py @@ -32,6 +32,7 @@ def is_fsdp_enabled() -> bool: + """Check if FSDP is active via Accelerate (env var based) — covers FSDP1 only.""" if not is_torch_available(): return False From 5948a1d39f924e889ab40c4083a9d97fce3a4fb6 Mon Sep 17 00:00:00 2001 From: 3outeille Date: Tue, 14 Apr 2026 15:06:37 +0000 Subject: [PATCH 4/7] all tests for core modeling are passing --- src/transformers/core_model_loading.py | 11 ++ tests/utils/test_core_model_loading.py | 206 ++++++++++++++++++------- 2 files changed, 159 insertions(+), 58 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 08907a21b454..7d972bfd413c 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -118,6 +118,17 @@ def reverse_op(self) -> ConversionOps: raise NotImplementedError +class _IdentityOp(ConversionOps): + """Pass-through reverse op for dequantize operations. + + Dequantized weights are already in their target dtype and should be + saved as-is without any conversion. + """ + + def convert(self, input_dict: dict[str, Any], **kwargs) -> dict[str, Any]: + return input_dict + + class Chunk(ConversionOps): """Split a tensor along ``dim`` into equally sized chunks.""" diff --git a/tests/utils/test_core_model_loading.py b/tests/utils/test_core_model_loading.py index 7a0b7c1b8ddc..5e7dc3aa7b2e 100644 --- a/tests/utils/test_core_model_loading.py +++ b/tests/utils/test_core_model_loading.py @@ -13,7 +13,6 @@ # limitations under the License. import unittest from types import SimpleNamespace -from unittest.mock import patch import torch import torch.nn as nn @@ -34,7 +33,7 @@ convert_and_load_state_dict_in_model, rename_source_key, revert_weight_conversion, - spawn_parallel_materialize, + spawn_materialize, ) from transformers.modeling_utils import LoadStateDictConfig from transformers.utils.import_utils import is_triton_available @@ -275,6 +274,44 @@ def _make_dtensor_shard_op(mesh, placements, param_shape, local_shape): class TestConvertAndLoadStateDict(unittest.TestCase): def test_dtensor_shard_aware_mixtral_conversion_uses_only_local_experts(self): + """ + The problem: Mixtral has 8 experts. The checkpoint stores them separately: + experts.0.w1.weight (2x2) + experts.0.w3.weight (2x2) + experts.1.w1.weight (2x2) + experts.1.w3.weight (2x2) + + The model stores them packed into one tensor: + experts.gate_up_proj.weight (2, 4, 2) + ^ ^ ^ + | | └─ features + | └─ w1 (2) + w3 (2) concatenated + └─ num_experts + + The conversion (without FSDP) is: load all expert w1/w3 tensors → MergeModulelist(dim=0) stacks experts → Concatenate(dim=1) joins w1+w3. + + Example — Mixtral experts with FSDP Shard(0) on the expert dim: + + checkpoint files shard_tensor rank 0 gets + ──────────────── ──────────── ─────────── + experts.0.w1 [[0,1],[2,3]] idx=0 → kept [[0,1],[2,3]] + experts.1.w1 [[10,11],...] idx=1 → None (not owned) + experts.0.w3 [[4,5],[6,7]] idx=0 → kept [[4,5],[6,7]] + experts.1.w3 [[14,15],...] idx=1 → None (not owned) + + WeightConverter then stacks + concatenates only the kept tensors: gate_up_proj = [[[0,1],[2,3],[4,5],[6,7]]] shape (1,4,2) + + MergeModulelist(dim=0): [[0,1],[2,3]] → [[[0,1],[2,3]]] (1 expert, shape 1x2x2) + [[4,5],[6,7]] → [[[4,5],[6,7]]] (1 expert, shape 1x2x2) + + Concatenate(dim=1): cat along dim 1 → [[[0,1],[2,3],[4,5],[6,7]]] (shape 1x4x2) + ~~~~~~~~~~~ ~~~~~~~~~~~ + w1 w3 + + The key point: DtensorShardOperation.shard_tensor(tensor_idx=1) returns None for rank 0, so the + converter never even processes expert 1's data. This saves memory during loading. this should explain as well the + other tests + """ shard_op = _make_dtensor_shard_op( FakeMesh(shape=(2,), rank=0), [Shard(0)], @@ -287,44 +324,40 @@ def test_dtensor_shard_aware_mixtral_conversion_uses_only_local_experts(self): operations=[MergeModulelist(dim=0), Concatenate(dim=1)], ) - with patch( - "transformers.core_model_loading.compute_local_shape_and_global_offset", - return_value=(torch.Size([1, 4, 2]), torch.Size([0, 0, 0])), + for idx, tensor in enumerate( + [ + torch.tensor([[0.0, 1.0], [2.0, 3.0]]), + torch.tensor([[10.0, 11.0], [12.0, 13.0]]), + ] ): - for idx, tensor in enumerate( - [ - torch.tensor([[0.0, 1.0], [2.0, 3.0]]), - torch.tensor([[10.0, 11.0], [12.0, 13.0]]), - ] - ): - converter.add_tensor( - "model.layers.0.experts.gate_up_proj.weight", - f"model.layers.0.experts.{idx}.w1.weight", - "experts.*.w1.weight", - spawn_parallel_materialize(None, tensor, shard_op, idx, device="cpu", dtype=None), - ) - - for idx, tensor in enumerate( - [ - torch.tensor([[4.0, 5.0], [6.0, 7.0]]), - torch.tensor([[14.0, 15.0], [16.0, 17.0]]), - ] - ): - converter.add_tensor( - "model.layers.0.experts.gate_up_proj.weight", - f"model.layers.0.experts.{idx}.w3.weight", - "experts.*.w3.weight", - spawn_parallel_materialize(None, tensor, shard_op, idx, device="cpu", dtype=None), - ) - - converted = converter.convert("model.layers.0.experts.gate_up_proj.weight") - - self.assertEqual(list(converted), ["model.layers.0.experts.gate_up_proj.weight"]) - torch.testing.assert_close( - converted["model.layers.0.experts.gate_up_proj.weight"], - torch.tensor([[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0], [6.0, 7.0]]]), + converter.add_tensor( + "model.layers.0.experts.gate_up_proj.weight", + f"model.layers.0.experts.{idx}.w1.weight", + "experts.*.w1.weight", + spawn_materialize(None, tensor, device="cpu", dtype=None, sharding_op=shard_op, tensor_idx=idx), + ) + + for idx, tensor in enumerate( + [ + torch.tensor([[4.0, 5.0], [6.0, 7.0]]), + torch.tensor([[14.0, 15.0], [16.0, 17.0]]), + ] + ): + converter.add_tensor( + "model.layers.0.experts.gate_up_proj.weight", + f"model.layers.0.experts.{idx}.w3.weight", + "experts.*.w3.weight", + spawn_materialize(None, tensor, device="cpu", dtype=None, sharding_op=shard_op, tensor_idx=idx), ) + converted = converter.convert("model.layers.0.experts.gate_up_proj.weight") + + self.assertEqual(list(converted), ["model.layers.0.experts.gate_up_proj.weight"]) + torch.testing.assert_close( + converted["model.layers.0.experts.gate_up_proj.weight"], + torch.tensor([[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0], [6.0, 7.0]]]), + ) + def test_moe_and_qkv_conversion(self): model = DummyRoot() model.config = PretrainedConfig() @@ -835,7 +868,32 @@ def test_ernie4_5_vl_moe_conversion_reversed(self): class TestDtensorShardOperation(unittest.TestCase): - """One test per code path in DtensorShardOperation.shard_tensor.""" + """Unit tests for DtensorShardOperation.shard_tensor — one test per code path. + + Branch coverage map: + + shard_tensor() + ├── A: no sharding placements → full copy [test_no_shard_returns_full_tensor] + ├── B: expert path (tensor_idx set, ndim mismatch) + │ ├── B1: has_expert_sharding=False → fall through to C [test_expert_shaped_tp_only_no_expert_sharding] + │ ├── B2: not owns_local_expert → None [test_expert_filtering] + │ ├── B3: owned, no inner placements → full copy [test_expert_filtering] + │ └── B4: owned, with inner placements → _shard_nd [test_expert_filtering_preserves_inner_sharding] + └── C: _shard_nd() + ├── C1: _can_shard_on_read=False → _materialize_and_split [test_nd_strided_plus_shard_same_dim_fallback] + ├── C2: has_strided=False → contiguous slice + │ ├── 1D mesh [test_1d_shard_fast_path] + │ ├── 2D mesh [test_nd_contiguous_single_slice] + │ ├── negative dim [test_negative_dim_normalizes_correctly] + │ └── uneven division [test_contiguous_shard_uneven_division] + └── C3: has_strided=True → _compute_dim_ranges + _slice_and_read + ├── _StridedShard → _strided_ranges [test_nd_strided_shard_disjoint_ranges] + └── _source_tensor_needs_packing → contiguous [test_prepacked_strided_shard_uses_contiguous_source_slice] + + _slice_and_read (tested directly) + ├── all single ranges → simple slice [test_slice_and_read_all_single_ranges] + └── two multi-range dims → ValueError [test_slice_and_read_raises_on_two_multi_range_dims] + """ def test_no_shard_returns_full_tensor(self): """Replicate-only → full copy.""" @@ -904,18 +962,23 @@ def test_prepacked_strided_shard_uses_contiguous_source_slice(self): ) torch.testing.assert_close(op.shard_tensor(tensor, tensor_idx=0), expected, msg=f"rank {rank}") + def test_expert_shaped_tp_only_no_expert_sharding(self): + """Expert-shaped param with TP on dim 1 but no expert sharding on dim 0 → regular _shard_nd path.""" + tensor = torch.arange(8).reshape(4, 2).float() + # Shard(1) on 3D param maps to dim 0 of the 2D checkpoint tensor (ndim_diff=1) + for rank, expected in [(0, tensor[:2]), (1, tensor[2:])]: + mesh = FakeMesh(shape=(2,), rank=rank) + op = _make_dtensor_shard_op(mesh, [Shard(1)], param_shape=(4, 4, 2), local_shape=(4, 2, 2)) + torch.testing.assert_close(op.shard_tensor(tensor, tensor_idx=0), expected, msg=f"rank {rank}") + def test_expert_filtering(self): """Mixtral-style experts: skip non-owned, return owned.""" mesh = FakeMesh(shape=(2,), rank=1) op = _make_dtensor_shard_op(mesh, [Shard(0)], param_shape=(4, 2, 2), local_shape=(2, 2, 2)) expert_tensor = torch.ones(2, 2) - with patch( - "transformers.core_model_loading.compute_local_shape_and_global_offset", - return_value=(torch.Size([2, 2, 2]), torch.Size([2, 0, 0])), - ): - # rank 1 owns experts 2,3 (offset=2) - self.assertIsNone(op.shard_tensor(expert_tensor, tensor_idx=0)) - torch.testing.assert_close(op.shard_tensor(expert_tensor, tensor_idx=2), expert_tensor) + # rank 1 owns experts 2,3 (offset=2) + self.assertIsNone(op.shard_tensor(expert_tensor, tensor_idx=0)) + torch.testing.assert_close(op.shard_tensor(expert_tensor, tensor_idx=2), expert_tensor) def test_expert_filtering_preserves_inner_sharding(self): """MoE expert ownership checks should still apply TP sharding on inner dims.""" @@ -929,20 +992,47 @@ def test_expert_filtering_preserves_inner_sharding(self): for rank in range(4): mesh = FakeMesh(shape=(2, 2), rank=rank) op = _make_dtensor_shard_op(mesh, [Shard(0), Shard(1)], param_shape=(4, 4, 2), local_shape=(2, 2, 2)) + shard = op.shard_tensor(tensor, tensor_idx=1) + if expected[rank] is None: + self.assertIsNone(shard) + else: + torch.testing.assert_close(shard, expected[rank], msg=f"rank {rank}") + + def test_negative_dim_normalizes_correctly(self): + """Shard(-1) on a 2D tensor should shard the last dimension.""" + tensor = torch.arange(16).reshape(4, 4).float() + for rank, expected in [(0, tensor[:, :2]), (1, tensor[:, 2:])]: + mesh = FakeMesh(shape=(2,), rank=rank) + op = _make_dtensor_shard_op(mesh, [Shard(-1)], param_shape=(4, 4), local_shape=(4, 2)) + torch.testing.assert_close(op.shard_tensor(tensor), expected, msg=f"rank {rank}") - def fake_local_shape_and_offset(*args, **kwargs): - expert_rank, tp_rank = mesh.get_coordinate() - return torch.Size([2, 2, 2]), torch.Size([2 * expert_rank, 2 * tp_rank, 0]) - - with patch( - "transformers.core_model_loading.compute_local_shape_and_global_offset", - side_effect=fake_local_shape_and_offset, - ): - shard = op.shard_tensor(tensor, tensor_idx=1) - if expected[rank] is None: - self.assertIsNone(shard) - else: - torch.testing.assert_close(shard, expected[rank], msg=f"rank {rank}") + def test_contiguous_shard_uneven_division(self): + """Shard(0) on 5 rows across 2 ranks → rank 0 gets 3 rows, rank 1 gets 2.""" + tensor = torch.arange(20).reshape(5, 4).float() + expected = {0: tensor[:3], 1: tensor[3:]} + for rank in range(2): + mesh = FakeMesh(shape=(2,), rank=rank) + local_rows = 3 if rank == 0 else 2 + op = _make_dtensor_shard_op(mesh, [Shard(0)], param_shape=(5, 4), local_shape=(local_rows, 4)) + torch.testing.assert_close(op.shard_tensor(tensor), expected[rank], msg=f"rank {rank}") + + def test_slice_and_read_all_single_ranges(self): + """When every dim has exactly one range, _slice_and_read takes the simple slice path (no concat).""" + tensor = torch.arange(64).reshape(8, 8).float() + mesh = FakeMesh(shape=(2,), rank=0) + op = _make_dtensor_shard_op(mesh, [Shard(0)], param_shape=(8, 8), local_shape=(4, 4)) + dim_ranges = {0: [(0, 4)], 1: [(2, 6)]} + result = op._slice_and_read(tensor, [8, 8], dim_ranges, None, None) + torch.testing.assert_close(result, tensor[0:4, 2:6]) + + def test_slice_and_read_raises_on_two_multi_range_dims(self): + """Multiple disjoint ranges on two different dims → ValueError.""" + tensor = torch.arange(64).reshape(8, 8).float() + mesh = FakeMesh(shape=(2,), rank=0) + op = _make_dtensor_shard_op(mesh, [Shard(0)], param_shape=(8, 8), local_shape=(4, 4)) + dim_ranges = {0: [(0, 2), (4, 6)], 1: [(0, 2), (4, 6)]} + with self.assertRaises(ValueError): + op._slice_and_read(tensor, [8, 8], dim_ranges, None, None) class TestConversionMapping(unittest.TestCase): From 01311a61043ea6b256bc554349bc26068e52ccf9 Mon Sep 17 00:00:00 2001 From: 3outeille Date: Tue, 14 Apr 2026 15:10:52 +0000 Subject: [PATCH 5/7] populate import from init for tp --- src/transformers/integrations/__init__.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 411a3fa9d1c6..d274b31837e2 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -160,7 +160,14 @@ "convert_and_export_with_cache", ] -_import_structure["tensor_parallel"] = [] +_import_structure["tensor_parallel"] = [ + "TPStyle", + "apply_tensor_parallel", + "convert_strided_to_shard", + "gather_full_state_dict", + "restore_strided_from_shard", + "verify_tp_plan", +] try: if not is_torch_greater_or_equal("2.5"): raise OptionalDependencyNotAvailable() @@ -291,6 +298,14 @@ from .quanto import replace_with_quanto_layers from .sinq import SinqDeserialize, SinqQuantize from .spqr import replace_with_spqr_linear + from .tensor_parallel import ( + TPStyle, + apply_tensor_parallel, + convert_strided_to_shard, + gather_full_state_dict, + restore_strided_from_shard, + verify_tp_plan, + ) from .vptq import replace_with_vptq_linear try: From 2e0045cad82a82138e24c8e7216982335c0bed10 Mon Sep 17 00:00:00 2001 From: 3outeille Date: Tue, 14 Apr 2026 15:57:25 +0000 Subject: [PATCH 6/7] ruff --- src/transformers/modeling_utils.py | 7 ++- tests/utils/test_core_model_loading.py | 82 ++++++++++++++------------ 2 files changed, 49 insertions(+), 40 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0ebd98b2836a..48453d47579b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4123,12 +4123,13 @@ def from_pretrained( if distributed_config is not None: model.config.distributed_config = distributed_config model.device_mesh = device_mesh - sub_mesh = lambda name: device_mesh[name] if device_mesh.ndim > 1 else device_mesh mesh_dim_names = device_mesh.mesh_dim_names or () if "tp" in mesh_dim_names: - model = apply_tensor_parallel(model, sub_mesh("tp"), distributed_config.tp_plan) + tp_mesh = device_mesh["tp"] if device_mesh.ndim > 1 else device_mesh + model = apply_tensor_parallel(model, tp_mesh, distributed_config.tp_plan) if "fsdp" in mesh_dim_names: - model = apply_fully_shard_data_parallel(model, sub_mesh("fsdp"), distributed_config.fsdp_plan) + fsdp_mesh = device_mesh["fsdp"] if device_mesh.ndim > 1 else device_mesh + model = apply_fully_shard_data_parallel(model, fsdp_mesh, distributed_config.fsdp_plan) else: # Accelerate path: auto device mapping if device_map is not None: diff --git a/tests/utils/test_core_model_loading.py b/tests/utils/test_core_model_loading.py index 5e7dc3aa7b2e..f7fa2d1ad10c 100644 --- a/tests/utils/test_core_model_loading.py +++ b/tests/utils/test_core_model_loading.py @@ -274,43 +274,51 @@ def _make_dtensor_shard_op(mesh, placements, param_shape, local_shape): class TestConvertAndLoadStateDict(unittest.TestCase): def test_dtensor_shard_aware_mixtral_conversion_uses_only_local_experts(self): - """ - The problem: Mixtral has 8 experts. The checkpoint stores them separately: - experts.0.w1.weight (2x2) - experts.0.w3.weight (2x2) - experts.1.w1.weight (2x2) - experts.1.w3.weight (2x2) - - The model stores them packed into one tensor: - experts.gate_up_proj.weight (2, 4, 2) - ^ ^ ^ - | | └─ features - | └─ w1 (2) + w3 (2) concatenated - └─ num_experts - - The conversion (without FSDP) is: load all expert w1/w3 tensors → MergeModulelist(dim=0) stacks experts → Concatenate(dim=1) joins w1+w3. - - Example — Mixtral experts with FSDP Shard(0) on the expert dim: - - checkpoint files shard_tensor rank 0 gets - ──────────────── ──────────── ─────────── - experts.0.w1 [[0,1],[2,3]] idx=0 → kept [[0,1],[2,3]] - experts.1.w1 [[10,11],...] idx=1 → None (not owned) - experts.0.w3 [[4,5],[6,7]] idx=0 → kept [[4,5],[6,7]] - experts.1.w3 [[14,15],...] idx=1 → None (not owned) - - WeightConverter then stacks + concatenates only the kept tensors: gate_up_proj = [[[0,1],[2,3],[4,5],[6,7]]] shape (1,4,2) - - MergeModulelist(dim=0): [[0,1],[2,3]] → [[[0,1],[2,3]]] (1 expert, shape 1x2x2) - [[4,5],[6,7]] → [[[4,5],[6,7]]] (1 expert, shape 1x2x2) - - Concatenate(dim=1): cat along dim 1 → [[[0,1],[2,3],[4,5],[6,7]]] (shape 1x4x2) - ~~~~~~~~~~~ ~~~~~~~~~~~ - w1 w3 - - The key point: DtensorShardOperation.shard_tensor(tensor_idx=1) returns None for rank 0, so the - converter never even processes expert 1's data. This saves memory during loading. this should explain as well the - other tests + """Integration test: FSDP-sharded expert loading + WeightConverter. + + The problem: Mixtral has 8 experts. The checkpoint stores them separately:: + + experts.0.w1.weight (2x2) + experts.0.w3.weight (2x2) + experts.1.w1.weight (2x2) + experts.1.w3.weight (2x2) + + The model stores them packed into one tensor:: + + experts.gate_up_proj.weight (2, 4, 2) + ^ ^ ^ + | | +-- features + | +-- w1 (2) + w3 (2) concatenated + +-- num_experts + + The conversion (without FSDP) is: load all expert w1/w3 tensors, + MergeModulelist(dim=0) stacks experts, Concatenate(dim=1) joins w1+w3. + + With FSDP, Shard(0) splits the expert dim across ranks. Rank 0 owns + expert 0, rank 1 owns expert 1. So rank 0 should skip loading expert 1 + entirely -- not load it then discard it. + + What the test checks:: + + checkpoint files shard_tensor rank 0 gets + ---------------- ------------ ----------- + experts.0.w1 [[0,1],[2,3]] idx=0 -> kept [[0,1],[2,3]] + experts.1.w1 [[10,11],...] idx=1 -> None (not owned) + experts.0.w3 [[4,5],[6,7]] idx=0 -> kept [[4,5],[6,7]] + experts.1.w3 [[14,15],...] idx=1 -> None (not owned) + + WeightConverter then combines only the kept tensors:: + + MergeModulelist(dim=0): stack owned experts -> shape (1, 2, 2) each + Concatenate(dim=1): cat w1 + w3 along dim 1 + + gate_up_proj = [[[0,1],[2,3],[4,5],[6,7]]] shape (1, 4, 2) + ~~~~~~~~~~ ~~~~~~~~~~ + w1 w3 + + The key point: DtensorShardOperation.shard_tensor(tensor_idx=1) returns + None for rank 0, so the converter never even processes expert 1's data. + This saves memory during loading. """ shard_op = _make_dtensor_shard_op( FakeMesh(shape=(2,), rank=0), From 39bea22677282b2f3d552c8f582ea8daeebf5d6c Mon Sep 17 00:00:00 2001 From: 3outeille Date: Tue, 14 Apr 2026 16:01:54 +0000 Subject: [PATCH 7/7] ruff --- src/transformers/distributed/utils.py | 4 +++- src/transformers/integrations/fsdp.py | 1 + src/transformers/integrations/tensor_parallel.py | 6 +++++- src/transformers/modeling_utils.py | 11 +++++++---- tests/utils/test_core_model_loading.py | 14 +++++++++----- 5 files changed, 25 insertions(+), 11 deletions(-) diff --git a/src/transformers/distributed/utils.py b/src/transformers/distributed/utils.py index b909b6de3683..fa4d554f130d 100644 --- a/src/transformers/distributed/utils.py +++ b/src/transformers/distributed/utils.py @@ -98,7 +98,9 @@ def init_device_mesh(distributed_config: DistributedConfig) -> torch.distributed tp_size = distributed_config.tp_size fsdp_size = distributed_config.fsdp_size - assert world_size == tp_size * fsdp_size, f"world_size ({world_size}) must be equal to tp_size ({tp_size}) * fsdp_size ({fsdp_size})" + assert world_size == tp_size * fsdp_size, ( + f"world_size ({world_size}) must be equal to tp_size ({tp_size}) * fsdp_size ({fsdp_size})" + ) dims, names = [], [] if fsdp_size > 1: diff --git a/src/transformers/integrations/fsdp.py b/src/transformers/integrations/fsdp.py index 71f637df044d..128cba7d253f 100644 --- a/src/transformers/integrations/fsdp.py +++ b/src/transformers/integrations/fsdp.py @@ -492,6 +492,7 @@ def apply_fully_shard_data_parallel( return model + # ========================= PEFT compatibility ========================= # TODO(3outeille): make sure new FSDP works with PEFT def get_fsdp_ckpt_kwargs(): diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 207adc293167..759280defe8f 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -123,7 +123,9 @@ def gather_full_state_dict(model) -> dict[str, torch.Tensor]: if isinstance(tensor, DTensor): # All ranks participate in the collective, only rank 0 keeps the result with torch.no_grad(): - full = tensor.redistribute(placements=[Replicate()] * tensor.device_mesh.ndim, async_op=False).to_local() + full = tensor.redistribute( + placements=[Replicate()] * tensor.device_mesh.ndim, async_op=False + ).to_local() if is_rank0: result[key] = _to_cpu_fresh(full) del full @@ -143,6 +145,7 @@ def _redistribute_dtensor(tensor: DTensor, target_placements: tuple) -> DTensor: replicated = tensor.redistribute(placements=[Replicate()] * tensor.device_mesh.ndim) return replicated.redistribute(placements=target_placements) + def convert_strided_to_shard(state_dict: dict) -> dict[str, tuple]: # Convert _StridedShard DTensors in a state dict to plain Shard for DCP compatibility. placement_map: dict[str, tuple] = {} @@ -171,6 +174,7 @@ def _resolve(d, dotted_key): if leaf_key in container and isinstance(container[leaf_key], DTensor): container[leaf_key] = _redistribute_dtensor(container[leaf_key], original_placements) + def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str | TPStyle] | None): """ Verify the TP plan of the model, log a warning if the layers that were not sharded and the rules that were not applied. diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 48453d47579b..1629c4ca4d9b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -38,6 +38,7 @@ from safetensors import safe_open from safetensors.torch import save_file as safe_save_file from torch import Tensor, nn +from torch.distributed.tensor import DTensor from torch.distributions import constraints from torch.utils.checkpoint import checkpoint @@ -4550,8 +4551,6 @@ def _move_missing_keys_from_meta_to_device( # will be re-initialized for nothing (which can be quite long) for key in missing_keys - self.all_tied_weights_keys.keys(): param = self.get_parameter_or_buffer(key) - from torch.distributed.tensor import DTensor - if isinstance(param, DTensor): # DTensor from parallelize_module on meta — materialize on actual device local_value = torch.empty( @@ -4560,8 +4559,12 @@ def _move_missing_keys_from_meta_to_device( device=torch.device(param.device_mesh.device_type, torch.cuda.current_device()), ) new_dtensor = DTensor.from_local( - local_value, param.device_mesh, param.placements, - run_check=False, shape=param.shape, stride=tuple(param.stride()), + local_value, + param.device_mesh, + param.placements, + run_check=False, + shape=param.shape, + stride=tuple(param.stride()), ) with torch.no_grad(): new_param = torch.nn.Parameter(new_dtensor, requires_grad=param.requires_grad) diff --git a/tests/utils/test_core_model_loading.py b/tests/utils/test_core_model_loading.py index f7fa2d1ad10c..787cf7b903ad 100644 --- a/tests/utils/test_core_model_loading.py +++ b/tests/utils/test_core_model_loading.py @@ -911,7 +911,7 @@ def test_no_shard_returns_full_tensor(self): torch.testing.assert_close(op.shard_tensor(tensor), tensor) def test_1d_shard_fast_path(self): - #TODO(3outeille): double check fast path + # TODO(3outeille): double check fast path tensor = torch.arange(16).reshape(4, 4).float() for rank, expected in [(0, tensor[:2]), (1, tensor[2:])]: mesh = FakeMesh(shape=(2,), rank=rank) @@ -940,8 +940,10 @@ def test_nd_strided_shard_disjoint_ranges(self): for rank in range(4): mesh = FakeMesh(shape=(2, 2), rank=rank) op = _make_dtensor_shard_op( - mesh, [Shard(0), _StridedShard(dim=1, split_factor=2)], - param_shape=(8, 8), local_shape=(4, 4), + mesh, + [Shard(0), _StridedShard(dim=1, split_factor=2)], + param_shape=(8, 8), + local_shape=(4, 4), ) torch.testing.assert_close(op.shard_tensor(tensor), expected[rank], msg=f"rank {rank}") @@ -952,8 +954,10 @@ def test_nd_strided_plus_shard_same_dim_fallback(self): for rank in range(4): mesh = FakeMesh(shape=(2, 2), rank=rank) op = _make_dtensor_shard_op( - mesh, [_StridedShard(dim=0, split_factor=2), Shard(0)], - param_shape=(4, 4), local_shape=(1, 4), + mesh, + [_StridedShard(dim=0, split_factor=2), Shard(0)], + param_shape=(4, 4), + local_shape=(1, 4), ) torch.testing.assert_close(op.shard_tensor(tensor), expected[rank], msg=f"rank {rank}")