Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 188 additions & 57 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,8 @@ class WeightTransform:
__slots__ = (
"source_patterns",
"target_patterns",
"compiled_sources",
"_source_regex_str",
"_compiled_sources",
"distributed_operation",
"quantization_operation",
"collected_tensors",
Expand Down Expand Up @@ -647,13 +648,40 @@ def __init__(self, source_patterns: str | list[str], target_patterns: str | list
pattern = process_source_pattern(pattern, self._original_target_patterns[i])
self.source_patterns[i] = pattern

# Construct the regex we will use to rename keys from the sources to the targets
# Build the regex source string, but compile lazily via `compiled_sources` below. During loading, any key
# that does not match a weight conversion op gets wrapped in a fresh per-weight `WeightRenaming` for
# convenience so it can reuse the same conversion/loading path. Those fallback wrappers never need to call
# `rename_source_key`, so eagerly compiling their regex would just waste work — and it dominates
# `from_pretrained` for models with many parameters.
branches = []
for i, source_pattern in enumerate(self.source_patterns):
group_name = f"g{i}"
pattern = source_pattern.replace(".*.", r"\..*\.")
branches.append(f"(?P<{group_name}>{pattern})")
self.compiled_sources = re.compile("|".join(branches))
self._source_regex_str = "|".join(branches)
self._compiled_sources = None
Comment thread
tarekziade marked this conversation as resolved.

@property
def compiled_sources(self) -> re.Pattern:
if self._compiled_sources is None:
self._compiled_sources = re.compile(self._source_regex_str)
return self._compiled_sources

def __deepcopy__(self, memo):
# A fresh-per-target copy is needed because `collected_tensors`, `layer_targets`, and `_was_used` accumulate
# state during loading. The compiled regex is stateless, so we share it across copies — avoiding a hidden
# `re.compile` that would otherwise run on every per-weight pickle/unpickle round-trip.
cls = self.__class__
new = cls.__new__(cls)
memo[id(self)] = new
for slot in chain.from_iterable(getattr(c, "__slots__", ()) for c in cls.__mro__):
if not hasattr(self, slot):
continue
if slot == "_compiled_sources":
new._compiled_sources = self._compiled_sources
else:
object.__setattr__(new, slot, deepcopy(getattr(self, slot), memo))
return new

def __repr__(self):
return f"{self.__class__.__name__}(source_patterns={self.source_patterns}, target_patterns={self.target_patterns})"
Expand Down Expand Up @@ -735,14 +763,13 @@ def materialize_tensors(self) -> dict[str, list[torch.Tensor]]:
for key in list(self.collected_tensors.keys()):
# Remove from internal attribute
tensors = self.collected_tensors.pop(key)
# Async loading
if isinstance(tensors[0], Future):
tensors = [future.result() for future in tensors if future.result() is not None]
# Sync loading
elif callable(tensors[0]):
tensors = [func() for func in tensors]
resolved_tensors = []
for tensor_or_future in tensors:
resolved_tensor = _resolve_pending_tensor(tensor_or_future)
if resolved_tensor is not None:
resolved_tensors.append(resolved_tensor)
# Add them to the new dictionary
collected_tensors[key] = tensors
collected_tensors[key] = resolved_tensors

return collected_tensors

Expand Down Expand Up @@ -933,6 +960,15 @@ def convert(
GLOBAL_WORKERS = min(4, os.cpu_count() or 4)


def _resolve_pending_tensor(tensor_or_future: Future | Callable | torch.Tensor) -> torch.Tensor | None:
if isinstance(tensor_or_future, Future):
return tensor_or_future.result()
elif callable(tensor_or_future):
return tensor_or_future()
else:
return tensor_or_future
Comment on lines +963 to +969
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not needed to have an outer function here

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's used in two spots that's why I vectorized it here



def _materialize_copy(tensor: torch.Tensor, device=None, dtype=None) -> torch.Tensor:
# This slicing is what actually loads the tensor from the safetensors slice object
tensor = tensor[...]
Expand Down Expand Up @@ -1048,9 +1084,16 @@ def set_param_for_module(
loading_info: LoadStateDictInfo,
distributed_operation: TensorParallelLayer | None,
hf_quantizer: HfQuantizer,
module_cache: dict[str, torch.nn.Module] | None = None,
):
module_path, _, param_name = target_name.rpartition(".")
module_obj = model.get_submodule(module_path) if module_path else model
if module_cache is not None:
module_obj = module_cache.get(module_path)
if module_obj is None:
module_obj = model.get_submodule(module_path) if module_path else model
module_cache[module_path] = module_obj
else:
module_obj = model.get_submodule(module_path) if module_path else model
Comment thread
tarekziade marked this conversation as resolved.

if param_name == torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX:
module_obj.set_extra_state(param_value)
Expand Down Expand Up @@ -1088,7 +1131,7 @@ def offload_and_maybe_resave_param(
target_name: str,
param: torch.Tensor,
loading_info: LoadStateDictInfo,
disk_offload_folder: str,
disk_offload_folder: str | None,
disk_offload_index: dict,
applied_ops: WeightConverter | WeightRenaming,
) -> dict:
Expand Down Expand Up @@ -1136,18 +1179,58 @@ def rename_source_key(
break

# 3. check if we need to add or remove prefix if necessary (only during loading, not saving)
if prefix is not None and meta_state_dict is not None:
if (
renamed_key.startswith(prefix)
and meta_state_dict.get(re.sub(f"^{prefix}.", "", renamed_key, count=1)) is not None
):
renamed_key = re.sub(f"^{prefix}.", "", renamed_key, count=1)
elif meta_state_dict.get(f"{prefix}.{renamed_key}") is not None:
renamed_key = f"{prefix}.{renamed_key}"
if prefix not in (None, "") and meta_state_dict is not None:
prefixed_key = f"{prefix}.{renamed_key}"
prefix_with_separator = f"{prefix}."
if renamed_key.startswith(prefix_with_separator):
unprefixed_key = renamed_key[len(prefix_with_separator) :]
if meta_state_dict.get(unprefixed_key) is not None:
renamed_key = unprefixed_key
elif meta_state_dict.get(prefixed_key) is not None:
renamed_key = prefixed_key

return renamed_key, source_pattern


def _assign_or_offload_param(
model: PreTrainedModel,
target_name: str,
param: torch.Tensor,
loading_info: LoadStateDictInfo,
device_map: dict | None,
model_buffers: set[str],
offload_buffers: bool,
disk_offload_folder: str | None,
disk_offload_index: dict | None,
distributed_operation: TensorParallelLayer | None,
hf_quantizer: HfQuantizer | None,
module_cache: dict[str, torch.nn.Module],
applied_ops: WeightConverter | WeightRenaming | None = None,
) -> dict | None:
param_device = get_device(device_map, target_name)
if param_device == "disk" and (target_name not in model_buffers or offload_buffers):
current_disk_offload_index = {} if disk_offload_index is None else disk_offload_index
if applied_ops is None:
loading_info.missing_keys.discard(target_name)
if target_name not in current_disk_offload_index:
return offload_weight(param, target_name, disk_offload_folder, current_disk_offload_index)
else:
return offload_and_maybe_resave_param(
target_name, param, loading_info, disk_offload_folder, current_disk_offload_index, applied_ops
)
else:
set_param_for_module(
model,
target_name,
param,
loading_info,
distributed_operation,
hf_quantizer,
module_cache=module_cache,
)
return disk_offload_index


def convert_and_load_state_dict_in_model(
model: PreTrainedModel,
state_dict: dict[str, Any],
Expand Down Expand Up @@ -1279,7 +1362,9 @@ def convert_and_load_state_dict_in_model(

renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)]
converters = [entry for entry in weight_mapping if isinstance(entry, WeightConverter)]
direct_param_loads: list[tuple[str, Future | Callable | torch.Tensor, TensorParallelLayer | None]] = []
param_name_to_load: dict[str, WeightRenaming | WeightConverter] = {}
module_cache: dict[str, torch.nn.Module] = {"": model}

# build '(?P<g0>.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'}
# and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched.
Expand All @@ -1303,22 +1388,28 @@ def convert_and_load_state_dict_in_model(
# 2. finally, collect the tensor into the proper converter
if renamed_key in meta_model_state_dict:
empty_param = meta_model_state_dict.get(renamed_key)
# If we enter here, we have a WeightConverter operation to perform
if source_pattern is not None:
new_converter = deepcopy(pattern_to_converter[source_pattern])
# each target key gets its own converter instance
mapping = param_name_to_load.setdefault(renamed_key, new_converter)
# Otherwise, only potential renaming
else:
mapping = param_name_to_load.setdefault(renamed_key, WeightRenaming(original_key, renamed_key))
source_pattern = original_key

# 3. Handle dtype casting
needs_quantization = (
hf_quantizer
and not hf_quantizer.pre_quantized
and hf_quantizer.param_needs_quantization(model, renamed_key)
)
mapping = None
if source_pattern is not None:
# each target key gets its own converter instance (deepcopy is lazy: skipped if target already seen,
# e.g. many-to-one/one-to-many converters where several sources land on the same target)
mapping = param_name_to_load.get(renamed_key)
if mapping is None:
mapping = deepcopy(pattern_to_converter[source_pattern])
param_name_to_load[renamed_key] = mapping
elif needs_quantization:
mapping = param_name_to_load.get(renamed_key)
if mapping is None:
mapping = WeightRenaming(original_key, renamed_key)
param_name_to_load[renamed_key] = mapping
source_pattern = original_key

if needs_quantization:
mapping.quantization_operation = hf_quantizer.get_quantize_ops()

Expand Down Expand Up @@ -1348,14 +1439,24 @@ def convert_and_load_state_dict_in_model(

# 4. Handle TP sharding or device_map placement
future_or_tensor = None
distributed_operation = None
if device_mesh and tp_plan:
if matched_tp_pattern := tp_plan_alt.search(renamed_key):
matched_tp_pattern = tp_plan_by_group_name[matched_tp_pattern.lastgroup]
if getattr(mapping, "distributed_operation", None) is None:
tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__
mapping.distributed_operation = tp_layer(
tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__
if mapping is None:
distributed_operation = tp_layer(
device_mesh=device_mesh, rank=device_mesh.get_local_rank(), empty_param=empty_param.clone()
)
else:
distributed_operation = getattr(mapping, "distributed_operation", None)
if distributed_operation is None:
distributed_operation = tp_layer(
device_mesh=device_mesh,
rank=device_mesh.get_local_rank(),
empty_param=empty_param.clone(),
)
mapping.distributed_operation = distributed_operation
shard_index = (
len(mapping.collected_tensors.get(source_pattern, []))
if isinstance(mapping, WeightConverter) and isinstance(mapping.operations[0], MergeModulelist)
Expand All @@ -1364,7 +1465,7 @@ def convert_and_load_state_dict_in_model(
future_or_tensor = spawn_tp_materialize(
thread_pool,
tensor,
mapping.distributed_operation,
distributed_operation,
shard_index,
device_map[""],
_dtype,
Expand All @@ -1374,7 +1475,12 @@ def convert_and_load_state_dict_in_model(
param_device = get_device(device_map, renamed_key, valid_torch_device=True)
future_or_tensor = spawn_materialize(thread_pool, tensor, param_device, _dtype)

mapping.add_tensor(renamed_key, original_key, source_pattern, future_or_tensor)
if mapping is None:
# Fast path for untouched or purely renamed parameters: avoid instantiating a per-weight
# `WeightRenaming` wrapper when we can load the tensor directly.
direct_param_loads.append((renamed_key, future_or_tensor, distributed_operation))
else:
mapping.add_tensor(renamed_key, original_key, source_pattern, future_or_tensor)
elif source_pattern is not None: # add all target keys as unexpected
mapping = pattern_to_converter[source_pattern]
for k in mapping.target_patterns:
Expand All @@ -1383,38 +1489,63 @@ def convert_and_load_state_dict_in_model(
loading_info.unexpected_keys.add(renamed_key)

try:
for first_param_name, mapping in tqdm(param_name_to_load.items(), desc="Loading weights"):
try:
realized_value = mapping.convert(
first_param_name,
model=model,
config=model.config,
hf_quantizer=hf_quantizer,
loading_info=loading_info,
)
for target_name, param in realized_value.items():
param = param[0] if isinstance(param, list) else param
param_device = get_device(device_map, target_name)
# Offloading support
if param_device == "disk" and (target_name not in model_buffers or offload_buffers):
disk_offload_index = offload_and_maybe_resave_param(
target_name, param, loading_info, disk_offload_folder, disk_offload_index, mapping
)
else:
set_param_for_module(
with tqdm(total=len(direct_param_loads) + len(param_name_to_load), desc="Loading weights") as progress_bar:
for target_name, pending_param, distributed_operation in direct_param_loads:
try:
param = _resolve_pending_tensor(pending_param)
if param is None:
continue
disk_offload_index = _assign_or_offload_param(
model,
target_name,
param,
loading_info,
device_map,
model_buffers,
offload_buffers,
disk_offload_folder,
disk_offload_index,
distributed_operation,
hf_quantizer,
module_cache,
)
finally:
progress_bar.update()

for first_param_name, mapping in param_name_to_load.items():
try:
realized_value = mapping.convert(
first_param_name,
model=model,
config=model.config,
hf_quantizer=hf_quantizer,
loading_info=loading_info,
)
for target_name, param in realized_value.items():
param = param[0] if isinstance(param, list) else param
disk_offload_index = _assign_or_offload_param(
model,
target_name,
param,
loading_info,
device_map,
model_buffers,
offload_buffers,
disk_offload_folder,
disk_offload_index,
mapping.distributed_operation,
hf_quantizer,
module_cache,
mapping,
)

# Cleanup all the tensors that were gathered before next iteration
del realized_value
# Cleanup all the tensors that were gathered before next iteration
del realized_value

except SkipParameters:
continue
except SkipParameters:
continue
finally:
progress_bar.update()

# Close the pool, independently of whether the code was interrupted or finished successfully
finally:
Expand Down
14 changes: 12 additions & 2 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4677,8 +4677,18 @@ def test_bc_torch_dtype(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)

# Check that it works for all dtypes
for dtype in ["float16", "bfloat16", "float32", "auto", torch.float16, torch.bfloat16, torch.float32]:
# Check a random-looking but reproducible subset of dtypes per model class.
supported_dtypes = [
"float16",
"bfloat16",
"float32",
"auto",
torch.float16,
torch.bfloat16,
torch.float32,
]
dtype_rng = random.Random(f"test_bc_torch_dtype:{model_class.__name__}")
for dtype in dtype_rng.sample(supported_dtypes, 3):
model_torch_dtype = model_class.from_pretrained(tmpdirname, torch_dtype=dtype)
model_dtype = model_class.from_pretrained(tmpdirname, dtype=dtype)

Expand Down
Loading
Loading