diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index cd0710649c91..79f1b0a6bf44 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -583,7 +583,8 @@ class WeightTransform: __slots__ = ( "source_patterns", "target_patterns", - "compiled_sources", + "_source_regex_str", + "_compiled_sources", "distributed_operation", "quantization_operation", "collected_tensors", @@ -647,13 +648,40 @@ def __init__(self, source_patterns: str | list[str], target_patterns: str | list pattern = process_source_pattern(pattern, self._original_target_patterns[i]) self.source_patterns[i] = pattern - # Construct the regex we will use to rename keys from the sources to the targets + # Build the regex source string, but compile lazily via `compiled_sources` below. During loading, any key + # that does not match a weight conversion op gets wrapped in a fresh per-weight `WeightRenaming` for + # convenience so it can reuse the same conversion/loading path. Those fallback wrappers never need to call + # `rename_source_key`, so eagerly compiling their regex would just waste work — and it dominates + # `from_pretrained` for models with many parameters. branches = [] for i, source_pattern in enumerate(self.source_patterns): group_name = f"g{i}" pattern = source_pattern.replace(".*.", r"\..*\.") branches.append(f"(?P<{group_name}>{pattern})") - self.compiled_sources = re.compile("|".join(branches)) + self._source_regex_str = "|".join(branches) + self._compiled_sources = None + + @property + def compiled_sources(self) -> re.Pattern: + if self._compiled_sources is None: + self._compiled_sources = re.compile(self._source_regex_str) + return self._compiled_sources + + def __deepcopy__(self, memo): + # A fresh-per-target copy is needed because `collected_tensors`, `layer_targets`, and `_was_used` accumulate + # state during loading. The compiled regex is stateless, so we share it across copies — avoiding a hidden + # `re.compile` that would otherwise run on every per-weight pickle/unpickle round-trip. + cls = self.__class__ + new = cls.__new__(cls) + memo[id(self)] = new + for slot in chain.from_iterable(getattr(c, "__slots__", ()) for c in cls.__mro__): + if not hasattr(self, slot): + continue + if slot == "_compiled_sources": + new._compiled_sources = self._compiled_sources + else: + object.__setattr__(new, slot, deepcopy(getattr(self, slot), memo)) + return new def __repr__(self): return f"{self.__class__.__name__}(source_patterns={self.source_patterns}, target_patterns={self.target_patterns})" @@ -735,14 +763,13 @@ def materialize_tensors(self) -> dict[str, list[torch.Tensor]]: for key in list(self.collected_tensors.keys()): # Remove from internal attribute tensors = self.collected_tensors.pop(key) - # Async loading - if isinstance(tensors[0], Future): - tensors = [future.result() for future in tensors if future.result() is not None] - # Sync loading - elif callable(tensors[0]): - tensors = [func() for func in tensors] + resolved_tensors = [] + for tensor_or_future in tensors: + resolved_tensor = _resolve_pending_tensor(tensor_or_future) + if resolved_tensor is not None: + resolved_tensors.append(resolved_tensor) # Add them to the new dictionary - collected_tensors[key] = tensors + collected_tensors[key] = resolved_tensors return collected_tensors @@ -933,6 +960,15 @@ def convert( GLOBAL_WORKERS = min(4, os.cpu_count() or 4) +def _resolve_pending_tensor(tensor_or_future: Future | Callable | torch.Tensor) -> torch.Tensor | None: + if isinstance(tensor_or_future, Future): + return tensor_or_future.result() + elif callable(tensor_or_future): + return tensor_or_future() + else: + return tensor_or_future + + def _materialize_copy(tensor: torch.Tensor, device=None, dtype=None) -> torch.Tensor: # This slicing is what actually loads the tensor from the safetensors slice object tensor = tensor[...] @@ -1048,9 +1084,16 @@ def set_param_for_module( loading_info: LoadStateDictInfo, distributed_operation: TensorParallelLayer | None, hf_quantizer: HfQuantizer, + module_cache: dict[str, torch.nn.Module] | None = None, ): module_path, _, param_name = target_name.rpartition(".") - module_obj = model.get_submodule(module_path) if module_path else model + if module_cache is not None: + module_obj = module_cache.get(module_path) + if module_obj is None: + module_obj = model.get_submodule(module_path) if module_path else model + module_cache[module_path] = module_obj + else: + module_obj = model.get_submodule(module_path) if module_path else model if param_name == torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX: module_obj.set_extra_state(param_value) @@ -1088,7 +1131,7 @@ def offload_and_maybe_resave_param( target_name: str, param: torch.Tensor, loading_info: LoadStateDictInfo, - disk_offload_folder: str, + disk_offload_folder: str | None, disk_offload_index: dict, applied_ops: WeightConverter | WeightRenaming, ) -> dict: @@ -1136,18 +1179,58 @@ def rename_source_key( break # 3. check if we need to add or remove prefix if necessary (only during loading, not saving) - if prefix is not None and meta_state_dict is not None: - if ( - renamed_key.startswith(prefix) - and meta_state_dict.get(re.sub(f"^{prefix}.", "", renamed_key, count=1)) is not None - ): - renamed_key = re.sub(f"^{prefix}.", "", renamed_key, count=1) - elif meta_state_dict.get(f"{prefix}.{renamed_key}") is not None: - renamed_key = f"{prefix}.{renamed_key}" + if prefix not in (None, "") and meta_state_dict is not None: + prefixed_key = f"{prefix}.{renamed_key}" + prefix_with_separator = f"{prefix}." + if renamed_key.startswith(prefix_with_separator): + unprefixed_key = renamed_key[len(prefix_with_separator) :] + if meta_state_dict.get(unprefixed_key) is not None: + renamed_key = unprefixed_key + elif meta_state_dict.get(prefixed_key) is not None: + renamed_key = prefixed_key return renamed_key, source_pattern +def _assign_or_offload_param( + model: PreTrainedModel, + target_name: str, + param: torch.Tensor, + loading_info: LoadStateDictInfo, + device_map: dict | None, + model_buffers: set[str], + offload_buffers: bool, + disk_offload_folder: str | None, + disk_offload_index: dict | None, + distributed_operation: TensorParallelLayer | None, + hf_quantizer: HfQuantizer | None, + module_cache: dict[str, torch.nn.Module], + applied_ops: WeightConverter | WeightRenaming | None = None, +) -> dict | None: + param_device = get_device(device_map, target_name) + if param_device == "disk" and (target_name not in model_buffers or offload_buffers): + current_disk_offload_index = {} if disk_offload_index is None else disk_offload_index + if applied_ops is None: + loading_info.missing_keys.discard(target_name) + if target_name not in current_disk_offload_index: + return offload_weight(param, target_name, disk_offload_folder, current_disk_offload_index) + else: + return offload_and_maybe_resave_param( + target_name, param, loading_info, disk_offload_folder, current_disk_offload_index, applied_ops + ) + else: + set_param_for_module( + model, + target_name, + param, + loading_info, + distributed_operation, + hf_quantizer, + module_cache=module_cache, + ) + return disk_offload_index + + def convert_and_load_state_dict_in_model( model: PreTrainedModel, state_dict: dict[str, Any], @@ -1279,7 +1362,9 @@ def convert_and_load_state_dict_in_model( renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)] converters = [entry for entry in weight_mapping if isinstance(entry, WeightConverter)] + direct_param_loads: list[tuple[str, Future | Callable | torch.Tensor, TensorParallelLayer | None]] = [] param_name_to_load: dict[str, WeightRenaming | WeightConverter] = {} + module_cache: dict[str, torch.nn.Module] = {"": model} # build '(?P.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'} # and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched. @@ -1303,15 +1388,6 @@ def convert_and_load_state_dict_in_model( # 2. finally, collect the tensor into the proper converter if renamed_key in meta_model_state_dict: empty_param = meta_model_state_dict.get(renamed_key) - # If we enter here, we have a WeightConverter operation to perform - if source_pattern is not None: - new_converter = deepcopy(pattern_to_converter[source_pattern]) - # each target key gets its own converter instance - mapping = param_name_to_load.setdefault(renamed_key, new_converter) - # Otherwise, only potential renaming - else: - mapping = param_name_to_load.setdefault(renamed_key, WeightRenaming(original_key, renamed_key)) - source_pattern = original_key # 3. Handle dtype casting needs_quantization = ( @@ -1319,6 +1395,21 @@ def convert_and_load_state_dict_in_model( and not hf_quantizer.pre_quantized and hf_quantizer.param_needs_quantization(model, renamed_key) ) + mapping = None + if source_pattern is not None: + # each target key gets its own converter instance (deepcopy is lazy: skipped if target already seen, + # e.g. many-to-one/one-to-many converters where several sources land on the same target) + mapping = param_name_to_load.get(renamed_key) + if mapping is None: + mapping = deepcopy(pattern_to_converter[source_pattern]) + param_name_to_load[renamed_key] = mapping + elif needs_quantization: + mapping = param_name_to_load.get(renamed_key) + if mapping is None: + mapping = WeightRenaming(original_key, renamed_key) + param_name_to_load[renamed_key] = mapping + source_pattern = original_key + if needs_quantization: mapping.quantization_operation = hf_quantizer.get_quantize_ops() @@ -1348,14 +1439,24 @@ def convert_and_load_state_dict_in_model( # 4. Handle TP sharding or device_map placement future_or_tensor = None + distributed_operation = None if device_mesh and tp_plan: if matched_tp_pattern := tp_plan_alt.search(renamed_key): matched_tp_pattern = tp_plan_by_group_name[matched_tp_pattern.lastgroup] - if getattr(mapping, "distributed_operation", None) is None: - tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__ - mapping.distributed_operation = tp_layer( + tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__ + if mapping is None: + distributed_operation = tp_layer( device_mesh=device_mesh, rank=device_mesh.get_local_rank(), empty_param=empty_param.clone() ) + else: + distributed_operation = getattr(mapping, "distributed_operation", None) + if distributed_operation is None: + distributed_operation = tp_layer( + device_mesh=device_mesh, + rank=device_mesh.get_local_rank(), + empty_param=empty_param.clone(), + ) + mapping.distributed_operation = distributed_operation shard_index = ( len(mapping.collected_tensors.get(source_pattern, [])) if isinstance(mapping, WeightConverter) and isinstance(mapping.operations[0], MergeModulelist) @@ -1364,7 +1465,7 @@ def convert_and_load_state_dict_in_model( future_or_tensor = spawn_tp_materialize( thread_pool, tensor, - mapping.distributed_operation, + distributed_operation, shard_index, device_map[""], _dtype, @@ -1374,7 +1475,12 @@ def convert_and_load_state_dict_in_model( param_device = get_device(device_map, renamed_key, valid_torch_device=True) future_or_tensor = spawn_materialize(thread_pool, tensor, param_device, _dtype) - mapping.add_tensor(renamed_key, original_key, source_pattern, future_or_tensor) + if mapping is None: + # Fast path for untouched or purely renamed parameters: avoid instantiating a per-weight + # `WeightRenaming` wrapper when we can load the tensor directly. + direct_param_loads.append((renamed_key, future_or_tensor, distributed_operation)) + else: + mapping.add_tensor(renamed_key, original_key, source_pattern, future_or_tensor) elif source_pattern is not None: # add all target keys as unexpected mapping = pattern_to_converter[source_pattern] for k in mapping.target_patterns: @@ -1383,38 +1489,63 @@ def convert_and_load_state_dict_in_model( loading_info.unexpected_keys.add(renamed_key) try: - for first_param_name, mapping in tqdm(param_name_to_load.items(), desc="Loading weights"): - try: - realized_value = mapping.convert( - first_param_name, - model=model, - config=model.config, - hf_quantizer=hf_quantizer, - loading_info=loading_info, - ) - for target_name, param in realized_value.items(): - param = param[0] if isinstance(param, list) else param - param_device = get_device(device_map, target_name) - # Offloading support - if param_device == "disk" and (target_name not in model_buffers or offload_buffers): - disk_offload_index = offload_and_maybe_resave_param( - target_name, param, loading_info, disk_offload_folder, disk_offload_index, mapping - ) - else: - set_param_for_module( + with tqdm(total=len(direct_param_loads) + len(param_name_to_load), desc="Loading weights") as progress_bar: + for target_name, pending_param, distributed_operation in direct_param_loads: + try: + param = _resolve_pending_tensor(pending_param) + if param is None: + continue + disk_offload_index = _assign_or_offload_param( + model, + target_name, + param, + loading_info, + device_map, + model_buffers, + offload_buffers, + disk_offload_folder, + disk_offload_index, + distributed_operation, + hf_quantizer, + module_cache, + ) + finally: + progress_bar.update() + + for first_param_name, mapping in param_name_to_load.items(): + try: + realized_value = mapping.convert( + first_param_name, + model=model, + config=model.config, + hf_quantizer=hf_quantizer, + loading_info=loading_info, + ) + for target_name, param in realized_value.items(): + param = param[0] if isinstance(param, list) else param + disk_offload_index = _assign_or_offload_param( model, target_name, param, loading_info, + device_map, + model_buffers, + offload_buffers, + disk_offload_folder, + disk_offload_index, mapping.distributed_operation, hf_quantizer, + module_cache, + mapping, ) - # Cleanup all the tensors that were gathered before next iteration - del realized_value + # Cleanup all the tensors that were gathered before next iteration + del realized_value - except SkipParameters: - continue + except SkipParameters: + continue + finally: + progress_bar.update() # Close the pool, independently of whether the code was interrupted or finished successfully finally: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index bc8f65891445..35bccebc9286 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4677,8 +4677,18 @@ def test_bc_torch_dtype(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) - # Check that it works for all dtypes - for dtype in ["float16", "bfloat16", "float32", "auto", torch.float16, torch.bfloat16, torch.float32]: + # Check a random-looking but reproducible subset of dtypes per model class. + supported_dtypes = [ + "float16", + "bfloat16", + "float32", + "auto", + torch.float16, + torch.bfloat16, + torch.float32, + ] + dtype_rng = random.Random(f"test_bc_torch_dtype:{model_class.__name__}") + for dtype in dtype_rng.sample(supported_dtypes, 3): model_torch_dtype = model_class.from_pretrained(tmpdirname, torch_dtype=dtype) model_dtype = model_class.from_pretrained(tmpdirname, dtype=dtype) diff --git a/tests/utils/test_core_model_loading.py b/tests/utils/test_core_model_loading.py index a358822d19f8..11780aeb2ab2 100644 --- a/tests/utils/test_core_model_loading.py +++ b/tests/utils/test_core_model_loading.py @@ -220,6 +220,48 @@ def __init__(self, add_extra_moe=False, with_mlp=True): class TestConvertAndLoadStateDict(unittest.TestCase): + def test_direct_and_renamed_weights_load_without_conversion_wrappers(self): + model = DummyRoot() + model.config = PretrainedConfig() + + state_dict = { + "model.layers.0.self_attn.q_proj.weight": torch.tensor([[1.0, 2.0]]), + "model.layers.1.self_attn.q_proj.weight": torch.tensor([[3.0, 4.0]]), + "mlp.w2.weight": torch.tensor([[5.0, 6.0], [7.0, 8.0]]), + } + loading_info, _ = convert_and_load_state_dict_in_model( + model, + state_dict, + LoadStateDictConfig(weight_mapping=[WeightRenaming("mlp.w2.weight", "mlp.down_proj.weight")]), + tp_plan=None, + ) + + self.assertEqual( + loading_info.missing_keys, + { + "model.layers.0.experts.down_proj.weight", + "model.layers.0.experts.gate_up_proj.weight", + "model.layers.0.self_attn.k_proj.weight", + "model.layers.0.self_attn.v_proj.weight", + "model.layers.1.experts.down_proj.weight", + "model.layers.1.experts.gate_up_proj.weight", + "model.layers.1.self_attn.k_proj.weight", + "model.layers.1.self_attn.v_proj.weight", + }, + ) + self.assertEqual(loading_info.unexpected_keys, set()) + self.assertEqual(loading_info.mismatched_keys, set()) + self.assertEqual(loading_info.conversion_errors, {}) + + model_state = model.state_dict() + torch.testing.assert_close( + model_state["model.layers.0.self_attn.q_proj.weight"], state_dict["model.layers.0.self_attn.q_proj.weight"] + ) + torch.testing.assert_close( + model_state["model.layers.1.self_attn.q_proj.weight"], state_dict["model.layers.1.self_attn.q_proj.weight"] + ) + torch.testing.assert_close(model_state["mlp.down_proj.weight"], state_dict["mlp.w2.weight"]) + def test_moe_and_qkv_conversion(self): model = DummyRoot() model.config = PretrainedConfig()