Skip to content
3 changes: 3 additions & 0 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,7 @@ def __init__(
make_vocab_size_divisible_by: int = 64,
dp_outside: bool = True,
overlap_p2p: bool = True,
fp8_communication: bool = False,
) -> None:
super().__init__()
assert (
Expand Down Expand Up @@ -1082,13 +1083,15 @@ def __init__(
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
overlap_p2p=overlap_p2p,
fp8_communication=fp8_communication,
)
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(
stage_manager=self.stage_manager,
num_microbatches=num_microbatches,
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
fp8_communication=fp8_communication,
)
else:
raise NotImplementedError()
Expand Down
30 changes: 30 additions & 0 deletions colossalai/pipeline/schedule/interleaved_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline
from colossalai.utils import get_current_device

from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
Expand All @@ -32,6 +33,7 @@ def __init__(
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
overlap_p2p: bool = True,
fp8_communication: bool = False,
) -> None:
super().__init__(stage_manager)
assert (
Expand All @@ -56,6 +58,8 @@ def __init__(
self.tensor_metadata_recv = None
self.grad_metadata_recv = None

self.fp8_communication = fp8_communication

def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.

Expand Down Expand Up @@ -191,8 +195,12 @@ def send_forward(self, model_chunk_id: int, output_tensor: Any, next_rank: int =
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_last_stage():
if self.fp8_communication:
cast_to_fp8_pipeline(output_tensor)
send_handles = self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
self.send_tensor_metadata = not self.enable_metadata_cache
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor)
return send_handles
return []

Expand All @@ -210,10 +218,14 @@ def send_backward(self, model_chunk_id: int, input_tensor_grad: Any, prev_rank:
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if not self.stage_manager.is_first_stage():
if self.fp8_communication:
cast_to_fp8_pipeline(input_tensor_grad)
send_handles = self.comm.send_backward(
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata
)
self.send_grad_metadata = not self.enable_metadata_cache
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor_grad)
return send_handles
return []

Expand All @@ -224,6 +236,8 @@ def send_forward_recv_forward(
is_send = not self.stage_manager.is_last_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
is_recv = not self.stage_manager.is_first_stage()
if self.fp8_communication:
cast_to_fp8_pipeline(output_tensor)
input_tensor, wait_handles = self.comm.send_forward_recv_forward(
output_tensor,
is_send,
Expand All @@ -237,6 +251,8 @@ def send_forward_recv_forward(
if is_recv and self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)

if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor)
return input_tensor, wait_handles

def send_backward_recv_backward(
Expand All @@ -246,6 +262,8 @@ def send_backward_recv_backward(
is_send = not self.stage_manager.is_first_stage()
with self.stage_manager.switch_model_chunk_id(model_chunk_id_recv):
is_recv = not self.stage_manager.is_last_stage()
if self.fp8_communication:
cast_to_fp8_pipeline(input_tensor_grad)
output_tensor_grad, wait_handles = self.comm.send_backward_recv_backward(
input_tensor_grad,
is_send,
Expand All @@ -258,6 +276,8 @@ def send_backward_recv_backward(
self.send_grad_metadata = not self.enable_metadata_cache and is_send
if is_recv and self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor_grad)
return output_tensor_grad, wait_handles

def forward_step(
Expand Down Expand Up @@ -379,6 +399,8 @@ def run_forward_only(

# Wait until current input is received
_wait_p2p(fwd_wait_handles)
if self.fp8_communication and input_obj is not None:
cast_from_fp8_pipeline(input_obj)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)

if not last_batch:
Expand Down Expand Up @@ -441,6 +463,8 @@ def run_forward_backward(

# Wait for input
_wait_p2p(fwd_wait_handles)
if self.fp8_communication and input_obj is not None:
cast_from_fp8_pipeline(input_obj)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
Expand All @@ -467,6 +491,8 @@ def run_forward_backward(

# Wait for input.
_wait_p2p(fwd_wait_handles)
if self.fp8_communication and input_obj is not None:
cast_from_fp8_pipeline(input_obj)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
# Add input_obj and output_obj to end of list.
input_objs[model_chunk_id].append(input_obj)
Expand Down Expand Up @@ -511,6 +537,8 @@ def send_backward_recv_backward():
input_obj, fwd_wait_handles = send_forward_recv_forward()
# Wait for upstream grad
_wait_p2p(bwd_wait_handles)
if self.fp8_communication and output_obj_grad is not None:
cast_from_fp8_pipeline(output_obj_grad)
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
# NOTE: It's documented by NCCL that running two concurrent communicators (batch_isend_irecv)
# risks deadlock (https://docs.nvidia.com/deeplearning/nccl/archives/nccl_2134/user-guide/docs/usage/communicators.html)
Expand All @@ -532,6 +560,8 @@ def send_backward_recv_backward():

# Wait for upstream grad
_wait_p2p(bwd_wait_handles)
if self.fp8_communication and output_obj_grad is not None:
cast_from_fp8_pipeline(output_obj_grad)
# backward local grads
input_obj_grad = self.backward_step(optimizer, _input_obj, _output_obj, output_obj_grad)
if not last_batch:
Expand Down
27 changes: 27 additions & 0 deletions colossalai/pipeline/schedule/one_f_one_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization.fp8 import cast_from_fp8_pipeline, cast_to_fp8_pipeline
from colossalai.utils import get_current_device

from ._utils import (
Expand All @@ -32,6 +33,7 @@ def __init__(
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
fp8_communication: bool = False,
) -> None:
"""1F1B pipeline schedule.

Expand Down Expand Up @@ -61,6 +63,8 @@ def __init__(
self.tensor_metadata_recv = None
self.grad_metadata_recv = None

self.fp8_communication = fp8_communication

def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.

Expand Down Expand Up @@ -129,6 +133,8 @@ def recv_forward(self, prev_rank: int = None) -> Any:
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)

if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor)
return input_tensor

def recv_backward(self, next_rank: int = None) -> Any:
Expand All @@ -143,6 +149,8 @@ def recv_backward(self, next_rank: int = None) -> Any:
"""
if not self.stage_manager.is_last_stage():
output_tensor_grad, _ = self.comm.recv_backward(next_rank, metadata_recv=self.grad_metadata_recv)
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor_grad)
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)

Expand All @@ -157,9 +165,14 @@ def send_forward(self, output_tensor: Any, next_rank: int = None) -> None:
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
if self.fp8_communication:
cast_to_fp8_pipeline(output_tensor)
self.comm.send_forward(output_tensor, next_rank, send_metadata=self.send_tensor_metadata)
self.send_tensor_metadata = not self.enable_metadata_cache

if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor, del_metadata=False)

def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
For 1F1B.
Expand All @@ -169,8 +182,12 @@ def send_backward(self, input_tensor_grad: Any, prev_rank: int = None) -> None:
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.stage_manager.is_first_stage():
if self.fp8_communication:
cast_to_fp8_pipeline(input_tensor_grad)
self.comm.send_backward(input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata)
self.send_grad_metadata = not self.enable_metadata_cache
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False)

def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bool] = None) -> Any:
"""Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
Expand All @@ -183,6 +200,8 @@ def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bo
if not self.stage_manager.is_last_stage():
if not self.send_tensor_metadata and self.grad_metadata_recv is not None:
send_first = None
if self.fp8_communication:
cast_to_fp8_pipeline(output_tensor)
output_tensor_grad, _ = self.comm.send_forward_recv_backward(
output_tensor,
send_metadata=self.send_tensor_metadata,
Expand All @@ -192,6 +211,9 @@ def send_forward_recv_backward(self, output_tensor: Any, send_first: Optional[bo
self.send_tensor_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.grad_metadata_recv is None:
self.grad_metadata_recv = create_send_metadata(output_tensor_grad)
if self.fp8_communication:
cast_from_fp8_pipeline(output_tensor, del_metadata=False)
cast_from_fp8_pipeline(output_tensor_grad)

return output_tensor_grad

Expand All @@ -206,6 +228,8 @@ def send_backward_recv_forward(self, input_tensor_grad: Any, send_first: Optiona
if not self.stage_manager.is_first_stage():
if not self.send_grad_metadata and self.tensor_metadata_recv is not None:
send_first = None # must not fallback
if self.fp8_communication:
cast_to_fp8_pipeline(input_tensor_grad)
input_tensor, _ = self.comm.send_backward_recv_forward(
input_tensor_grad,
send_metadata=self.send_grad_metadata,
Expand All @@ -215,6 +239,9 @@ def send_backward_recv_forward(self, input_tensor_grad: Any, send_first: Optiona
self.send_grad_metadata = not self.enable_metadata_cache
if self.enable_metadata_cache and self.tensor_metadata_recv is None:
self.tensor_metadata_recv = create_send_metadata(input_tensor)
if self.fp8_communication:
cast_from_fp8_pipeline(input_tensor)
cast_from_fp8_pipeline(input_tensor_grad, del_metadata=False)

return input_tensor

Expand Down
Loading