From 9806515b1db2cdbc9e61de8bc5995350d198c255 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 24 Jul 2021 13:44:42 -0700 Subject: [PATCH 1/8] add codecarbon --- megatron/arguments.py | 2 ++ megatron/training.py | 17 +++++++++++++++++ requirements.txt | 3 +++ 3 files changed, 22 insertions(+) diff --git a/megatron/arguments.py b/megatron/arguments.py index 0c2811117..1b6cbd771 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -440,6 +440,8 @@ def _add_training_args(parser): help='Run optimizer on CPU') group.add_argument('--cpu_torch_adam', action='store_true', help='Use Torch Adam as optimizer on CPU.') + group.add_argument('--codecarbon-dir', type=str, default=None, + help='Write CodeCarbon logs to this directory.') return parser diff --git a/megatron/training.py b/megatron/training.py index a9debff30..667eb1ecd 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -110,6 +110,17 @@ def pretrain(train_valid_test_dataset_provider, args = get_args() timers = get_timers() + # XXX: quick hack-in for now - add a clean wrapper later + if args.codecarbon_dir is not None: + import codecarbon + from pathlib import Path + print("CC START") + + Path(args.codecarbon_dir).mkdir(parents=True, exist_ok=True) + output_file = f"emissions-{args.rank:03d}.csv" + cc_tracker = codecarbon.EmissionsTracker(output_dir=args.codecarbon_dir, output_file=output_file) + cc_tracker.start() + # Model, optimizer, and learning rate. timers('model-and-optimizer-setup').start() model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) @@ -162,6 +173,12 @@ def pretrain(train_valid_test_dataset_provider, test_data_iterator, model, 0, True) + # XXX: clean up + if args.codecarbon_dir is not None: + print("CC STOP") + cc_tracker.stop() + + def update_train_iters(args): # For iteration-based training, we don't need to do anything diff --git a/requirements.txt b/requirements.txt index 1f7389c3e..776c5fb0b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,6 @@ torch six regex numpy + +# edit to a higher SHA or future release if needed +git+git://github.com/mlco2/codecarbon.git@d772616c4e55a710c3541469eced4cec98eff329 From 6d867f771249c5e8fb0fb4d4cce97929be1daaab Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 24 Jul 2021 15:12:04 -0700 Subject: [PATCH 2/8] switch to offline --- megatron/training.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/megatron/training.py b/megatron/training.py index 667eb1ecd..79ec6ff4a 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -118,7 +118,10 @@ def pretrain(train_valid_test_dataset_provider, Path(args.codecarbon_dir).mkdir(parents=True, exist_ok=True) output_file = f"emissions-{args.rank:03d}.csv" - cc_tracker = codecarbon.EmissionsTracker(output_dir=args.codecarbon_dir, output_file=output_file) + cc_tracker = codecarbon.OfflineEmissionsTracker(output_dir=args.codecarbon_dir, + output_file=output_file, + country_iso_code="FRA", + ) cc_tracker.start() # Model, optimizer, and learning rate. From d5d9dd6c0fcc2d24673742009d0155e6dc6a7b7f Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 26 Jul 2021 20:12:54 -0700 Subject: [PATCH 3/8] rework to also restart the tracker at each checkpoint save to ensure as little as possible data is lost --- megatron/checkpointing.py | 10 +++++-- megatron/global_vars.py | 61 ++++++++++++++++++++++++++++++++++++++- megatron/training.py | 22 +++----------- 3 files changed, 72 insertions(+), 21 deletions(-) diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 3cc6a8e2e..717c5b7a4 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -22,6 +22,7 @@ import torch +from megatron.global_vars import codecarbon_tracker_restart from megatron import (get_args, mpu, print_rank_0, @@ -134,7 +135,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): for i in range(len(model)): mpu.set_virtual_pipeline_model_parallel_rank(i) state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint() - + # Optimizer stuff. if not args.no_save_optim: if optimizer is not None: @@ -182,6 +183,11 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): if torch.distributed.is_initialized(): torch.distributed.barrier() + # since the code can be exited or aborted in various places we use the checkpoint saving as + # a save saving point for the codecarbon tracker. If the program doesn't run to its normal + # end, then only the data since the last saved checkpoint will be lost. + codecarbon_tracker_restart() + def _transpose_first_dim(t, num_splits, num_splits_first, model): input_shape = t.size() # We use a self_attention module but the values extracted aren't @@ -416,7 +422,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True def load_biencoder_checkpoint(model, only_query_model=False, only_context_model=False, custom_load_path=None): """ - selectively load retrieval models for indexing/retrieving + selectively load retrieval models for indexing/retrieving from saved checkpoints """ diff --git a/megatron/global_vars.py b/megatron/global_vars.py index de0c12794..faec8ca61 100644 --- a/megatron/global_vars.py +++ b/megatron/global_vars.py @@ -19,6 +19,8 @@ import sys import time +from pathlib import Path + import torch from megatron.tokenizer import build_tokenizer @@ -29,10 +31,10 @@ _GLOBAL_NUM_MICROBATCHES_CALCULATOR = None _GLOBAL_TOKENIZER = None _GLOBAL_TENSORBOARD_WRITER = None +_GLOBAL_CODECARBON_TRACKER = None _GLOBAL_ADLR_AUTORESUME = None _GLOBAL_TIMERS = None - def get_args(): """Return arguments.""" _ensure_var_is_initialized(_GLOBAL_ARGS, 'args') @@ -63,6 +65,10 @@ def get_tensorboard_writer(): to check if it is initialized.""" return _GLOBAL_TENSORBOARD_WRITER +def get_codecarbon_tracker(): + """Return codecarbon tracker. It can be None so no need + to check if it is initialized.""" + return _GLOBAL_CODECARBON_TRACKER def get_adlr_autoresume(): """ADLR autoresume object. It can be None so no need @@ -86,6 +92,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={}, if args.vocab_file or args.tokenizer_name_or_path: _ = _build_tokenizer(args) _set_tensorboard_writer(args) + _set_codecarbon_tracker(args) _set_adlr_autoresume(args) _set_timers() @@ -145,6 +152,58 @@ def _set_tensorboard_writer(args): 'no TensorBoard logs will be written.', flush=True) +def _set_codecarbon_tracker(args): + global _GLOBAL_CODECARBON_TRACKER + if hasattr(args, 'codecarbon_dir'): + import codecarbon + print('> setting codecarbon ...') + output_dir = args.codecarbon_dir + output_file = f"emissions-{args.rank:03d}.csv" + #log_level = "warning" + log_level = "info" + country_iso_code="FRA" + + Path(output_dir).mkdir(parents=True, exist_ok=True) + _GLOBAL_CODECARBON_TRACKER = codecarbon.OfflineEmissionsTracker( + output_dir=output_dir, + output_file=output_file, + log_level=log_level, + country_iso_code=country_iso_code, + ) + + +def codecarbon_tracker_start(): + global _GLOBAL_CODECARBON_TRACKER + if _GLOBAL_CODECARBON_TRACKER is None: + return + + print('codecarbon START') + _GLOBAL_CODECARBON_TRACKER.start() + + +def codecarbon_tracker_stop(): + global _GLOBAL_CODECARBON_TRACKER + if _GLOBAL_CODECARBON_TRACKER is None: + return + + print('codecarbon STOP') + _GLOBAL_CODECARBON_TRACKER.stop() + + +def codecarbon_tracker_restart(): + global _GLOBAL_CODECARBON_TRACKER + if _GLOBAL_CODECARBON_TRACKER is None: + return + + # output_dir = _GLOBAL_CODECARBON_TRACKER._output_dir + # output_file = _GLOBAL_CODECARBON_TRACKER._output_file + # log_level = _GLOBAL_CODECARBON_TRACKER._log_level + # country_iso_code = _GLOBAL_CODECARBON_TRACKER._country_iso_code + + codecarbon_tracker_stop() + codecarbon_tracker_start() + + def _set_adlr_autoresume(args): """Initialize ADLR autoresume.""" global _GLOBAL_ADLR_AUTORESUME diff --git a/megatron/training.py b/megatron/training.py index 79ec6ff4a..ea8ef975c 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -51,6 +51,7 @@ from megatron.schedules import forward_backward_pipelining_without_interleaving from megatron.schedules import forward_backward_pipelining_with_interleaving from megatron.utils import report_memory, flops_calculator +from megatron.global_vars import codecarbon_tracker_start, codecarbon_tracker_stop import deepspeed @@ -95,6 +96,8 @@ def pretrain(train_valid_test_dataset_provider, initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=args_defaults) + codecarbon_tracker_start() + # Adjust the startup time so it reflects the largest value. # This will be closer to what scheduler will see (outside of # image ... launches. @@ -110,20 +113,6 @@ def pretrain(train_valid_test_dataset_provider, args = get_args() timers = get_timers() - # XXX: quick hack-in for now - add a clean wrapper later - if args.codecarbon_dir is not None: - import codecarbon - from pathlib import Path - print("CC START") - - Path(args.codecarbon_dir).mkdir(parents=True, exist_ok=True) - output_file = f"emissions-{args.rank:03d}.csv" - cc_tracker = codecarbon.OfflineEmissionsTracker(output_dir=args.codecarbon_dir, - output_file=output_file, - country_iso_code="FRA", - ) - cc_tracker.start() - # Model, optimizer, and learning rate. timers('model-and-optimizer-setup').start() model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) @@ -176,10 +165,7 @@ def pretrain(train_valid_test_dataset_provider, test_data_iterator, model, 0, True) - # XXX: clean up - if args.codecarbon_dir is not None: - print("CC STOP") - cc_tracker.stop() + codecarbon_tracker_stop() def update_train_iters(args): From 4511b156e2fa42ee75c54f0cf304f0f7b97779a5 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Fri, 13 Aug 2021 15:53:25 -0700 Subject: [PATCH 4/8] adjust API to match https://github.com/bigscience-workshop/codecarbon/pull/1 --- megatron/global_vars.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/megatron/global_vars.py b/megatron/global_vars.py index 30b763bcd..6fa351bcc 100644 --- a/megatron/global_vars.py +++ b/megatron/global_vars.py @@ -190,18 +190,32 @@ def codecarbon_tracker_stop(): _GLOBAL_CODECARBON_TRACKER.stop() -def codecarbon_tracker_restart(): +def codecarbon_tracker_pause(): + """ pause saves intermediate results """ + global _GLOBAL_CODECARBON_TRACKER + if _GLOBAL_CODECARBON_TRACKER is None: + return + + print('codecarbon STOP') + _GLOBAL_CODECARBON_TRACKER.pause() + + +def codecarbon_tracker_resume(): global _GLOBAL_CODECARBON_TRACKER if _GLOBAL_CODECARBON_TRACKER is None: return - # output_dir = _GLOBAL_CODECARBON_TRACKER._output_dir - # output_file = _GLOBAL_CODECARBON_TRACKER._output_file - # log_level = _GLOBAL_CODECARBON_TRACKER._log_level - # country_iso_code = _GLOBAL_CODECARBON_TRACKER._country_iso_code + print('codecarbon STOP') + _GLOBAL_CODECARBON_TRACKER.resume() + + +def codecarbon_tracker_restart(): + global _GLOBAL_CODECARBON_TRACKER + if _GLOBAL_CODECARBON_TRACKER is None: + return - codecarbon_tracker_stop() - codecarbon_tracker_start() + codecarbon_tracker_pause() + codecarbon_tracker_resume() def _set_adlr_autoresume(args): From f33ccd2fe42aaaa18280849608ed3c8e31a5b52e Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Fri, 13 Aug 2021 22:45:37 -0700 Subject: [PATCH 5/8] fix logging --- megatron/global_vars.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/megatron/global_vars.py b/megatron/global_vars.py index 6fa351bcc..80694f4ef 100644 --- a/megatron/global_vars.py +++ b/megatron/global_vars.py @@ -154,22 +154,25 @@ def _set_tensorboard_writer(args): def _set_codecarbon_tracker(args): global _GLOBAL_CODECARBON_TRACKER - if hasattr(args, 'codecarbon_dir'): - import codecarbon + if not hasattr(args, 'codecarbon_dir'): + return + + import codecarbon + if args.rank == 0: print('> setting codecarbon ...') - output_dir = args.codecarbon_dir - output_file = f"emissions-{args.rank:03d}.csv" - #log_level = "warning" - log_level = "info" - country_iso_code="FRA" - Path(output_dir).mkdir(parents=True, exist_ok=True) - _GLOBAL_CODECARBON_TRACKER = codecarbon.OfflineEmissionsTracker( - output_dir=output_dir, - output_file=output_file, - log_level=log_level, - country_iso_code=country_iso_code, - ) + output_dir = args.codecarbon_dir + output_file = f"emissions-{args.rank:03d}.csv" + log_level = "warning" + country_iso_code="FRA" + + Path(output_dir).mkdir(parents=True, exist_ok=True) + _GLOBAL_CODECARBON_TRACKER = codecarbon.OfflineEmissionsTracker( + output_dir=output_dir, + output_file=output_file, + log_level=log_level, + country_iso_code=country_iso_code, + ) def codecarbon_tracker_start(): @@ -177,7 +180,7 @@ def codecarbon_tracker_start(): if _GLOBAL_CODECARBON_TRACKER is None: return - print('codecarbon START') + _GLOBAL_CODECARBON_TRACKER.start() @@ -186,7 +189,6 @@ def codecarbon_tracker_stop(): if _GLOBAL_CODECARBON_TRACKER is None: return - print('codecarbon STOP') _GLOBAL_CODECARBON_TRACKER.stop() @@ -196,7 +198,6 @@ def codecarbon_tracker_pause(): if _GLOBAL_CODECARBON_TRACKER is None: return - print('codecarbon STOP') _GLOBAL_CODECARBON_TRACKER.pause() @@ -205,7 +206,6 @@ def codecarbon_tracker_resume(): if _GLOBAL_CODECARBON_TRACKER is None: return - print('codecarbon STOP') _GLOBAL_CODECARBON_TRACKER.resume() From 06d880243bb27ceac7219289ca4c19f710c39b41 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sun, 22 Aug 2021 08:23:13 -0700 Subject: [PATCH 6/8] new implementation based on https://github.com/mlco2/codecarbon/pull/236 --- megatron/checkpointing.py | 4 ++-- megatron/global_vars.py | 26 +++++--------------------- 2 files changed, 7 insertions(+), 23 deletions(-) diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index dd089b989..f7328dcbb 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -22,7 +22,7 @@ import torch -from megatron.global_vars import codecarbon_tracker_restart +from megatron.global_vars import codecarbon_tracker_flush from megatron import (get_args, mpu, print_rank_0, @@ -187,7 +187,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): # since the code can be exited or aborted in various places we use the checkpoint saving as # a save saving point for the codecarbon tracker. If the program doesn't run to its normal # end, then only the data since the last saved checkpoint will be lost. - codecarbon_tracker_restart() + codecarbon_tracker_flush() def _transpose_first_dim(t, num_splits, num_splits_first, model): input_shape = t.size() diff --git a/megatron/global_vars.py b/megatron/global_vars.py index 80694f4ef..b5dcac4d9 100644 --- a/megatron/global_vars.py +++ b/megatron/global_vars.py @@ -180,7 +180,7 @@ def codecarbon_tracker_start(): if _GLOBAL_CODECARBON_TRACKER is None: return - + #print("CC START") _GLOBAL_CODECARBON_TRACKER.start() @@ -189,33 +189,17 @@ def codecarbon_tracker_stop(): if _GLOBAL_CODECARBON_TRACKER is None: return + #print("CC STOP") _GLOBAL_CODECARBON_TRACKER.stop() -def codecarbon_tracker_pause(): - """ pause saves intermediate results """ - global _GLOBAL_CODECARBON_TRACKER - if _GLOBAL_CODECARBON_TRACKER is None: - return - - _GLOBAL_CODECARBON_TRACKER.pause() - - -def codecarbon_tracker_resume(): - global _GLOBAL_CODECARBON_TRACKER - if _GLOBAL_CODECARBON_TRACKER is None: - return - - _GLOBAL_CODECARBON_TRACKER.resume() - - -def codecarbon_tracker_restart(): +def codecarbon_tracker_flush(): global _GLOBAL_CODECARBON_TRACKER if _GLOBAL_CODECARBON_TRACKER is None: return - codecarbon_tracker_pause() - codecarbon_tracker_resume() + #print("CC FLUSH") + _GLOBAL_CODECARBON_TRACKER.flush() def _set_adlr_autoresume(args): From 8088c0fed7c350dfd50be014791cba00a6700bb7 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sun, 22 Aug 2021 09:05:24 -0700 Subject: [PATCH 7/8] add test --- tests/test_training.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_training.py b/tests/test_training.py index 7306615f1..c0a8fff69 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -97,6 +97,7 @@ def test_training_all(self): --save {output_dir}/checkpoints --load {output_dir}/checkpoints --data-path {data_dir}/meg-gpt2-openwebtext_text_document + --codecarbon-dir {output_dir}/codecarbon --tensorboard-dir {output_dir}/tensorboard --tensorboard-queue-size 5 --log-timers-to-tensorboard From a5a511da65e7e9306f26b7b6e2ba76af68b9dc52 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 25 Aug 2021 12:20:20 -0700 Subject: [PATCH 8/8] update requirements --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ea184c86d..a96ff42be 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,4 @@ numpy transformers # git+https://github.com/microsoft/DeepSpeed.git@big-science # edit to a higher SHA or future release if needed -git+git://github.com/mlco2/codecarbon.git@d772616c4e55a710c3541469eced4cec98eff329 +git+git://github.com/mlco2/codecarbon.git@03479b695a771c28df6b877a809f5af3eb9ef3b8