-
Notifications
You must be signed in to change notification settings - Fork 33.1k
qa: speed up dtype regex weight load + reduce dtype tests to 3 random #45635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
tarekziade
wants to merge
3
commits into
main
Choose a base branch
from
tarek-loadweight-hotspot
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+242
−59
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -583,7 +583,8 @@ class WeightTransform: | |
| __slots__ = ( | ||
| "source_patterns", | ||
| "target_patterns", | ||
| "compiled_sources", | ||
| "_source_regex_str", | ||
| "_compiled_sources", | ||
| "distributed_operation", | ||
| "quantization_operation", | ||
| "collected_tensors", | ||
|
|
@@ -647,13 +648,40 @@ def __init__(self, source_patterns: str | list[str], target_patterns: str | list | |
| pattern = process_source_pattern(pattern, self._original_target_patterns[i]) | ||
| self.source_patterns[i] = pattern | ||
|
|
||
| # Construct the regex we will use to rename keys from the sources to the targets | ||
| # Build the regex source string, but compile lazily via `compiled_sources` below. During loading, any key | ||
| # that does not match a weight conversion op gets wrapped in a fresh per-weight `WeightRenaming` for | ||
| # convenience so it can reuse the same conversion/loading path. Those fallback wrappers never need to call | ||
| # `rename_source_key`, so eagerly compiling their regex would just waste work — and it dominates | ||
| # `from_pretrained` for models with many parameters. | ||
| branches = [] | ||
| for i, source_pattern in enumerate(self.source_patterns): | ||
| group_name = f"g{i}" | ||
| pattern = source_pattern.replace(".*.", r"\..*\.") | ||
| branches.append(f"(?P<{group_name}>{pattern})") | ||
| self.compiled_sources = re.compile("|".join(branches)) | ||
| self._source_regex_str = "|".join(branches) | ||
| self._compiled_sources = None | ||
|
|
||
| @property | ||
| def compiled_sources(self) -> re.Pattern: | ||
| if self._compiled_sources is None: | ||
| self._compiled_sources = re.compile(self._source_regex_str) | ||
| return self._compiled_sources | ||
|
|
||
| def __deepcopy__(self, memo): | ||
| # A fresh-per-target copy is needed because `collected_tensors`, `layer_targets`, and `_was_used` accumulate | ||
| # state during loading. The compiled regex is stateless, so we share it across copies — avoiding a hidden | ||
| # `re.compile` that would otherwise run on every per-weight pickle/unpickle round-trip. | ||
| cls = self.__class__ | ||
| new = cls.__new__(cls) | ||
| memo[id(self)] = new | ||
| for slot in chain.from_iterable(getattr(c, "__slots__", ()) for c in cls.__mro__): | ||
| if not hasattr(self, slot): | ||
| continue | ||
| if slot == "_compiled_sources": | ||
| new._compiled_sources = self._compiled_sources | ||
| else: | ||
| object.__setattr__(new, slot, deepcopy(getattr(self, slot), memo)) | ||
| return new | ||
|
|
||
| def __repr__(self): | ||
| return f"{self.__class__.__name__}(source_patterns={self.source_patterns}, target_patterns={self.target_patterns})" | ||
|
|
@@ -735,14 +763,13 @@ def materialize_tensors(self) -> dict[str, list[torch.Tensor]]: | |
| for key in list(self.collected_tensors.keys()): | ||
| # Remove from internal attribute | ||
| tensors = self.collected_tensors.pop(key) | ||
| # Async loading | ||
| if isinstance(tensors[0], Future): | ||
| tensors = [future.result() for future in tensors if future.result() is not None] | ||
| # Sync loading | ||
| elif callable(tensors[0]): | ||
| tensors = [func() for func in tensors] | ||
| resolved_tensors = [] | ||
| for tensor_or_future in tensors: | ||
| resolved_tensor = _resolve_pending_tensor(tensor_or_future) | ||
| if resolved_tensor is not None: | ||
| resolved_tensors.append(resolved_tensor) | ||
| # Add them to the new dictionary | ||
| collected_tensors[key] = tensors | ||
| collected_tensors[key] = resolved_tensors | ||
|
|
||
| return collected_tensors | ||
|
|
||
|
|
@@ -933,6 +960,15 @@ def convert( | |
| GLOBAL_WORKERS = min(4, os.cpu_count() or 4) | ||
|
|
||
|
|
||
| def _resolve_pending_tensor(tensor_or_future: Future | Callable | torch.Tensor) -> torch.Tensor | None: | ||
| if isinstance(tensor_or_future, Future): | ||
| return tensor_or_future.result() | ||
| elif callable(tensor_or_future): | ||
| return tensor_or_future() | ||
| else: | ||
| return tensor_or_future | ||
|
Comment on lines
+963
to
+969
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably not needed to have an outer function here
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's used in two spots that's why I vectorized it here |
||
|
|
||
|
|
||
| def _materialize_copy(tensor: torch.Tensor, device=None, dtype=None) -> torch.Tensor: | ||
| # This slicing is what actually loads the tensor from the safetensors slice object | ||
| tensor = tensor[...] | ||
|
|
@@ -1048,9 +1084,16 @@ def set_param_for_module( | |
| loading_info: LoadStateDictInfo, | ||
| distributed_operation: TensorParallelLayer | None, | ||
| hf_quantizer: HfQuantizer, | ||
| module_cache: dict[str, torch.nn.Module] | None = None, | ||
| ): | ||
| module_path, _, param_name = target_name.rpartition(".") | ||
| module_obj = model.get_submodule(module_path) if module_path else model | ||
| if module_cache is not None: | ||
| module_obj = module_cache.get(module_path) | ||
| if module_obj is None: | ||
| module_obj = model.get_submodule(module_path) if module_path else model | ||
| module_cache[module_path] = module_obj | ||
| else: | ||
| module_obj = model.get_submodule(module_path) if module_path else model | ||
|
tarekziade marked this conversation as resolved.
|
||
|
|
||
| if param_name == torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX: | ||
| module_obj.set_extra_state(param_value) | ||
|
|
@@ -1088,7 +1131,7 @@ def offload_and_maybe_resave_param( | |
| target_name: str, | ||
| param: torch.Tensor, | ||
| loading_info: LoadStateDictInfo, | ||
| disk_offload_folder: str, | ||
| disk_offload_folder: str | None, | ||
| disk_offload_index: dict, | ||
| applied_ops: WeightConverter | WeightRenaming, | ||
| ) -> dict: | ||
|
|
@@ -1136,18 +1179,58 @@ def rename_source_key( | |
| break | ||
|
|
||
| # 3. check if we need to add or remove prefix if necessary (only during loading, not saving) | ||
| if prefix is not None and meta_state_dict is not None: | ||
| if ( | ||
| renamed_key.startswith(prefix) | ||
| and meta_state_dict.get(re.sub(f"^{prefix}.", "", renamed_key, count=1)) is not None | ||
| ): | ||
| renamed_key = re.sub(f"^{prefix}.", "", renamed_key, count=1) | ||
| elif meta_state_dict.get(f"{prefix}.{renamed_key}") is not None: | ||
| renamed_key = f"{prefix}.{renamed_key}" | ||
| if prefix not in (None, "") and meta_state_dict is not None: | ||
| prefixed_key = f"{prefix}.{renamed_key}" | ||
| prefix_with_separator = f"{prefix}." | ||
| if renamed_key.startswith(prefix_with_separator): | ||
| unprefixed_key = renamed_key[len(prefix_with_separator) :] | ||
| if meta_state_dict.get(unprefixed_key) is not None: | ||
| renamed_key = unprefixed_key | ||
| elif meta_state_dict.get(prefixed_key) is not None: | ||
| renamed_key = prefixed_key | ||
|
|
||
| return renamed_key, source_pattern | ||
|
|
||
|
|
||
| def _assign_or_offload_param( | ||
| model: PreTrainedModel, | ||
| target_name: str, | ||
| param: torch.Tensor, | ||
| loading_info: LoadStateDictInfo, | ||
| device_map: dict | None, | ||
| model_buffers: set[str], | ||
| offload_buffers: bool, | ||
| disk_offload_folder: str | None, | ||
| disk_offload_index: dict | None, | ||
| distributed_operation: TensorParallelLayer | None, | ||
| hf_quantizer: HfQuantizer | None, | ||
| module_cache: dict[str, torch.nn.Module], | ||
| applied_ops: WeightConverter | WeightRenaming | None = None, | ||
| ) -> dict | None: | ||
| param_device = get_device(device_map, target_name) | ||
| if param_device == "disk" and (target_name not in model_buffers or offload_buffers): | ||
| current_disk_offload_index = {} if disk_offload_index is None else disk_offload_index | ||
| if applied_ops is None: | ||
| loading_info.missing_keys.discard(target_name) | ||
| if target_name not in current_disk_offload_index: | ||
| return offload_weight(param, target_name, disk_offload_folder, current_disk_offload_index) | ||
| else: | ||
| return offload_and_maybe_resave_param( | ||
| target_name, param, loading_info, disk_offload_folder, current_disk_offload_index, applied_ops | ||
| ) | ||
| else: | ||
| set_param_for_module( | ||
| model, | ||
| target_name, | ||
| param, | ||
| loading_info, | ||
| distributed_operation, | ||
| hf_quantizer, | ||
| module_cache=module_cache, | ||
| ) | ||
| return disk_offload_index | ||
|
|
||
|
|
||
| def convert_and_load_state_dict_in_model( | ||
| model: PreTrainedModel, | ||
| state_dict: dict[str, Any], | ||
|
|
@@ -1279,7 +1362,9 @@ def convert_and_load_state_dict_in_model( | |
|
|
||
| renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)] | ||
| converters = [entry for entry in weight_mapping if isinstance(entry, WeightConverter)] | ||
| direct_param_loads: list[tuple[str, Future | Callable | torch.Tensor, TensorParallelLayer | None]] = [] | ||
| param_name_to_load: dict[str, WeightRenaming | WeightConverter] = {} | ||
| module_cache: dict[str, torch.nn.Module] = {"": model} | ||
|
|
||
| # build '(?P<g0>.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'} | ||
| # and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched. | ||
|
|
@@ -1303,22 +1388,28 @@ def convert_and_load_state_dict_in_model( | |
| # 2. finally, collect the tensor into the proper converter | ||
| if renamed_key in meta_model_state_dict: | ||
| empty_param = meta_model_state_dict.get(renamed_key) | ||
| # If we enter here, we have a WeightConverter operation to perform | ||
| if source_pattern is not None: | ||
| new_converter = deepcopy(pattern_to_converter[source_pattern]) | ||
| # each target key gets its own converter instance | ||
| mapping = param_name_to_load.setdefault(renamed_key, new_converter) | ||
| # Otherwise, only potential renaming | ||
| else: | ||
| mapping = param_name_to_load.setdefault(renamed_key, WeightRenaming(original_key, renamed_key)) | ||
| source_pattern = original_key | ||
|
|
||
| # 3. Handle dtype casting | ||
| needs_quantization = ( | ||
| hf_quantizer | ||
| and not hf_quantizer.pre_quantized | ||
| and hf_quantizer.param_needs_quantization(model, renamed_key) | ||
| ) | ||
| mapping = None | ||
| if source_pattern is not None: | ||
| # each target key gets its own converter instance (deepcopy is lazy: skipped if target already seen, | ||
| # e.g. many-to-one/one-to-many converters where several sources land on the same target) | ||
| mapping = param_name_to_load.get(renamed_key) | ||
| if mapping is None: | ||
| mapping = deepcopy(pattern_to_converter[source_pattern]) | ||
| param_name_to_load[renamed_key] = mapping | ||
| elif needs_quantization: | ||
| mapping = param_name_to_load.get(renamed_key) | ||
| if mapping is None: | ||
| mapping = WeightRenaming(original_key, renamed_key) | ||
| param_name_to_load[renamed_key] = mapping | ||
| source_pattern = original_key | ||
|
|
||
| if needs_quantization: | ||
| mapping.quantization_operation = hf_quantizer.get_quantize_ops() | ||
|
|
||
|
|
@@ -1348,14 +1439,24 @@ def convert_and_load_state_dict_in_model( | |
|
|
||
| # 4. Handle TP sharding or device_map placement | ||
| future_or_tensor = None | ||
| distributed_operation = None | ||
| if device_mesh and tp_plan: | ||
| if matched_tp_pattern := tp_plan_alt.search(renamed_key): | ||
| matched_tp_pattern = tp_plan_by_group_name[matched_tp_pattern.lastgroup] | ||
| if getattr(mapping, "distributed_operation", None) is None: | ||
| tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__ | ||
| mapping.distributed_operation = tp_layer( | ||
| tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__ | ||
| if mapping is None: | ||
| distributed_operation = tp_layer( | ||
| device_mesh=device_mesh, rank=device_mesh.get_local_rank(), empty_param=empty_param.clone() | ||
| ) | ||
| else: | ||
| distributed_operation = getattr(mapping, "distributed_operation", None) | ||
| if distributed_operation is None: | ||
| distributed_operation = tp_layer( | ||
| device_mesh=device_mesh, | ||
| rank=device_mesh.get_local_rank(), | ||
| empty_param=empty_param.clone(), | ||
| ) | ||
| mapping.distributed_operation = distributed_operation | ||
| shard_index = ( | ||
| len(mapping.collected_tensors.get(source_pattern, [])) | ||
| if isinstance(mapping, WeightConverter) and isinstance(mapping.operations[0], MergeModulelist) | ||
|
|
@@ -1364,7 +1465,7 @@ def convert_and_load_state_dict_in_model( | |
| future_or_tensor = spawn_tp_materialize( | ||
| thread_pool, | ||
| tensor, | ||
| mapping.distributed_operation, | ||
| distributed_operation, | ||
| shard_index, | ||
| device_map[""], | ||
| _dtype, | ||
|
|
@@ -1374,7 +1475,12 @@ def convert_and_load_state_dict_in_model( | |
| param_device = get_device(device_map, renamed_key, valid_torch_device=True) | ||
| future_or_tensor = spawn_materialize(thread_pool, tensor, param_device, _dtype) | ||
|
|
||
| mapping.add_tensor(renamed_key, original_key, source_pattern, future_or_tensor) | ||
| if mapping is None: | ||
| # Fast path for untouched or purely renamed parameters: avoid instantiating a per-weight | ||
| # `WeightRenaming` wrapper when we can load the tensor directly. | ||
| direct_param_loads.append((renamed_key, future_or_tensor, distributed_operation)) | ||
| else: | ||
| mapping.add_tensor(renamed_key, original_key, source_pattern, future_or_tensor) | ||
| elif source_pattern is not None: # add all target keys as unexpected | ||
| mapping = pattern_to_converter[source_pattern] | ||
| for k in mapping.target_patterns: | ||
|
|
@@ -1383,38 +1489,63 @@ def convert_and_load_state_dict_in_model( | |
| loading_info.unexpected_keys.add(renamed_key) | ||
|
|
||
| try: | ||
| for first_param_name, mapping in tqdm(param_name_to_load.items(), desc="Loading weights"): | ||
| try: | ||
| realized_value = mapping.convert( | ||
| first_param_name, | ||
| model=model, | ||
| config=model.config, | ||
| hf_quantizer=hf_quantizer, | ||
| loading_info=loading_info, | ||
| ) | ||
| for target_name, param in realized_value.items(): | ||
| param = param[0] if isinstance(param, list) else param | ||
| param_device = get_device(device_map, target_name) | ||
| # Offloading support | ||
| if param_device == "disk" and (target_name not in model_buffers or offload_buffers): | ||
| disk_offload_index = offload_and_maybe_resave_param( | ||
| target_name, param, loading_info, disk_offload_folder, disk_offload_index, mapping | ||
| ) | ||
| else: | ||
| set_param_for_module( | ||
| with tqdm(total=len(direct_param_loads) + len(param_name_to_load), desc="Loading weights") as progress_bar: | ||
| for target_name, pending_param, distributed_operation in direct_param_loads: | ||
| try: | ||
| param = _resolve_pending_tensor(pending_param) | ||
| if param is None: | ||
| continue | ||
| disk_offload_index = _assign_or_offload_param( | ||
| model, | ||
| target_name, | ||
| param, | ||
| loading_info, | ||
| device_map, | ||
| model_buffers, | ||
| offload_buffers, | ||
| disk_offload_folder, | ||
| disk_offload_index, | ||
| distributed_operation, | ||
| hf_quantizer, | ||
| module_cache, | ||
| ) | ||
| finally: | ||
| progress_bar.update() | ||
|
|
||
| for first_param_name, mapping in param_name_to_load.items(): | ||
| try: | ||
| realized_value = mapping.convert( | ||
| first_param_name, | ||
| model=model, | ||
| config=model.config, | ||
| hf_quantizer=hf_quantizer, | ||
| loading_info=loading_info, | ||
| ) | ||
| for target_name, param in realized_value.items(): | ||
| param = param[0] if isinstance(param, list) else param | ||
| disk_offload_index = _assign_or_offload_param( | ||
| model, | ||
| target_name, | ||
| param, | ||
| loading_info, | ||
| device_map, | ||
| model_buffers, | ||
| offload_buffers, | ||
| disk_offload_folder, | ||
| disk_offload_index, | ||
| mapping.distributed_operation, | ||
| hf_quantizer, | ||
| module_cache, | ||
| mapping, | ||
| ) | ||
|
|
||
| # Cleanup all the tensors that were gathered before next iteration | ||
| del realized_value | ||
| # Cleanup all the tensors that were gathered before next iteration | ||
| del realized_value | ||
|
|
||
| except SkipParameters: | ||
| continue | ||
| except SkipParameters: | ||
| continue | ||
| finally: | ||
| progress_bar.update() | ||
|
|
||
| # Close the pool, independently of whether the code was interrupted or finished successfully | ||
| finally: | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.