Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 130 additions & 91 deletions colossalai/pipeline/schedule/interleaved_pp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Callable, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

import torch
import torch.cuda
Expand All @@ -22,6 +22,7 @@ def __init__(
num_model_chunks: int,
num_microbatch: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
) -> None:
super().__init__(stage_manager)
assert (
Expand All @@ -39,6 +40,7 @@ def __init__(
self.microbatch_offset: List[int]

# P2PMeta cache
self.enable_metadata_cache = enable_metadata_cache
self.send_metadata_forward = True
self.send_metadata_backward = True
self.metadata_recv_forward = None
Expand All @@ -54,30 +56,33 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None)
batch = next(data_iter)
if device is not None:
batch = tree_map(partial(to_device, device=device), batch)

self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
self.batch = batch
self.batch_size = get_batch_size(batch)
if self.last_batch_size is None:
self.last_batch_size = self.batch_size
else:
assert self.forward_only or self.last_batch_size == self.batch_size
# TODO: support arbitrary batch size when forward_only=True
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
if self.num_microbatch is not None:

if self.microbatch_size is None:
assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch"
self.microbatch_size = self.batch_size // self.num_microbatch
elif self.microbatch_size is not None:
if self.num_microbatch is None:
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
self.num_microbatch = self.batch_size // self.microbatch_size
else:
raise ValueError("Either num_microbatch or microbatch_size should be provided")

assert (
self.num_microbatch % self.num_model_chunks == 0
), "Number of microbatch should be an integer multiple of number of model chunks"
if not self.forward_only:
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
assert self.batch_size == self.microbatch_size * self.num_microbatch

assert (
self.num_microbatch % self.stage_manager.num_stages == 0
), "Number of microbatch should be an integer multiple of number of pipeline parallel devices"
if self.forward_only:
self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1
# NOTE: disable metadata cache when batch size changes (not valid anymore)
if self.batch_size != self.last_batch_size:
self.enable_metadata_cache = False
self.send_metadata_forward = True
self.send_metadata_backward = True
self.metadata_recv_forward = None
self.metadata_recv_backward = None

self.last_batch_size = self.batch_size

def load_micro_batch(self, model_chunk_id: int) -> Any:
"""Load a micro batch from the current batch.
Expand All @@ -88,6 +93,7 @@ def load_micro_batch(self, model_chunk_id: int) -> Any:
Returns:
Any: Micro batch.
"""
assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted"
micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
self.microbatch_offset[model_chunk_id] += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
Expand Down Expand Up @@ -122,7 +128,7 @@ def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any:
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage():
input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward)
if self.metadata_recv_forward is None:
if self.enable_metadata_cache and self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)

return input_tensor
Expand All @@ -141,7 +147,7 @@ def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any:
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage():
output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward)
if self.metadata_recv_backward is None:
if self.enable_metadata_cache and self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)

return output_tensor_grad
Expand All @@ -158,7 +164,7 @@ def send_forward(self, model_chunk_id: int, output_object: Any, next_rank: int =
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward)
self.send_metadata_forward = False
self.send_metadata_forward = not self.enable_metadata_cache

def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
Expand All @@ -172,7 +178,7 @@ def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int =
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward)
self.send_metadata_backward = False
self.send_metadata_backward = not self.enable_metadata_cache

def send_forward_recv_backward(
self, model_chunk_id: int, output_object: Any, next_rank: Optional[int] = None
Expand All @@ -185,8 +191,8 @@ def send_forward_recv_backward(
send_metadata=self.send_metadata_forward,
metadata_recv=self.metadata_recv_backward,
)
self.send_metadata_forward = False
if self.metadata_recv_backward is None:
self.send_metadata_forward = not self.enable_metadata_cache
if self.enable_metadata_cache and self.metadata_recv_backward is None:
self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad)

return output_tensor_grad
Expand All @@ -202,8 +208,8 @@ def send_backward_recv_forward(
send_metadata=self.send_metadata_backward,
metadata_recv=self.metadata_recv_forward,
)
self.send_metadata_backward = False
if self.metadata_recv_forward is None:
self.send_metadata_backward = not self.enable_metadata_cache
if self.enable_metadata_cache and self.metadata_recv_forward is None:
self.metadata_recv_forward = create_fast_send_metadata(input_tensor)

return input_tensor
Expand Down Expand Up @@ -297,66 +303,74 @@ def backward_step(
input_obj_grad[k] = v.grad
return input_obj_grad

def forward_backward_step(
def run_forward_only(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""Runs interleaved schedule, with communication between pipeline stages.
) -> Dict:
assert self.forward_only

Args:
model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
self.load_batch(data_iter)

Returns:
dict: A dict with keys: 'loss' and 'outputs'.
outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None

accum_loss = None
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_current_device())

# Run warmup forward passes.
for i in range(self.num_microbatch * self.num_model_chunks):
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
self.send_forward(model_chunk_id, output_obj)

if outputs is not None:
outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs}

def run_forward_backward(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> Dict:
"""
self.forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert self.forward_only, "Optimizer should be passed when doing backward."
Runs interleaved schedule, with communication between pipeline stages.
"""
assert not self.forward_only

self.load_batch(data_iter)

num_microbatch = self.num_microbatch * self.num_model_chunks
if self.forward_only:
num_warmup_microbatch = num_microbatch
else:
num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages
num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch)

num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
num_warmup_microbatch += (self.num_model_chunks - 1) * self.stage_manager.num_stages
num_warmup_microbatch = min(num_warmup_microbatch, num_microbatch)
num_microbatch_remaining = num_microbatch - num_warmup_microbatch

# Input, output tensors only need to be saved when doing backward passes
input_objs = None
output_objs = None

if not self.forward_only:
input_objs = [[] for _ in range(self.num_model_chunks)]
output_objs = [[] for _ in range(self.num_model_chunks)]
input_objs = [[] for _ in range(self.num_model_chunks)]
output_objs = [[] for _ in range(self.num_model_chunks)]

outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None

accum_loss = None
if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True):
accum_loss = torch.zeros(1, device=get_current_device())
accum_loss = torch.scalar_tensor(0, device=get_current_device())

# Run warmup forward passes.
for i in range(num_warmup_microbatch):
model_chunk_id = self.get_model_chunk_id(i, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if not self.forward_only:
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
self.send_forward(model_chunk_id, output_obj)

if num_microbatch_remaining > 0:
Expand All @@ -369,47 +383,72 @@ def forward_backward_step(
last_iteration = i == num_microbatch_remaining - 1

output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if self.forward_only:
if not last_iteration:
input_obj = self.send_forward_recv_backward(model_chunk_id, output_obj)
else:
self.send_forward(model_chunk_id, output_obj)

else:
self.send_forward(model_chunk_id, output_obj)
# Add input_obj and output_obj to end of list.
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
self.send_forward(model_chunk_id, output_obj)
# Add input_obj and output_obj to end of list.
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)

model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
output_obj_grad = self.recv_backward(model_chunk_id)
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
output_obj_grad = self.recv_backward(model_chunk_id)

# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_obj = input_objs[model_chunk_id].pop(0)
output_obj = output_objs[model_chunk_id].pop(0)
# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_obj = input_objs[model_chunk_id].pop(0)
output_obj = output_objs[model_chunk_id].pop(0)

# backward
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad)
# backward
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad)

if not last_iteration:
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)
if not last_iteration:
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True)
input_obj = self.recv_forward(model_chunk_id)

# Run cooldown backward passes.
if not self.forward_only:
for i in range(num_microbatch_remaining, num_microbatch):
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
input_obj = input_objs[model_chunk_id].pop(0)
output_obj = output_objs[model_chunk_id].pop(0)
output_obj_grad = self.recv_backward(model_chunk_id)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad)
for i in range(num_microbatch_remaining, num_microbatch):
model_chunk_id = self.get_model_chunk_id(i, is_forward=False)
input_obj = input_objs[model_chunk_id].pop(0)
output_obj = output_objs[model_chunk_id].pop(0)
output_obj_grad = self.recv_backward(model_chunk_id)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad)

if not self.forward_only:
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)

if outputs is not None:
outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs}

def forward_backward_step(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""
Args:
model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.

Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
self.forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert self.forward_only, "Optimizer should be passed when doing backward."

if self.forward_only:
result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs)
else:
result = self.run_forward_backward(
model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs
)

return result
Loading