diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 336db3773f76..d274b31837e2 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -161,9 +161,12 @@ ] _import_structure["tensor_parallel"] = [ - "shard_and_distribute_module", - "ALL_PARALLEL_STYLES", - "translate_to_torch_parallel_style", + "TPStyle", + "apply_tensor_parallel", + "convert_strided_to_shard", + "gather_full_state_dict", + "restore_strided_from_shard", + "verify_tp_plan", ] try: if not is_torch_greater_or_equal("2.5"): @@ -295,6 +298,14 @@ from .quanto import replace_with_quanto_layers from .sinq import SinqDeserialize, SinqQuantize from .spqr import replace_with_spqr_linear + from .tensor_parallel import ( + TPStyle, + apply_tensor_parallel, + convert_strided_to_shard, + gather_full_state_dict, + restore_strided_from_shard, + verify_tp_plan, + ) from .vptq import replace_with_vptq_linear try: @@ -305,12 +316,6 @@ else: from .executorch import TorchExportableModuleWithStaticCache, convert_and_export_with_cache - from .tensor_parallel import ( - ALL_PARALLEL_STYLES, - shard_and_distribute_module, - translate_to_torch_parallel_style, - ) - try: if not is_torch_greater_or_equal("2.5"): raise OptionalDependencyNotAvailable() diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 082b294fb41f..1629c4ca4d9b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -38,6 +38,7 @@ from safetensors import safe_open from safetensors.torch import save_file as safe_save_file from torch import Tensor, nn +from torch.distributed.tensor import DTensor from torch.distributions import constraints from torch.utils.checkpoint import checkpoint @@ -4123,15 +4124,13 @@ def from_pretrained( if distributed_config is not None: model.config.distributed_config = distributed_config model.device_mesh = device_mesh - - def sub_mesh(name): - return device_mesh[name] if device_mesh.ndim > 1 else device_mesh - mesh_dim_names = device_mesh.mesh_dim_names or () if "tp" in mesh_dim_names: - model = apply_tensor_parallel(model, sub_mesh("tp"), distributed_config.tp_plan) + tp_mesh = device_mesh["tp"] if device_mesh.ndim > 1 else device_mesh + model = apply_tensor_parallel(model, tp_mesh, distributed_config.tp_plan) if "fsdp" in mesh_dim_names: - model = apply_fully_shard_data_parallel(model, sub_mesh("fsdp"), distributed_config.fsdp_plan) + fsdp_mesh = device_mesh["fsdp"] if device_mesh.ndim > 1 else device_mesh + model = apply_fully_shard_data_parallel(model, fsdp_mesh, distributed_config.fsdp_plan) else: # Accelerate path: auto device mapping if device_map is not None: @@ -4552,8 +4551,6 @@ def _move_missing_keys_from_meta_to_device( # will be re-initialized for nothing (which can be quite long) for key in missing_keys - self.all_tied_weights_keys.keys(): param = self.get_parameter_or_buffer(key) - from torch.distributed.tensor import DTensor - if isinstance(param, DTensor): # DTensor from parallelize_module on meta — materialize on actual device local_value = torch.empty( diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 1b8dacb632cc..57042d1aad92 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2392,9 +2392,10 @@ def get_cp_size(self) -> int: def get_tp_size(self) -> int: """Get the tensor parallel size from either the model or DeepSpeed config.""" - # 1. Check model.tp_size first - if (model_tp := getattr(self.model, "_tp_size", None)) is not None: - return model_tp + # TODO: adapt it cleaner with distributed config once distributed api is stable + dc = getattr(getattr(self.model, "config", None), "distributed_config", None) + if dc is not None and dc.tp_size is not None: + return dc.tp_size # 2. Fall back to DeepSpeed config if enabled if self.is_deepspeed_enabled and (deepspeed_config := getattr(self.args, "hf_deepspeed_config", None)): diff --git a/tests/test_fsdp_mixin.py b/tests/test_fsdp_mixin.py index d7a0a4ca3340..f6f7f7e6e2ab 100644 --- a/tests/test_fsdp_mixin.py +++ b/tests/test_fsdp_mixin.py @@ -49,9 +49,10 @@ from torch.distributed.tensor import DTensor from torch.nn.parallel import DistributedDataParallel as DDP + from transformers.distributed import DistributedConfig from transformers.integrations.fsdp import ( _find_final_norm, - apply_fsdp2, + apply_fully_shard_data_parallel, get_transformer_block_classes, initialize_fsdp, ) @@ -373,12 +374,11 @@ def train_fsdp2( ): # -- Phase 1: Pre-checkpoint run -- train only the first `checkpoint_step` steps, then save _set_determinism(SEED) - _, device_mesh, _ = initialize_fsdp(fsdp_plan=fsdp_plan) + distributed_config = DistributedConfig(fsdp_plan=fsdp_plan) pre_ckpt_model = AutoModelForCausalLM.from_pretrained( init_model_dir, torch_dtype=dtype, - fsdp_plan=fsdp_plan, - device_mesh=device_mesh, + distributed_config=distributed_config, attn_implementation="eager", ) pre_ckpt_model.train() @@ -415,8 +415,7 @@ def train_fsdp2( resumed_model = AutoModelForCausalLM.from_pretrained( model_dir, torch_dtype=dtype, - fsdp_plan=fsdp_plan, - device_mesh=device_mesh, + distributed_config=distributed_config, attn_implementation="eager", ) resumed_model.train() @@ -461,16 +460,14 @@ def _test_fsdp2_save_load_impl(rank, config_class, config_dict): batches = _build_repeated_training_batches(config, device, 3) - auto_plan = {"mode": "auto"} + distributed_config = DistributedConfig(fsdp_plan="auto") init_tmpdir, init_tmpdir_obj = _save_init_pretrained(rank, config, torch.float32) try: - _, device_mesh, _ = initialize_fsdp(fsdp_plan=auto_plan) _set_determinism(SEED) model = AutoModelForCausalLM.from_pretrained( init_tmpdir, - fsdp_plan=auto_plan, - device_mesh=device_mesh, + distributed_config=distributed_config, attn_implementation="eager", ) dist.barrier() @@ -495,8 +492,7 @@ def _test_fsdp2_save_load_impl(rank, config_class, config_dict): new_model = AutoModelForCausalLM.from_pretrained( tmpdir, - fsdp_plan=auto_plan, - device_mesh=device_mesh, + distributed_config=distributed_config, attn_implementation="eager", ) dist.barrier() @@ -522,7 +518,7 @@ def _test_fsdp2_save_load_impl(rank, config_class, config_dict): def _test_fsdp2_sharding_structure_impl(rank, config_class, config_dict, tie_word_embeddings): """ - Verify that apply_fsdp2(fsdp_plan={"mode": "auto"}) wraps exactly the right modules. + Verify that apply_fully_shard_data_parallel(fsdp_plan={"mode": "auto"}) wraps exactly the right modules. Expected FSDP targets: UNTIED TIED @@ -570,7 +566,7 @@ def _test_fsdp2_sharding_structure_impl(rank, config_class, config_dict, tie_wor if not weights_tied: expected_targets |= {output_name} - model = apply_fsdp2(model, device_mesh, fsdp_plan=auto_plan) + model = apply_fully_shard_data_parallel(model, device_mesh, fsdp_plan=auto_plan) actual_targets = {name for name, module in model.named_modules() if type(module).__name__.startswith("FSDP")} diff --git a/tests/utils/test_core_model_loading.py b/tests/utils/test_core_model_loading.py index 942dcdc99b11..787cf7b903ad 100644 --- a/tests/utils/test_core_model_loading.py +++ b/tests/utils/test_core_model_loading.py @@ -16,14 +16,15 @@ import torch import torch.nn as nn +from torch.distributed.tensor.placement_types import Replicate, Shard, _StridedShard from transformers import PretrainedConfig from transformers.conversion_mapping import get_checkpoint_conversion_mapping, register_checkpoint_conversion_mapping from transformers.core_model_loading import ( Chunk, Concatenate, + DtensorShardOperation, ErnieFuseAndSplitTextVisionExperts, - FSDPShardOperation, MergeModulelist, PermuteForRope, WeightConverter, @@ -32,7 +33,7 @@ convert_and_load_state_dict_in_model, rename_source_key, revert_weight_conversion, - spawn_parallel_materialize, + spawn_materialize, ) from transformers.modeling_utils import LoadStateDictConfig from transformers.utils.import_utils import is_triton_available @@ -217,24 +218,113 @@ def __init__(self, add_extra_moe=False): class FakeMesh: - def __init__(self, world_size: int, rank: int): - self.shape = (world_size,) - self._rank = rank + """Fake multi-dimensional device mesh for testing DtensorShardOperation.""" + + def __init__(self, shape, rank, dim_names=None): + if isinstance(shape, int): + shape = (shape,) + self.shape = tuple(shape) + self.ndim = len(self.shape) + self.mesh_dim_names = dim_names or tuple(f"dim{i}" for i in range(self.ndim)) + # Compute nD coordinate (row-major: last dim changes fastest) + self._coord = [] + r = rank + for s in reversed(self.shape): + self._coord.insert(0, r % s) + r //= s def get_local_rank(self): - return self._rank + return self._coord[0] def get_coordinate(self): - return (self._rank,) + return tuple(self._coord) + + def size(self): + result = 1 + for s in self.shape: + result *= s + return result + + def _is_current_rank_part_of_mesh(self): + return True + + def _sym_get_coordinate(self, dim): + return self._coord[dim] + + def __getitem__(self, name): + idx = self.mesh_dim_names.index(name) + return FakeMesh( + shape=(self.shape[idx],), + rank=self._coord[idx], + dim_names=(name,), + ) + + +def _make_dtensor_shard_op(mesh, placements, param_shape, local_shape): + """Build a DtensorShardOperation without requiring a real DTensor / distributed init.""" + op = object.__new__(DtensorShardOperation) + op.device_mesh = mesh + op.placements = tuple(placements) + ns = SimpleNamespace(shape=torch.Size(param_shape), ndim=len(param_shape)) + ns.dim = lambda: len(param_shape) + op.param = ns + op.local_shape = tuple(local_shape) + return op class TestConvertAndLoadStateDict(unittest.TestCase): - def test_fsdp_shard_aware_mixtral_conversion_uses_only_local_experts(self): - shard_op = FSDPShardOperation( - device_mesh=FakeMesh(world_size=2, rank=0), - rank=0, - empty_param=torch.empty((2, 4, 2)), - placements=(torch.distributed.tensor.placement_types.Shard(0),), + def test_dtensor_shard_aware_mixtral_conversion_uses_only_local_experts(self): + """Integration test: FSDP-sharded expert loading + WeightConverter. + + The problem: Mixtral has 8 experts. The checkpoint stores them separately:: + + experts.0.w1.weight (2x2) + experts.0.w3.weight (2x2) + experts.1.w1.weight (2x2) + experts.1.w3.weight (2x2) + + The model stores them packed into one tensor:: + + experts.gate_up_proj.weight (2, 4, 2) + ^ ^ ^ + | | +-- features + | +-- w1 (2) + w3 (2) concatenated + +-- num_experts + + The conversion (without FSDP) is: load all expert w1/w3 tensors, + MergeModulelist(dim=0) stacks experts, Concatenate(dim=1) joins w1+w3. + + With FSDP, Shard(0) splits the expert dim across ranks. Rank 0 owns + expert 0, rank 1 owns expert 1. So rank 0 should skip loading expert 1 + entirely -- not load it then discard it. + + What the test checks:: + + checkpoint files shard_tensor rank 0 gets + ---------------- ------------ ----------- + experts.0.w1 [[0,1],[2,3]] idx=0 -> kept [[0,1],[2,3]] + experts.1.w1 [[10,11],...] idx=1 -> None (not owned) + experts.0.w3 [[4,5],[6,7]] idx=0 -> kept [[4,5],[6,7]] + experts.1.w3 [[14,15],...] idx=1 -> None (not owned) + + WeightConverter then combines only the kept tensors:: + + MergeModulelist(dim=0): stack owned experts -> shape (1, 2, 2) each + Concatenate(dim=1): cat w1 + w3 along dim 1 + + gate_up_proj = [[[0,1],[2,3],[4,5],[6,7]]] shape (1, 4, 2) + ~~~~~~~~~~ ~~~~~~~~~~ + w1 w3 + + The key point: DtensorShardOperation.shard_tensor(tensor_idx=1) returns + None for rank 0, so the converter never even processes expert 1's data. + This saves memory during loading. + """ + shard_op = _make_dtensor_shard_op( + FakeMesh(shape=(2,), rank=0), + [Shard(0)], + param_shape=(2, 4, 2), + local_shape=(1, 4, 2), ) converter = WeightConverter( ["experts.*.w1.weight", "experts.*.w3.weight"], @@ -252,7 +342,7 @@ def test_fsdp_shard_aware_mixtral_conversion_uses_only_local_experts(self): "model.layers.0.experts.gate_up_proj.weight", f"model.layers.0.experts.{idx}.w1.weight", "experts.*.w1.weight", - spawn_parallel_materialize(None, tensor, shard_op, idx, device="cpu", dtype=None), + spawn_materialize(None, tensor, device="cpu", dtype=None, sharding_op=shard_op, tensor_idx=idx), ) for idx, tensor in enumerate( @@ -265,7 +355,7 @@ def test_fsdp_shard_aware_mixtral_conversion_uses_only_local_experts(self): "model.layers.0.experts.gate_up_proj.weight", f"model.layers.0.experts.{idx}.w3.weight", "experts.*.w3.weight", - spawn_parallel_materialize(None, tensor, shard_op, idx, device="cpu", dtype=None), + spawn_materialize(None, tensor, device="cpu", dtype=None, sharding_op=shard_op, tensor_idx=idx), ) converted = converter.convert("model.layers.0.experts.gate_up_proj.weight") @@ -785,6 +875,178 @@ def test_ernie4_5_vl_moe_conversion_reversed(self): self.assertTrue(compare_state_dicts(reversed_state_dict, state_dict)) +class TestDtensorShardOperation(unittest.TestCase): + """Unit tests for DtensorShardOperation.shard_tensor — one test per code path. + + Branch coverage map: + + shard_tensor() + ├── A: no sharding placements → full copy [test_no_shard_returns_full_tensor] + ├── B: expert path (tensor_idx set, ndim mismatch) + │ ├── B1: has_expert_sharding=False → fall through to C [test_expert_shaped_tp_only_no_expert_sharding] + │ ├── B2: not owns_local_expert → None [test_expert_filtering] + │ ├── B3: owned, no inner placements → full copy [test_expert_filtering] + │ └── B4: owned, with inner placements → _shard_nd [test_expert_filtering_preserves_inner_sharding] + └── C: _shard_nd() + ├── C1: _can_shard_on_read=False → _materialize_and_split [test_nd_strided_plus_shard_same_dim_fallback] + ├── C2: has_strided=False → contiguous slice + │ ├── 1D mesh [test_1d_shard_fast_path] + │ ├── 2D mesh [test_nd_contiguous_single_slice] + │ ├── negative dim [test_negative_dim_normalizes_correctly] + │ └── uneven division [test_contiguous_shard_uneven_division] + └── C3: has_strided=True → _compute_dim_ranges + _slice_and_read + ├── _StridedShard → _strided_ranges [test_nd_strided_shard_disjoint_ranges] + └── _source_tensor_needs_packing → contiguous [test_prepacked_strided_shard_uses_contiguous_source_slice] + + _slice_and_read (tested directly) + ├── all single ranges → simple slice [test_slice_and_read_all_single_ranges] + └── two multi-range dims → ValueError [test_slice_and_read_raises_on_two_multi_range_dims] + """ + + def test_no_shard_returns_full_tensor(self): + """Replicate-only → full copy.""" + mesh = FakeMesh(shape=(2,), rank=0) + op = _make_dtensor_shard_op(mesh, [Replicate()], param_shape=(4, 4), local_shape=(4, 4)) + tensor = torch.arange(16).reshape(4, 4).float() + torch.testing.assert_close(op.shard_tensor(tensor), tensor) + + def test_1d_shard_fast_path(self): + # TODO(3outeille): double check fast path + tensor = torch.arange(16).reshape(4, 4).float() + for rank, expected in [(0, tensor[:2]), (1, tensor[2:])]: + mesh = FakeMesh(shape=(2,), rank=rank) + op = _make_dtensor_shard_op(mesh, [Shard(0)], param_shape=(4, 4), local_shape=(2, 4)) + torch.testing.assert_close(op.shard_tensor(tensor), expected, msg=f"rank {rank}") + + def test_nd_contiguous_single_slice(self): + """nD Shard on different dims → single slice read per rank.""" + tensor = torch.arange(64).reshape(8, 8).float() + expected = {0: tensor[:4, :4], 1: tensor[:4, 4:], 2: tensor[4:, :4], 3: tensor[4:, 4:]} + for rank in range(4): + mesh = FakeMesh(shape=(2, 2), rank=rank) + op = _make_dtensor_shard_op(mesh, [Shard(0), Shard(1)], param_shape=(8, 8), local_shape=(4, 4)) + torch.testing.assert_close(op.shard_tensor(tensor), expected[rank], msg=f"rank {rank}") + + def test_nd_strided_shard_disjoint_ranges(self): + """_StridedShard on its own dim → multiple slice reads + cat.""" + tensor = torch.arange(64).reshape(8, 8).float() + # Shard(0) splits rows; _StridedShard(1, split_factor=2) produces disjoint col ranges + expected = { + 0: torch.cat([tensor[:4, :2], tensor[:4, 4:6]], dim=1), + 1: torch.cat([tensor[:4, 2:4], tensor[:4, 6:8]], dim=1), + 2: torch.cat([tensor[4:, :2], tensor[4:, 4:6]], dim=1), + 3: torch.cat([tensor[4:, 2:4], tensor[4:, 6:8]], dim=1), + } + for rank in range(4): + mesh = FakeMesh(shape=(2, 2), rank=rank) + op = _make_dtensor_shard_op( + mesh, + [Shard(0), _StridedShard(dim=1, split_factor=2)], + param_shape=(8, 8), + local_shape=(4, 4), + ) + torch.testing.assert_close(op.shard_tensor(tensor), expected[rank], msg=f"rank {rank}") + + def test_nd_strided_plus_shard_same_dim_fallback(self): + """_StridedShard + Shard on same dim → materialize-then-split fallback.""" + tensor = torch.arange(16).reshape(4, 4).float() + expected = {0: tensor[[0]], 1: tensor[[2]], 2: tensor[[1]], 3: tensor[[3]]} + for rank in range(4): + mesh = FakeMesh(shape=(2, 2), rank=rank) + op = _make_dtensor_shard_op( + mesh, + [_StridedShard(dim=0, split_factor=2), Shard(0)], + param_shape=(4, 4), + local_shape=(1, 4), + ) + torch.testing.assert_close(op.shard_tensor(tensor), expected[rank], msg=f"rank {rank}") + + def test_prepacked_strided_shard_uses_contiguous_source_slice(self): + """Pre-concat w1/w3 tensors should shard contiguously before gate/up packing.""" + tensor = torch.arange(8).reshape(4, 2).float() + for rank, expected in [(0, tensor[:2]), (1, tensor[2:])]: + mesh = FakeMesh(shape=(2,), rank=rank) + op = _make_dtensor_shard_op( + mesh, + [_StridedShard(dim=1, split_factor=2)], + param_shape=(8, 8, 2), + local_shape=(8, 4, 2), + ) + torch.testing.assert_close(op.shard_tensor(tensor, tensor_idx=0), expected, msg=f"rank {rank}") + + def test_expert_shaped_tp_only_no_expert_sharding(self): + """Expert-shaped param with TP on dim 1 but no expert sharding on dim 0 → regular _shard_nd path.""" + tensor = torch.arange(8).reshape(4, 2).float() + # Shard(1) on 3D param maps to dim 0 of the 2D checkpoint tensor (ndim_diff=1) + for rank, expected in [(0, tensor[:2]), (1, tensor[2:])]: + mesh = FakeMesh(shape=(2,), rank=rank) + op = _make_dtensor_shard_op(mesh, [Shard(1)], param_shape=(4, 4, 2), local_shape=(4, 2, 2)) + torch.testing.assert_close(op.shard_tensor(tensor, tensor_idx=0), expected, msg=f"rank {rank}") + + def test_expert_filtering(self): + """Mixtral-style experts: skip non-owned, return owned.""" + mesh = FakeMesh(shape=(2,), rank=1) + op = _make_dtensor_shard_op(mesh, [Shard(0)], param_shape=(4, 2, 2), local_shape=(2, 2, 2)) + expert_tensor = torch.ones(2, 2) + # rank 1 owns experts 2,3 (offset=2) + self.assertIsNone(op.shard_tensor(expert_tensor, tensor_idx=0)) + torch.testing.assert_close(op.shard_tensor(expert_tensor, tensor_idx=2), expert_tensor) + + def test_expert_filtering_preserves_inner_sharding(self): + """MoE expert ownership checks should still apply TP sharding on inner dims.""" + tensor = torch.arange(8).reshape(4, 2).float() + expected = { + 0: tensor[:2], + 1: tensor[2:], + 2: None, + 3: None, + } + for rank in range(4): + mesh = FakeMesh(shape=(2, 2), rank=rank) + op = _make_dtensor_shard_op(mesh, [Shard(0), Shard(1)], param_shape=(4, 4, 2), local_shape=(2, 2, 2)) + shard = op.shard_tensor(tensor, tensor_idx=1) + if expected[rank] is None: + self.assertIsNone(shard) + else: + torch.testing.assert_close(shard, expected[rank], msg=f"rank {rank}") + + def test_negative_dim_normalizes_correctly(self): + """Shard(-1) on a 2D tensor should shard the last dimension.""" + tensor = torch.arange(16).reshape(4, 4).float() + for rank, expected in [(0, tensor[:, :2]), (1, tensor[:, 2:])]: + mesh = FakeMesh(shape=(2,), rank=rank) + op = _make_dtensor_shard_op(mesh, [Shard(-1)], param_shape=(4, 4), local_shape=(4, 2)) + torch.testing.assert_close(op.shard_tensor(tensor), expected, msg=f"rank {rank}") + + def test_contiguous_shard_uneven_division(self): + """Shard(0) on 5 rows across 2 ranks → rank 0 gets 3 rows, rank 1 gets 2.""" + tensor = torch.arange(20).reshape(5, 4).float() + expected = {0: tensor[:3], 1: tensor[3:]} + for rank in range(2): + mesh = FakeMesh(shape=(2,), rank=rank) + local_rows = 3 if rank == 0 else 2 + op = _make_dtensor_shard_op(mesh, [Shard(0)], param_shape=(5, 4), local_shape=(local_rows, 4)) + torch.testing.assert_close(op.shard_tensor(tensor), expected[rank], msg=f"rank {rank}") + + def test_slice_and_read_all_single_ranges(self): + """When every dim has exactly one range, _slice_and_read takes the simple slice path (no concat).""" + tensor = torch.arange(64).reshape(8, 8).float() + mesh = FakeMesh(shape=(2,), rank=0) + op = _make_dtensor_shard_op(mesh, [Shard(0)], param_shape=(8, 8), local_shape=(4, 4)) + dim_ranges = {0: [(0, 4)], 1: [(2, 6)]} + result = op._slice_and_read(tensor, [8, 8], dim_ranges, None, None) + torch.testing.assert_close(result, tensor[0:4, 2:6]) + + def test_slice_and_read_raises_on_two_multi_range_dims(self): + """Multiple disjoint ranges on two different dims → ValueError.""" + tensor = torch.arange(64).reshape(8, 8).float() + mesh = FakeMesh(shape=(2,), rank=0) + op = _make_dtensor_shard_op(mesh, [Shard(0)], param_shape=(8, 8), local_shape=(4, 4)) + dim_ranges = {0: [(0, 2), (4, 6)], 1: [(0, 2), (4, 6)]} + with self.assertRaises(ValueError): + op._slice_and_read(tensor, [8, 8], dim_ranges, None, None) + + class TestConversionMapping(unittest.TestCase): def test_register_checkpoint_conversion_mapping(self): register_checkpoint_conversion_mapping(