From ff22c112ba4b6db06156a9e1c0f7fdea01ceed47 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Fri, 17 Nov 2023 15:56:46 +0800 Subject: [PATCH 01/22] test: add more p2p tests --- tests/test_pipeline/test_p2p_communication.py | 50 +++++++++++-------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/tests/test_pipeline/test_p2p_communication.py b/tests/test_pipeline/test_p2p_communication.py index 1665711ceeef..2201e444d079 100644 --- a/tests/test_pipeline/test_p2p_communication.py +++ b/tests/test_pipeline/test_p2p_communication.py @@ -9,39 +9,45 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device +WORLD_SIZE = 2 + def check_p2p_communication(): - pg_mesh = ProcessGroupMesh(2) + pg_mesh = ProcessGroupMesh(WORLD_SIZE) stage_manager = PipelineStageManager(pg_mesh, 0) p2p = PipelineP2PCommunication(stage_manager) rank = dist.get_rank() tensor = torch.ones(1, device=get_current_device()) + data = [ + "tensor", + tensor, + [tensor], + {"tensor": tensor}, + ] if rank == 0: - p2p.send_forward(tensor) - p2p.send_forward([tensor]) - p2p.send_forward({"tensor": tensor}) - else: - obj = p2p.recv_forward() - assert torch.equal(obj, tensor) - obj = p2p.recv_forward() - assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor) - obj = p2p.recv_forward() - assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor) + for obj in data: + p2p.send_forward(obj) + for obj in data: + recv_obj = p2p.send_forward_recv_backward(obj) + assert recv_obj == obj + elif rank == 1: + for obj in data: + recv_obj = p2p.recv_forward() + assert recv_obj == obj + for obj in data: + recv_obj = p2p.send_backward_recv_forward(obj) + assert recv_obj == obj if rank == 1: - p2p.send_backward(tensor) - p2p.send_backward([tensor]) - p2p.send_backward({"tensor": tensor}) - else: - obj = p2p.recv_backward() - assert torch.equal(obj, tensor) - obj = p2p.recv_backward() - assert type(obj) == list and len(obj) == 1 and torch.equal(obj[0], tensor) - obj = p2p.recv_backward() - assert type(obj) == dict and "tensor" in obj and torch.equal(obj["tensor"], tensor) + for obj in data: + p2p.send_backward(obj) + elif rank == 0: + for obj in data: + recv_obj = p2p.recv_backward() + assert recv_obj == obj def run_dist(rank, world_size, port): @@ -52,7 +58,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_pipeline_p2p(): - spawn(run_dist, 2) + spawn(run_dist, WORLD_SIZE) if __name__ == "__main__": From e9bcad6122dab880d8a4366b72d45692f707531c Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Fri, 17 Nov 2023 15:57:51 +0800 Subject: [PATCH 02/22] fix: remove send_forward_recv_forward as p2p op list need to use the same group --- colossalai/pipeline/p2p.py | 29 ++++------------------------- 1 file changed, 4 insertions(+), 25 deletions(-) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 6e49fa36bb83..34906c61cb82 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -597,33 +597,12 @@ def send_backward_recv_forward(self, input_object: Any, prev_rank: int = None) - cur_rank = self.stage_manager.get_rank() group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank) - return _communicate( - input_object, prev_rank, prev_rank, - send_group=group, recv_group=group, - ) - - def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any: - """Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline. - - Args: - input_object (Any): Object to be sent. - prev_rank (int, optional): The rank of the sender of the tensor - next_rank (int, optional): The rank of the recipient of the tensor - """ - if prev_rank is None: - prev_rank = self.stage_manager.get_prev_rank() - if next_rank is None: - next_rank = self.stage_manager.get_next_rank() - - cur_rank = self.stage_manager.get_rank() - recv_group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank) - send_group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank) return _communicate( input_object, - send_dst=next_rank, - recv_src=prev_rank, - send_group=send_group, - recv_group=recv_group, + prev_rank, + prev_rank, + send_group=group, + recv_group=group, ) def p2p_communicate( From c79df79003a816de2d269dc7a76b88c0341eb88b Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Fri, 17 Nov 2023 15:59:44 +0800 Subject: [PATCH 03/22] fix: remove _broadcast_object_list as not used --- colossalai/pipeline/p2p.py | 91 +------------------------------------- 1 file changed, 1 insertion(+), 90 deletions(-) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 34906c61cb82..2bf522f1f5fd 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -57,96 +57,7 @@ def check_for_nccl_backend(group): while isinstance(pg, c10d._ProcessGroupWrapper): pg = pg.wrapped_pg - return ( - c10d.is_nccl_available() and - pg.name() == c10d.Backend.NCCL - ) - - -def _broadcast_object_list( - object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None -): - """This is a modified version of the broadcast_object_list in torch.distribution - The only difference is that object will be move to correct device after unpickled. - If local_rank = src, then object list will be sent to rank src. Otherwise, object list will - be updated with data sent from rank src. - - Args: - object_list (List[Any]): list of object to broadcast - src (int): source rank to broadcast - dst (int): dst rank to broadcast - device (:class:`torch.device`): device to do broadcast. current device in default - - """ - - if c10d._rank_not_in_group(group): - c10d._warn_not_in_group("broadcast_object_list") - return - - is_nccl_backend = check_for_nccl_backend(group) - current_device = None - - if device is not None: - if is_nccl_backend and device.type != "cuda": - raise ValueError("device type must be cuda for nccl backend") - current_device = device - else: - current_device = torch.device("cpu") - if is_nccl_backend: - current_device = torch.device("cuda", torch.cuda.current_device()) - - my_rank = dist.get_rank() - # Serialize object_list elements to tensors on src rank. - if my_rank == src: - if Version(torch.__version__) >= Version("1.13.0"): - tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=current_device) for obj in object_list]) - else: - tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) - object_sizes_tensor = torch.cat(size_list) - else: - object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long) - - if is_nccl_backend: - object_sizes_tensor = object_sizes_tensor.to(current_device) - - # Broadcast object sizes - c10d.broadcast(object_sizes_tensor, src=src, group=group, async_op=False) - - # Concatenate and broadcast serialized object tensors - if my_rank == src: - object_tensor = torch.cat(tensor_list) - else: - object_tensor = torch.empty( # type: ignore[call-overload] - torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] - dtype=torch.uint8, - ) - - if is_nccl_backend: - object_tensor = object_tensor.to(current_device) - - c10d.broadcast(object_tensor, src=src, group=group, async_op=False) - - # Deserialize objects using their stored sizes. - offset = 0 - - if my_rank != src: - for i, obj_size in enumerate(object_sizes_tensor): - obj_view = object_tensor[offset: offset + obj_size] - obj_view = obj_view.type(torch.uint8) - if obj_view.device != torch.device("cpu"): - obj_view = obj_view.cpu() - offset += obj_size - # unpickle - unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size) - - # unconsistence in device - if ( - isinstance(unpickle_object, torch.Tensor) - and unpickle_object.device.index != torch.cuda.current_device() - ): - unpickle_object = unpickle_object.cuda() - - object_list[i] = unpickle_object + return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL def check_device(group): From 21f46bac97c2765b2e598cccf784a47691a0b8dc Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 20 Nov 2023 14:26:37 +0800 Subject: [PATCH 04/22] test: update p2p test --- tests/test_pipeline/test_p2p_communication.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/test_pipeline/test_p2p_communication.py b/tests/test_pipeline/test_p2p_communication.py index 2201e444d079..eb833aa15d05 100644 --- a/tests/test_pipeline/test_p2p_communication.py +++ b/tests/test_pipeline/test_p2p_communication.py @@ -30,24 +30,32 @@ def check_p2p_communication(): if rank == 0: for obj in data: p2p.send_forward(obj) - for obj in data: - recv_obj = p2p.send_forward_recv_backward(obj) - assert recv_obj == obj + for i in range(len(data)): + recv_obj = p2p.send_forward_recv_backward(data[i]) + assert recv_obj == data[-(i + 1)] elif rank == 1: for obj in data: recv_obj = p2p.recv_forward() assert recv_obj == obj - for obj in data: - recv_obj = p2p.send_backward_recv_forward(obj) - assert recv_obj == obj + for i in range(len(data)): + p2p.send_backward(data[-(i + 1)]) + recv_obj = p2p.recv_forward() + assert recv_obj == data[i] if rank == 1: for obj in data: p2p.send_backward(obj) + for i in range(len(data)): + recv_obj = p2p.send_backward_recv_forward(data[i]) + assert recv_obj == data[-(i + 1)] elif rank == 0: for obj in data: recv_obj = p2p.recv_backward() assert recv_obj == obj + for i in range(len(data)): + recv_obj = p2p.recv_backward() + p2p.send_forward(data[-(i + 1)]) + assert recv_obj == data[i] def run_dist(rank, world_size, port): From fd05c2d534e17f9da43eaf10d343849ff32506da Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 20 Nov 2023 16:38:12 +0800 Subject: [PATCH 05/22] fix: make send and receive atomic --- colossalai/pipeline/p2p.py | 289 +++++++++++++++++++++---------------- 1 file changed, 168 insertions(+), 121 deletions(-) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 2bf522f1f5fd..5a62b7a9590f 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -4,13 +4,14 @@ import io import pickle import re -from typing import Any, List, Optional, Union +import warnings from collections import namedtuple +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, List, Optional, Union import torch import torch.distributed as dist -from dataclasses import dataclass -from enum import Enum from packaging.version import Version from torch.distributed import ProcessGroup from torch.distributed import distributed_c10d as c10d @@ -20,7 +21,7 @@ _unpickler = pickle.Unpickler -def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> object: +def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -> Any: """transform tensor to object with unpickle. Info of the device in bytes stream will be modified into current device before unpickling @@ -70,14 +71,14 @@ def check_device(group): return current_device, is_nccl_backend -TensorMetadata = namedtuple('TensorMetadata', ['key', 'shape', 'dtype', 'requires_grad']) +TensorMetadata = namedtuple("TensorMetadata", ["key", "shape", "dtype", "requires_grad"]) class P2PDataType(Enum): - serialization = 0 - tensor = 1 - list = 2 - dict = 3 + Serialization = 0 + Tensor = 1 + List = 2 + Dict = 3 @dataclass @@ -86,45 +87,55 @@ class P2PMetadata: content: Union[List[TensorMetadata], TensorMetadata, Any] -def filling_ops_queue(obj, comm_op, comm_rank, ops_queue, group): +def filling_ops_queue(obj: Any, comm_op: Callable, comm_rank: int, ops_queue: List, group: ProcessGroup): if isinstance(obj, torch.Tensor): obj = obj.contiguous() op_to_add = dist.P2POp(comm_op, obj, comm_rank, group) ops_queue.append(op_to_add) else: for tensor_to_comm in obj: - tensor_to_comm = tensor_to_comm.contiguous() - op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank, group) - ops_queue.append(op_to_add) + assert isinstance(tensor_to_comm, torch.Tensor) + filling_ops_queue(tensor_to_comm, comm_op, comm_rank, ops_queue, group) -def create_recv_buffer(p2p_metadata: P2PMetadata, current_device): - if p2p_metadata.data_type == P2PDataType.tensor: +def create_recv_buffer(p2p_metadata: P2PMetadata, current_device: Any): + if p2p_metadata.data_type == P2PDataType.Tensor: metadata = p2p_metadata.content - tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype) + tensor_recv = torch.empty( + metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype + ) return tensor_recv - elif p2p_metadata.data_type in (P2PDataType.list, P2PDataType.dict): + elif p2p_metadata.data_type in (P2PDataType.List, P2PDataType.Dict): buffer_recv = [] for metadata in p2p_metadata.content: - tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype) + tensor_recv = torch.empty( + metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype + ) buffer_recv.append(tensor_recv) return buffer_recv else: raise ValueError(f"Unknown data_type: {p2p_metadata.data_type}") -def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device): +def _batch_send_recv_tensor( + send_tensor_list: Optional[Union[torch.Tensor, List[torch.Tensor]]], + recv_tensor_metadata: Optional[P2PMetadata], + send_dst: Optional[int], + recv_src: Optional[int], + send_group: Optional[ProcessGroup], + recv_group: Optional[ProcessGroup], + current_device: Any, +) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]: buffer_recv = None - if recv_tensor_metadata is not None: + if recv_tensor_metadata is not None and recv_tensor_metadata.data_type != P2PDataType.Serialization: buffer_recv = create_recv_buffer(recv_tensor_metadata, current_device) ops = [] - - if send_dst is not None: + if send_dst is not None and send_tensor_list is not None: + assert send_group is not None filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group) - - if recv_src is not None: - assert buffer_recv is not None + if recv_src is not None and buffer_recv is not None: + assert recv_group is not None filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group) if len(ops) > 0: @@ -132,24 +143,26 @@ def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, re for req in reqs: req.wait() - torch.cuda.synchronize() - # Remove synchronization according to Pytorch's documentation # However, the Megatron-LM does synchronization here # https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112 # In case there is potential error, uncomment the following `torch.cuda.synchronize()` - # torch.cuda.synchronize() + torch.cuda.synchronize() return buffer_recv def _send_recv_serialization_object( - object: Any, - send_dst: Optional[int], recv_src: Optional[int], - send_group: Optional[ProcessGroup], recv_group: Optional[ProcessGroup], - current_device, - is_nccl_backend): + object: Any, + send_dst: Optional[int], + recv_src: Optional[int], + send_group: Optional[ProcessGroup], + recv_group: Optional[ProcessGroup], + current_device: Any, + is_nccl_backend: bool, +) -> Optional[P2PMetadata]: ops = [] + send_object_tensor = None if object is not None and send_dst is not None: if Version(torch.__version__) >= Version("1.13.0"): @@ -175,10 +188,8 @@ def _send_recv_serialization_object( for req in reqs: req.wait() - torch.cuda.synchronize() - # See the comment in `_batch_send_recv_tensor` - # torch.cuda.synchronize() + torch.cuda.synchronize() ops = [] @@ -197,52 +208,79 @@ def _send_recv_serialization_object( for req in reqs: req.wait() - torch.cuda.synchronize() - # See the comment in `_batch_send_recv_tensor` - # torch.cuda.synchronize() + torch.cuda.synchronize() if recv_object_tensor is not None and recv_object_size_tensor is not None: recv_object_tensor = recv_object_tensor.type(torch.uint8) if recv_object_tensor.device != torch.device("cpu"): recv_object_tensor = recv_object_tensor.cpu() - unpickle_object = _cuda_safe_tensor_to_object( - recv_object_tensor, recv_object_size_tensor.item()) + unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item()) - if ( - isinstance(unpickle_object, torch.Tensor) - and unpickle_object.device.index != torch.cuda.current_device() - ): + if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device(): unpickle_object = unpickle_object.cuda() return unpickle_object -def _check_if_fast_send_available(object): - if type(object) is torch.Tensor: +def _check_if_fast_send_available(object: Any) -> bool: + if isinstance(object, torch.Tensor): return True - elif type(object) is list: - is_list_of_tensor = all([type(v) is torch.Tensor for v in object]) + elif isinstance(object, list): + is_list_of_tensor = all([isinstance(v, torch.Tensor) for v in object]) return is_list_of_tensor - elif type(object) is dict: - is_dict_of_tensor = all([type(k) is str and type( - v) is torch.Tensor for k, v in object.items()]) - + elif isinstance(object, dict): + is_dict_of_tensor = all([isinstance(k, str) and isinstance(v, torch.Tensor) for k, v in object.items()]) return is_dict_of_tensor return False def _communicate( - object, + object: Any, send_dst: Optional[int], recv_src: Optional[int], send_group: Optional[ProcessGroup] = None, recv_group: Optional[ProcessGroup] = None, + send_metadata: bool = True, + metadata_recv: Optional[P2PMetadata] = None, ) -> Any: - if c10d._rank_not_in_group(send_group) or c10d._rank_not_in_group(recv_group): - c10d._warn_not_in_group("_communicate") + """ + Send and receive object from send_dst and recv_src respectively + + Args: + object (Any): object needed to be sent + send_dst (int): rank of the destination + recv_src (int): rank of the source + send_group (ProcessGroup, optional): process group of sender + recv_group (ProcessGroup, optional): process group of receiver + send_metadata (bool, optional): whether to send metadata + metadata_recv (P2PMetadata, optional): metadata of the object to be received + """ + if send_dst is None and recv_src is None: return + assert send_dst is None or send_group is not None, "send_group must be specified when send_dst is not None" + assert recv_src is None or recv_group is not None, "recv_group must be specified when recv_src is not None" + send_metadata = send_metadata or (object is not None and not _check_if_fast_send_available(object)) + assert ( + metadata_recv is None or metadata_recv.data_type != P2PDataType.Serialization + ), "metadata_recv type must not be Serialization" + + # NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata, + # we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case. + if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None): + warnings.warn("Fall back to individual send & recv") + _communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata) + return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv) + + # NOTE: only the following 5 cases are valid: + # 1. send() [needs extra metadata] and no recv() + # 2. recv() [needs extra metadata] and no send() + # 3. neither send() or recv() need extra metadata + assert not (send_dst is not None and send_metadata) or recv_src is None + assert not (recv_src is not None and metadata_recv is None) or send_dst is None + assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None) + assert not c10d._rank_not_in_group(send_group) and not c10d._rank_not_in_group(recv_group) current_send_device, is_send_nccl_backend = check_device(send_group) current_recv_device, is_recv_nccl_backend = check_device(recv_group) @@ -252,64 +290,64 @@ def _communicate( assert current_send_device == current_recv_device current_device = current_send_device - assert (send_dst is not None) or (recv_src is not None) - - can_fast_send = False - send_metadata = None - if send_dst is not None: - can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend - if not can_fast_send: - send_metadata = P2PMetadata(P2PDataType.serialization, object) - else: - if type(object) is torch.Tensor: - data_type = P2PDataType.tensor - content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad) - elif type(object) is list: - data_type = P2PDataType.list - content = [] - for v in object: - content.append(TensorMetadata(None, v.shape, v.dtype, v.requires_grad)) - elif type(object) is dict: - data_type = P2PDataType.dict - content = [] - for k, v in object.items(): - content.append(TensorMetadata(k, v.shape, v.dtype, v.requires_grad)) + if send_metadata or metadata_recv is None: + metadata_send = None + if send_dst is not None and send_metadata: + can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend + if not can_fast_send: + metadata_send = P2PMetadata(P2PDataType.Serialization, object) else: - raise ValueError('Cannot send object of type {}'.format(type(object))) - send_metadata = P2PMetadata(data_type, content) - - recv_metadata = _send_recv_serialization_object(send_metadata, send_dst, recv_src, send_group, recv_group, current_device, is_nccl_backend) - if recv_metadata is not None: - assert type(recv_metadata) is P2PMetadata - if recv_metadata.data_type == P2PDataType.serialization: - return recv_metadata.content - if not can_fast_send and send_dst is not None: - return + if isinstance(object, torch.Tensor): + data_type = P2PDataType.Tensor + content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad) + elif isinstance(object, list): + data_type = P2PDataType.List + content = [TensorMetadata(None, v.shape, v.dtype, v.requires_grad) for v in object] + elif isinstance(object, dict): + data_type = P2PDataType.Dict + content = [TensorMetadata(k, v.shape, v.dtype, v.requires_grad) for k, v in object.items()] + else: + raise ValueError("Cannot send object of type {}".format(type(object))) + metadata_send = P2PMetadata(data_type, content) + + # Send and receive metadata + _metadata_recv = _send_recv_serialization_object( + object=metadata_send, + send_dst=send_dst if send_metadata else None, + recv_src=recv_src if metadata_recv is None else None, + send_group=send_group if send_metadata else None, + recv_group=recv_group if metadata_recv is None else None, + current_device=current_device, + is_nccl_backend=is_nccl_backend, + ) + assert metadata_recv is None or _metadata_recv is None + metadata_recv = _metadata_recv if metadata_recv is None else metadata_recv send_tensor_list = None - if type(object) is torch.Tensor: + if isinstance(object, torch.Tensor): send_tensor_list = object - elif type(object) is list: + elif isinstance(object, list): send_tensor_list = object - elif type(object) is dict: + elif isinstance(object, dict): send_tensor_list = list(object.values()) - recv_buffer = _batch_send_recv_tensor(send_tensor_list, recv_metadata, send_dst, recv_src, send_group, recv_group, current_device) - - if recv_metadata is not None: - assert recv_buffer is not None - if recv_metadata.data_type in [P2PDataType.tensor, P2PDataType.list]: - return recv_buffer - elif recv_metadata.data_type == P2PDataType.dict: - return { - k: v - for k, v in zip( - [m.key for m in recv_metadata.content], - recv_buffer, - ) - } + # Send and receive data + recv_buffer = _batch_send_recv_tensor( + send_tensor_list, metadata_recv, send_dst, recv_src, send_group, recv_group, current_device + ) + + if metadata_recv is not None: + assert isinstance(metadata_recv, P2PMetadata) + if metadata_recv.data_type == P2PDataType.Serialization: + return metadata_recv.content else: - raise ValueError('Unknown data type {}'.format(recv_metadata.data_type)) + assert recv_buffer is not None + if metadata_recv.data_type in [P2PDataType.Tensor, P2PDataType.List]: + return recv_buffer + elif metadata_recv.data_type == P2PDataType.Dict: + return {k: v for k, v in zip([m.key for m in metadata_recv.content], recv_buffer)} + else: + raise ValueError("Unknown data type {}".format(metadata_recv.data_type)) def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None: @@ -347,7 +385,7 @@ def _p2p_comm( """ Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication. - Agrs: + Args: tensor_send_next (torch.Tensor): tensor to be sent to next stage recv_prev (bool): whether to receive tensor from previous stage peer (int): rank of the peer @@ -378,7 +416,6 @@ def _p2p_comm( group=group, ) ops.append(recv_prev_op) - if len(ops) > 0: reqs = dist.batch_isend_irecv(ops) for req in reqs: @@ -401,7 +438,6 @@ def _p2p_comm( group=group, ) ops.append(send_next_op) - if tensor_recv_prev is not None: recv_prev_op = dist.P2POp( dist.irecv, @@ -421,7 +457,7 @@ class PipelineP2PCommunication: def __init__(self, stage_manager: PipelineStageManager) -> None: self.stage_manager = stage_manager - def recv_forward(self, prev_rank: int = None) -> Any: + def recv_forward(self, prev_rank: Optional[int] = None) -> Any: """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. Args: @@ -437,7 +473,7 @@ def recv_forward(self, prev_rank: int = None) -> Any: return input_tensor - def recv_backward(self, next_rank: int = None) -> Any: + def recv_backward(self, next_rank: Optional[int] = None) -> Any: """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. Args: @@ -455,7 +491,7 @@ def recv_backward(self, next_rank: int = None) -> Any: return output_tensor_grad - def send_forward(self, output_object: Any, next_rank: int = None) -> None: + def send_forward(self, output_object: Any, next_rank: Optional[int] = None) -> None: """Sends the input tensor to the next stage in pipeline. Args: @@ -467,7 +503,7 @@ def send_forward(self, output_object: Any, next_rank: int = None) -> None: cur_rank = self.stage_manager.get_rank() _send_object(output_object, cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank)) - def send_backward(self, input_object: Any, prev_rank: int = None) -> None: + def send_backward(self, input_object: Any, prev_rank: Optional[int] = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. Args: @@ -479,7 +515,7 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None: cur_rank = self.stage_manager.get_rank() _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) - def send_forward_recv_backward(self, input_object: Any, next_rank: int = None) -> Any: + def send_forward_recv_backward(self, input_object: Any, next_rank: Optional[int] = None) -> Any: """Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline Args: @@ -492,11 +528,14 @@ def send_forward_recv_backward(self, input_object: Any, next_rank: int = None) - cur_rank = self.stage_manager.get_rank() group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank) return _communicate( - input_object, next_rank, next_rank, - send_group=group, recv_group=group, + input_object, + next_rank, + next_rank, + send_group=group, + recv_group=group, ) - def send_backward_recv_forward(self, input_object: Any, prev_rank: int = None) -> Any: + def send_backward_recv_forward(self, input_object: Any, prev_rank: Optional[int] = None) -> Any: """Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline Args: @@ -517,7 +556,11 @@ def send_backward_recv_forward(self, input_object: Any, prev_rank: int = None) - ) def p2p_communicate( - self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16 + self, + output_object: Any, + recv_pre: bool, + next_rank: Optional[int] = None, + comm_dtype: torch.dtype = torch.float16, ) -> None: """ Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch. @@ -526,10 +569,14 @@ def p2p_communicate( output_object (Any): Object to be sent. next_rank (int, optional): The rank of the recipient of the tensor. """ - if peer is None: - peer = self.stage_manager.get_next_rank() + if next_rank is None: + next_rank = self.stage_manager.get_next_rank() cur_rank = self.stage_manager.get_rank() recv_tensor = _p2p_comm( - output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype + output_object, + recv_pre, + next_rank, + self.stage_manager.get_p2p_process_group(cur_rank, next_rank), + comm_dtype, ) return recv_tensor From c6ac68c2e3b20af3f68c68eaf0e08721196892f8 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 20 Nov 2023 18:10:34 +0800 Subject: [PATCH 06/22] feat: update P2PComm fn --- colossalai/pipeline/p2p.py | 58 +++++++++++++++++++++++++++++--------- 1 file changed, 44 insertions(+), 14 deletions(-) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 5a62b7a9590f..9272205422b9 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -350,7 +350,7 @@ def _communicate( raise ValueError("Unknown data type {}".format(metadata_recv.data_type)) -def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None: +def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_metadata: bool = True) -> None: """send anything to dst rank Args: @@ -360,10 +360,10 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None: Returns: None """ - _communicate(object, send_dst=dst, recv_src=None, send_group=group) + _communicate(object, send_dst=dst, recv_src=None, send_group=group, send_metadata=send_metadata) -def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: +def _recv_object(src: int, dst: int, group: ProcessGroup, metadata_recv: Optional[P2PMetadata] = None) -> Any: """recv anything from src Args: @@ -372,7 +372,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: Returns: Any: Object received from src. """ - return _communicate(None, send_dst=None, recv_src=src, recv_group=group) + return _communicate(None, send_dst=None, recv_src=src, recv_group=group, metadata_recv=metadata_recv) def _p2p_comm( @@ -457,7 +457,7 @@ class PipelineP2PCommunication: def __init__(self, stage_manager: PipelineStageManager) -> None: self.stage_manager = stage_manager - def recv_forward(self, prev_rank: Optional[int] = None) -> Any: + def recv_forward(self, prev_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any: """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. Args: @@ -469,11 +469,13 @@ def recv_forward(self, prev_rank: Optional[int] = None) -> Any: if prev_rank is None: prev_rank = self.stage_manager.get_prev_rank() cur_rank = self.stage_manager.get_rank() - input_tensor = _recv_object(prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)) + input_tensor = _recv_object( + prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank), metadata_recv + ) return input_tensor - def recv_backward(self, next_rank: Optional[int] = None) -> Any: + def recv_backward(self, next_rank: Optional[int] = None, metadata_recv: Optional[P2PMetadata] = None) -> Any: """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. Args: @@ -486,12 +488,12 @@ def recv_backward(self, next_rank: Optional[int] = None) -> Any: next_rank = self.stage_manager.get_next_rank() cur_rank = self.stage_manager.get_rank() output_tensor_grad = _recv_object( - next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank) + next_rank, cur_rank, self.stage_manager.get_p2p_process_group(next_rank, cur_rank), metadata_recv ) return output_tensor_grad - def send_forward(self, output_object: Any, next_rank: Optional[int] = None) -> None: + def send_forward(self, output_object: Any, next_rank: Optional[int] = None, send_metadata: bool = True) -> None: """Sends the input tensor to the next stage in pipeline. Args: @@ -501,9 +503,15 @@ def send_forward(self, output_object: Any, next_rank: Optional[int] = None) -> N if next_rank is None: next_rank = self.stage_manager.get_next_rank() cur_rank = self.stage_manager.get_rank() - _send_object(output_object, cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank)) + _send_object( + output_object, + cur_rank, + next_rank, + self.stage_manager.get_p2p_process_group(cur_rank, next_rank), + send_metadata, + ) - def send_backward(self, input_object: Any, prev_rank: Optional[int] = None) -> None: + def send_backward(self, input_object: Any, prev_rank: Optional[int] = None, send_metadata: bool = True) -> None: """Sends the gradient tensor to the previous stage in pipeline. Args: @@ -513,9 +521,21 @@ def send_backward(self, input_object: Any, prev_rank: Optional[int] = None) -> N if prev_rank is None: prev_rank = self.stage_manager.get_prev_rank() cur_rank = self.stage_manager.get_rank() - _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) + _send_object( + input_object, + cur_rank, + prev_rank, + self.stage_manager.get_p2p_process_group(cur_rank, prev_rank), + send_metadata, + ) - def send_forward_recv_backward(self, input_object: Any, next_rank: Optional[int] = None) -> Any: + def send_forward_recv_backward( + self, + input_object: Any, + next_rank: Optional[int] = None, + send_metadata: bool = True, + metadata_recv: Optional[P2PMetadata] = None, + ) -> Any: """Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline Args: @@ -533,9 +553,17 @@ def send_forward_recv_backward(self, input_object: Any, next_rank: Optional[int] next_rank, send_group=group, recv_group=group, + send_metadata=send_metadata, + metadata_recv=metadata_recv, ) - def send_backward_recv_forward(self, input_object: Any, prev_rank: Optional[int] = None) -> Any: + def send_backward_recv_forward( + self, + input_object: Any, + prev_rank: Optional[int] = None, + send_metadata: bool = True, + metadata_recv: Optional[P2PMetadata] = None, + ) -> Any: """Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline Args: @@ -553,6 +581,8 @@ def send_backward_recv_forward(self, input_object: Any, prev_rank: Optional[int] prev_rank, send_group=group, recv_group=group, + send_metadata=send_metadata, + metadata_recv=metadata_recv, ) def p2p_communicate( From 0520503f6be9b9bd6ea916781c5dbc8c8123f581 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 20 Nov 2023 18:12:15 +0800 Subject: [PATCH 07/22] test: add warning test --- tests/test_pipeline/test_p2p_communication.py | 37 +++++++++++++++---- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/tests/test_pipeline/test_p2p_communication.py b/tests/test_pipeline/test_p2p_communication.py index eb833aa15d05..5ecf1868bb2c 100644 --- a/tests/test_pipeline/test_p2p_communication.py +++ b/tests/test_pipeline/test_p2p_communication.py @@ -1,10 +1,12 @@ +import warnings + import pytest import torch import torch.distributed as dist import colossalai from colossalai.cluster import ProcessGroupMesh -from colossalai.pipeline.p2p import PipelineP2PCommunication +from colossalai.pipeline.p2p import P2PDataType, P2PMetadata, PipelineP2PCommunication, TensorMetadata from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device @@ -30,9 +32,11 @@ def check_p2p_communication(): if rank == 0: for obj in data: p2p.send_forward(obj) - for i in range(len(data)): - recv_obj = p2p.send_forward_recv_backward(data[i]) - assert recv_obj == data[-(i + 1)] + with warnings.catch_warnings(record=True) as w: + for i in range(len(data)): + recv_obj = p2p.send_forward_recv_backward(data[i]) + assert recv_obj == data[-(i + 1)] + assert "Fall back" in str(w[-1].message) elif rank == 1: for obj in data: recv_obj = p2p.recv_forward() @@ -45,9 +49,11 @@ def check_p2p_communication(): if rank == 1: for obj in data: p2p.send_backward(obj) - for i in range(len(data)): - recv_obj = p2p.send_backward_recv_forward(data[i]) - assert recv_obj == data[-(i + 1)] + with warnings.catch_warnings(record=True) as w: + for i in range(len(data)): + recv_obj = p2p.send_backward_recv_forward(data[i]) + assert recv_obj == data[-(i + 1)] + assert "Fall back" in str(w[-1].message) elif rank == 0: for obj in data: recv_obj = p2p.recv_backward() @@ -57,6 +63,23 @@ def check_p2p_communication(): p2p.send_forward(data[-(i + 1)]) assert recv_obj == data[i] + warnings.filterwarnings("error") + tensor_metadata = TensorMetadata( + key=None, shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad + ) + comm_metadata = P2PMetadata(data_type=P2PDataType.Tensor, content=tensor_metadata) + if rank == 0: + recv_obj = p2p.send_forward_recv_backward( + tensor, + send_metadata=False, + metadata_recv=comm_metadata, + ) + assert recv_obj == tensor + elif rank == 1: + recv_obj = p2p.recv_forward(metadata_recv=comm_metadata) + assert recv_obj == tensor + p2p.send_backward(tensor, send_metadata=False) + def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost") From 1fe08d24f878cf6ccb3264585c50fdc2f67e7e97 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Tue, 21 Nov 2023 15:20:42 +0800 Subject: [PATCH 08/22] test: modify 1f1b test --- .../test_schedule/test_oneF_oneB.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py index 1d77edc2db11..7424cb160a40 100644 --- a/tests/test_pipeline/test_schedule/test_oneF_oneB.py +++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py @@ -4,6 +4,7 @@ import pytest import torch +import torch.distributed as dist import torch.nn as nn import colossalai @@ -14,12 +15,17 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all +WORLD_SIZE = 2 +DIM = 8 +NUM_MICRO_BATCHS = 4 +BATCH_SIZE = 4 + class MlpModel(nn.Module): def __init__(self): super(MlpModel, self).__init__() - self.linear1 = nn.Linear(4, 8) - self.linear2 = nn.Linear(8, 4) + self.linear1 = nn.Linear(DIM, DIM) + self.linear2 = nn.Linear(DIM, DIM) def forward(self, x): x = self.linear1(x) @@ -43,13 +49,10 @@ def examine_pp(): This test is to examine the correctness of 1F1B, compared with torch. Be aware it contains some hardcodes. """ - world_size = torch.distributed.get_world_size() - local_rank = torch.distributed.get_rank() + world_size = dist.get_world_size() + local_rank = dist.get_rank() seed_all(1453) - NUM_MICRO_BATCHS = 4 - BATCH_SIZE = 4 - # create models torch_model = MlpModel().cuda() @@ -73,13 +76,10 @@ def examine_pp(): # create seed_all(1453) - if stage_manager.is_first_stage(): - input_list = [torch.rand(BATCH_SIZE, 4).cuda()] - else: - input_list = [torch.zeros(BATCH_SIZE, 4).cuda()] - torch.distributed.all_reduce(input_list[0]) + input_list = [torch.rand(BATCH_SIZE, DIM).cuda()] + dist.all_reduce(input_list[0]) - criterion = lambda x, y: torch.mean(x) + criterion = lambda x, y: (x * x).mean() # forward and backward torch_output = torch_model(input_list[0]) @@ -113,7 +113,7 @@ def examine_pp(): assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data) -def run_dist(rank, world_size, port): +def run_dist(rank: int, world_size: int, port: int): colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") examine_pp() @@ -121,7 +121,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_pp(): - spawn(run_dist, 2) + spawn(run_dist, WORLD_SIZE) if __name__ == "__main__": From f365452908ffa7426f5b1025e21ba1c2ad8642f6 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Tue, 21 Nov 2023 15:42:38 +0800 Subject: [PATCH 09/22] feat: add metadata cache in 1f1b --- colossalai/pipeline/p2p.py | 35 +++--- colossalai/pipeline/schedule/one_f_one_b.py | 133 +++++++++++--------- 2 files changed, 94 insertions(+), 74 deletions(-) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 9272205422b9..f396a3d54eb7 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -117,6 +117,22 @@ def create_recv_buffer(p2p_metadata: P2PMetadata, current_device: Any): raise ValueError(f"Unknown data_type: {p2p_metadata.data_type}") +def create_fast_send_metadata(object: Any) -> P2PMetadata: + assert _check_if_fast_send_available(object) + if isinstance(object, torch.Tensor): + data_type = P2PDataType.Tensor + content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad) + elif isinstance(object, list): + data_type = P2PDataType.List + content = [TensorMetadata(None, v.shape, v.dtype, v.requires_grad) for v in object] + elif isinstance(object, dict): + data_type = P2PDataType.Dict + content = [TensorMetadata(k, v.shape, v.dtype, v.requires_grad) for k, v in object.items()] + else: + raise RuntimeError("Cannot handle object of type {}".format(type(object))) + return P2PMetadata(data_type, content) + + def _batch_send_recv_tensor( send_tensor_list: Optional[Union[torch.Tensor, List[torch.Tensor]]], recv_tensor_metadata: Optional[P2PMetadata], @@ -290,25 +306,14 @@ def _communicate( assert current_send_device == current_recv_device current_device = current_send_device - if send_metadata or metadata_recv is None: + if (send_dst is not None and send_metadata) or (recv_src is not None and metadata_recv is None): metadata_send = None if send_dst is not None and send_metadata: can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend if not can_fast_send: metadata_send = P2PMetadata(P2PDataType.Serialization, object) else: - if isinstance(object, torch.Tensor): - data_type = P2PDataType.Tensor - content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad) - elif isinstance(object, list): - data_type = P2PDataType.List - content = [TensorMetadata(None, v.shape, v.dtype, v.requires_grad) for v in object] - elif isinstance(object, dict): - data_type = P2PDataType.Dict - content = [TensorMetadata(k, v.shape, v.dtype, v.requires_grad) for k, v in object.items()] - else: - raise ValueError("Cannot send object of type {}".format(type(object))) - metadata_send = P2PMetadata(data_type, content) + metadata_send = create_fast_send_metadata(object) # Send and receive metadata _metadata_recv = _send_recv_serialization_object( @@ -350,7 +355,7 @@ def _communicate( raise ValueError("Unknown data type {}".format(metadata_recv.data_type)) -def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_metadata: bool = True) -> None: +def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_metadata: bool) -> None: """send anything to dst rank Args: @@ -363,7 +368,7 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup, send_meta _communicate(object, send_dst=dst, recv_src=None, send_group=group, send_metadata=send_metadata) -def _recv_object(src: int, dst: int, group: ProcessGroup, metadata_recv: Optional[P2PMetadata] = None) -> Any: +def _recv_object(src: int, dst: int, group: ProcessGroup, metadata_recv: Optional[P2PMetadata]) -> Any: """recv anything from src Args: diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index fd918cf1921c..c4199017d7a8 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -7,7 +7,7 @@ from torch.utils._pytree import tree_map from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.pipeline.p2p import PipelineP2PCommunication +from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.utils.device import get_current_device @@ -47,9 +47,16 @@ def __init__( self.microbatch_size = microbatch_size self.batch: Optional[Any] = None self.batch_size: Optional[int] = None + self.last_batch_size: Optional[int] = None self.microbatch_offset: Optional[int] = None self._use_microbatch_size = num_microbatches is None + # P2PMeta cache + self.send_metadata_forward = True + self.send_metadata_backward = True + self.metadata_recv_forward = None + self.metadata_recv_backward = None + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -60,8 +67,14 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) batch = next(data_iter) if device is not None: batch = tree_map(partial(to_device, device=device), batch) + self.batch = batch self.batch_size = get_batch_size(batch) + if self.last_batch_size is None: + self.last_batch_size = self.batch_size + else: + assert self.forward_only or self.last_batch_size == self.batch_size + # TODO: support arbitrary batch size when forward_only=True self.microbatch_offset = 0 if not self._use_microbatch_size: assert ( @@ -92,12 +105,12 @@ def recv_forward(self, prev_rank: int = None) -> Any: Returns: Any: The input tensor or input tensor list. """ - if self.stage_manager.is_first_stage(): - input_tensor = None - else: - input_tensor = self.comm.recv_forward(prev_rank) + if not self.stage_manager.is_first_stage(): + input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward) + if self.metadata_recv_forward is None: + self.metadata_recv_forward = create_fast_send_metadata(input_tensor) - return input_tensor + return input_tensor def recv_backward(self, next_rank: int = None) -> Any: """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. @@ -109,12 +122,12 @@ def recv_backward(self, next_rank: int = None) -> Any: Returns: Any: The input gradient tensor or gradient tensor list. """ - if self.stage_manager.is_last_stage(): - output_tensor_grad = None - else: - output_tensor_grad = self.comm.recv_backward(next_rank) + if not self.stage_manager.is_last_stage(): + output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward) + if self.metadata_recv_backward is None: + self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) - return output_tensor_grad + return output_tensor_grad def send_forward(self, output_object: Any, next_rank: int = None) -> None: """Sends the input tensor to the next stage in pipeline. @@ -125,18 +138,8 @@ def send_forward(self, output_object: Any, next_rank: int = None) -> None: next_rank (int, optional): The rank of the recipient of the tensor. """ if not self.stage_manager.is_last_stage(): - self.comm.send_forward(output_object, next_rank) - - def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any: - """Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline. - For 1F1B. - - Args: - output_object (Any): Object to be sent. - next_rank (int, optional): The rank of the recipient of the tensor. - """ - if not self.stage_manager.is_last_stage(): - return self.comm.send_forward_recv_backward(output_object, next_rank) + self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward) + self.send_metadata_forward = False def send_backward(self, input_object: Any, prev_rank: int = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. @@ -147,34 +150,50 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None: prev_rank (int, optional): The rank of the recipient of the tensor """ if not self.stage_manager.is_first_stage(): - self.comm.send_backward(input_object, prev_rank) + self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward) + self.send_metadata_backward = False - def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any: - """Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline. + def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any: + """Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline. For 1F1B. Args: output_object (Any): Object to be sent. - prev_rank (int, optional): The rank of the recipient of the tensor. + next_rank (int, optional): The rank of the recipient of the tensor. """ - if not self.stage_manager.is_first_stage(): - return self.comm.send_backward_recv_forward(output_object, prev_rank) + if not self.stage_manager.is_last_stage(): + output_tensor_grad = self.comm.send_forward_recv_backward( + output_object, + next_rank, + send_metadata=self.send_metadata_forward, + metadata_recv=self.metadata_recv_backward, + ) + self.send_metadata_forward = False + if self.metadata_recv_backward is None: + self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) + + return output_tensor_grad - def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any: - """Sends the input tensor to the next stage and copy the input tensor from the previous stage in pipeline. + def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any: + """Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline. For 1F1B. Args: - input_object (Any): Object to be sent. - prev_rank (int, optional): The previous rank of the recipient of the tensor. - next_rank (int, optional): The next rank of the recipient of the tensor. + output_object (Any): Object to be sent. + prev_rank (int, optional): The rank of the recipient of the tensor. """ - if self.stage_manager.is_first_stage(): - return self.comm.send_forward(input_object, next_rank) - elif self.stage_manager.is_last_stage(): - return self.comm.recv_forward(prev_rank) - else: - return self.comm.send_forward_recv_forward(input_object, prev_rank, next_rank) + if not self.stage_manager.is_first_stage(): + input_tensor = self.comm.send_backward_recv_forward( + output_object, + prev_rank, + send_metadata=self.send_metadata_backward, + metadata_recv=self.metadata_recv_forward, + ) + self.send_metadata_backward = False + if self.metadata_recv_forward is None: + self.metadata_recv_forward = create_fast_send_metadata(input_tensor) + + return input_tensor def forward_step( self, @@ -276,9 +295,10 @@ def forward_backward_step( Returns: dict: A dict with keys: 'loss' and 'outputs'. """ - forward_only = not torch.is_grad_enabled() + + self.forward_only = not torch.is_grad_enabled() if optimizer is None: - assert forward_only, "Optimizer should be passed when doing backward." + assert self.forward_only, "Optimizer should be passed when doing backward." self.load_batch(data_iter) @@ -291,25 +311,22 @@ def forward_backward_step( input_objs = None output_objs = None - if not forward_only: + if not self.forward_only: input_objs = [] output_objs = [] - - outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None + + accum_loss = None if return_loss and self.stage_manager.is_last_stage(): accum_loss = torch.zeros(1, device=get_current_device()) - else: - accum_loss = None + outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None # Run warmup forward passes. for i in range(num_warmup_microbatches): input_obj = self.recv_forward() - output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) - self.send_forward(output_obj) - if not forward_only: + if not self.forward_only: input_objs.append(input_obj) output_objs.append(output_obj) @@ -324,16 +341,15 @@ def forward_backward_step( last_iteration = i == (num_microbatches_remaining - 1) output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) - if forward_only: + + if self.forward_only: self.send_forward(output_obj) if not last_iteration: input_obj = self.recv_forward() - else: - # TODO adjust here - self.send_forward(output_obj) - output_obj_grad = self.recv_backward() + else: + output_obj_grad = self.send_forward_recv_backward(output_obj) # Add input_obj and output_obj to end of list. input_objs.append(input_obj) output_objs.append(output_obj) @@ -345,13 +361,12 @@ def forward_backward_step( input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) if last_iteration: - input_obj = None + self.send_backward(input_obj_grad) else: - input_obj = self.recv_forward() - self.send_backward(input_obj_grad) + input_obj = self.send_backward_recv_forward(input_obj_grad) # Run cooldown backward passes. - if not forward_only: + if not self.forward_only: for i in range(num_warmup_microbatches): input_obj = input_objs.pop(0) output_obj = output_objs.pop(0) From 9f8bf84cb78881bbab379c70312d168765d94614 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Thu, 23 Nov 2023 10:56:17 +0800 Subject: [PATCH 10/22] test: add more 1f1b tests --- colossalai/pipeline/schedule/one_f_one_b.py | 5 +- .../test_schedule/test_oneF_oneB.py | 126 ++++++++++++------ 2 files changed, 88 insertions(+), 43 deletions(-) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index c4199017d7a8..35cfde1f5aff 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -314,7 +314,7 @@ def forward_backward_step( if not self.forward_only: input_objs = [] output_objs = [] - + accum_loss = None if return_loss and self.stage_manager.is_last_stage(): accum_loss = torch.zeros(1, device=get_current_device()) @@ -375,6 +375,9 @@ def forward_backward_step( input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) self.send_backward(input_obj_grad) + if not self.forward_only: + assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) + if outputs is not None: if isinstance(model, ModelWrapper): model = model.unwrap() diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py index 7424cb160a40..ea60e76e5428 100644 --- a/tests/test_pipeline/test_schedule/test_oneF_oneB.py +++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py @@ -15,26 +15,26 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all -WORLD_SIZE = 2 DIM = 8 -NUM_MICRO_BATCHS = 4 -BATCH_SIZE = 4 +NUM_LAYER = 8 class MlpModel(nn.Module): def __init__(self): - super(MlpModel, self).__init__() - self.linear1 = nn.Linear(DIM, DIM) - self.linear2 = nn.Linear(DIM, DIM) + super().__init__() + self.layers = nn.ModuleList([nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)]) def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) + for layer in self.layers: + x = layer(x) return x def pp_linear_fwd( - forward, data: torch.Tensor = None, input_obj: torch.Tensor = None, stage_mgr: PipelineStageManager = None + forward, + data: torch.Tensor = None, + input_obj: torch.Tensor = None, + stage_mgr: PipelineStageManager = None, ): if stage_mgr.is_first_stage(): return {"input_obj": forward(data)} @@ -44,13 +44,13 @@ def pp_linear_fwd( return {"input_obj": forward(input_obj)} -def examine_pp(): +def examine_pp(num_microbatch: int, batch_size: int): """ This test is to examine the correctness of 1F1B, compared with torch. Be aware it contains some hardcodes. """ world_size = dist.get_world_size() - local_rank = dist.get_rank() + dist.get_rank() seed_all(1453) # create models @@ -58,17 +58,31 @@ def examine_pp(): pp_model = copy.deepcopy(torch_model).cuda() - DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 - pg_mesh = ProcessGroupMesh(1, world_size, 1) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=NUM_MICRO_BATCHS) - - for idx, (_, sub_model) in enumerate(pp_model.named_children()): - if idx % (world_size) == local_rank: - sharded_model = sub_model.cuda() + pg_mesh = ProcessGroupMesh(world_size) + stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0) + schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=num_microbatch) + + rank = dist.get_rank() + sharded_model = torch.nn.ModuleList() + num_local_layer = NUM_LAYER // world_size + for idx, sub_model in enumerate(pp_model.layers): + if idx // num_local_layer == rank: + sharded_model.append(sub_model.cuda()) + assert len(sharded_model) == num_local_layer + + def custom_fwd(self, x): + for layer in self._modules.values(): + x = layer(x) + return x - sharded_model._forward = sharded_model.forward - sharded_model.forward = MethodType(partial(pp_linear_fwd, stage_mgr=stage_manager), sharded_model._forward) + sharded_model._forward = MethodType(custom_fwd, sharded_model) + sharded_model.forward = MethodType( + partial( + pp_linear_fwd, + stage_mgr=stage_manager, + ), + sharded_model._forward, + ) # create optimizer torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) @@ -76,16 +90,15 @@ def examine_pp(): # create seed_all(1453) - input_list = [torch.rand(BATCH_SIZE, DIM).cuda()] + input_list = [torch.rand(batch_size, DIM).cuda()] dist.all_reduce(input_list[0]) - criterion = lambda x, y: (x * x).mean() + criterion = lambda x, *arg, **kwargs: (x * x).mean() # forward and backward torch_output = torch_model(input_list[0]) - torch_loss = criterion(torch_output, _) + torch_loss = criterion(torch_output) torch_loss.backward() - pp_ret = schedule.forward_backward_step( sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True ) @@ -95,34 +108,63 @@ def examine_pp(): assert torch.allclose(torch_loss, pp_ret["loss"]) # check gradients - torch_grad = [] - for torch_p in torch_model.parameters(): - torch_grad.append(torch_p.grad.data) - for idx, pp_p in enumerate(sharded_model.parameters()): - assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data) + for i in range(len(sharded_model)): + idx = rank * num_local_layer + i + assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad) + assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad) # step torch_optimizer.step() pp_optimizer.step() + pp_optimizer.zero_grad() # check updated param - torch_param = [] - for torch_p in torch_model.parameters(): - torch_param.append(torch_p.data) - for idx, pp_p in enumerate(sharded_model.parameters()): - assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data) - - -def run_dist(rank: int, world_size: int, port: int): + for i in range(len(sharded_model)): + idx = rank * num_local_layer + i + assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight) + assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias) + + # forward only + with torch.no_grad(): + torch_output = torch_model(input_list[0]) + torch_loss = criterion(torch_output) + + pp_ret = schedule.forward_backward_step( + sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True + ) + if stage_manager.is_last_stage(): + assert torch.allclose(torch_loss, pp_ret["loss"]) + + for layer in sharded_model: + assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad)) + assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad)) + + +def run_dist( + rank: int, + world_size: int, + port: int, + num_microbatch: int, + batch_size: int, +): colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") - examine_pp() + examine_pp(num_microbatch, batch_size) @pytest.mark.dist +@pytest.mark.parametrize("num_microbatch", [4, 12]) +@pytest.mark.parametrize("batch_size", [12]) +@pytest.mark.parametrize("world_size", [2, 4]) @rerun_if_address_is_in_use() -def test_pp(): - spawn(run_dist, WORLD_SIZE) +def test_pp(num_microbatch: int, batch_size: int, world_size: int): + assert NUM_LAYER % world_size == 0 + spawn( + run_dist, + world_size, + num_microbatch=num_microbatch, + batch_size=batch_size, + ) if __name__ == "__main__": - test_pp() + test_pp(num_microbatch=4, batch_size=4, world_size=4) From 03dda55c86bea4adb49005143c2760d067c6928b Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Thu, 23 Nov 2023 13:18:18 +0800 Subject: [PATCH 11/22] feat: add metadata cache in interleaved pp --- .../pipeline/schedule/interleaved_pp.py | 139 +++++++++++------- 1 file changed, 83 insertions(+), 56 deletions(-) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 7c3f15e80726..0f743a3e4206 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -7,7 +7,7 @@ from torch.utils._pytree import tree_map from colossalai.interface import OptimizerWrapper -from colossalai.pipeline.p2p import PipelineP2PCommunication +from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.utils.device import get_current_device @@ -34,8 +34,15 @@ def __init__( self.batch: Any self.batch_size: int + self.last_batch_size: Optional[int] = None self.microbatch_offset: List[int] + # P2PMeta cache + self.send_metadata_forward = True + self.send_metadata_backward = True + self.metadata_recv_forward = None + self.metadata_recv_backward = None + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -48,6 +55,11 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) batch = tree_map(partial(to_device, device=device), batch) self.batch = batch self.batch_size = get_batch_size(batch) + if self.last_batch_size is None: + self.last_batch_size = self.batch_size + else: + assert self.forward_only or self.last_batch_size == self.batch_size + # TODO: support arbitrary batch size when forward_only=True self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] if self.num_microbatch is not None: assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch" @@ -106,12 +118,12 @@ def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any: Returns: Any: The input tensor or input tensor list. """ - if self.stage_manager.is_first_stage(model_chunk_id): - input_tensor = None - else: - input_tensor = self.comm.recv_forward(prev_rank) + if not self.stage_manager.is_first_stage(model_chunk_id): + input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward) + if self.metadata_recv_forward is None: + self.metadata_recv_forward = create_fast_send_metadata(input_tensor) - return input_tensor + return input_tensor def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any: """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. @@ -124,14 +136,14 @@ def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any: Returns: Any: The input gradient tensor or gradient tensor list. """ - if self.stage_manager.is_last_stage(model_chunk_id): - output_tensor_grad = None - else: - output_tensor_grad = self.comm.recv_backward(next_rank) + if not self.stage_manager.is_last_stage(model_chunk_id): + output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward) + if self.metadata_recv_backward is None: + self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) - return output_tensor_grad + return output_tensor_grad - def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None) -> None: + def send_forward(self, model_chunk_id: int, output_object: Any, next_rank: int = None) -> None: """Sends the input tensor to the next stage in pipeline. For interleaved 1F1B. @@ -141,9 +153,10 @@ def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None next_rank (int, optional): The rank of the recipient of the tensor. """ if not self.stage_manager.is_last_stage(model_chunk_id): - self.comm.send_forward(output_object, next_rank) + self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward) + self.send_metadata_forward = False - def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None: + def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. For interleaved 1F1B. @@ -153,7 +166,40 @@ def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None prev_rank (int, optional): The rank of the recipient of the tensor """ if not self.stage_manager.is_first_stage(model_chunk_id): - self.comm.send_backward(input_object, prev_rank) + self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward) + self.send_metadata_backward = False + + def send_forward_recv_backward( + self, model_chunk_id: int, output_object: Any, next_rank: Optional[int] = None + ) -> Any: + if not self.stage_manager.is_last_stage(model_chunk_id): + output_tensor_grad = self.comm.send_forward_recv_backward( + output_object, + next_rank, + send_metadata=self.send_metadata_forward, + metadata_recv=self.metadata_recv_backward, + ) + self.send_metadata_forward = False + if self.metadata_recv_backward is None: + self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) + + return output_tensor_grad + + def send_backward_recv_forward( + self, model_chunk_id: int, output_object: Any, prev_rank: Optional[int] = None + ) -> Any: + if not self.stage_manager.is_first_stage(model_chunk_id): + input_tensor = self.comm.send_backward_recv_forward( + output_object, + prev_rank, + send_metadata=self.send_metadata_backward, + metadata_recv=self.metadata_recv_forward, + ) + self.send_metadata_backward = False + if self.metadata_recv_forward is None: + self.metadata_recv_forward = create_fast_send_metadata(input_tensor) + + return input_tensor def forward_step( self, @@ -267,15 +313,14 @@ def forward_backward_step( Returns: dict: A dict with keys: 'loss' and 'outputs'. """ - # TODO: handle arbitrary batch size when forward_only == True - forward_only = not torch.is_grad_enabled() + self.forward_only = not torch.is_grad_enabled() if optimizer is None: - assert forward_only, "Optimizer should be passed when doing backward." + assert self.forward_only, "Optimizer should be passed when doing backward." self.load_batch(data_iter) num_microbatch = self.num_microbatch * self.num_model_chunks - if forward_only: + if self.forward_only: num_warmup_microbatch = num_microbatch else: num_warmup_microbatch = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2 @@ -288,43 +333,29 @@ def forward_backward_step( input_objs = None output_objs = None - if not forward_only: + if not self.forward_only: input_objs = [[] for _ in range(self.num_model_chunks)] output_objs = [[] for _ in range(self.num_model_chunks)] outputs = [] if return_outputs and self.stage_manager.is_last_stage(-1) else None + accum_loss = None if return_loss and self.stage_manager.is_last_stage(-1): accum_loss = torch.zeros(1, device=get_current_device()) - else: - accum_loss = None - - # for ranks except the first one, get into recv state - input_obj = self.recv_forward(0) # Run warmup forward passes. for i in range(num_warmup_microbatch): model_chunk_id = self.get_model_chunk_id(i, is_forward=True) - # recv first on first rank to avoid sending or receiving at the same time - if self.stage_manager.is_first_stage(-1): - input_obj = self.recv_forward(model_chunk_id) - output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) - self.send_forward(model_chunk_id, output_obj) - if not forward_only: - input_objs[model_chunk_id].append(input_obj) - output_objs[model_chunk_id].append(output_obj) - else: - output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) - if not forward_only: - input_objs[model_chunk_id].append(input_obj) - output_objs[model_chunk_id].append(output_obj) - self.send_forward(model_chunk_id, output_obj) - - if num_microbatch_remaining == 0 and i + 1 == num_warmup_microbatch: - break + input_obj = self.recv_forward(model_chunk_id) + output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) + if not self.forward_only: + input_objs[model_chunk_id].append(input_obj) + output_objs[model_chunk_id].append(output_obj) + self.send_forward(model_chunk_id, output_obj) - model_chunk_id = self.get_model_chunk_id(i + 1, is_forward=True) - input_obj = self.recv_forward(model_chunk_id) + if num_microbatch_remaining > 0: + model_chunk_id = self.get_model_chunk_id(num_warmup_microbatch, is_forward=True) + input_obj = self.recv_forward(model_chunk_id) # Run 1F1B in steady state. for i in range(num_microbatch_remaining): @@ -332,11 +363,11 @@ def forward_backward_step( last_iteration = i == num_microbatch_remaining - 1 output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs) - if forward_only: - self.send_forward(model_chunk_id, output_obj) - + if self.forward_only: if not last_iteration: - input_obj = self.recv_forward(model_chunk_id) + input_obj = self.send_forward_recv_backward(model_chunk_id, output_obj) + else: + self.send_forward(model_chunk_id, output_obj) else: self.send_forward(model_chunk_id, output_obj) @@ -354,18 +385,14 @@ def forward_backward_step( # backward input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) + self.send_backward(model_chunk_id, input_obj_grad) - if last_iteration: - input_obj = None - else: + if not last_iteration: model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatch + 1, is_forward=True) input_obj = self.recv_forward(model_chunk_id) - model_chunk_id = self.get_model_chunk_id(i, is_forward=False) - self.send_backward(model_chunk_id, input_obj_grad) - # Run cooldown backward passes. - if not forward_only: + if not self.forward_only: for i in range(num_microbatch_remaining, num_microbatch): model_chunk_id = self.get_model_chunk_id(i, is_forward=False) input_obj = input_objs[model_chunk_id].pop(0) @@ -374,7 +401,7 @@ def forward_backward_step( input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) self.send_backward(model_chunk_id, input_obj_grad) - if not forward_only: + if not self.forward_only: assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) if outputs is not None: From 3620ab5243e9a8cd97356ebcd54c45445b0e8aec Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Thu, 23 Nov 2023 13:18:30 +0800 Subject: [PATCH 12/22] test: add forward test --- .../test_schedule/test_interleaved.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py index 4de50245feeb..a0d1f3c0f8cc 100644 --- a/tests/test_pipeline/test_schedule/test_interleaved.py +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -119,6 +119,7 @@ def criterion(x, *args, **kwargs): # step torch_optimizer.step() pp_optimizer.step() + pp_optimizer.zero_grad() # check updated param for i in range(num_model_chunk): @@ -126,6 +127,28 @@ def criterion(x, *args, **kwargs): assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight) assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias) + # forward only + with torch.no_grad(): + torch_output = torch_model(input_list[0]) + torch_loss = criterion(torch_output) + + pp_ret = schedule.forward_backward_step( + sharded_model, + iter(input_list), + criterion, + pp_optimizer, + return_loss=True, + return_outputs=True + ) + if stage_manager.is_last_stage(-1): + assert torch.allclose(torch_loss, pp_ret["loss"]) + + for layer in sharded_model: + assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad)) + assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad)) + + + @pytest.mark.dist @pytest.mark.parametrize("num_microbatch", [4, 12]) From 28ca6abb697426f407c1e2e2429ba57cb792d288 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Fri, 24 Nov 2023 17:53:45 +0800 Subject: [PATCH 13/22] feat: modify is_xx_stage fn --- .../pipeline/schedule/interleaved_pp.py | 127 +++++++++--------- colossalai/pipeline/stage_manager.py | 36 ++--- colossalai/shardformer/policies/bert.py | 2 +- examples/language/bert/finetune.py | 8 +- .../test_schedule/test_interleaved.py | 17 +-- 5 files changed, 96 insertions(+), 94 deletions(-) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 0f743a3e4206..2b59e6f23316 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -118,12 +118,13 @@ def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any: Returns: Any: The input tensor or input tensor list. """ - if not self.stage_manager.is_first_stage(model_chunk_id): - input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward) - if self.metadata_recv_forward is None: - self.metadata_recv_forward = create_fast_send_metadata(input_tensor) + with self.stage_manager.set_model_chunk_id(model_chunk_id): + if not self.stage_manager.is_first_stage(): + input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward) + if self.metadata_recv_forward is None: + self.metadata_recv_forward = create_fast_send_metadata(input_tensor) - return input_tensor + return input_tensor def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any: """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. @@ -136,12 +137,13 @@ def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any: Returns: Any: The input gradient tensor or gradient tensor list. """ - if not self.stage_manager.is_last_stage(model_chunk_id): - output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward) - if self.metadata_recv_backward is None: - self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) + with self.stage_manager.set_model_chunk_id(model_chunk_id): + if not self.stage_manager.is_last_stage(): + output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward) + if self.metadata_recv_backward is None: + self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) - return output_tensor_grad + return output_tensor_grad def send_forward(self, model_chunk_id: int, output_object: Any, next_rank: int = None) -> None: """Sends the input tensor to the next stage in pipeline. @@ -152,9 +154,10 @@ def send_forward(self, model_chunk_id: int, output_object: Any, next_rank: int = output_object (Any): Object to be sent. next_rank (int, optional): The rank of the recipient of the tensor. """ - if not self.stage_manager.is_last_stage(model_chunk_id): - self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward) - self.send_metadata_forward = False + with self.stage_manager.set_model_chunk_id(model_chunk_id): + if not self.stage_manager.is_last_stage(): + self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward) + self.send_metadata_forward = False def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. @@ -165,41 +168,44 @@ def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = input_object (Any): Object to be sent. prev_rank (int, optional): The rank of the recipient of the tensor """ - if not self.stage_manager.is_first_stage(model_chunk_id): - self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward) - self.send_metadata_backward = False + with self.stage_manager.set_model_chunk_id(model_chunk_id): + if not self.stage_manager.is_first_stage(): + self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward) + self.send_metadata_backward = False def send_forward_recv_backward( self, model_chunk_id: int, output_object: Any, next_rank: Optional[int] = None ) -> Any: - if not self.stage_manager.is_last_stage(model_chunk_id): - output_tensor_grad = self.comm.send_forward_recv_backward( - output_object, - next_rank, - send_metadata=self.send_metadata_forward, - metadata_recv=self.metadata_recv_backward, - ) - self.send_metadata_forward = False - if self.metadata_recv_backward is None: - self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) - - return output_tensor_grad + with self.stage_manager.set_model_chunk_id(model_chunk_id): + if not self.stage_manager.is_last_stage(): + output_tensor_grad = self.comm.send_forward_recv_backward( + output_object, + next_rank, + send_metadata=self.send_metadata_forward, + metadata_recv=self.metadata_recv_backward, + ) + self.send_metadata_forward = False + if self.metadata_recv_backward is None: + self.metadata_recv_backward = create_fast_send_metadata(output_tensor_grad) + + return output_tensor_grad def send_backward_recv_forward( self, model_chunk_id: int, output_object: Any, prev_rank: Optional[int] = None ) -> Any: - if not self.stage_manager.is_first_stage(model_chunk_id): - input_tensor = self.comm.send_backward_recv_forward( - output_object, - prev_rank, - send_metadata=self.send_metadata_backward, - metadata_recv=self.metadata_recv_forward, - ) - self.send_metadata_backward = False - if self.metadata_recv_forward is None: - self.metadata_recv_forward = create_fast_send_metadata(input_tensor) - - return input_tensor + with self.stage_manager.set_model_chunk_id(model_chunk_id): + if not self.stage_manager.is_first_stage(): + input_tensor = self.comm.send_backward_recv_forward( + output_object, + prev_rank, + send_metadata=self.send_metadata_backward, + metadata_recv=self.metadata_recv_forward, + ) + self.send_metadata_backward = False + if self.metadata_recv_forward is None: + self.metadata_recv_forward = create_fast_send_metadata(input_tensor) + + return input_tensor def forward_step( self, @@ -226,25 +232,24 @@ def forward_step( # for the first stage, input_obj is None # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict - self.stage_manager.model_chunk_id = model_chunk_id - if isinstance(model_chunk, ModuleList): - output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) - else: - # NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers - internal_inputs = {} if input_obj is None else input_obj - internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] - output_obj = model_forward(model_chunk, micro_batch, internal_inputs) - self.stage_manager.model_chunk_id = None - - if self.stage_manager.is_last_stage(model_chunk_id): - loss = criterion(output_obj, micro_batch) / self.num_microbatch - if accum_loss is not None: - accum_loss.add_(loss.detach()) - if outputs is not None: - outputs.append(tree_map(detach, output_obj)) - return loss - else: - return output_obj + with self.stage_manager.set_model_chunk_id(model_chunk_id): + if isinstance(model_chunk, ModuleList): + output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) + else: + # NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers + internal_inputs = {} if input_obj is None else input_obj + internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] + output_obj = model_forward(model_chunk, micro_batch, internal_inputs) + + if self.stage_manager.is_last_stage(): + loss = criterion(output_obj, micro_batch) / self.num_microbatch + if accum_loss is not None: + accum_loss.add_(loss.detach()) + if outputs is not None: + outputs.append(tree_map(detach, output_obj)) + return loss + else: + return output_obj def backward_step( self, @@ -337,10 +342,10 @@ def forward_backward_step( input_objs = [[] for _ in range(self.num_model_chunks)] output_objs = [[] for _ in range(self.num_model_chunks)] - outputs = [] if return_outputs and self.stage_manager.is_last_stage(-1) else None + outputs = [] if return_outputs and self.stage_manager.is_last_stage(ignore_chunk=True) else None accum_loss = None - if return_loss and self.stage_manager.is_last_stage(-1): + if return_loss and self.stage_manager.is_last_stage(ignore_chunk=True): accum_loss = torch.zeros(1, device=get_current_device()) # Run warmup forward passes. diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index d7853938ab4e..9bfed2e29cd6 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -1,3 +1,4 @@ +import contextlib from typing import Dict, List, Optional, Tuple import torch.distributed as dist @@ -68,45 +69,37 @@ def __init__( # for shardformer, hold model chunk id self.model_chunk_id: Optional[int] = None - def is_first_stage(self, model_chunk_id: Optional[int] = None) -> bool: + def is_first_stage(self, ignore_chunk: bool = False) -> bool: """Is the current stage the first stage. NOTE: 1. if using interleaved pipeline parallel, the first stage is the first chunk of the first device. - 2. invoke is_first_stage() with model_chunk_id < 0 is equivalent to invoke is_first_device() + 2. invoke is_first_stage() with ignore_chunk=True is equivalent to invoke is_first_device() Returns: bool: Whether the current stage is the first stage. """ - if self.is_interleave and model_chunk_id is None: - model_chunk_id = self.model_chunk_id - assert self.is_interleave ^ ( - model_chunk_id is None - ), "model_chunk_id must be specified when using interleaved pipeline" - if not self.is_interleave or model_chunk_id < 0: + assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None) + if not self.is_interleave or ignore_chunk: return self.stage == 0 else: - return self.stage == 0 and model_chunk_id == 0 + return self.stage == 0 and self.model_chunk_id == 0 - def is_last_stage(self, model_chunk_id: Optional[int] = None) -> bool: + def is_last_stage(self, ignore_chunk: bool = False) -> bool: """Is the current stage the last stage. NOTE: 1. if using interleaved pipeline parallel, the last stage is the last chunk of the last device. - 2. invoke is_last_stage() with model_chunk_id < 0 is equivalent to invoke is_last_device() + 2. invoke is_last_stage() with ignore_chunk=True is equivalent to invoke is_last_device() Returns: bool: Whether the current stage is the last stage. """ - if self.is_interleave and model_chunk_id is None: - model_chunk_id = self.model_chunk_id - assert self.is_interleave ^ ( - model_chunk_id is None - ), "model_chunk_id must be specified when using interleaved pipeline" - if not self.is_interleave or model_chunk_id < 0: + assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None) + if not self.is_interleave or ignore_chunk: return self.stage == self.num_stages - 1 else: - return self.stage == self.num_stages - 1 and model_chunk_id == self.num_model_chunks - 1 + return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1 @property def num_stages(self) -> int: @@ -174,3 +167,10 @@ def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup: ProcessGroup: Process group of the given stages. """ return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages) + + @contextlib.contextmanager + def set_model_chunk_id(self, model_chunk_id: int): + old_model_chunk_id = self.model_chunk_id + self.model_chunk_id = model_chunk_id + yield + self.model_chunk_id = old_model_chunk_id diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 78363bf5ea99..da158e1c6e19 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -313,7 +313,7 @@ def get_held_layers(self) -> List[Module]: held_layers.append(module.embeddings) for start_idx, end_idx in stage_indices: held_layers.extend(module.encoder.layer[start_idx:end_idx]) - if stage_manager.is_last_stage(-1): + if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(module.pooler) else: diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index b349d7edfdd8..aad12c9c2c59 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -57,9 +57,7 @@ def evaluate_model( def evaluate_subset(dataloader: DataLoader): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 - is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage( - None if not booster.plugin.stage_manager.is_interleave else -1 - ) + is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(ignore_chunk=True) accum_loss = torch.zeros(1, device=get_current_device()) for batch in dataloader: @@ -136,9 +134,7 @@ def train_epoch( coordinator: DistCoordinator, ): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 - is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage( - None if not booster.plugin.stage_manager.is_interleave else -1 - ) + is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(ignore_chunk=True) print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_device) total_step = len(train_dataloader) diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py index a0d1f3c0f8cc..0d7b1dcc86ac 100644 --- a/tests/test_pipeline/test_schedule/test_interleaved.py +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -37,12 +37,13 @@ def pp_linear_fwd( stage_mgr: PipelineStageManager = None, model_chunk_id: int = None, ): - if stage_mgr.is_first_stage(model_chunk_id): - return {"input_obj": forward(data)} - elif stage_mgr.is_last_stage(model_chunk_id): - return forward(input_obj) - else: - return {"input_obj": forward(input_obj)} + with stage_mgr.set_model_chunk_id(model_chunk_id): + if stage_mgr.is_first_stage(): + return {"input_obj": forward(data)} + elif stage_mgr.is_last_stage(): + return forward(input_obj) + else: + return {"input_obj": forward(input_obj)} def run_pp( @@ -107,7 +108,7 @@ def criterion(x, *args, **kwargs): ) # check loss - if stage_manager.is_last_stage(-1): + if stage_manager.is_last_stage(ignore_chunk=True): assert torch.allclose(torch_loss, pp_ret["loss"]) # check gradients @@ -140,7 +141,7 @@ def criterion(x, *args, **kwargs): return_loss=True, return_outputs=True ) - if stage_manager.is_last_stage(-1): + if stage_manager.is_last_stage(ignore_chunk=True): assert torch.allclose(torch_loss, pp_ret["loss"]) for layer in sharded_model: From 71cddab9bdbbe9e0108d61e91742a9739a950aab Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 27 Nov 2023 15:15:54 +0800 Subject: [PATCH 14/22] fix: add shard bert test and fix is_xx_stage arg --- colossalai/pipeline/stage_manager.py | 2 ++ colossalai/shardformer/policies/bert.py | 18 ++++++++--------- tests/test_shardformer/test_model/_utils.py | 6 +++++- .../test_model/test_shard_bert.py | 20 +++++++++++++++---- 4 files changed, 32 insertions(+), 14 deletions(-) diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 9bfed2e29cd6..7b3edd084835 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -79,6 +79,7 @@ def is_first_stage(self, ignore_chunk: bool = False) -> bool: Returns: bool: Whether the current stage is the first stage. """ + assert isinstance(ignore_chunk, bool) assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None) if not self.is_interleave or ignore_chunk: return self.stage == 0 @@ -95,6 +96,7 @@ def is_last_stage(self, ignore_chunk: bool = False) -> bool: Returns: bool: Whether the current stage is the last stage. """ + assert isinstance(ignore_chunk, bool) assert not self.is_interleave or (ignore_chunk or self.model_chunk_id is not None) if not self.is_interleave or ignore_chunk: return self.stage == self.num_stages - 1 diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index da158e1c6e19..0ab63b7650c1 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -309,7 +309,7 @@ def get_held_layers(self) -> List[Module]: num_model_chunks=stage_manager.num_model_chunks, num_stages=stage_manager.num_stages, ) - if stage_manager.is_first_stage(-1): + if stage_manager.is_first_stage(ignore_chunk=True): held_layers.append(module.embeddings) for start_idx, end_idx in stage_indices: held_layers.extend(module.encoder.layer[start_idx:end_idx]) @@ -370,7 +370,7 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage""" held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.cls) return held_layers @@ -409,7 +409,7 @@ def get_held_layers(self) -> List[Module]: """ held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.cls) return held_layers @@ -447,7 +447,7 @@ def get_held_layers(self) -> List[Module]: """ held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.cls) return held_layers @@ -499,7 +499,7 @@ def get_held_layers(self) -> List[Module]: """ held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(None if not stage_manager.is_interleave else -1): + if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.dropout) held_layers.append(self.model.classifier) return held_layers @@ -543,7 +543,7 @@ def get_held_layers(self) -> List[Module]: """ held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.dropout) held_layers.append(self.model.classifier) return held_layers @@ -574,7 +574,7 @@ def get_held_layers(self) -> List[Module]: """ held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.cls) return held_layers @@ -617,7 +617,7 @@ def get_held_layers(self) -> List[Module]: """ held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.dropout) held_layers.append(self.model.classifier) return held_layers @@ -647,7 +647,7 @@ def get_held_layers(self) -> List[Module]: """ held_layers = super().get_held_layers() stage_manager = self.pipeline_stage_manager - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.qa_outputs) return held_layers diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 6acbe4ff523d..87e6618023d3 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -203,7 +203,7 @@ def check_output_hidden_state( ): org_hidden_state = org_output.last_hidden_state - if stage_manager and stage_manager.is_last_stage(): + if stage_manager and stage_manager.is_last_stage(ignore_chunk=True): sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] else: sharded_hidden_state = sharded_output.last_hidden_state @@ -229,6 +229,10 @@ def check_weight( org_weight = getattr_(org_model, suffix).weight sharded_weight = getattr_(sharded_model, suffix).weight + # skip if layer is not held by this process + if sharded_weight is None: + continue + if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight): sharded_weight_list = [ torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group)) diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index b38793b7c388..768bd95bdb42 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -37,6 +37,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, norm_layer_for_check = ["encoder.layer[0].attention.output.LayerNorm", "embeddings.LayerNorm"] col_layer_for_check = ["encoder.layer[0].output.dense"] row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"] + weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"] # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} @@ -44,7 +45,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 1e-4, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0: col_layer_grads = get_grad_tensors_for_check( bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False ) @@ -72,7 +73,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, sharded_optimizer.step() # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): + if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True): if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: @@ -87,8 +88,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 - if stage_manager is None or stage_manager.is_first_stage(): - check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) + if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): + check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1) # check grads check_all_grad_tensors(grads_to_check) @@ -183,6 +184,17 @@ def run_bert_test(test_config): "zero_stage": 1, "initial_scale": 1, }, + { + "tp_size": 2, + "pp_size": 2, + "pp_style": "interleaved", + "num_model_chunks": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, ], ) def run_bert_3d_test(test_config): From c207ff690ede67c975901e642200381dc2741b67 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Wed, 29 Nov 2023 14:23:04 +0800 Subject: [PATCH 15/22] fix: fix pipeline test --- .../test_schedule/test_interleaved.py | 14 +++----------- .../test_pipeline/test_schedule/test_oneF_oneB.py | 3 +-- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py index 0d7b1dcc86ac..88d745e2fa7c 100644 --- a/tests/test_pipeline/test_schedule/test_interleaved.py +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -134,21 +134,13 @@ def criterion(x, *args, **kwargs): torch_loss = criterion(torch_output) pp_ret = schedule.forward_backward_step( - sharded_model, - iter(input_list), - criterion, - pp_optimizer, - return_loss=True, - return_outputs=True + sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True ) if stage_manager.is_last_stage(ignore_chunk=True): assert torch.allclose(torch_loss, pp_ret["loss"]) - - for layer in sharded_model: - assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad)) - assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad)) - + for layer in sharded_model: + assert layer.weight.grad is None and layer.bias.grad is None @pytest.mark.dist diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py index ea60e76e5428..999d5845841d 100644 --- a/tests/test_pipeline/test_schedule/test_oneF_oneB.py +++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py @@ -136,8 +136,7 @@ def custom_fwd(self, x): assert torch.allclose(torch_loss, pp_ret["loss"]) for layer in sharded_model: - assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad)) - assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad)) + assert layer.weight.grad is None and layer.bias.grad is None def run_dist( From eccc9d65404b747258bbfe0823c1c83826674823 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Wed, 29 Nov 2023 14:36:47 +0800 Subject: [PATCH 16/22] revert: add _broadcast_object_list --- colossalai/pipeline/p2p.py | 85 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index f396a3d54eb7..ea92167fc928 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -49,6 +49,91 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - return unpickle +# NOTE: FIXME: NPU DOES NOT support isend nor irecv, so broadcast is kept for future use +def _broadcast_object_list( + object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None +): + """This is a modified version of the broadcast_object_list in torch.distribution + The only difference is that object will be move to correct device after unpickled. + If local_rank = src, then object list will be sent to rank src. Otherwise, object list will + be updated with data sent from rank src. + Args: + object_list (List[Any]): list of object to broadcast + src (int): source rank to broadcast + dst (int): dst rank to broadcast + device (:class:`torch.device`): device to do broadcast. current device in default + """ + + if c10d._rank_not_in_group(group): + c10d._warn_not_in_group("broadcast_object_list") + return + + is_nccl_backend = check_for_nccl_backend(group) + current_device = None + + if device is not None: + if is_nccl_backend and device.type != "cuda": + raise ValueError("device type must be cuda for nccl backend") + current_device = device + else: + current_device = torch.device("cpu") + if is_nccl_backend: + current_device = torch.device("cuda", torch.cuda.current_device()) + + my_rank = dist.get_rank() + # Serialize object_list elements to tensors on src rank. + if my_rank == src: + if Version(torch.__version__) >= Version("1.13.0"): + tensor_list, size_list = zip(*[c10d._object_to_tensor(obj, device=current_device) for obj in object_list]) + else: + tensor_list, size_list = zip(*[c10d._object_to_tensor(obj) for obj in object_list]) + object_sizes_tensor = torch.cat(size_list) + else: + object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long) + + if is_nccl_backend: + object_sizes_tensor = object_sizes_tensor.to(current_device) + + # Broadcast object sizes + c10d.broadcast(object_sizes_tensor, src=src, group=group, async_op=False) + + # Concatenate and broadcast serialized object tensors + if my_rank == src: + object_tensor = torch.cat(tensor_list) + else: + object_tensor = torch.empty( # type: ignore[call-overload] + torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + dtype=torch.uint8, + ) + + if is_nccl_backend: + object_tensor = object_tensor.to(current_device) + + c10d.broadcast(object_tensor, src=src, group=group, async_op=False) + + # Deserialize objects using their stored sizes. + offset = 0 + + if my_rank != src: + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset : offset + obj_size] + obj_view = obj_view.type(torch.uint8) + if obj_view.device != torch.device("cpu"): + obj_view = obj_view.cpu() + offset += obj_size + # unpickle + unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size) + + # unconsistence in device + if ( + isinstance(unpickle_object, torch.Tensor) + and unpickle_object.device.index != torch.cuda.current_device() + ): + unpickle_object = unpickle_object.cuda() + + object_list[i] = unpickle_object + + def check_for_nccl_backend(group): pg = group or c10d._get_default_group() # Gate PG wrapper check on Gloo availability. From 4bdc79b7ddbe182da5e62e7f65399afd14902fd8 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Wed, 29 Nov 2023 18:33:18 +0800 Subject: [PATCH 17/22] fix: fix grad check --- tests/test_pipeline/test_schedule/test_interleaved.py | 6 +++++- tests/test_pipeline/test_schedule/test_oneF_oneB.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py index 88d745e2fa7c..e1c2b30d282f 100644 --- a/tests/test_pipeline/test_schedule/test_interleaved.py +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -140,7 +140,11 @@ def criterion(x, *args, **kwargs): assert torch.allclose(torch_loss, pp_ret["loss"]) for layer in sharded_model: - assert layer.weight.grad is None and layer.bias.grad is None + if layer.weight.grad is None: + assert layer.weight.grad is None and layer.bias.grad is None + else: + assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad)) + assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad)) @pytest.mark.dist diff --git a/tests/test_pipeline/test_schedule/test_oneF_oneB.py b/tests/test_pipeline/test_schedule/test_oneF_oneB.py index 999d5845841d..5f27be39657d 100644 --- a/tests/test_pipeline/test_schedule/test_oneF_oneB.py +++ b/tests/test_pipeline/test_schedule/test_oneF_oneB.py @@ -136,7 +136,11 @@ def custom_fwd(self, x): assert torch.allclose(torch_loss, pp_ret["loss"]) for layer in sharded_model: - assert layer.weight.grad is None and layer.bias.grad is None + if layer.weight.grad is None: + assert layer.weight.grad is None and layer.bias.grad is None + else: + assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad)) + assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad)) def run_dist( From 4ec7313fe5e62f7c9bf86265be886a728185eb79 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Thu, 30 Nov 2023 13:20:29 +0800 Subject: [PATCH 18/22] fix: rename fn, fix typo and change debug info --- colossalai/pipeline/p2p.py | 10 ++++++---- colossalai/pipeline/schedule/interleaved_pp.py | 14 +++++++------- colossalai/pipeline/stage_manager.py | 2 +- tests/test_pipeline/test_p2p_communication.py | 16 ++++++---------- .../test_schedule/test_interleaved.py | 2 +- 5 files changed, 21 insertions(+), 23 deletions(-) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index ea92167fc928..7ea4e751f872 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -16,8 +16,11 @@ from torch.distributed import ProcessGroup from torch.distributed import distributed_c10d as c10d +from colossalai.logging import get_dist_logger + from .stage_manager import PipelineStageManager +logger = get_dist_logger() _unpickler = pickle.Unpickler @@ -358,8 +361,7 @@ def _communicate( send_metadata (bool, optional): whether to send metadata metadata_recv (P2PMetadata, optional): metadata of the object to be received """ - if send_dst is None and recv_src is None: - return + assert send_dst is not None or recv_src is not None, "send_dst and recv_src cannot be both None" assert send_dst is None or send_group is not None, "send_group must be specified when send_dst is not None" assert recv_src is None or recv_group is not None, "recv_group must be specified when recv_src is not None" send_metadata = send_metadata or (object is not None and not _check_if_fast_send_available(object)) @@ -370,14 +372,14 @@ def _communicate( # NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata, # we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case. if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None): - warnings.warn("Fall back to individual send & recv") + logger.debug("Fall back to individual send & recv") _communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata) return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv) # NOTE: only the following 5 cases are valid: # 1. send() [needs extra metadata] and no recv() # 2. recv() [needs extra metadata] and no send() - # 3. neither send() or recv() need extra metadata + # 3. neither send() nor recv() need extra metadata assert not (send_dst is not None and send_metadata) or recv_src is None assert not (recv_src is not None and metadata_recv is None) or send_dst is None assert not (send_dst is not None and recv_src is not None) or (not send_metadata and metadata_recv is not None) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 2b59e6f23316..853c27d68c4b 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -118,7 +118,7 @@ def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any: Returns: Any: The input tensor or input tensor list. """ - with self.stage_manager.set_model_chunk_id(model_chunk_id): + with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_first_stage(): input_tensor = self.comm.recv_forward(prev_rank, metadata_recv=self.metadata_recv_forward) if self.metadata_recv_forward is None: @@ -137,7 +137,7 @@ def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any: Returns: Any: The input gradient tensor or gradient tensor list. """ - with self.stage_manager.set_model_chunk_id(model_chunk_id): + with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_last_stage(): output_tensor_grad = self.comm.recv_backward(next_rank, metadata_recv=self.metadata_recv_backward) if self.metadata_recv_backward is None: @@ -154,7 +154,7 @@ def send_forward(self, model_chunk_id: int, output_object: Any, next_rank: int = output_object (Any): Object to be sent. next_rank (int, optional): The rank of the recipient of the tensor. """ - with self.stage_manager.set_model_chunk_id(model_chunk_id): + with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_last_stage(): self.comm.send_forward(output_object, next_rank, send_metadata=self.send_metadata_forward) self.send_metadata_forward = False @@ -168,7 +168,7 @@ def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = input_object (Any): Object to be sent. prev_rank (int, optional): The rank of the recipient of the tensor """ - with self.stage_manager.set_model_chunk_id(model_chunk_id): + with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_first_stage(): self.comm.send_backward(input_object, prev_rank, send_metadata=self.send_metadata_backward) self.send_metadata_backward = False @@ -176,7 +176,7 @@ def send_backward(self, model_chunk_id: int, input_object: Any, prev_rank: int = def send_forward_recv_backward( self, model_chunk_id: int, output_object: Any, next_rank: Optional[int] = None ) -> Any: - with self.stage_manager.set_model_chunk_id(model_chunk_id): + with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_last_stage(): output_tensor_grad = self.comm.send_forward_recv_backward( output_object, @@ -193,7 +193,7 @@ def send_forward_recv_backward( def send_backward_recv_forward( self, model_chunk_id: int, output_object: Any, prev_rank: Optional[int] = None ) -> Any: - with self.stage_manager.set_model_chunk_id(model_chunk_id): + with self.stage_manager.switch_model_chunk_id(model_chunk_id): if not self.stage_manager.is_first_stage(): input_tensor = self.comm.send_backward_recv_forward( output_object, @@ -232,7 +232,7 @@ def forward_step( # for the first stage, input_obj is None # for the non-first stage, input_obj is the output of the previous stage and it's must be a dict - with self.stage_manager.set_model_chunk_id(model_chunk_id): + with self.stage_manager.switch_model_chunk_id(model_chunk_id): if isinstance(model_chunk, ModuleList): output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj) else: diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index 7b3edd084835..c8f9042084da 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -171,7 +171,7 @@ def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup: return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages) @contextlib.contextmanager - def set_model_chunk_id(self, model_chunk_id: int): + def switch_model_chunk_id(self, model_chunk_id: int): old_model_chunk_id = self.model_chunk_id self.model_chunk_id = model_chunk_id yield diff --git a/tests/test_pipeline/test_p2p_communication.py b/tests/test_pipeline/test_p2p_communication.py index 5ecf1868bb2c..40b6ac8eb6ff 100644 --- a/tests/test_pipeline/test_p2p_communication.py +++ b/tests/test_pipeline/test_p2p_communication.py @@ -32,11 +32,9 @@ def check_p2p_communication(): if rank == 0: for obj in data: p2p.send_forward(obj) - with warnings.catch_warnings(record=True) as w: - for i in range(len(data)): - recv_obj = p2p.send_forward_recv_backward(data[i]) - assert recv_obj == data[-(i + 1)] - assert "Fall back" in str(w[-1].message) + for i in range(len(data)): + recv_obj = p2p.send_forward_recv_backward(data[i]) + assert recv_obj == data[-(i + 1)] elif rank == 1: for obj in data: recv_obj = p2p.recv_forward() @@ -49,11 +47,9 @@ def check_p2p_communication(): if rank == 1: for obj in data: p2p.send_backward(obj) - with warnings.catch_warnings(record=True) as w: - for i in range(len(data)): - recv_obj = p2p.send_backward_recv_forward(data[i]) - assert recv_obj == data[-(i + 1)] - assert "Fall back" in str(w[-1].message) + for i in range(len(data)): + recv_obj = p2p.send_backward_recv_forward(data[i]) + assert recv_obj == data[-(i + 1)] elif rank == 0: for obj in data: recv_obj = p2p.recv_backward() diff --git a/tests/test_pipeline/test_schedule/test_interleaved.py b/tests/test_pipeline/test_schedule/test_interleaved.py index e1c2b30d282f..0e81818eb239 100644 --- a/tests/test_pipeline/test_schedule/test_interleaved.py +++ b/tests/test_pipeline/test_schedule/test_interleaved.py @@ -37,7 +37,7 @@ def pp_linear_fwd( stage_mgr: PipelineStageManager = None, model_chunk_id: int = None, ): - with stage_mgr.set_model_chunk_id(model_chunk_id): + with stage_mgr.switch_model_chunk_id(model_chunk_id): if stage_mgr.is_first_stage(): return {"input_obj": forward(data)} elif stage_mgr.is_last_stage(): From 5e98471264aec992b5e6d732caf2e1e5858932df Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 27 Nov 2023 16:51:01 +0800 Subject: [PATCH 19/22] feat: add interleaved pp in llama policy --- colossalai/shardformer/policies/llama.py | 86 ++++++++++++++----- .../test_model/test_shard_llama.py | 17 +++- 2 files changed, 78 insertions(+), 25 deletions(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index eee2259f2c56..39a4d40234aa 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -8,7 +8,11 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D -from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward, get_lm_forward_with_dist_cross_entropy +from ..modeling.llama import ( + LlamaPipelineForwards, + get_llama_flash_attention_forward, + get_lm_forward_with_dist_cross_entropy, +) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["LlamaPolicy", "LlamaForCausalLMPolicy", "LlamaForSequenceClassificationPolicy"] @@ -140,21 +144,42 @@ def postprocess(self): def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface to customized forward method, and add this changing to policy.""" - if self.pipeline_stage_manager: - stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == "LlamaModel": - module = self.model - else: - module = self.model.model + if self.pipeline_stage_manager is None: + return + + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "LlamaModel": + module = self.model + else: + module = self.model.model + + if stage_manager.is_interleave: + layers_per_stage = self.distribute_layers( + len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks + ) + stage_manager.stage_indices = Policy.get_stage_index( + layers_per_stage, + stage_manager.stage, + num_model_chunks=stage_manager.num_model_chunks, + num_stages=stage_manager.num_stages, + ) + method_replacement = { + "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) + } + else: layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config)} + method_replacement = { + "forward": partial( + new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + ) + } self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls ) - return + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" @@ -167,13 +192,32 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) - if stage_manager.is_first_stage(): - held_layers.append(module.embed_tokens) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) - held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.is_last_stage(): - held_layers.append(module.norm) + if stage_manager.is_interleave: + assert stage_manager.num_model_chunks is not None + layers_per_stage = self.distribute_layers( + len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks + ) + stage_indices = Policy.get_stage_index( + layers_per_stage, + stage_manager.stage, + num_model_chunks=stage_manager.num_model_chunks, + num_stages=stage_manager.num_stages, + ) + if stage_manager.is_first_stage(ignore_chunk=True): + held_layers.append(module.embed_tokens) + for start_idx, end_idx in stage_indices: + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(ignore_chunk=True): + held_layers.append(module.norm) + + else: + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) return held_layers @@ -211,11 +255,9 @@ def module_policy(self): new_item = { LlamaForCausalLM: ModulePolicyDescription( sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col - ) + SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col) ], - method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)} + method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, ) } policy.update(new_item) @@ -232,7 +274,7 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.lm_head) return held_layers @@ -285,7 +327,7 @@ def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" stage_manager = self.pipeline_stage_manager held_layers = super().get_held_layers() - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(self.model.score) return held_layers diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index f8f08e1d0075..c7edcfb3510c 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -44,7 +44,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} - if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: + if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0: if test_config["precision"] == "fp32": atol, rtol = 1e-6, 1e-4 else: @@ -63,7 +63,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, sharded_optimizer.step() # check last hidden state & loss - if stage_manager is None or stage_manager.is_last_stage(): + if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True): if test_config["precision"] == "fp32": atol, rtol = 1e-5, 1e-3 else: @@ -75,7 +75,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) # check weights - if stage_manager is None or stage_manager.is_first_stage(): + if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True): if test_config["precision"] == "fp32": atol, rtol = 1e-4, 1e-3 else: @@ -179,6 +179,17 @@ def run_llama_test(test_config): "zero_stage": 1, "initial_scale": 1, }, + { + "tp_size": 2, + "pp_size": 2, + "pp_style": "interleaved", + "num_model_chunks": 2, + "num_microbatches": 4, + "enable_all_optimization": False, + "precision": "fp16", + "zero_stage": 1, + "initial_scale": 1, + }, ], ) def run_llama_3d_test(test_config): From 9c0f96915ce61695476a461b224ec12c1212dbd5 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Fri, 1 Dec 2023 11:22:39 +0800 Subject: [PATCH 20/22] feat: use interleaved pp in llama and add warning --- colossalai/pipeline/schedule/interleaved_pp.py | 4 ++++ colossalai/pipeline/schedule/one_f_one_b.py | 4 ++++ examples/language/llama2/benchmark.py | 2 ++ 3 files changed, 10 insertions(+) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 853c27d68c4b..3bf0604f90d5 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -7,6 +7,7 @@ from torch.utils._pytree import tree_map from colossalai.interface import OptimizerWrapper +from colossalai.logging import get_dist_logger from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.utils.device import get_current_device @@ -27,6 +28,9 @@ def __init__( assert ( num_microbatch is not None or microbatch_size is not None ), "Either num_microbatch or microbatch_size should be provided" + self.logger = get_dist_logger() + self.logger.warning("If pipeline hangs, please enlarge NCCL_BUFFSIZE.") + self.comm = PipelineP2PCommunication(stage_manager) self.num_microbatch = num_microbatch self.microbatch_size = microbatch_size diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 35cfde1f5aff..a716ab804bd5 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -7,6 +7,7 @@ from torch.utils._pytree import tree_map from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.logging import get_dist_logger from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.utils.device import get_current_device @@ -42,6 +43,9 @@ def __init__( assert ( num_microbatches is not None or microbatch_size is not None ), "Either num_microbatches or microbatch_size should be provided" + self.logger = get_dist_logger() + self.logger.warning("If pipeline hangs, please enlarge NCCL_BUFFSIZE.") + self.comm = PipelineP2PCommunication(stage_manager) self.num_microbatches = num_microbatches self.microbatch_size = microbatch_size diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py index daf7d2fd4b0b..a4c29b7c8231 100644 --- a/examples/language/llama2/benchmark.py +++ b/examples/language/llama2/benchmark.py @@ -133,7 +133,9 @@ def empty_init(): plugin = HybridParallelPlugin( tp_size=args.tp, pp_size=args.pp, + pp_style="interleaved", zero_stage=args.zero, + num_model_chunks=2, enable_fused_normalization=torch.cuda.is_available(), num_microbatches=args.mbs, precision="bf16", From a881b4d3da91dc1c4da518c712717b2f3927337c Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Tue, 5 Dec 2023 14:42:42 +0800 Subject: [PATCH 21/22] feat: set NCCL_BUFFSIZE in HybridParallelPlugin --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 11 ++++++++++- colossalai/pipeline/p2p.py | 5 ----- colossalai/pipeline/schedule/interleaved_pp.py | 3 --- colossalai/pipeline/schedule/one_f_one_b.py | 3 --- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 91fcba55a0aa..6430345c88d2 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,4 +1,5 @@ import ctypes +import os import random from contextlib import contextmanager from functools import partial @@ -21,7 +22,8 @@ from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh -from colossalai.interface import ModelWrapper, OptimizerWrapper, AMPModelMixin +from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper +from colossalai.logging import get_dist_logger from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer @@ -982,6 +984,13 @@ def __init__( self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: + if os.getenv("NCCL_BUFFSIZE") is None: + logger = get_dist_logger() + logger.warning( + "Setting NCCL_BUFFSIZE to 256MB to avoid p2p hangs. " "Please increase it if hangs still happen." + ) + os.environ["NCCL_BUFFSIZE"] = "268435456" + assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" assert ( diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 7ea4e751f872..cdb7a6a1e539 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -4,7 +4,6 @@ import io import pickle import re -import warnings from collections import namedtuple from dataclasses import dataclass from enum import Enum @@ -16,11 +15,8 @@ from torch.distributed import ProcessGroup from torch.distributed import distributed_c10d as c10d -from colossalai.logging import get_dist_logger - from .stage_manager import PipelineStageManager -logger = get_dist_logger() _unpickler = pickle.Unpickler @@ -372,7 +368,6 @@ def _communicate( # NOTE: send & recv should be atomic operations. However, if we need to send metadata or receive metadata, # we are not able to do that (1. send & recv metadata 2. send & recv). So we need to split the send & recv into two parts in this case. if (send_dst is not None and recv_src is not None) and (send_metadata or metadata_recv is None): - logger.debug("Fall back to individual send & recv") _communicate(object, send_dst=send_dst, recv_src=None, send_group=send_group, send_metadata=send_metadata) return _communicate(None, send_dst=None, recv_src=recv_src, recv_group=recv_group, metadata_recv=metadata_recv) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 3bf0604f90d5..3c8b00977429 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -7,7 +7,6 @@ from torch.utils._pytree import tree_map from colossalai.interface import OptimizerWrapper -from colossalai.logging import get_dist_logger from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.utils.device import get_current_device @@ -28,8 +27,6 @@ def __init__( assert ( num_microbatch is not None or microbatch_size is not None ), "Either num_microbatch or microbatch_size should be provided" - self.logger = get_dist_logger() - self.logger.warning("If pipeline hangs, please enlarge NCCL_BUFFSIZE.") self.comm = PipelineP2PCommunication(stage_manager) self.num_microbatch = num_microbatch diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index a716ab804bd5..8c161efec9d8 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -7,7 +7,6 @@ from torch.utils._pytree import tree_map from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.logging import get_dist_logger from colossalai.pipeline.p2p import PipelineP2PCommunication, create_fast_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.utils.device import get_current_device @@ -43,8 +42,6 @@ def __init__( assert ( num_microbatches is not None or microbatch_size is not None ), "Either num_microbatches or microbatch_size should be provided" - self.logger = get_dist_logger() - self.logger.warning("If pipeline hangs, please enlarge NCCL_BUFFSIZE.") self.comm = PipelineP2PCommunication(stage_manager) self.num_microbatches = num_microbatches From fc6d26e1cf001d6c3d9866089f1597b13fc6ab7c Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Tue, 5 Dec 2023 15:36:17 +0800 Subject: [PATCH 22/22] fix: fix buffer size --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 6430345c88d2..ea74f75f43c8 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -987,9 +987,9 @@ def __init__( if os.getenv("NCCL_BUFFSIZE") is None: logger = get_dist_logger() logger.warning( - "Setting NCCL_BUFFSIZE to 256MB to avoid p2p hangs. " "Please increase it if hangs still happen." + "Setting NCCL_BUFFSIZE to 128MB to avoid p2p hangs. " "Please increase it if hangs still happen." ) - os.environ["NCCL_BUFFSIZE"] = "268435456" + os.environ["NCCL_BUFFSIZE"] = "134217728" assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"