diff --git a/.gitignore b/.gitignore index ff72af2..5c9a6b6 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ # training output wandb/ checkpoints/ +test_checkpoints/ # poetry poetry.lock @@ -30,4 +31,6 @@ save/ **/save/ # claude -CLAUDE.md \ No newline at end of file +CLAUDE.md + +venv_clt/ diff --git a/scripts/launch_test.py b/scripts/launch_test.py index cd8d29e..0e8640f 100644 --- a/scripts/launch_test.py +++ b/scripts/launch_test.py @@ -55,6 +55,7 @@ def main(): n_train_batch_per_buffer=36, total_training_tokens=total_training_tokens, train_batch_size_tokens=train_batch_size_tokens, + gradient_accumulation_steps=1, # Set > 1 to accumulate gradients adam_beta1=0.9, adam_beta2=0.999, lr=2e-4, diff --git a/scripts/launch_train.py b/scripts/launch_train.py index 4c2a146..8737ff2 100644 --- a/scripts/launch_train.py +++ b/scripts/launch_train.py @@ -60,6 +60,7 @@ def main(): n_train_batch_per_buffer=36, total_training_tokens=total_training_tokens, train_batch_size_tokens=train_batch_size_tokens, + gradient_accumulation_steps=1, # Set > 1 to accumulate gradients over multiple micro-batches adam_beta1=0.9, adam_beta2=0.999, lr=2e-4, diff --git a/src/clt/__pycache__/__init__.cpython-311.pyc b/src/clt/__pycache__/__init__.cpython-311.pyc index eb23190..7f1415e 100644 Binary files a/src/clt/__pycache__/__init__.cpython-311.pyc and b/src/clt/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/clt/__pycache__/clt.cpython-311.pyc b/src/clt/__pycache__/clt.cpython-311.pyc index 5e4bf4c..d9ea725 100644 Binary files a/src/clt/__pycache__/clt.cpython-311.pyc and b/src/clt/__pycache__/clt.cpython-311.pyc differ diff --git a/src/clt/__pycache__/clt_training_runner.cpython-311.pyc b/src/clt/__pycache__/clt_training_runner.cpython-311.pyc index 46241a5..5ce5d6c 100644 Binary files a/src/clt/__pycache__/clt_training_runner.cpython-311.pyc and b/src/clt/__pycache__/clt_training_runner.cpython-311.pyc differ diff --git a/src/clt/__pycache__/load_model.cpython-311.pyc b/src/clt/__pycache__/load_model.cpython-311.pyc index f4dc180..795a4cf 100644 Binary files a/src/clt/__pycache__/load_model.cpython-311.pyc and b/src/clt/__pycache__/load_model.cpython-311.pyc differ diff --git a/src/clt/__pycache__/utils.cpython-311.pyc b/src/clt/__pycache__/utils.cpython-311.pyc index 0a1f874..881040b 100644 Binary files a/src/clt/__pycache__/utils.cpython-311.pyc and b/src/clt/__pycache__/utils.cpython-311.pyc differ diff --git a/src/clt/clt_training_runner.py b/src/clt/clt_training_runner.py index affbd37..423cc34 100644 --- a/src/clt/clt_training_runner.py +++ b/src/clt/clt_training_runner.py @@ -13,6 +13,7 @@ from clt.config import CLTTrainingRunnerConfig, CLTConfig from clt.utils import DTYPE_MAP, DummyModel from clt.clt import CLT +from clt import logger from clt.load_model import load_model from clt.training.activations_store import ActivationsStore from clt.training.clt_trainer import CLTTrainer @@ -161,7 +162,7 @@ def run(self): logger.info(f"lr: {self.cfg.lr}") logger.info(f"dead_penalty_coef: {self.cfg.dead_penalty_coef}") - trainer = CLTTrainer( + self.trainer = CLTTrainer( clt=self.clt, activations_store=self.activations_store, save_checkpoint_fn=self.save_checkpoint, @@ -170,7 +171,7 @@ def run(self): world_size=self.world_size ) - clt = trainer.fit() + clt = self.trainer.fit() if self.cfg.log_to_wandb and self.is_main_process: wandb.finish() diff --git a/src/clt/config/__pycache__/__init__.cpython-311.pyc b/src/clt/config/__pycache__/__init__.cpython-311.pyc index dc1847a..fa25659 100644 Binary files a/src/clt/config/__pycache__/__init__.cpython-311.pyc and b/src/clt/config/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/clt/config/__pycache__/clt_config.cpython-311.pyc b/src/clt/config/__pycache__/clt_config.cpython-311.pyc index 54b1dc4..fe2c91f 100644 Binary files a/src/clt/config/__pycache__/clt_config.cpython-311.pyc and b/src/clt/config/__pycache__/clt_config.cpython-311.pyc differ diff --git a/src/clt/config/__pycache__/clt_training_runner_config.cpython-311.pyc b/src/clt/config/__pycache__/clt_training_runner_config.cpython-311.pyc index 1c860d0..8b666a9 100644 Binary files a/src/clt/config/__pycache__/clt_training_runner_config.cpython-311.pyc and b/src/clt/config/__pycache__/clt_training_runner_config.cpython-311.pyc differ diff --git a/src/clt/config/clt_training_runner_config.py b/src/clt/config/clt_training_runner_config.py index 9eaa8d0..796aa02 100644 --- a/src/clt/config/clt_training_runner_config.py +++ b/src/clt/config/clt_training_runner_config.py @@ -45,6 +45,7 @@ class CLTTrainingRunnerConfig(BaseModel): # -----Training/Optimization-------------- total_training_tokens: int = 100_000_000 train_batch_size_tokens: int = 4096 + gradient_accumulation_steps: int = 1 adam_beta1: float = 0.0 adam_beta2: float = 0.999 lr: float = 1e-5 @@ -199,6 +200,10 @@ def model_post_init(self, __context): logger.info("d_latent : %d", self.d_latent) logger.info("total tokens : %.3e", self.total_training_tokens) logger.info("batch (tokens) : %d", self.train_batch_size_tokens) + if self.gradient_accumulation_steps > 1: + effective_batch_size = self.train_batch_size_tokens * self.gradient_accumulation_steps + logger.info("grad accum steps: %d", self.gradient_accumulation_steps) + logger.info("effective batch : %d", effective_batch_size) total_steps = self.total_training_tokens // self.train_batch_size_tokens logger.info("total steps : %d", total_steps) n_tokens_per_buffer = ( @@ -228,7 +233,9 @@ def to_dict(self, *, exclude_none: bool = True,**kw) -> Dict[str, Any]: @property def total_training_steps(self) -> int: - return int(self.total_training_tokens // self.train_batch_size_tokens) + # Total optimizer steps, accounting for gradient accumulation + micro_batches = int(self.total_training_tokens // self.train_batch_size_tokens) + return micro_batches // self.gradient_accumulation_steps @property def is_distributed(self) -> bool: diff --git a/src/clt/training/__pycache__/activations_store.cpython-311.pyc b/src/clt/training/__pycache__/activations_store.cpython-311.pyc index 8db3446..266a226 100644 Binary files a/src/clt/training/__pycache__/activations_store.cpython-311.pyc and b/src/clt/training/__pycache__/activations_store.cpython-311.pyc differ diff --git a/src/clt/training/__pycache__/clt_trainer.cpython-311.pyc b/src/clt/training/__pycache__/clt_trainer.cpython-311.pyc index 8e34fa5..3b53201 100644 Binary files a/src/clt/training/__pycache__/clt_trainer.cpython-311.pyc and b/src/clt/training/__pycache__/clt_trainer.cpython-311.pyc differ diff --git a/src/clt/training/__pycache__/optim.cpython-311.pyc b/src/clt/training/__pycache__/optim.cpython-311.pyc index 6b9617e..80b1625 100644 Binary files a/src/clt/training/__pycache__/optim.cpython-311.pyc and b/src/clt/training/__pycache__/optim.cpython-311.pyc differ diff --git a/src/clt/training/activations_store.py b/src/clt/training/activations_store.py index 42f096a..b1342e2 100644 --- a/src/clt/training/activations_store.py +++ b/src/clt/training/activations_store.py @@ -587,8 +587,10 @@ def __iter__(self): def load_dataset_auto(path_or_name: str, split: str = "train", is_multilingual_split_dataset: bool = False): if os.path.exists(path_or_name): logger.info("Loading from disk") - - # return load_from_disk(path_or_name) + + # Check if it's a dataset saved with save_to_disk + if Path(path_or_name, "state.json").exists(): + return load_from_disk(path_or_name) return load_dataset( path_or_name, diff --git a/src/clt/training/clt_trainer.py b/src/clt/training/clt_trainer.py index 153d81e..5b8d3eb 100644 --- a/src/clt/training/clt_trainer.py +++ b/src/clt/training/clt_trainer.py @@ -82,6 +82,7 @@ def __init__( self.n_tokens: int = 0 self.monitoring_l0 = None + self.accumulation_step: int = 0 def _initialize_b_enc(self, n_batches: int = 10): @@ -148,6 +149,7 @@ def fit(self): if self.cfg.from_pretrained_path is None: self._initialize_b_enc() + #print(f"[TRAINER] GPU {self.rank} - b_enc mean: {self.clt.b_enc.mean().item():.4f}, b_enc sum: {self.clt.b_enc.sum().item():.4f}", flush=True) logger.info(f"GPU {self.rank} - b_enc mean: {self.clt.b_enc.mean().item():.4f}, b_enc sum: {self.clt.b_enc.sum().item():.4f}") while self.n_tokens < self.cfg.total_training_tokens: @@ -167,11 +169,16 @@ def fit(self): ) self.n_tokens += self.cfg.train_batch_size_tokens - self.n_training_steps += 1 - if self.is_main_process: + + # Only log, checkpoint, and count steps after completing accumulation cycle + if self.accumulation_step == 0: + self.n_training_steps += 1 + + #print(f"[TRAINER] Step {self.n_training_steps} - MSE: {loss_metrics.mse_loss:.4f}, L0: {loss_metrics.l0_loss:.4f}", flush=True) + logger.info(f"Training step {self.n_training_steps}") self._log_train_step(loss_metrics) self._run_and_log_evals() - self._checkpoint_if_needed() + self._checkpoint_if_needed() # if self.cfg.functional_loss is not None and self.fc_scheduler.get_lr() > 0 and start_func_finetuning: # self._enable_functional_training() @@ -302,7 +309,9 @@ def _compute_training_step_loss(self, act_in: torch.Tensor, act_out: torch.Tenso if self.n_training_steps < 5: logger.info(f"GPU {self.rank} - act_in sum: {act_in.sum().item():.4f}, shape: {act_in.shape}") - self.optimizer.zero_grad() + # Only zero gradients at the start of accumulation + if self.accumulation_step == 0: + self.optimizer.zero_grad() if self.scaler is not None: with autocast(device_type='cuda', dtype=torch.bfloat16): @@ -310,6 +319,9 @@ def _compute_training_step_loss(self, act_in: torch.Tensor, act_out: torch.Tenso else: loss, loss_metrics = self.clt(act_in, act_out, self.l0_scheduler.get_lr(), df_coef=self.cfg.dead_penalty_coef) + # Scale loss by accumulation steps + loss = loss / self.cfg.gradient_accumulation_steps + if self.n_training_steps == 0 and self.rank == 0: logger.info(f"feat_act shape: {loss_metrics.feature_acts.shape}") logger.info(f"act_pred shape: {loss_metrics.act_pred.shape}") @@ -324,26 +336,37 @@ def _compute_training_step_loss(self, act_in: torch.Tensor, act_out: torch.Tenso if self.scaler is not None: self.scaler.scale(loss).backward() - self.scaler.unscale_(self.optimizer) - torch.nn.utils.clip_grad_norm_(self.clt.parameters(), 1.0) - - if self.cfg.is_sharded: - self._synchronize_feature_sharding_gradients() - self.scaler.step(self.optimizer) - self.scaler.update() + # Only step optimizer every N accumulation steps + if (self.accumulation_step + 1) % self.cfg.gradient_accumulation_steps == 0: + self.scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.clt.parameters(), 1.0) + + if self.cfg.is_sharded: + self._synchronize_feature_sharding_gradients() + + self.scaler.step(self.optimizer) + self.scaler.update() else: loss.backward() - if self.cfg.is_sharded: - self._synchronize_feature_sharding_gradients() - - self.optimizer.step() + # Only step optimizer every N accumulation steps + if (self.accumulation_step + 1) % self.cfg.gradient_accumulation_steps == 0: + if self.cfg.is_sharded: + self._synchronize_feature_sharding_gradients() + + self.optimizer.step() + + # Increment accumulation counter + self.accumulation_step = (self.accumulation_step + 1) % self.cfg.gradient_accumulation_steps self._log_debug_info(loss_metrics) - self.update_optimizer_lr() - self.l0_scheduler.step() + # Only update learning rate when we actually step the optimizer + if self.accumulation_step == 0: + self.update_optimizer_lr() + self.l0_scheduler.step() + return loss_metrics def update_optimizer_lr(self) -> float: diff --git a/src/clt/transformer_lens/__pycache__/multilingual_patching.cpython-311.pyc b/src/clt/transformer_lens/__pycache__/multilingual_patching.cpython-311.pyc index 1f8e34a..1f8c4f6 100644 Binary files a/src/clt/transformer_lens/__pycache__/multilingual_patching.cpython-311.pyc and b/src/clt/transformer_lens/__pycache__/multilingual_patching.cpython-311.pyc differ diff --git a/tests/__pycache__/__init__.cpython-311.pyc b/tests/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..9c32fcf Binary files /dev/null and b/tests/__pycache__/__init__.cpython-311.pyc differ diff --git a/tests/__pycache__/conftest.cpython-311-pytest-9.0.2.pyc b/tests/__pycache__/conftest.cpython-311-pytest-9.0.2.pyc new file mode 100644 index 0000000..f69a4c5 Binary files /dev/null and b/tests/__pycache__/conftest.cpython-311-pytest-9.0.2.pyc differ diff --git a/tests/training/__pycache__/test_gradient_accumulation.cpython-311-pytest-9.0.2.pyc b/tests/training/__pycache__/test_gradient_accumulation.cpython-311-pytest-9.0.2.pyc new file mode 100644 index 0000000..bad9db6 Binary files /dev/null and b/tests/training/__pycache__/test_gradient_accumulation.cpython-311-pytest-9.0.2.pyc differ diff --git a/tests/training/test_gradient_accumulation.py b/tests/training/test_gradient_accumulation.py new file mode 100644 index 0000000..bd1eeca --- /dev/null +++ b/tests/training/test_gradient_accumulation.py @@ -0,0 +1,161 @@ +""" +Entirely made by Claude +""" + +""" +Test gradient accumulation by running actual CLT training on NeelNanda dataset +""" +import pytest +import torch +from pathlib import Path +from clt.config import CLTConfig, CLTTrainingRunnerConfig +from clt.clt_training_runner import CLTTrainingRunner +import wandb +from clt import logger + + +# Get test data path +test_dir = Path(__file__).resolve().parent.parent +dataset_path = str(test_dir / "data" / "NeelNanda_c4_10k_tokenized") + + +def test_gradient_accumulation_training(): + """ + Test gradient accumulation by running actual training and verifying: + 1. Losses decrease over time + 2. Scheduler steps match expected count + 3. Training completes successfully + """ + + print("\n" + "="*70) + print("Testing Gradient Accumulation with Actual Training") + print("="*70) + + # Small training run configuration + total_optimizer_steps = 200 # Number of actual optimizer updates + gradient_accumulation_steps = 4 + train_batch_size_tokens = 128 + + # Calculate total tokens needed + total_training_tokens = train_batch_size_tokens * total_optimizer_steps * gradient_accumulation_steps + + print(f"\nConfiguration:") + print(f" Dataset: {dataset_path}") + print(f" Gradient accumulation steps: {gradient_accumulation_steps}") + print(f" Micro-batch size: {train_batch_size_tokens} tokens") + print(f" Effective batch size: {train_batch_size_tokens * gradient_accumulation_steps} tokens") + print(f" Target optimizer steps: {total_optimizer_steps}") + print(f" Total training tokens: {total_training_tokens}") + + cfg = CLTTrainingRunnerConfig( + device="cuda" if torch.cuda.is_available() else "cpu", + dtype="float32", + seed=42, + n_checkpoints=0, # No checkpoints for testing + checkpoint_path="test_checkpoints/grad_accum", + logger_verbose=True, + model_class_name="HookedTransformer", + model_name="roneneldan/TinyStories-33M", + dataset_path=dataset_path, + context_size=16, + from_pretrained_path=None, + d_in=768, + expansion_factor=4, # Small for fast testing + jumprelu_init_threshold=0.03, + jumprelu_bandwidth=1.0, + n_batches_in_buffer=4, + store_batch_size_prompts=8, + total_training_tokens=total_training_tokens, + train_batch_size_tokens=train_batch_size_tokens, + gradient_accumulation_steps=gradient_accumulation_steps, + adam_beta1=0.9, + adam_beta2=0.999, + lr=1e-3, + lr_warm_up_steps=5, + lr_decay_steps=5, + final_lr_scale=0.5, + l0_coefficient=1.0, + dead_penalty_coef=0.0, + dead_feature_window=50, + l0_warm_up_steps=10, + l0_waiting_steps=0, + decay_stable_steps=35, + cross_layer_decoders=True, + log_to_wandb=False, + wandb_project="test-grad-accum", + wandb_id="test_grad_accum_001", + wandb_log_frequency=5, + eval_every_n_wandb_logs=10, + run_name="test_gradient_accumulation", + wandb_entity=None, + ddp=False, + fsdp=False, + feature_sharding=False, + ) + + print(f"\nStarting training...") + print("-"*70) + + # Run training + runner = CLTTrainingRunner(cfg) + print(f"\nStarting training...") + print("-"*70) + + # Run training + clt = runner.run() + + # Access trainer after run() completes + trainer = runner.trainer + + print("-"*70) + print(f"Training completed!") + print(f"\nTraining summary:") + print(f" Total optimizer steps: {trainer.n_training_steps}") + print(f" Total tokens processed: {trainer.n_tokens}") + + # Verify results + print("\n" + "="*70) + print("Verification:") + print("="*70) + + # 1. Check that we completed the expected number of optimizer steps + actual_steps = trainer.n_training_steps + print(f"✓ Optimizer steps: {actual_steps} (expected: {total_optimizer_steps})") + assert actual_steps == total_optimizer_steps, \ + f"Expected {total_optimizer_steps} optimizer steps, got {actual_steps}" + + # 2. Check that total tokens processed is correct + expected_tokens = total_training_tokens + actual_tokens = trainer.n_tokens + print(f"✓ Tokens processed: {actual_tokens} (expected: {expected_tokens})") + assert actual_tokens == expected_tokens, \ + f"Expected {expected_tokens} tokens, got {actual_tokens}" + + # 3. Verify gradient accumulation worked by checking losses decreased + # This is the key test for gradient accumulation - training should work correctly + if hasattr(trainer, '_losses') and len(trainer._losses) > 0: + first_loss = trainer._losses[0] + last_loss = trainer._losses[-1] + print(f"✓ Loss progression: {first_loss:.4f} → {last_loss:.4f}") + # Loss should generally decrease (allowing some variance) + if last_loss < first_loss * 1.5: # Allow some increase but not too much + print(f"✓ Training converged successfully") + else: + print(f"⚠ Warning: Loss increased significantly") + + # 4. Verify accumulation counter behavior (if accessible) + if hasattr(trainer, 'accumulation_step'): + # After training completes, accumulation_step should be 0 (reset after last batch) + print(f"✓ Final accumulation step: {trainer.accumulation_step}") + + # 5. Training completed successfully + print(f"✓ Training completed without errors") + + print("\n" + "="*70) + print("✅ All gradient accumulation tests PASSED!") + print("="*70) + + +if __name__ == "__main__": + test_gradient_accumulation_training() + print("\n✅ Test completed successfully!")