Skip to content
2 changes: 2 additions & 0 deletions examples/legacy/seq2seq/finetune_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def main():
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Set the verbosity to info of the Transformers logger (on main process only):
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
if is_main_process(training_args.local_rank):
transformers.utils.logging.set_verbosity_info()
logger.info("Training/evaluation parameters %s", training_args)
Expand Down
149 changes: 148 additions & 1 deletion src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,153 @@ def rewrite_logs(d):
return new_d


def init_deepspeed(trainer, num_training_steps):

import torch


# Model parallel group that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None
_MODEL_PARALLEL_GROUP_DEVICE_IDS = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
_DATA_PARALLEL_GROUP_DEVICE_IDS = None

# adjusted from Megatron-LM/mpu/
class MPU:
def initialize_model_parallel(self, model_parallel_size_):
"""
Initialize model data parallel groups.

Arguments:
model_parallel_size: number of GPUs used to parallelize model.
**Important**: not the total number of gpus!

Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model. The present function will
create 4 model parallel groups and 2 data parallel groups as:
4 model parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
2 data parallel groups:
[g0, g2, g4, g6], [g1, g3, g5, g7]

Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.

Let's say we have a total of 4 GPUs denoted by g0 ... g3 and we
use 2 GPUs to parallelize the model. The present function will
create 2 model parallel groups and 2 data parallel groups as:
2 model parallel groups:
[g0, g1], [g2, g3]
2 data parallel groups:
[g0, g2], [g1, g3]

"""

def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)

if torch.distributed.get_rank() == 0:
print("> initializing model parallel with size {}".format(model_parallel_size_))
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
model_parallel_size = min(model_parallel_size_, world_size)
ensure_divisibility(world_size, model_parallel_size)
rank = torch.distributed.get_rank()

print(f"MP size: {model_parallel_size}")
print(f"world_size: {world_size}")
print(f"rank: {rank}")

# Build the data parallel groups.
global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_GROUP_DEVICE_IDS
assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
for i in range(model_parallel_size):
ranks = range(i, world_size, model_parallel_size)
group = torch.distributed.new_group(ranks)
if i == (rank % model_parallel_size):
#print(f"DP ranks: {list(ranks)}")
_DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GROUP_DEVICE_IDS = list(ranks)

# Build the model parallel groups.
global _MODEL_PARALLEL_GROUP
global _MODEL_PARALLEL_GROUP_DEVICE_IDS
assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized"
for i in range(world_size // model_parallel_size):
ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size)
group = torch.distributed.new_group(ranks)
if i == (rank // model_parallel_size):
#print(f"MP ranks: {list(ranks)}")
_MODEL_PARALLEL_GROUP = group
_MODEL_PARALLEL_GROUP_DEVICE_IDS = list(ranks)

def model_parallel_is_initialized(self):
"""Check if model and data parallel groups are initialized."""
if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None:
return False
return True

def get_model_parallel_group_device_ids(self):
"""Get the model parallel device ids of the group the caller rank belongs to."""
assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized"
return _MODEL_PARALLEL_GROUP_DEVICE_IDS

def get_model_parallel_group(self):
"""Get the model parallel group the caller rank belongs to."""
assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized"
return _MODEL_PARALLEL_GROUP

def get_data_parallel_group_device_ids(self):
"""Get the data parallel device ids of the group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"
return _DATA_PARALLEL_GROUP_DEVICE_IDS

def get_data_parallel_group(self):
"""Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"
return _DATA_PARALLEL_GROUP

def get_model_parallel_world_size(self):
"""Return world size for the model parallel group."""
return torch.distributed.get_world_size(group=self.get_model_parallel_group())

def get_model_parallel_rank(self):
"""Return my rank for the model parallel group."""
return torch.distributed.get_rank(group=self.get_model_parallel_group())

def get_model_parallel_src_rank(self):
"""Calculate the global rank corresponding to a local rank zero
in the model parallel group."""
global_rank = torch.distributed.get_rank()
local_world_size = get_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size

def get_data_parallel_world_size(self):
"""Return world size for the data parallel group."""
return torch.distributed.get_world_size(group=self.get_data_parallel_group())

def get_data_parallel_rank(self):
"""Return my rank for the data parallel group."""
return torch.distributed.get_rank(group=self.get_data_parallel_group())

def destroy_model_parallel(self):
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
global _MODEL_PARALLEL_GROUP_DEVICE_IDS
global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_GROUP_DEVICE_IDS
_MODEL_PARALLEL_GROUP = None
_MODEL_PARALLEL_GROUP_DEVICE_IDS = None
_DATA_PARALLEL_GROUP = None
_DATA_PARALLEL_GROUP_DEVICE_IDS = None


def init_deepspeed(trainer, num_training_steps, mpu):
"""
Init DeepSpeed, after converting any relevant Trainer's args into DeepSpeed configuration

Expand Down Expand Up @@ -415,6 +561,7 @@ def init_deepspeed(trainer, num_training_steps):
model=model,
model_parameters=model_parameters,
config_params=config,
mpu = mpu,
)

return model, optimizer, lr_scheduler
Expand Down
67 changes: 67 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1793,3 +1793,70 @@ def forward(self, hidden_states):
return torch.cat(output_chunks, dim=chunk_dim)

return forward_fn(*input_tensors)


def recursive_to(device, item):
"""
Switch any tensors found in `item` to `device`. Currently can handle a single tensor, or any of the nested list,
tuple and dict structures.
"""

if torch.is_tensor(item):
return item.to(device)

elif isinstance(item, list):
for i, x in enumerate(item):
item[i] = recursive_to(device, x)
return item

elif isinstance(item, tuple):
return tuple(recursive_to(device, list(item)))

elif isinstance(item, dict):
for k, v in item.items():
item[k] = recursive_to(device, v)
return item

else:
return item


# tnone = torch.tensor([float('nan')]*batch_size)
def pipe_none_or_empty_to_torch(x, batch_size, device):
tnone = torch.tensor([-100] * batch_size).to(device)
tempty = torch.empty(0).to(device)
if x is None:
return tnone.to(device)
if x == ():
return tempty.to(device)
return x


def pipe_torch_to_none_or_empty(x, batch_size, device):
tnone = torch.tensor([-100] * batch_size).to(device)
# tempty = torch.empty(0).to(device)
# if torch.is_tensor(x):
# print(x.shape, x)
# else:
# print(x)
if torch.is_tensor(x) and x.shape[0] == batch_size:
if not x.numel():
return ()
# print(x.numel(), batch_size, x, tnone)
if x.shape == tnone.shape and all(x == tnone):
return None
return x


def pipe_encode_all(input, batch_size, device):
input = list(input)
for i, x in enumerate(input):
input[i] = pipe_none_or_empty_to_torch(x, batch_size, device)
return tuple(input)


def pipe_decode_all(input, batch_size, device):
input = list(input)
for i, x in enumerate(input):
input[i] = pipe_torch_to_none_or_empty(x, batch_size, device)
return tuple(input)
Loading