From 739332cdcc3352a77d5e7daa61aa4c5ab15441f4 Mon Sep 17 00:00:00 2001 From: 3outeille Date: Mon, 13 Apr 2026 14:14:20 +0000 Subject: [PATCH 1/2] DistributedConfig + shard-on-read loading - DtensorShardOperation for range-math shard-on-read - spawn_materialize() enhancements - from_pretrained wiring for distributed config - Shard operation helpers in tensor_parallel - Shard-on-read and LoadStateDictConfig tests --- src/transformers/core_model_loading.py | 342 +++++++++++++++--- .../integrations/tensor_parallel.py | 10 +- src/transformers/modeling_utils.py | 36 +- tests/utils/test_core_model_loading.py | 84 ++++- tests/utils/test_modeling_utils.py | 37 ++ 5 files changed, 427 insertions(+), 82 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index e0310c4abfeb..2b528bd0a829 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -32,7 +32,7 @@ import torch from .integrations.accelerate import get_device, offload_weight -from .integrations.tensor_parallel import ALL_PARALLEL_STYLES +from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, get_tensor_shard from .utils import is_env_variable_true from .utils.loading_report import LoadStateDictInfo from .utils.logging import get_logger, tqdm @@ -41,9 +41,11 @@ _torch_distributed_available = torch.distributed.is_available() if TYPE_CHECKING: - from .integrations.tensor_parallel import TensorParallelLayer from .modeling_utils import LoadStateDictConfig, PreTrainedModel from .quantizers import HfQuantizer +elif _torch_distributed_available: + from torch.distributed.tensor import DTensor + from torch.distributed.tensor._utils import compute_local_shape_and_global_offset logger = get_logger(__name__) @@ -81,6 +83,21 @@ def build_glob_alternation( return alternation, src_group_to_glob, tgt_group_to_glob +def resolve_target_wildcards(source_pattern: str, target_pattern: str, source_key: str) -> str: + if "*" not in target_pattern or "*" not in source_pattern: + return target_pattern + + wildcard_regex = re.escape(source_pattern).replace(r"\*", r"(.*?)") + match = re.fullmatch(wildcard_regex, source_key) + if match is None: + return target_pattern + + resolved_target = target_pattern + for wildcard_value in match.groups(): + resolved_target = resolved_target.replace("*", wildcard_value, 1) + return resolved_target + + class ConversionOps: """Base class for weight conversion operations.""" @@ -316,7 +333,7 @@ def __init__(self): def _apply(self, tensor: torch.Tensor) -> torch.Tensor: dim1, dim2 = tensor.shape - n_heads = self.config.getattr("num_attention_heads", 1) + n_heads = getattr(self.config, "num_attention_heads", 1) tensor = tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2) tensor = tensor.transpose(1, 2).reshape(dim1, dim2) @@ -332,11 +349,10 @@ def convert( **kwargs, ) -> dict[str, list[torch.Tensor]]: self.config = config - output: dict[str, list[torch.Tensor]] = {} + output = {} for key, tensors in input_dict.items(): - if len(tensors) != 1: - raise ValueError("PermuteForRope expects a single tensor per key.") - output[key] = [self._apply(tensors[0])] + tensor = tensors[0] if isinstance(tensors, list) else tensors + output[key] = self._apply(tensor) return output @@ -519,7 +535,7 @@ class WeightTransform: target_patterns: str | list[str] = field(init=True) compiled_sources: re.Pattern = field(init=False) - distributed_operation: TensorParallelLayer | None = None + distributed_operation: Any | None = None quantization_operation: ConversionOps | None = None collected_tensors: dict[str, list[Future]] = field(default_factory=lambda: defaultdict(list), init=False) @@ -612,6 +628,7 @@ def rename_source_key(self, source_key: str) -> tuple[str, str | None]: source_pattern_that_matched = self.source_patterns[int(matching_group_name[1:])] # If we matched, we always replace with the first target pattern, in case we have several (one to many transform) replacement = self.target_patterns[0] + replacement = resolve_target_wildcards(source_pattern_that_matched, replacement, source_key) # Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper, sam3) if r"\1" in replacement: # The index of the internal group we need to replace is the index of the matched named group as it comes @@ -659,7 +676,7 @@ def materialize_tensors(self) -> dict[str, list[torch.Tensor]]: tensors = [future.result() for future in tensors if future.result() is not None] # Sync loading elif callable(tensors[0]): - tensors = [func() for func in tensors] + tensors = [tensor for func in tensors if (tensor := func()) is not None] # Add them to the new dictionary collected_tensors[key] = tensors @@ -766,18 +783,41 @@ def convert( pass if hf_quantizer is not None and self.quantization_operation is not None: - with log_conversion_errors( - layer_name, loading_info, (len(collected_tensors), layer_name), self.quantization_operation - ): - collected_tensors = self.quantization_operation.convert( - collected_tensors, - source_patterns=self.source_patterns, - target_patterns=self.target_patterns, - full_layer_name=layer_name, - config=config, - model=model, - missing_keys=loading_info.missing_keys if loading_info else None, - ) + if len(collected_tensors) > 1 and model is not None: + quantized_tensors = {} + for target_key, tensor in collected_tensors.items(): + if not hf_quantizer.param_needs_quantization(model, target_key): + quantized_tensors[target_key] = tensor + continue + quantize_input = tensor if isinstance(tensor, list) else [tensor] + with log_conversion_errors( + target_key, loading_info, (len(quantize_input), target_key), self.quantization_operation + ): + quantized_tensors.update( + self.quantization_operation.convert( + {target_key: quantize_input}, + source_patterns=self.source_patterns, + target_patterns=[target_key], + full_layer_name=target_key, + config=config, + model=model, + missing_keys=loading_info.missing_keys if loading_info else None, + ) + ) + collected_tensors = quantized_tensors + else: + with log_conversion_errors( + layer_name, loading_info, (len(collected_tensors), layer_name), self.quantization_operation + ): + collected_tensors = self.quantization_operation.convert( + collected_tensors, + source_patterns=self.source_patterns, + target_patterns=self.target_patterns, + full_layer_name=layer_name, + config=config, + model=model, + missing_keys=loading_info.missing_keys if loading_info else None, + ) return collected_tensors @@ -815,10 +855,15 @@ def _job(): return _job -def spawn_tp_materialize( - thread_pool: ThreadPoolExecutor | None, tensor: torch.Tensor, sharding_method, tensor_idx, device=None, dtype=None +def spawn_parallel_materialize( + thread_pool: ThreadPoolExecutor | None, + tensor: torch.Tensor, + sharding_method, + tensor_idx, + device=None, + dtype=None, ) -> Future | Callable: - """Materialize and shard a tensor (according to the TP-plan) from file asynchronously if `thread_pool` is provided, or + """Materialize and shard a tensor according to the active parallelism strategy if `thread_pool` is provided, or return a Callable that will load the tensor synchronously when called.""" def _job(): @@ -832,6 +877,133 @@ def _job(): return _job +@dataclass(slots=True) +class ParallelMaterializationContext: + distributed_operation: Any + tensor_idx: int | None + device: Any + + +def is_dtensor_like(value: Any) -> bool: + return all(hasattr(value, attr) for attr in ("device_mesh", "placements", "to_local")) + + +@dataclass(slots=True) +class FSDPShardOperation: + device_mesh: Any + rank: int + empty_param: Any + placements: tuple[Any, ...] + shard_placement: Any | None = field(init=False, default=None) + local_shape: tuple[int, ...] = field(init=False) + + def __post_init__(self): + shard_placements = [placement for placement in self.placements if placement.is_shard()] + if len(shard_placements) > 1: + raise NotImplementedError( + f"FSDP shard-on-read does not support multiple shard placements yet: {self.placements}" + ) + self.shard_placement = shard_placements[0] if shard_placements else None + if self.shard_placement is not None and len(self.placements) != 1: + raise NotImplementedError( + f"FSDP shard-on-read only supports a single placement today. Got placements={self.placements}." + ) + self.local_shape = self.get_expected_sharded_shape(self.empty_param.shape) + + @classmethod + def from_param(cls, param: Any) -> FSDPShardOperation: + return cls( + device_mesh=param.device_mesh, + rank=param.device_mesh.get_local_rank(), + empty_param=param, + placements=tuple(param.placements), + ) + + def shard_tensor( + self, param: torch.Tensor, tensor_idx: int | None = None, device=None, dtype=None + ) -> torch.Tensor | None: + if self.shard_placement is None: + local_tensor = param[...] + else: + param_shape = list(param.shape) if isinstance(param, torch.Tensor) else param.get_shape() + # Mixtral-style converted expert weights first stack individual expert tensors along dim 0 before + # concatenating. Only materialize the experts owned by this rank. + if ( + tensor_idx is not None + and len(self.empty_param.shape) == len(param_shape) + 1 + and self.shard_placement.dim == 0 + ): + local_expert_count = self.local_shape[0] + expert_offset = compute_local_shape_and_global_offset( + self.empty_param.shape, self.device_mesh, self.placements + )[1][0] + if tensor_idx < expert_offset or tensor_idx >= expert_offset + local_expert_count: + return None + local_tensor = param[...] + else: + local_tensor = get_tensor_shard( + param, + self.empty_param, + self.device_mesh, + self.rank, + self.shard_placement.dim, + tensor_idx=tensor_idx, + ) + if local_tensor is None: + return None + return local_tensor.to(device=device, dtype=dtype) + + def get_expected_sharded_shape(self, full_shape: tuple[int, ...] | torch.Size) -> tuple[int, ...]: + local_shape, _ = compute_local_shape_and_global_offset(full_shape, self.device_mesh, self.placements) + return tuple(local_shape) + + def update_module_attributes(self, module: torch.nn.Module): + return None + + +def get_parallel_materialization_context( + mapping: WeightTransform, + renamed_key: str, + source_pattern: str, + empty_param: Any, + device_mesh: Any, + parallel_plan: dict[str, Any], + parallel_pattern_matcher: re.Pattern | None, + parallel_pattern_by_group_name: dict[str, str] | None, + device_map: dict[str, Any], +) -> ParallelMaterializationContext | None: + tensor_idx = ( + len(mapping.collected_tensors.get(source_pattern, [])) + if isinstance(mapping, WeightConverter) and isinstance(mapping.operations[0], MergeModulelist) + else None + ) + + if ( + device_mesh + and parallel_plan + and parallel_pattern_matcher is not None + and parallel_pattern_by_group_name is not None + ): + if matched_parallel_pattern := parallel_pattern_matcher.search(renamed_key): + matched_parallel_pattern = parallel_pattern_by_group_name[matched_parallel_pattern.lastgroup] + if getattr(mapping, "distributed_operation", None) is None: + parallel_layer = ALL_PARALLEL_STYLES[parallel_plan[matched_parallel_pattern]].__class__ + mapping.distributed_operation = parallel_layer( + device_mesh=device_mesh, rank=device_mesh.get_local_rank(), empty_param=empty_param.clone() + ) + return ParallelMaterializationContext(mapping.distributed_operation, tensor_idx, device_map[""]) + + if is_dtensor_like(empty_param): + if getattr(mapping, "distributed_operation", None) is None: + mapping.distributed_operation = FSDPShardOperation.from_param(empty_param) + return ParallelMaterializationContext( + mapping.distributed_operation, + tensor_idx, + get_device(device_map, renamed_key, valid_torch_device=True), + ) + + return None + def dot_natural_key(s: str): """Sort key for state-dict names: split on ``"."`` and sort digits numerically and strings alphabetically. We emit a tuple at each point to sort ints @@ -900,7 +1072,7 @@ def set_param_for_module( target_name: str, param_value: torch.Tensor, loading_info: LoadStateDictInfo, - distributed_operation: TensorParallelLayer | None, + distributed_operation: Any | None, hf_quantizer: HfQuantizer, ): module_path, _, param_name = target_name.rpartition(".") @@ -915,15 +1087,18 @@ def set_param_for_module( if ref is None: loading_info.unexpected_keys.add(target_name) else: - if not isinstance(param_value, torch.nn.Parameter): + if not isinstance(param_value, torch.nn.Parameter) and not is_dtensor_like(ref): if param_name not in module_obj._buffers: param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point()) # Remove from missing keys (it's either mismatched, or all good) loading_info.missing_keys.discard(target_name) - # Determine expected shape: for TP, use sharded shape; otherwise, use full shape - if distributed_operation is not None: + # Determine expected shape: for TP/FSDP shard-on-read, use the local shard shape; otherwise, use full shape + if is_dtensor_like(ref): + local_shape, _ = compute_local_shape_and_global_offset(ref.shape, ref.device_mesh, ref.placements) + expected_shape = torch.Size(local_shape) + elif distributed_operation is not None: expected_shape = torch.Size(distributed_operation.get_expected_sharded_shape(ref.shape)) else: expected_shape = ref.shape @@ -931,11 +1106,29 @@ def set_param_for_module( if ref is not None and param_value.shape != expected_shape and hf_quantizer is None: loading_info.mismatched_keys.add((target_name, param_value.shape, expected_shape)) else: - # super important otherwise _init_weight will re-init the param - param_value._is_hf_initialized = True - setattr(module_obj, param_name, param_value) - if distributed_operation is not None: - distributed_operation.update_module_attributes(module_obj) + if is_dtensor_like(ref): + local_param = param_value.detach() if isinstance(param_value, torch.nn.Parameter) else param_value + fsdp_param = DTensor.from_local( + local_param.contiguous(), + ref.device_mesh, + ref.placements, + run_check=False, + shape=ref.shape, + stride=tuple(ref.stride()), + ) + with torch.no_grad(): + if ref.is_meta: + fsdp_param = torch.nn.Parameter(fsdp_param, requires_grad=ref.requires_grad) + torch.utils.swap_tensors(ref, fsdp_param) + else: + ref.copy_(fsdp_param) + ref._is_hf_initialized = True + else: + # super important otherwise _init_weight will re-init the param + param_value._is_hf_initialized = True + setattr(module_obj, param_name, param_value) + if distributed_operation is not None: + distributed_operation.update_module_attributes(module_obj) def offload_and_maybe_resave_param( @@ -1002,6 +1195,30 @@ def rename_source_key( return renamed_key, source_pattern +def concretize_target_patterns( + converter: WeightConverter, + source_key: str, + source_pattern: str, + prefix: str | None, + meta_state_dict: dict | None, +) -> WeightConverter: + concrete_targets = [] + for target_pattern in converter.target_patterns: + concrete_target = resolve_target_wildcards(source_pattern, target_pattern, source_key) + if prefix is not None and meta_state_dict is not None: + if ( + concrete_target.startswith(prefix) + and meta_state_dict.get(re.sub(f"^{prefix}.", "", concrete_target, count=1)) is not None + ): + concrete_target = re.sub(f"^{prefix}.", "", concrete_target, count=1) + elif meta_state_dict.get(f"{prefix}.{concrete_target}") is not None: + concrete_target = f"{prefix}.{concrete_target}" + concrete_targets.append(concrete_target) + + object.__setattr__(converter, "target_patterns", concrete_targets) + return converter + + def convert_and_load_state_dict_in_model( model: PreTrainedModel, state_dict: dict[str, Any], @@ -1156,10 +1373,25 @@ def convert_and_load_state_dict_in_model( # 2. finally, collect the tensor into the proper converter if renamed_key in meta_model_state_dict: - empty_param = meta_model_state_dict.get(renamed_key) + empty_param = meta_model_state_dict[renamed_key] + try: + empty_param = model.get_parameter_or_buffer(renamed_key) + except (AttributeError, KeyError): + if getattr(model, "_is_fsdp_managed_module", False): + raise RuntimeError( + f"FSDP shard-on-read requires the live parameter for {renamed_key!r}, " + f"but get_parameter_or_buffer() failed." + ) # If we enter here, we have a WeightConverter operation to perform if source_pattern is not None: new_converter = deepcopy(pattern_to_converter[source_pattern]) + new_converter = concretize_target_patterns( + new_converter, + original_key, + source_pattern, + prefix, + meta_model_state_dict, + ) # each target key gets its own converter instance mapping = param_name_to_load.setdefault(renamed_key, new_converter) # Otherwise, only potential renaming @@ -1200,29 +1432,27 @@ def convert_and_load_state_dict_in_model( elif empty_param is not None and empty_param.dtype != _dtype: _dtype = empty_param.dtype # usually correct when initializing - # 4. Handle TP sharding or device_map placement + # 4. Handle parallel shard-on-read or device_map placement future_or_tensor = None - if device_mesh and tp_plan: - if matched_tp_pattern := tp_plan_alt.search(renamed_key): - matched_tp_pattern = tp_plan_by_group_name[matched_tp_pattern.lastgroup] - if getattr(mapping, "distributed_operation", None) is None: - tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__ - mapping.distributed_operation = tp_layer( - device_mesh=device_mesh, rank=device_mesh.get_local_rank(), empty_param=empty_param.clone() - ) - shard_index = ( - len(mapping.collected_tensors.get(source_pattern, [])) - if isinstance(mapping, WeightConverter) and isinstance(mapping.operations[0], MergeModulelist) - else None - ) - future_or_tensor = spawn_tp_materialize( - thread_pool, - tensor, - mapping.distributed_operation, - shard_index, - device_map[""], - _dtype, - ) + if parallel_context := get_parallel_materialization_context( + mapping=mapping, + renamed_key=renamed_key, + source_pattern=source_pattern, + empty_param=empty_param, + device_mesh=device_mesh, + parallel_plan=tp_plan, + parallel_pattern_matcher=tp_plan_alt if tp_plan else None, + parallel_pattern_by_group_name=tp_plan_by_group_name if tp_plan else None, + device_map=device_map, + ): + future_or_tensor = spawn_parallel_materialize( + thread_pool, + tensor, + parallel_context.distributed_operation, + parallel_context.tensor_idx, + parallel_context.device, + _dtype, + ) if future_or_tensor is None: param_device = get_device(device_map, renamed_key, valid_torch_device=True) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 39a2e696941b..c378ebbac227 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -1506,8 +1506,8 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None): logger.warning(f"The following layers were not sharded: {', '.join(unsharded_layers)}") -def distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size): - """Distribute a model according to the TP plan.""" +def distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size, fsdp_plan=None): + """Attach distributed runtime hooks before checkpoint loading.""" model._tp_size = tp_size model._device_mesh = device_mesh if distributed_config is not None: @@ -1517,7 +1517,7 @@ def distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size): # Set the new requested tp_plan on the model if isinstance(tp_plan, dict): model.tp_plan = tp_plan - model_plan = model.tp_plan + model_plan = model.tp_plan if tp_plan is not None or tp_size is not None else None if model_plan is not None and _torch_distributed_available: for v in model_plan.values(): if v not in ALL_PARALLEL_STYLES: @@ -1533,4 +1533,8 @@ def distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size): device_mesh, ) module._is_hooked = True + if fsdp_plan is not None: + from .fsdp import apply_fsdp2 + + model = apply_fsdp2(model, device_mesh, fsdp_plan) return model diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2a7304650640..b6feec6f7ba6 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -69,7 +69,7 @@ from .integrations.flash_attention import flash_attention_forward from .integrations.flash_paged import paged_attention_forward from .integrations.flex_attention import flex_attention_forward -from .integrations.fsdp import apply_fsdp2 +from .integrations.fsdp import initialize_fsdp from .integrations.hub_kernels import allow_all_hub_kernels, is_kernel from .integrations.peft import maybe_load_adapters from .integrations.sdpa_attention import sdpa_attention_forward @@ -178,6 +178,7 @@ class LoadStateDictConfig: dtype_plan: dict = field(default_factory=dict) hf_quantizer: HfQuantizer | None = None device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None + tp_plan: dict[str, str] | None = None weights_only: bool = True weight_mapping: list[WeightConverter | WeightRenaming] | None = None @@ -4003,11 +4004,21 @@ def from_pretrained( ": PartialState().process_index} where PartialState comes from accelerate library" ) + if fsdp_plan is not None and (tp_plan is not None or tp_size is not None): + raise ValueError("Combining `fsdp_plan` with tensor parallel loading is not supported yet.") + if tp_plan is not None or tp_size is not None: # TP warnings, and setup device_map, device_mesh, tp_size = initialize_tensor_parallelism( tp_plan, tp_size=tp_size, device_mesh=device_mesh, device_map=device_map ) + if fsdp_plan is not None: + device_map, device_mesh, _ = initialize_fsdp( + fsdp_plan=fsdp_plan, + device_mesh=device_mesh, + device_map=device_map, + ) + if gguf_file is not None and not is_accelerate_available(): raise ValueError("accelerate is required when loading a GGUF file `pip install accelerate`.") @@ -4132,14 +4143,17 @@ def from_pretrained( # Obtain the weight conversion mapping for this model if any are registered and apply to all submodels recursively weight_conversions = get_model_conversion_mapping(model, key_mapping, hf_quantizer) - if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights - model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size) + if _torch_distributed_available and device_mesh is not None and (tp_plan is not None or fsdp_plan is not None): + model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size, fsdp_plan=fsdp_plan) # Prepare the full device map - if device_map is not None: + if isinstance(device_map, dict): device_map = _get_device_map(model, device_map, max_memory, hf_quantizer) + elif device_map is not None: + device_map = {"": device_map} # Finalize model weight initialization + active_tp_plan = getattr(model, "_tp_plan", None) if tp_size is not None else None load_config = LoadStateDictConfig( pretrained_model_name_or_path=pretrained_model_name_or_path, ignore_mismatched_sizes=ignore_mismatched_sizes, @@ -4151,6 +4165,7 @@ def from_pretrained( dtype_plan=dtype_plan, hf_quantizer=hf_quantizer, device_mesh=device_mesh, + tp_plan=active_tp_plan, weights_only=weights_only, weight_mapping=weight_conversions, use_safetensors=use_safetensors, @@ -4161,15 +4176,6 @@ def from_pretrained( model.eval() # Set model in evaluation mode to deactivate Dropout modules by default model.set_use_kernels(use_kernels, kernel_config) - # Apply FSDP2 if configured (must be after weight loading) - if fsdp_plan is not None: - if device_mesh is None: - raise ValueError( - "`fsdp_plan` was provided but no device mesh is available. " - "Pass `device_mesh` to `from_pretrained`." - ) - model = apply_fsdp2(model, device_mesh, fsdp_plan) - # If it is a model with generation capabilities, attempt to load generation files (generation config, # custom generate function) if model.can_generate() and hasattr(model, "adjust_generation_fn") and not gguf_file: @@ -4226,7 +4232,7 @@ def _load_pretrained_model( expected_keys = list(model.state_dict().keys()) if expected_keys is None else expected_keys if logger.level >= logging.WARNING: - verify_tp_plan(expected_keys, getattr(model, "_tp_plan", None)) + verify_tp_plan(expected_keys, load_config.tp_plan) # This offload index if for params explicitly on the "disk" in the device_map disk_offload_index = None @@ -4289,7 +4295,7 @@ def _load_pretrained_model( model=model, state_dict=merged_state_dict, load_config=load_config, - tp_plan=model._tp_plan, + tp_plan=load_config.tp_plan, disk_offload_index=disk_offload_index, ) diff --git a/tests/utils/test_core_model_loading.py b/tests/utils/test_core_model_loading.py index 3e8c18b1d351..942dcdc99b11 100644 --- a/tests/utils/test_core_model_loading.py +++ b/tests/utils/test_core_model_loading.py @@ -23,6 +23,7 @@ Chunk, Concatenate, ErnieFuseAndSplitTextVisionExperts, + FSDPShardOperation, MergeModulelist, PermuteForRope, WeightConverter, @@ -31,6 +32,7 @@ convert_and_load_state_dict_in_model, rename_source_key, revert_weight_conversion, + spawn_parallel_materialize, ) from transformers.modeling_utils import LoadStateDictConfig from transformers.utils.import_utils import is_triton_available @@ -214,7 +216,66 @@ def __init__(self, add_extra_moe=False): self.mlp = DummyMLP() +class FakeMesh: + def __init__(self, world_size: int, rank: int): + self.shape = (world_size,) + self._rank = rank + + def get_local_rank(self): + return self._rank + + def get_coordinate(self): + return (self._rank,) + + class TestConvertAndLoadStateDict(unittest.TestCase): + def test_fsdp_shard_aware_mixtral_conversion_uses_only_local_experts(self): + shard_op = FSDPShardOperation( + device_mesh=FakeMesh(world_size=2, rank=0), + rank=0, + empty_param=torch.empty((2, 4, 2)), + placements=(torch.distributed.tensor.placement_types.Shard(0),), + ) + converter = WeightConverter( + ["experts.*.w1.weight", "experts.*.w3.weight"], + "experts.gate_up_proj.weight", + operations=[MergeModulelist(dim=0), Concatenate(dim=1)], + ) + + for idx, tensor in enumerate( + [ + torch.tensor([[0.0, 1.0], [2.0, 3.0]]), + torch.tensor([[10.0, 11.0], [12.0, 13.0]]), + ] + ): + converter.add_tensor( + "model.layers.0.experts.gate_up_proj.weight", + f"model.layers.0.experts.{idx}.w1.weight", + "experts.*.w1.weight", + spawn_parallel_materialize(None, tensor, shard_op, idx, device="cpu", dtype=None), + ) + + for idx, tensor in enumerate( + [ + torch.tensor([[4.0, 5.0], [6.0, 7.0]]), + torch.tensor([[14.0, 15.0], [16.0, 17.0]]), + ] + ): + converter.add_tensor( + "model.layers.0.experts.gate_up_proj.weight", + f"model.layers.0.experts.{idx}.w3.weight", + "experts.*.w3.weight", + spawn_parallel_materialize(None, tensor, shard_op, idx, device="cpu", dtype=None), + ) + + converted = converter.convert("model.layers.0.experts.gate_up_proj.weight") + + self.assertEqual(list(converted), ["model.layers.0.experts.gate_up_proj.weight"]) + torch.testing.assert_close( + converted["model.layers.0.experts.gate_up_proj.weight"], + torch.tensor([[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0], [6.0, 7.0]]]), + ) + def test_moe_and_qkv_conversion(self): model = DummyRoot() model.config = PretrainedConfig() @@ -467,6 +528,10 @@ def __init__(self): self, "quantization_config", SimpleNamespace(weight_block_size=bs) ), "param_needs_quantization": lambda self, _model, param_name: param_name.endswith("q_proj.weight"), + "get_quantize_ops": lambda self: __import__( + "transformers.integrations.finegrained_fp8", + fromlist=["Fp8Quantize"], + ).Fp8Quantize(self), "pre_quantized": False, }, ) @@ -499,11 +564,11 @@ def __init__(self): model_state = model.state_dict() self.assertFalse(torch.allclose(raw_k, expected_k)) - torch.testing.assert_close(model_state["model.layers.0.self_attn.k_proj.weight"], expected_k) - torch.testing.assert_close(model_state["model.layers.0.self_attn.v_proj.weight"], expected_v) + torch.testing.assert_close(model_state["layers.0.self_attn.k_proj.weight"], expected_k) + torch.testing.assert_close(model_state["layers.0.self_attn.v_proj.weight"], expected_v) - q_weight_key = "model.layers.0.self_attn.q_proj.weight" - scale_key = "model.layers.0.self_attn.q_proj.weight_scale_inv" + q_weight_key = "layers.0.self_attn.q_proj.weight" + scale_key = "layers.0.self_attn.q_proj.weight_scale_inv" self.assertIn(scale_key, model_state) expected_dtype = torch.float8_e4m3fn if hasattr(torch, "float8_e4m3fn") else torch.int8 self.assertEqual(model_state[q_weight_key].dtype, expected_dtype) @@ -514,11 +579,14 @@ def __init__(self): torch.Size((out_dim // block_size[0], in_dim // block_size[1])), ) - dequant = Fp8Dequantize(block_size=block_size) + dequant = Fp8Dequantize(quantizer) dequantized_q = dequant.convert( - [model_state[q_weight_key], model_state[scale_key]], - context={"quantization_config": quantizer.quantization_config}, - ) + { + "weight$": [model_state[q_weight_key]], + "weight_scale_inv": [model_state[scale_key]], + }, + full_layer_name=q_weight_key, + )[q_weight_key] torch.testing.assert_close(dequantized_q, expected_q, rtol=1e-2, atol=1e-2) def test_ernie4_5_vl_moe_conversion(self): diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 7366845c4d78..59177fec5061 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -431,6 +431,43 @@ def test_get_total_byte_count_does_not_require_process_group(self): self.assertIn(torch.device("cpu"), total_byte_count) self.assertGreater(total_byte_count[torch.device("cpu")], 0) + def test_model_from_pretrained_fsdp_distributes_before_loading(self): + model = GPT2LMHeadModel(GPT2Config(n_layer=1, n_head=2, n_embd=8, n_positions=8, n_ctx=8, vocab_size=32)) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + call_order = [] + + def fake_distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size, fsdp_plan=None): + call_order.append("distribute") + self.assertEqual(fsdp_plan, {"mode": "auto"}) + model._tp_plan = {"model.layers.*.mlp.experts.gate_up_proj": "packed_colwise"} + model._is_fsdp_managed_module = True + return model + + def fake_load_pretrained_model(model, state_dict, checkpoint_files, load_config, expected_keys=None): + call_order.append("load") + self.assertEqual(load_config.device_mesh, "fake-mesh") + self.assertEqual(load_config.device_map, {"": torch.device("cpu")}) + self.assertIsNone(load_config.tp_plan) + return mock.Mock(), None + + with ( + patch( + "transformers.modeling_utils.initialize_fsdp", return_value=(torch.device("cpu"), "fake-mesh", 2) + ), + patch("transformers.modeling_utils.distribute_model", side_effect=fake_distribute_model), + patch.object(GPT2LMHeadModel, "_load_pretrained_model", side_effect=fake_load_pretrained_model), + patch.object( + GPT2LMHeadModel, + "_finalize_model_loading", + side_effect=lambda model, load_config, loading_info: loading_info, + ), + ): + GPT2LMHeadModel.from_pretrained(tmp_dir, fsdp_plan={"mode": "auto"}) + + self.assertEqual(call_order, ["distribute", "load"]) + def test_hub_retry(self): @hub_retry(max_attempts=2) def test_func(): From 21f05610b20caf50e6e017e23f94f8d027f5cefd Mon Sep 17 00:00:00 2001 From: 3outeille Date: Tue, 14 Apr 2026 14:22:30 +0000 Subject: [PATCH 2/2] Fix ruff formatting in core_model_loading.py --- src/transformers/core_model_loading.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 2b528bd0a829..63ff8fac7d7c 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -1004,6 +1004,7 @@ def get_parallel_materialization_context( return None + def dot_natural_key(s: str): """Sort key for state-dict names: split on ``"."`` and sort digits numerically and strings alphabetically. We emit a tuple at each point to sort ints