diff --git a/examples/pytorch/test_examples.py b/examples/pytorch/test_examples.py index 1547fc84d714..5d4f0c24c1a5 100644 --- a/examples/pytorch/test_examples.py +++ b/examples/pytorch/test_examples.py @@ -204,7 +204,6 @@ def test_run_ner(self): run_ner.main() result = get_results(tmp_dir) self.assertGreaterEqual(result["eval_accuracy"], 0.75) - self.assertGreaterEqual(result["eval_precision"], 0.75) self.assertLess(result["eval_loss"], 0.5) def test_run_squad(self): diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index eebea8b4a2dd..a97b566104c1 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -20,6 +20,7 @@ import inspect import math import os +import random import re import shutil import sys @@ -127,6 +128,7 @@ from .utils.modeling_auto_mapping import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES +_is_torch_generator_available = False _is_native_amp_available = False DEFAULT_CALLBACKS = [DefaultFlowCallback] @@ -141,6 +143,7 @@ from apex import amp if version.parse(torch.__version__) >= version.parse("1.6"): + _is_torch_generator_available = True _is_native_amp_available = True from torch.cuda.amp import autocast @@ -525,6 +528,11 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: if not isinstance(self.train_dataset, collections.abc.Sized): return None + generator = None + if self.args.world_size <= 1 and _is_torch_generator_available: + generator = torch.Generator() + generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) + # Build the sampler. if self.args.group_by_length: if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset): @@ -538,7 +546,11 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None if self.args.world_size <= 1: return LengthGroupedSampler( - self.train_dataset, self.args.train_batch_size, lengths=lengths, model_input_name=model_input_name + self.train_dataset, + self.args.train_batch_size, + lengths=lengths, + model_input_name=model_input_name, + generator=generator, ) else: return DistributedLengthGroupedSampler( @@ -553,6 +565,8 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: else: if self.args.world_size <= 1: + if _is_torch_generator_available: + return RandomSampler(self.train_dataset, generator=generator) return RandomSampler(self.train_dataset) elif ( self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL] @@ -1224,6 +1238,8 @@ def train( steps_trained_in_current_epoch -= 1 if steps_trained_progress_bar is not None: steps_trained_progress_bar.update(1) + if steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) continue elif steps_trained_progress_bar is not None: steps_trained_progress_bar.close() @@ -1381,6 +1397,41 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch): self._save_checkpoint(model, trial, metrics=metrics) self.control = self.callback_handler.on_save(self.args, self.state, self.control) + def _load_rng_state(self, checkpoint): + # Load RNG states from `checkpoint` + if checkpoint is None: + return + + local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank + if local_rank != -1: + rng_file = os.path.join(checkpoint, f"rng_state_{local_rank}.pth") + if not os.path.isfile(os.path.join(checkpoint, rng_file)): + logger.info( + f"Didn't find an RNG file for process {local_rank}, if you are resuming a training that " + "wasn't launched in a distributed fashion, reproducibility is not guaranteed." + ) + return + else: + rng_file = os.path.join(checkpoint, "rng_state.pth") + if not os.path.isfile(os.path.join(checkpoint, rng_file)): + logger.info( + "Didn't find an RNG file, if you are resuming a training that was launched in a distributed " + "fashion, reproducibility is not guaranteed." + ) + return + + checkpoint_rng_state = torch.load(rng_file) + random.setstate(checkpoint_rng_state["python"]) + np.random.set_state(checkpoint_rng_state["numpy"]) + torch.random.set_rng_state(checkpoint_rng_state["cpu"]) + if torch.cuda.is_available(): + if self.args.local_rank != -1: + torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"]) + else: + torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"]) + if is_torch_tpu_available(): + xm.set_rng_state(checkpoint_rng_state["xla"]) + def _save_checkpoint(self, model, trial, metrics=None): # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we # want to save except FullyShardedDDP. @@ -1459,6 +1510,28 @@ def _save_checkpoint(self, model, trial, metrics=None): if self.is_world_process_zero(): self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) + # Save RNG state in non-distributed training + rng_states = { + "python": random.getstate(), + "numpy": np.random.get_state(), + "cpu": torch.random.get_rng_state(), + } + if torch.cuda.is_available(): + if self.args.local_rank == -1: + # In non distributed, we save the global CUDA RNG state (will take care of DataParallel) + rng_states["cuda"] = torch.cuda.random.get_rng_state_all() + else: + rng_states["cuda"] = torch.cuda.random.get_rng_state() + + if is_torch_tpu_available(): + rng_states["xla"] = xm.get_rng_state() + + local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank + if local_rank == -1: + torch.save(rng_states, os.path.join(output_dir, "rng_state.pth")) + else: + torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth")) + def _load_optimizer_and_scheduler(self, checkpoint): """If optimizer and scheduler states exist, load them.""" if checkpoint is None: diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 62cc1aa480d3..66cc3735a520 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -510,6 +510,7 @@ def __init__( batch_size: int, lengths: Optional[List[int]] = None, model_input_name: Optional[str] = None, + generator=None, ): self.dataset = dataset self.batch_size = batch_size @@ -525,12 +526,13 @@ def __init__( ) lengths = [len(feature[self.model_input_name]) for feature in dataset] self.lengths = lengths + self.generator = generator def __len__(self): return len(self.lengths) def __iter__(self): - indices = get_length_grouped_indices(self.lengths, self.batch_size) + indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=self.generator) return iter(indices) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 68a15ae67350..c040333a83bc 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -15,7 +15,9 @@ import dataclasses import gc +import math import os +import random import re import tempfile import unittest @@ -195,6 +197,28 @@ def forward(self, input_x, labels=None, **kwargs): loss = torch.nn.functional.mse_loss(y, labels) return (loss, y, y) if self.double_output else (loss, y) + class RegressionRandomPreTrainedModel(PreTrainedModel): + config_class = RegressionModelConfig + base_model_prefix = "regression" + + def __init__(self, config): + super().__init__(config) + self.a = torch.nn.Parameter(torch.tensor(config.a).float()) + self.b = torch.nn.Parameter(torch.tensor(config.b).float()) + + def forward(self, input_x, labels=None, **kwargs): + y = input_x * self.a + self.b + torch_rand = torch.randn(1).squeeze() + np_rand = np.random.rand() + rand_rand = random.random() + + y += 0.05 * torch_rand + 0.05 * torch.tensor(np_rand + rand_rand) + + if labels is None: + return (y,) + loss = torch.nn.functional.mse_loss(y, labels) + return (loss, y) + class TstLayer(torch.nn.Module): def __init__(self, hidden_size): super().__init__() @@ -699,6 +723,34 @@ def test_can_resume_training(self): trainer.train(resume_from_checkpoint=True) self.assertTrue("No valid checkpoint found in output directory" in str(context.exception)) + def test_resume_training_with_randomness(self): + if torch.cuda.device_count() >= 2: + # This test will fail flakily for more than 2 GPUs since the result will be slightly more different. + return + + if torch.cuda.is_available(): + torch.backends.cudnn.deterministic = True + train_dataset = RegressionDataset(length=128) + eval_dataset = RegressionDataset() + + config = RegressionModelConfig(a=0, b=2) + model = RegressionRandomPreTrainedModel(config) + + tmp_dir = self.get_auto_remove_tmp_dir() + args = RegressionTrainingArguments(tmp_dir, save_steps=5, learning_rate=0.1) + trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset) + + trainer.train() + (a, b) = trainer.model.a.item(), trainer.model.b.item() + + model = RegressionRandomPreTrainedModel(config) + trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset) + trainer.train(resume_from_checkpoint=os.path.join(tmp_dir, "checkpoint-15")) + (a1, b1) = trainer.model.a.item(), trainer.model.b.item() + + self.assertTrue(math.isclose(a, a1, rel_tol=1e-8)) + self.assertTrue(math.isclose(b, b1, rel_tol=1e-8)) + def test_resume_training_with_gradient_accumulation(self): if torch.cuda.device_count() > 2: # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of