From 422b79853146aa503c18d10c6d84c6cee50e129c Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 4 Feb 2026 15:52:03 +0800 Subject: [PATCH 1/3] wip --- cookbook/sft/fsdp_qwen3_moe.py | 102 ++++++++++++++++++ .../transformers/strategy/native_fsdp.py | 11 +- .../model/transformers/transformers.py | 1 + src/twinkle/utils/platform.py | 4 +- 4 files changed, 113 insertions(+), 5 deletions(-) create mode 100644 cookbook/sft/fsdp_qwen3_moe.py diff --git a/cookbook/sft/fsdp_qwen3_moe.py b/cookbook/sft/fsdp_qwen3_moe.py new file mode 100644 index 00000000..83e0f19b --- /dev/null +++ b/cookbook/sft/fsdp_qwen3_moe.py @@ -0,0 +1,102 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import os + +from transformers import AutoConfig + +import twinkle +from twinkle import DeviceMesh, Platform, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import TransformersModel + +logger = get_logger() + +MODEL_ID = os.environ.get( + "QWEN3_MODEL_ID", "ms://Qwen/Qwen3-30B-A3B-Instruct-2507" +) +DATASET_ID = os.environ.get("QWEN3_DATASET_ID", "/path/to/alpaca/dataset") +TEMPLATE_ID = os.environ.get("QWEN3_TEMPLATE_ID", "Template") +PROCESSOR_ID = os.environ.get("QWEN3_PROCESSOR_ID", "AlpacaProcessor") + +WORLD_SIZE = int(os.environ.get("QWEN3_WORLD_SIZE", "2")) +NUM_LAYERS = int(os.environ.get("QWEN3_NUM_LAYERS", "4")) +BATCH_SIZE = int(os.environ.get("QWEN3_BATCH_SIZE", "4")) +GRAD_ACCUM_STEPS = int(os.environ.get("QWEN3_GRAD_ACCUM_STEPS", "4")) +SAVE_INTERVAL = int(os.environ.get("QWEN3_SAVE_INTERVAL", "50")) + +# EP is disabled: mesh only has fsdp/dp, and fsdp_size = world_size. +device_mesh = DeviceMesh.from_sizes( + world_size=WORLD_SIZE, + fsdp_size=WORLD_SIZE, + dp_size=1, + device_type=Platform.get_platform().device_prefix(), +) + +os.environ.setdefault("RAY_DEDUP_LOGS", "0") +if Platform.get_world_size() != WORLD_SIZE: + raise RuntimeError( + f"QWEN3_WORLD_SIZE={WORLD_SIZE} but distributed world size is " + f"{Platform.get_world_size()}. Use torchrun with matching nproc_per_node." + ) +twinkle.initialize( + mode="local", + global_device_mesh=device_mesh, +) + + +def train(): + config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) + if hasattr(config, "num_hidden_layers"): + config.num_hidden_layers = min(NUM_LAYERS, config.num_hidden_layers) + if hasattr(config, "use_cache"): + config.use_cache = False + + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID)) + try: + dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID) + except ValueError: + dataset.set_template("Template", model_id=MODEL_ID) + + processor = PROCESSOR_ID + if PROCESSOR_ID.lower() == "alpaca": + processor = "AlpacaProcessor" + + dataset.map(processor) + dataset.encode(batched=True) + dataloader = DataLoader( + dataset=dataset, + batch_size=BATCH_SIZE, + device_mesh=device_mesh, + ) + + model = TransformersModel( + model_id=MODEL_ID, + config=config, + strategy="native_fsdp", + device_mesh=device_mesh, + ) + model.set_optimizer("AdamW") + + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + + for step, batch in enumerate(dataloader): + if callable(batch): + batch = batch() + model.forward_backward( + inputs=batch, gradient_accumulation_steps=GRAD_ACCUM_STEPS + ) + model.clip_grad_and_step(gradient_accumulation_steps=GRAD_ACCUM_STEPS) + if step % GRAD_ACCUM_STEPS == 0: + metric = model.calculate_metric(is_training=True) + if callable(metric): + metric = metric() + logger.info( + f"Current is step {step // GRAD_ACCUM_STEPS}, metric: {metric}" + ) + if step % SAVE_INTERVAL == 0: + model.save("./output") + + +if __name__ == "__main__": + train() diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index 101f66c7..f2f34534 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -15,10 +15,12 @@ class NativeFSDPStrategy: def __init__(self, device_mesh: Optional[DeviceMesh] = None, mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16', - fsdp_config: Dict[str, Any] = None): + fsdp_config: Dict[str, Any] = None, + enable_ep: bool = True): self.device_mesh = device_mesh self.mixed_precision = mixed_precision self.fsdp_config = fsdp_config or {} + self.enable_ep = enable_ep def wrap_model(self, model, optimizer=None): if self.device_mesh is None: @@ -26,11 +28,12 @@ def wrap_model(self, model, optimizer=None): fsdp_mesh = _build_fsdp_mesh(self.device_mesh) if fsdp_mesh is not None: - _ensure_moe_patched_if_needed(model, self.device_mesh) - _place_ep_experts_on_local_device(model, self.device_mesh) + if self.enable_ep: + _ensure_moe_patched_if_needed(model, self.device_mesh) + _place_ep_experts_on_local_device(model, self.device_mesh) mp_policy = _build_mp_policy(self.mixed_precision) reshard_after_forward = self.fsdp_config.get("reshard_after_forward", True) - ignored_params = _collect_expert_params(model) + ignored_params = _collect_expert_params(model) if self.enable_ep else None _maybe_shard_layers( model, diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 336ee187..8c72edc0 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -206,6 +206,7 @@ def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']): mixed_precision=self.mixed_precision, fsdp_config=self._fsdp_config, device_mesh=self.device_mesh, + enable_ep=self._enable_expert_parallel, ) else: self.strategy = AccelerateStrategy(mixed_precision=self.mixed_precision, ddp_config=self._ddp_config, diff --git a/src/twinkle/utils/platform.py b/src/twinkle/utils/platform.py index 1dd1015f..9862ee83 100644 --- a/src/twinkle/utils/platform.py +++ b/src/twinkle/utils/platform.py @@ -348,12 +348,14 @@ def data_world_size(self) -> int: dp_world_size = self.dp_world_size fsdp_world_size = self.fsdp_world_size if fsdp_world_size is not None and fsdp_world_size > 1: - if dp_world_size is not None: + if dp_world_size is not None and dp_world_size > 0: return dp_world_size * fsdp_world_size else: return fsdp_world_size ulysses_size = self.ulysses_size or 1 + if dp_world_size is None or dp_world_size == 0: + return 1 assert dp_world_size % ulysses_size == 0, f'dp_world_size: {dp_world_size} cannot be divided by ulysses_size: {ulysses_size}.' return dp_world_size // ulysses_size From 1cebf0ae2ae15f64357e4a713266453ccf605d36 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Thu, 5 Feb 2026 09:35:01 +0800 Subject: [PATCH 2/3] wip --- cookbook/sft/fsdp_qwen3_moe.py | 39 ++++---- .../model/transformers/transformers.py | 91 ++++++++++++++++++- 2 files changed, 107 insertions(+), 23 deletions(-) diff --git a/cookbook/sft/fsdp_qwen3_moe.py b/cookbook/sft/fsdp_qwen3_moe.py index 83e0f19b..a8cd10fa 100644 --- a/cookbook/sft/fsdp_qwen3_moe.py +++ b/cookbook/sft/fsdp_qwen3_moe.py @@ -12,19 +12,19 @@ logger = get_logger() MODEL_ID = os.environ.get( - "QWEN3_MODEL_ID", "ms://Qwen/Qwen3-30B-A3B-Instruct-2507" + "MODEL_ID", "ms://Qwen/Qwen3-30B-A3B-Instruct-2507" ) -DATASET_ID = os.environ.get("QWEN3_DATASET_ID", "/path/to/alpaca/dataset") -TEMPLATE_ID = os.environ.get("QWEN3_TEMPLATE_ID", "Template") -PROCESSOR_ID = os.environ.get("QWEN3_PROCESSOR_ID", "AlpacaProcessor") +DATASET_ID = os.environ.get("DATASET_ID", "/path/to/alpaca/dataset") +TEMPLATE_ID = os.environ.get("TEMPLATE_ID", "Template") +STRATEGY = os.environ.get("TRAIN_STRATEGY", "native_fsdp") + +WORLD_SIZE = int(os.environ.get("WORLD_SIZE", "2")) +NUM_LAYERS = int(os.environ.get("NUM_LAYERS", "4")) +BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "4")) +GRAD_ACCUM_STEPS = int(os.environ.get("GRAD_ACCUM_STEPS", "4")) +SAVE_INTERVAL = int(os.environ.get("SAVE_INTERVAL", "50")) -WORLD_SIZE = int(os.environ.get("QWEN3_WORLD_SIZE", "2")) -NUM_LAYERS = int(os.environ.get("QWEN3_NUM_LAYERS", "4")) -BATCH_SIZE = int(os.environ.get("QWEN3_BATCH_SIZE", "4")) -GRAD_ACCUM_STEPS = int(os.environ.get("QWEN3_GRAD_ACCUM_STEPS", "4")) -SAVE_INTERVAL = int(os.environ.get("QWEN3_SAVE_INTERVAL", "50")) -# EP is disabled: mesh only has fsdp/dp, and fsdp_size = world_size. device_mesh = DeviceMesh.from_sizes( world_size=WORLD_SIZE, fsdp_size=WORLD_SIZE, @@ -32,12 +32,7 @@ device_type=Platform.get_platform().device_prefix(), ) -os.environ.setdefault("RAY_DEDUP_LOGS", "0") -if Platform.get_world_size() != WORLD_SIZE: - raise RuntimeError( - f"QWEN3_WORLD_SIZE={WORLD_SIZE} but distributed world size is " - f"{Platform.get_world_size()}. Use torchrun with matching nproc_per_node." - ) + twinkle.initialize( mode="local", global_device_mesh=device_mesh, @@ -57,10 +52,7 @@ def train(): except ValueError: dataset.set_template("Template", model_id=MODEL_ID) - processor = PROCESSOR_ID - if PROCESSOR_ID.lower() == "alpaca": - processor = "AlpacaProcessor" - + processor = "AlpacaProcessor" dataset.map(processor) dataset.encode(batched=True) dataloader = DataLoader( @@ -72,8 +64,11 @@ def train(): model = TransformersModel( model_id=MODEL_ID, config=config, - strategy="native_fsdp", + strategy=STRATEGY, device_mesh=device_mesh, + fsdp_config={ + "transformer_cls_names_to_wrap": ["Qwen3MoeSparseMoeBlock"], + }, ) model.set_optimizer("AdamW") @@ -94,7 +89,7 @@ def train(): logger.info( f"Current is step {step // GRAD_ACCUM_STEPS}, metric: {metric}" ) - if step % SAVE_INTERVAL == 0: + if step > 1 and step % SAVE_INTERVAL == 0: model.save("./output") diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 8c72edc0..1133ec01 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -37,6 +37,10 @@ from twinkle.model.transformers.strategy import AccelerateStrategy, NativeFSDPStrategy from twinkle.metric import LossMetric, Accuracy, TrainMetric +_debug_logger = twinkle.get_logger() +_debug_enabled = os.environ.get("TWINKLE_TRANSFORMERS_DEBUG", "0") == "1" +_debug_every = max(1, int(os.environ.get("TWINKLE_TRANSFORMERS_DEBUG_EVERY", "1"))) + @dataclass class OptimizerGroup: @@ -117,6 +121,12 @@ def accumulate_metrics(self, is_training): metric.accumulate(self.inputs, {**self.outputs, 'lr': self._get_lr(), 'step': self.cur_step-1, 'gradient_accumulation_steps': self.gradient_accumulation_steps}) def calculate_metrics(self, is_training): + if _debug_enabled: + _debug_logger.info( + f"[TRANSFORMERS_DEBUG][rank={Platform.get_rank()} local_rank={Platform.get_local_rank()}/" + f"{Platform.get_world_size()}] cur_step={self.cur_step} " + f"calculate_metrics accumulate begin is_training={is_training}" + ) self.accumulate_metrics(is_training) if is_training: metrics = self.train_metrics @@ -124,7 +134,20 @@ def calculate_metrics(self, is_training): metrics = self.eval_metrics results = {} for metric in metrics: + if _debug_enabled: + _debug_logger.info( + f"[TRANSFORMERS_DEBUG][rank={Platform.get_rank()} local_rank={Platform.get_local_rank()}/" + f"{Platform.get_world_size()}] cur_step={self.cur_step} " + f"metric.calculate begin metric={metric.__class__.__name__} " + f"pg_none={getattr(metric, 'process_group', None) is None}" + ) results.update(metric.calculate()) + if _debug_enabled: + _debug_logger.info( + f"[TRANSFORMERS_DEBUG][rank={Platform.get_rank()} local_rank={Platform.get_local_rank()}/" + f"{Platform.get_world_size()}] cur_step={self.cur_step} " + f"metric.calculate done metric={metric.__class__.__name__}" + ) self.inputs = None self.outputs = None return results @@ -247,6 +270,23 @@ def _lazy_wrap_model(self): self.model, optimizer = self.strategy.wrap_model(self.model, optimizer) optimizer_group.optimizer = optimizer self._model_wrapped = True + self._debug_log("model wrapped by strategy", force=True) + + def _debug_log(self, message: str, *, adapter_name: str = _default_adapter_name, force: bool = False): + if not _debug_enabled: + return + cur_step = -1 + if adapter_name in self.optimizer_group: + cur_step = self.optimizer_group[adapter_name].cur_step + if not force and cur_step >= 0 and (cur_step % _debug_every != 0): + return + rank = Platform.get_rank() + local_rank = Platform.get_local_rank() + world_size = Platform.get_world_size() + _debug_logger.info( + f"[TRANSFORMERS_DEBUG][rank={rank} local_rank={local_rank}/{world_size}] " + f"cur_step={cur_step} {message}" + ) @staticmethod def _should_enable_expert_parallel(expert_parallel_config: Optional[Dict[str, Any]], @@ -425,10 +465,15 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr Returns: The output of the model forward. """ + adapter_name = kwargs.get('adapter_name', _default_adapter_name) + self._debug_log("forward_backward begin", adapter_name=adapter_name) output = self.forward(inputs=inputs, **kwargs) + self._debug_log("forward done", adapter_name=adapter_name) loss = self.calculate_loss(**kwargs) + self._debug_log(f"calculate_loss done, loss={loss}", adapter_name=adapter_name) output['loss'] = loss self.backward(**kwargs) + self._debug_log("backward done", adapter_name=adapter_name) return loss @remote_function() @@ -446,6 +491,7 @@ def clip_grad_norm(self, max_grad_norm: float=1.0, norm_type=2, **kwargs): adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] if not optimizer_config.do_grad_sync(kwargs.get('gradient_accumulation_steps')): + self._debug_log("clip_grad_norm skipped(do_grad_sync=False)", adapter_name=adapter_name) return optimizer = optimizer_config.optimizer @@ -462,9 +508,21 @@ def clip_grad_norm(self, max_grad_norm: float=1.0, norm_type=2, **kwargs): optimizer_config._ensure_dp_group() num_tokens = optimizer_config.num_tokens + self._debug_log( + f"clip_grad_norm before gather_object local_num_tokens={num_tokens}", + adapter_name=adapter_name, + ) num_tokens = torch_util.gather_object([num_tokens], self.device_mesh, optimizer_config._dp_group) num_tokens = sum(num_tokens) + self._debug_log( + f"clip_grad_norm after gather_object total_num_tokens={num_tokens}", + adapter_name=adapter_name, + ) parameters = list(self._get_trainable_parameters(adapter_name).values()) + self._debug_log( + f"clip_grad_norm before normalize_and_clip_grad_norm params={len(parameters)}", + adapter_name=adapter_name, + ) grad_norm = normalize_and_clip_grad_norm( parameters, num_tokens=num_tokens, @@ -474,14 +532,21 @@ def clip_grad_norm(self, max_grad_norm: float=1.0, norm_type=2, **kwargs): ) outputs['grad_norm'] = grad_norm optimizer_config.num_tokens = 0 + self._debug_log(f"clip_grad_norm done grad_norm={grad_norm}", adapter_name=adapter_name) return grad_norm @remote_function(dispatch='all') def clip_grad_and_step(self, max_grad_norm: float=1.0, norm_type=2, **kwargs): + adapter_name = kwargs.get('adapter_name', _default_adapter_name) + self._debug_log("clip_grad_and_step begin", adapter_name=adapter_name) grad_norm = self.clip_grad_norm(max_grad_norm, norm_type, **kwargs) + self._debug_log("clip_grad_norm returned", adapter_name=adapter_name) self.step(**kwargs) + self._debug_log("optimizer step done", adapter_name=adapter_name) self.zero_grad(**kwargs) + self._debug_log("zero_grad done", adapter_name=adapter_name) self.lr_step(**kwargs) + self._debug_log("lr_step done", adapter_name=adapter_name) return grad_norm def _create_param_group(self, adapter_name: str, lr: float=DEFAULT_LEARNING_RATE, weight_decay:float=DEFAULT_WEIGHT_DECAY, **kwargs): @@ -852,7 +917,31 @@ def get_state_dict(self, **kwargs): def calculate_metric(self, is_training, **kwargs): adapter_name = kwargs.pop('adapter_name', _default_adapter_name) optimizer_config = self.optimizer_group[adapter_name] - return optimizer_config.calculate_metrics(is_training) + metric_names = [] + metric_pg_none = [] + metrics = optimizer_config.train_metrics if is_training else optimizer_config.eval_metrics + for metric in metrics: + metric_names.append(metric.__class__.__name__) + metric_pg_none.append(getattr(metric, "process_group", None) is None) + self._debug_log( + f"calculate_metric begin is_training={is_training} metrics={metric_names} " + f"pg_none={metric_pg_none}", + adapter_name=adapter_name, + force=True, + ) + optimizer_config._ensure_dp_group() + self._debug_log( + f"calculate_metric dp_group_is_none={optimizer_config._dp_group is None}", + adapter_name=adapter_name, + force=True, + ) + results = optimizer_config.calculate_metrics(is_training) + self._debug_log( + f"calculate_metric done keys={list(results.keys())}", + adapter_name=adapter_name, + force=True, + ) + return results def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str], train_group: str, **kwargs): assert adapter_name, 'Use a different adapter_name, current is empty.' From 1e6f95b21a9b5a521b4038b6a8d6eee1cc1e43f4 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Fri, 6 Feb 2026 10:56:02 +0800 Subject: [PATCH 3/3] wip --- cookbook/sft/fsdp_qwen3_moe.py | 97 ------------------- .../model/transformers/transformers.py | 61 ------------ 2 files changed, 158 deletions(-) delete mode 100644 cookbook/sft/fsdp_qwen3_moe.py diff --git a/cookbook/sft/fsdp_qwen3_moe.py b/cookbook/sft/fsdp_qwen3_moe.py deleted file mode 100644 index a8cd10fa..00000000 --- a/cookbook/sft/fsdp_qwen3_moe.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import os - -from transformers import AutoConfig - -import twinkle -from twinkle import DeviceMesh, Platform, get_device_placement, get_logger -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.model import TransformersModel - -logger = get_logger() - -MODEL_ID = os.environ.get( - "MODEL_ID", "ms://Qwen/Qwen3-30B-A3B-Instruct-2507" -) -DATASET_ID = os.environ.get("DATASET_ID", "/path/to/alpaca/dataset") -TEMPLATE_ID = os.environ.get("TEMPLATE_ID", "Template") -STRATEGY = os.environ.get("TRAIN_STRATEGY", "native_fsdp") - -WORLD_SIZE = int(os.environ.get("WORLD_SIZE", "2")) -NUM_LAYERS = int(os.environ.get("NUM_LAYERS", "4")) -BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "4")) -GRAD_ACCUM_STEPS = int(os.environ.get("GRAD_ACCUM_STEPS", "4")) -SAVE_INTERVAL = int(os.environ.get("SAVE_INTERVAL", "50")) - - -device_mesh = DeviceMesh.from_sizes( - world_size=WORLD_SIZE, - fsdp_size=WORLD_SIZE, - dp_size=1, - device_type=Platform.get_platform().device_prefix(), -) - - -twinkle.initialize( - mode="local", - global_device_mesh=device_mesh, -) - - -def train(): - config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) - if hasattr(config, "num_hidden_layers"): - config.num_hidden_layers = min(NUM_LAYERS, config.num_hidden_layers) - if hasattr(config, "use_cache"): - config.use_cache = False - - dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID)) - try: - dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID) - except ValueError: - dataset.set_template("Template", model_id=MODEL_ID) - - processor = "AlpacaProcessor" - dataset.map(processor) - dataset.encode(batched=True) - dataloader = DataLoader( - dataset=dataset, - batch_size=BATCH_SIZE, - device_mesh=device_mesh, - ) - - model = TransformersModel( - model_id=MODEL_ID, - config=config, - strategy=STRATEGY, - device_mesh=device_mesh, - fsdp_config={ - "transformer_cls_names_to_wrap": ["Qwen3MoeSparseMoeBlock"], - }, - ) - model.set_optimizer("AdamW") - - logger.info(get_device_placement()) - logger.info(model.get_train_configs()) - - for step, batch in enumerate(dataloader): - if callable(batch): - batch = batch() - model.forward_backward( - inputs=batch, gradient_accumulation_steps=GRAD_ACCUM_STEPS - ) - model.clip_grad_and_step(gradient_accumulation_steps=GRAD_ACCUM_STEPS) - if step % GRAD_ACCUM_STEPS == 0: - metric = model.calculate_metric(is_training=True) - if callable(metric): - metric = metric() - logger.info( - f"Current is step {step // GRAD_ACCUM_STEPS}, metric: {metric}" - ) - if step > 1 and step % SAVE_INTERVAL == 0: - model.save("./output") - - -if __name__ == "__main__": - train() diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 5366df1e..3960d33d 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -38,10 +38,6 @@ from twinkle.model.transformers.strategy import AccelerateStrategy, NativeFSDPStrategy from twinkle.metric import LossMetric, Accuracy, TrainMetric -_debug_logger = twinkle.get_logger() -_debug_enabled = os.environ.get("TWINKLE_TRANSFORMERS_DEBUG", "0") == "1" -_debug_every = max(1, int(os.environ.get("TWINKLE_TRANSFORMERS_DEBUG_EVERY", "1"))) - @dataclass class OptimizerGroup: @@ -122,12 +118,6 @@ def accumulate_metrics(self, is_training): metric.accumulate(self.inputs, {**self.outputs, 'lr': self._get_lr(), 'step': self.cur_step-1, 'gradient_accumulation_steps': self.gradient_accumulation_steps}) def calculate_metrics(self, is_training): - if _debug_enabled: - _debug_logger.info( - f"[TRANSFORMERS_DEBUG][rank={Platform.get_rank()} local_rank={Platform.get_local_rank()}/" - f"{Platform.get_world_size()}] cur_step={self.cur_step} " - f"calculate_metrics accumulate begin is_training={is_training}" - ) self.accumulate_metrics(is_training) if is_training: metrics = self.train_metrics @@ -135,20 +125,7 @@ def calculate_metrics(self, is_training): metrics = self.eval_metrics results = {} for metric in metrics: - if _debug_enabled: - _debug_logger.info( - f"[TRANSFORMERS_DEBUG][rank={Platform.get_rank()} local_rank={Platform.get_local_rank()}/" - f"{Platform.get_world_size()}] cur_step={self.cur_step} " - f"metric.calculate begin metric={metric.__class__.__name__} " - f"pg_none={getattr(metric, 'process_group', None) is None}" - ) results.update(metric.calculate()) - if _debug_enabled: - _debug_logger.info( - f"[TRANSFORMERS_DEBUG][rank={Platform.get_rank()} local_rank={Platform.get_local_rank()}/" - f"{Platform.get_world_size()}] cur_step={self.cur_step} " - f"metric.calculate done metric={metric.__class__.__name__}" - ) self.inputs = None self.outputs = None return results @@ -290,23 +267,6 @@ def _lazy_wrap_model(self): # maybe forward_only, no optimizer_group available self.model = self.strategy.wrap_model(self.model) self._model_wrapped = True - self._debug_log("model wrapped by strategy", force=True) - - def _debug_log(self, message: str, *, adapter_name: str = _default_adapter_name, force: bool = False): - if not _debug_enabled: - return - cur_step = -1 - if adapter_name in self.optimizer_group: - cur_step = self.optimizer_group[adapter_name].cur_step - if not force and cur_step >= 0 and (cur_step % _debug_every != 0): - return - rank = Platform.get_rank() - local_rank = Platform.get_local_rank() - world_size = Platform.get_world_size() - _debug_logger.info( - f"[TRANSFORMERS_DEBUG][rank={rank} local_rank={local_rank}/{world_size}] " - f"cur_step={cur_step} {message}" - ) @staticmethod def _should_enable_expert_parallel(expert_parallel_config: Optional[Dict[str, Any]], @@ -491,7 +451,6 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr self.forward(inputs=inputs, **kwargs) loss = self.calculate_loss(**kwargs) self.backward(**kwargs) - self._debug_log("backward done", adapter_name=adapter_name) return loss @remote_function() @@ -509,7 +468,6 @@ def clip_grad_norm(self, max_grad_norm: float=1.0, norm_type=2, **kwargs): adapter_name = kwargs.pop('adapter_name', self._get_default_group()) optimizer_config = self.optimizer_group[adapter_name] if not optimizer_config.do_grad_sync(kwargs.get('gradient_accumulation_steps')): - self._debug_log("clip_grad_norm skipped(do_grad_sync=False)", adapter_name=adapter_name) return optimizer = optimizer_config.optimizer @@ -526,21 +484,9 @@ def clip_grad_norm(self, max_grad_norm: float=1.0, norm_type=2, **kwargs): optimizer_config._ensure_dp_group() num_tokens = optimizer_config.num_tokens - self._debug_log( - f"clip_grad_norm before gather_object local_num_tokens={num_tokens}", - adapter_name=adapter_name, - ) num_tokens = torch_util.gather_object([num_tokens], self.device_mesh, optimizer_config._dp_group) num_tokens = sum(num_tokens) - self._debug_log( - f"clip_grad_norm after gather_object total_num_tokens={num_tokens}", - adapter_name=adapter_name, - ) parameters = list(self._get_trainable_parameters(adapter_name).values()) - self._debug_log( - f"clip_grad_norm before normalize_and_clip_grad_norm params={len(parameters)}", - adapter_name=adapter_name, - ) grad_norm = normalize_and_clip_grad_norm( parameters, num_tokens=num_tokens, @@ -550,21 +496,14 @@ def clip_grad_norm(self, max_grad_norm: float=1.0, norm_type=2, **kwargs): ) outputs['grad_norm'] = grad_norm optimizer_config.num_tokens = 0 - self._debug_log(f"clip_grad_norm done grad_norm={grad_norm}", adapter_name=adapter_name) return grad_norm @remote_function(dispatch='all') def clip_grad_and_step(self, max_grad_norm: float=1.0, norm_type=2, **kwargs): - adapter_name = kwargs.get('adapter_name', _default_adapter_name) - self._debug_log("clip_grad_and_step begin", adapter_name=adapter_name) grad_norm = self.clip_grad_norm(max_grad_norm, norm_type, **kwargs) - self._debug_log("clip_grad_norm returned", adapter_name=adapter_name) self.step(**kwargs) - self._debug_log("optimizer step done", adapter_name=adapter_name) self.zero_grad(**kwargs) - self._debug_log("zero_grad done", adapter_name=adapter_name) self.lr_step(**kwargs) - self._debug_log("lr_step done", adapter_name=adapter_name) return grad_norm def _create_param_group(self, adapter_name: str, lr: float=DEFAULT_LEARNING_RATE, weight_decay:float=DEFAULT_WEIGHT_DECAY, **kwargs):