-
Notifications
You must be signed in to change notification settings - Fork 33k
Reproducible checkpoint #11582
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reproducible checkpoint #11582
Changes from all commits
3ee5f67
ee424d7
4c2220e
d915bbc
1d5a16c
10f0b9a
108989a
5cad1f4
75009c8
4323979
d1da404
8374498
1995ff2
2b0c7d9
4ec3c99
4b8f370
08eb713
f7849d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |
| import inspect | ||
| import math | ||
| import os | ||
| import random | ||
| import re | ||
| import shutil | ||
| import sys | ||
|
|
@@ -127,6 +128,7 @@ | |
| from .utils.modeling_auto_mapping import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES | ||
|
|
||
|
|
||
| _is_torch_generator_available = False | ||
| _is_native_amp_available = False | ||
|
|
||
| DEFAULT_CALLBACKS = [DefaultFlowCallback] | ||
|
|
@@ -141,6 +143,7 @@ | |
| from apex import amp | ||
|
|
||
| if version.parse(torch.__version__) >= version.parse("1.6"): | ||
| _is_torch_generator_available = True | ||
| _is_native_amp_available = True | ||
| from torch.cuda.amp import autocast | ||
|
|
||
|
|
@@ -525,6 +528,11 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: | |
| if not isinstance(self.train_dataset, collections.abc.Sized): | ||
| return None | ||
|
|
||
| generator = None | ||
| if self.args.world_size <= 1 and _is_torch_generator_available: | ||
| generator = torch.Generator() | ||
| generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'm very late here, but shouldn't this use
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, this is copied from PyTorch actually. Since torch has been seeded, this will be deterministic. |
||
|
|
||
| # Build the sampler. | ||
| if self.args.group_by_length: | ||
| if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset): | ||
|
|
@@ -538,7 +546,11 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: | |
| model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None | ||
| if self.args.world_size <= 1: | ||
| return LengthGroupedSampler( | ||
| self.train_dataset, self.args.train_batch_size, lengths=lengths, model_input_name=model_input_name | ||
| self.train_dataset, | ||
| self.args.train_batch_size, | ||
| lengths=lengths, | ||
| model_input_name=model_input_name, | ||
| generator=generator, | ||
| ) | ||
| else: | ||
| return DistributedLengthGroupedSampler( | ||
|
|
@@ -553,6 +565,8 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: | |
|
|
||
| else: | ||
| if self.args.world_size <= 1: | ||
| if _is_torch_generator_available: | ||
| return RandomSampler(self.train_dataset, generator=generator) | ||
| return RandomSampler(self.train_dataset) | ||
| elif ( | ||
| self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL] | ||
|
|
@@ -1224,6 +1238,8 @@ def train( | |
| steps_trained_in_current_epoch -= 1 | ||
| if steps_trained_progress_bar is not None: | ||
| steps_trained_progress_bar.update(1) | ||
| if steps_trained_in_current_epoch == 0: | ||
| self._load_rng_state(resume_from_checkpoint) | ||
| continue | ||
| elif steps_trained_progress_bar is not None: | ||
| steps_trained_progress_bar.close() | ||
|
|
@@ -1381,6 +1397,41 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch): | |
| self._save_checkpoint(model, trial, metrics=metrics) | ||
| self.control = self.callback_handler.on_save(self.args, self.state, self.control) | ||
|
|
||
| def _load_rng_state(self, checkpoint): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this method also print a warning in case we're on TPU as we don't expect reproducibility when on TPUs?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mmm, maybe just a comment in the README where we document resuming from checkpoint? I don't really want to issue a warning for each run on TPU using a checkpoint. |
||
| # Load RNG states from `checkpoint` | ||
| if checkpoint is None: | ||
| return | ||
|
|
||
| local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank | ||
| if local_rank != -1: | ||
| rng_file = os.path.join(checkpoint, f"rng_state_{local_rank}.pth") | ||
| if not os.path.isfile(os.path.join(checkpoint, rng_file)): | ||
| logger.info( | ||
| f"Didn't find an RNG file for process {local_rank}, if you are resuming a training that " | ||
| "wasn't launched in a distributed fashion, reproducibility is not guaranteed." | ||
| ) | ||
| return | ||
| else: | ||
| rng_file = os.path.join(checkpoint, "rng_state.pth") | ||
| if not os.path.isfile(os.path.join(checkpoint, rng_file)): | ||
| logger.info( | ||
| "Didn't find an RNG file, if you are resuming a training that was launched in a distributed " | ||
| "fashion, reproducibility is not guaranteed." | ||
| ) | ||
| return | ||
|
|
||
| checkpoint_rng_state = torch.load(rng_file) | ||
| random.setstate(checkpoint_rng_state["python"]) | ||
| np.random.set_state(checkpoint_rng_state["numpy"]) | ||
| torch.random.set_rng_state(checkpoint_rng_state["cpu"]) | ||
| if torch.cuda.is_available(): | ||
| if self.args.local_rank != -1: | ||
| torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"]) | ||
| else: | ||
| torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"]) | ||
| if is_torch_tpu_available(): | ||
| xm.set_rng_state(checkpoint_rng_state["xla"]) | ||
|
|
||
| def _save_checkpoint(self, model, trial, metrics=None): | ||
| # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we | ||
| # want to save except FullyShardedDDP. | ||
|
|
@@ -1459,6 +1510,28 @@ def _save_checkpoint(self, model, trial, metrics=None): | |
| if self.is_world_process_zero(): | ||
| self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) | ||
|
|
||
| # Save RNG state in non-distributed training | ||
| rng_states = { | ||
| "python": random.getstate(), | ||
| "numpy": np.random.get_state(), | ||
| "cpu": torch.random.get_rng_state(), | ||
| } | ||
| if torch.cuda.is_available(): | ||
| if self.args.local_rank == -1: | ||
| # In non distributed, we save the global CUDA RNG state (will take care of DataParallel) | ||
| rng_states["cuda"] = torch.cuda.random.get_rng_state_all() | ||
| else: | ||
| rng_states["cuda"] = torch.cuda.random.get_rng_state() | ||
|
|
||
| if is_torch_tpu_available(): | ||
| rng_states["xla"] = xm.get_rng_state() | ||
|
|
||
| local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank | ||
| if local_rank == -1: | ||
| torch.save(rng_states, os.path.join(output_dir, "rng_state.pth")) | ||
| else: | ||
| torch.save(rng_states, os.path.join(output_dir, f"rng_state_{local_rank}.pth")) | ||
|
|
||
| def _load_optimizer_and_scheduler(self, checkpoint): | ||
| """If optimizer and scheduler states exist, load them.""" | ||
| if checkpoint is None: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.