Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,8 @@ def _add_training_args(parser):
help='Run optimizer on CPU')
group.add_argument('--cpu_torch_adam', action='store_true',
help='Use Torch Adam as optimizer on CPU.')
group.add_argument('--codecarbon-dir', type=str, default=None,
help='Write CodeCarbon logs to this directory.')

return parser

Expand Down
10 changes: 8 additions & 2 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import torch

from megatron.global_vars import codecarbon_tracker_flush
from megatron import (get_args,
mpu,
print_rank_0,
Expand Down Expand Up @@ -135,7 +136,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint()

# Optimizer stuff.
if not args.no_save_optim:
if optimizer is not None:
Expand Down Expand Up @@ -183,6 +184,11 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
if torch.distributed.is_initialized():
torch.distributed.barrier()

# since the code can be exited or aborted in various places we use the checkpoint saving as
# a save saving point for the codecarbon tracker. If the program doesn't run to its normal
# end, then only the data since the last saved checkpoint will be lost.
codecarbon_tracker_flush()

def _transpose_first_dim(t, num_splits, num_splits_first, model):
input_shape = t.size()
# We use a self_attention module but the values extracted aren't
Expand Down Expand Up @@ -417,7 +423,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
def load_biencoder_checkpoint(model, only_query_model=False,
only_context_model=False, custom_load_path=None):
"""
selectively load retrieval models for indexing/retrieving
selectively load retrieval models for indexing/retrieving
from saved checkpoints
"""

Expand Down
59 changes: 58 additions & 1 deletion megatron/global_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import sys
import time

from pathlib import Path

import torch

from megatron.tokenizer import build_tokenizer
Expand All @@ -29,10 +31,10 @@
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
_GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_CODECARBON_TRACKER = None
_GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None


def get_args():
"""Return arguments."""
_ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
Expand Down Expand Up @@ -63,6 +65,10 @@ def get_tensorboard_writer():
to check if it is initialized."""
return _GLOBAL_TENSORBOARD_WRITER

def get_codecarbon_tracker():
"""Return codecarbon tracker. It can be None so no need
to check if it is initialized."""
return _GLOBAL_CODECARBON_TRACKER

def get_adlr_autoresume():
"""ADLR autoresume object. It can be None so no need
Expand All @@ -86,6 +92,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
if args.vocab_file or args.tokenizer_name_or_path:
_ = _build_tokenizer(args)
_set_tensorboard_writer(args)
_set_codecarbon_tracker(args)
_set_adlr_autoresume(args)
_set_timers()

Expand Down Expand Up @@ -145,6 +152,56 @@ def _set_tensorboard_writer(args):
'no TensorBoard logs will be written.', flush=True)


def _set_codecarbon_tracker(args):
global _GLOBAL_CODECARBON_TRACKER
if not hasattr(args, 'codecarbon_dir'):
return
Comment on lines +157 to +158
Copy link
Copy Markdown
Member

@thomasw21 thomasw21 Aug 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should check if it's none. @TevenLeScao

Basically this line adds the attribute, but assigns None when not set: https://github.com/bigscience-workshop/Megatron-DeepSpeed/pull/15/files#diff-5f7d1ddfb0666cb6bb4ec0f07fd2fd7b1cd0354f421df5560489091db2ff5a55R455
So I believe hasattr(args, "codecarbon_dir") will return True. Despite having no paths.

#74


import codecarbon
if args.rank == 0:
print('> setting codecarbon ...')

output_dir = args.codecarbon_dir
output_file = f"emissions-{args.rank:03d}.csv"
log_level = "warning"
country_iso_code="FRA"

Path(output_dir).mkdir(parents=True, exist_ok=True)
_GLOBAL_CODECARBON_TRACKER = codecarbon.OfflineEmissionsTracker(
output_dir=output_dir,
output_file=output_file,
log_level=log_level,
country_iso_code=country_iso_code,
)


def codecarbon_tracker_start():
global _GLOBAL_CODECARBON_TRACKER
if _GLOBAL_CODECARBON_TRACKER is None:
return

#print("CC START")
_GLOBAL_CODECARBON_TRACKER.start()


def codecarbon_tracker_stop():
global _GLOBAL_CODECARBON_TRACKER
if _GLOBAL_CODECARBON_TRACKER is None:
return

#print("CC STOP")
_GLOBAL_CODECARBON_TRACKER.stop()


def codecarbon_tracker_flush():
global _GLOBAL_CODECARBON_TRACKER
if _GLOBAL_CODECARBON_TRACKER is None:
return

#print("CC FLUSH")
_GLOBAL_CODECARBON_TRACKER.flush()


def _set_adlr_autoresume(args):
"""Initialize ADLR autoresume."""
global _GLOBAL_ADLR_AUTORESUME
Expand Down
6 changes: 6 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from megatron.schedules import forward_backward_pipelining_without_interleaving
from megatron.schedules import forward_backward_pipelining_with_interleaving
from megatron.utils import report_memory, flops_calculator
from megatron.global_vars import codecarbon_tracker_start, codecarbon_tracker_stop

import deepspeed

Expand Down Expand Up @@ -95,6 +96,8 @@ def pretrain(train_valid_test_dataset_provider,
initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)

codecarbon_tracker_start()

# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
# image ... launches.
Expand Down Expand Up @@ -162,6 +165,9 @@ def pretrain(train_valid_test_dataset_provider,
test_data_iterator, model,
0, True)

codecarbon_tracker_stop()


def update_train_iters(args):

# For iteration-based training, we don't need to do anything
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ regex
numpy
transformers
# git+https://github.com/microsoft/DeepSpeed.git@big-science
# edit to a higher SHA or future release if needed
git+git://github.com/mlco2/codecarbon.git@03479b695a771c28df6b877a809f5af3eb9ef3b8
1 change: 1 addition & 0 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def test_training_all(self):
--save {output_dir}/checkpoints
--load {output_dir}/checkpoints
--data-path {data_dir}/meg-gpt2-openwebtext_text_document
--codecarbon-dir {output_dir}/codecarbon
--tensorboard-dir {output_dir}/tensorboard
--tensorboard-queue-size 5
--log-timers-to-tensorboard
Expand Down