Skip to content
Merged
23 changes: 14 additions & 9 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,12 @@
]

_import_structure["tensor_parallel"] = [
"shard_and_distribute_module",
"ALL_PARALLEL_STYLES",
"translate_to_torch_parallel_style",
"TPStyle",
"apply_tensor_parallel",
"convert_strided_to_shard",
"gather_full_state_dict",
"restore_strided_from_shard",
"verify_tp_plan",
]
try:
if not is_torch_greater_or_equal("2.5"):
Expand Down Expand Up @@ -295,6 +298,14 @@
from .quanto import replace_with_quanto_layers
from .sinq import SinqDeserialize, SinqQuantize
from .spqr import replace_with_spqr_linear
from .tensor_parallel import (
TPStyle,
apply_tensor_parallel,
convert_strided_to_shard,
gather_full_state_dict,
restore_strided_from_shard,
verify_tp_plan,
)
from .vptq import replace_with_vptq_linear

try:
Expand All @@ -305,12 +316,6 @@
else:
from .executorch import TorchExportableModuleWithStaticCache, convert_and_export_with_cache

from .tensor_parallel import (
ALL_PARALLEL_STYLES,
shard_and_distribute_module,
translate_to_torch_parallel_style,
)

try:
if not is_torch_greater_or_equal("2.5"):
raise OptionalDependencyNotAvailable()
Expand Down
13 changes: 5 additions & 8 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from safetensors import safe_open
from safetensors.torch import save_file as safe_save_file
from torch import Tensor, nn
from torch.distributed.tensor import DTensor
from torch.distributions import constraints
from torch.utils.checkpoint import checkpoint

Expand Down Expand Up @@ -4123,15 +4124,13 @@ def from_pretrained(
if distributed_config is not None:
model.config.distributed_config = distributed_config
model.device_mesh = device_mesh

def sub_mesh(name):
return device_mesh[name] if device_mesh.ndim > 1 else device_mesh

mesh_dim_names = device_mesh.mesh_dim_names or ()
if "tp" in mesh_dim_names:
model = apply_tensor_parallel(model, sub_mesh("tp"), distributed_config.tp_plan)
tp_mesh = device_mesh["tp"] if device_mesh.ndim > 1 else device_mesh
model = apply_tensor_parallel(model, tp_mesh, distributed_config.tp_plan)
if "fsdp" in mesh_dim_names:
model = apply_fully_shard_data_parallel(model, sub_mesh("fsdp"), distributed_config.fsdp_plan)
fsdp_mesh = device_mesh["fsdp"] if device_mesh.ndim > 1 else device_mesh
model = apply_fully_shard_data_parallel(model, fsdp_mesh, distributed_config.fsdp_plan)
else:
# Accelerate path: auto device mapping
if device_map is not None:
Expand Down Expand Up @@ -4552,8 +4551,6 @@ def _move_missing_keys_from_meta_to_device(
# will be re-initialized for nothing (which can be quite long)
for key in missing_keys - self.all_tied_weights_keys.keys():
param = self.get_parameter_or_buffer(key)
from torch.distributed.tensor import DTensor

if isinstance(param, DTensor):
# DTensor from parallelize_module on meta — materialize on actual device
local_value = torch.empty(
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2392,9 +2392,10 @@ def get_cp_size(self) -> int:
def get_tp_size(self) -> int:
"""Get the tensor parallel size from either the model or DeepSpeed config."""

# 1. Check model.tp_size first
if (model_tp := getattr(self.model, "_tp_size", None)) is not None:
return model_tp
# TODO: adapt it cleaner with distributed config once distributed api is stable
dc = getattr(getattr(self.model, "config", None), "distributed_config", None)
if dc is not None and dc.tp_size is not None:
return dc.tp_size

# 2. Fall back to DeepSpeed config if enabled
if self.is_deepspeed_enabled and (deepspeed_config := getattr(self.args, "hf_deepspeed_config", None)):
Expand Down
24 changes: 10 additions & 14 deletions tests/test_fsdp_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@
from torch.distributed.tensor import DTensor
from torch.nn.parallel import DistributedDataParallel as DDP

from transformers.distributed import DistributedConfig
from transformers.integrations.fsdp import (
_find_final_norm,
apply_fsdp2,
apply_fully_shard_data_parallel,
get_transformer_block_classes,
initialize_fsdp,
)
Expand Down Expand Up @@ -373,12 +374,11 @@ def train_fsdp2(
):
# -- Phase 1: Pre-checkpoint run -- train only the first `checkpoint_step` steps, then save
_set_determinism(SEED)
_, device_mesh, _ = initialize_fsdp(fsdp_plan=fsdp_plan)
distributed_config = DistributedConfig(fsdp_plan=fsdp_plan)
pre_ckpt_model = AutoModelForCausalLM.from_pretrained(
init_model_dir,
torch_dtype=dtype,
fsdp_plan=fsdp_plan,
device_mesh=device_mesh,
distributed_config=distributed_config,
attn_implementation="eager",
)
pre_ckpt_model.train()
Expand Down Expand Up @@ -415,8 +415,7 @@ def train_fsdp2(
resumed_model = AutoModelForCausalLM.from_pretrained(
model_dir,
torch_dtype=dtype,
fsdp_plan=fsdp_plan,
device_mesh=device_mesh,
distributed_config=distributed_config,
attn_implementation="eager",
)
resumed_model.train()
Expand Down Expand Up @@ -461,16 +460,14 @@ def _test_fsdp2_save_load_impl(rank, config_class, config_dict):

batches = _build_repeated_training_batches(config, device, 3)

auto_plan = {"mode": "auto"}
distributed_config = DistributedConfig(fsdp_plan="auto")

init_tmpdir, init_tmpdir_obj = _save_init_pretrained(rank, config, torch.float32)
try:
_, device_mesh, _ = initialize_fsdp(fsdp_plan=auto_plan)
_set_determinism(SEED)
model = AutoModelForCausalLM.from_pretrained(
init_tmpdir,
fsdp_plan=auto_plan,
device_mesh=device_mesh,
distributed_config=distributed_config,
attn_implementation="eager",
)
dist.barrier()
Expand All @@ -495,8 +492,7 @@ def _test_fsdp2_save_load_impl(rank, config_class, config_dict):

new_model = AutoModelForCausalLM.from_pretrained(
tmpdir,
fsdp_plan=auto_plan,
device_mesh=device_mesh,
distributed_config=distributed_config,
attn_implementation="eager",
)
dist.barrier()
Expand All @@ -522,7 +518,7 @@ def _test_fsdp2_save_load_impl(rank, config_class, config_dict):

def _test_fsdp2_sharding_structure_impl(rank, config_class, config_dict, tie_word_embeddings):
"""
Verify that apply_fsdp2(fsdp_plan={"mode": "auto"}) wraps exactly the right modules.
Verify that apply_fully_shard_data_parallel(fsdp_plan={"mode": "auto"}) wraps exactly the right modules.

Expected FSDP targets:
UNTIED TIED
Expand Down Expand Up @@ -570,7 +566,7 @@ def _test_fsdp2_sharding_structure_impl(rank, config_class, config_dict, tie_wor
if not weights_tied:
expected_targets |= {output_name}

model = apply_fsdp2(model, device_mesh, fsdp_plan=auto_plan)
model = apply_fully_shard_data_parallel(model, device_mesh, fsdp_plan=auto_plan)

actual_targets = {name for name, module in model.named_modules() if type(module).__name__.startswith("FSDP")}

Expand Down
Loading
Loading