Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/twinkle/model/transformers/strategy/sequence_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,8 +844,6 @@ def _get_ulysses_size(device_mesh, sp_config: Optional[Dict[str, Any]] = None) -
return 1
if getattr(device_mesh, "ulysses_size", None) is not None:
return int(device_mesh.ulysses_size)
if getattr(device_mesh, "has_dim", None) and device_mesh.has_dim("sp"):
return device_mesh.get_dim_size("sp")
return 1


Expand Down
31 changes: 19 additions & 12 deletions src/twinkle/model/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ def __init__(self, # noqa
else:
model_id = HubOperation.download_model(model_id)
self.model = model_cls.from_pretrained(model_id, config=config, **kwargs)
# Construct sequence-parallel strategy lazily during wrapping to reduce init-time side effects.
self.sp_strategy = None
self._model_wrapped = False
self.optimizer_group: Dict[str, OptimizerGroup] = {_default_adapter_name: self._construct_default_optimizer_group()}

Expand All @@ -212,21 +214,25 @@ def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']):
self.strategy = AccelerateStrategy(mixed_precision=self.mixed_precision, ddp_config=self._ddp_config,
fsdp_config=self._fsdp_config, device_mesh=self.device_mesh)

enable_sp = False
# Sequence parallel ("ulysses") is derived from dp/fsdp ranks; it does not change world size.
# We construct `sp_strategy` after the underlying HF model is initialized (see __init__).
self._enable_sp = False
if self.device_mesh is not None:
sp_size = self.device_mesh.ulysses_size
enable_sp = bool(sp_size and sp_size > 1)
sp_size = getattr(self.device_mesh, "ulysses_size", None)
self._enable_sp = bool(sp_size and sp_size > 1)

def _ensure_sp_strategy(self) -> None:
if not getattr(self, "_enable_sp", False):
return
if self.sp_strategy is not None:
return
from .strategy.sequence_parallel import SequenceParallelStrategy
self.sp_strategy = (
SequenceParallelStrategy(
self.device_mesh,
{},
model=self.model,
tokenizer_id=self.tokenizer_id,
)
if enable_sp
else None

self.sp_strategy = SequenceParallelStrategy(
self.device_mesh,
{},
model=self.model,
tokenizer_id=self.tokenizer_id,
)

def _get_default_group(self):
Expand All @@ -247,6 +253,7 @@ def _lazy_wrap_model(self):
if not self._model_wrapped:
optimizer_groups = [og for og in self.optimizer_group.values() if og.optimizer is not None]
self._maybe_apply_expert_parallel()
self._ensure_sp_strategy()
if self.sp_strategy is not None:
self.sp_strategy.initialize()
if len(optimizer_groups) == 1:
Expand Down
Loading