From 823e655b42012523c77183c0c76ecb3af87e4738 Mon Sep 17 00:00:00 2001 From: Elsa Date: Mon, 6 Nov 2023 01:02:42 +0800 Subject: [PATCH 01/14] Use p2p --- colossalai/pipeline/p2p.py | 173 +++++++++++++++++++++++++++++++++++-- 1 file changed, 165 insertions(+), 8 deletions(-) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index f822c1819adc..8d8d72d0ea4e 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -4,7 +4,7 @@ import io import pickle import re -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Union, Dict import torch import torch.distributed as dist @@ -45,6 +45,21 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - return unpickle +def check_for_nccl_backend(group): + pg = group or c10d._get_default_group() + # Gate PG wrapper check on Gloo availability. + if c10d._GLOO_AVAILABLE: + # It is not expected for PG to be wrapped many times, but support it just + # in case + while isinstance(pg, c10d._ProcessGroupWrapper): + pg = pg.wrapped_pg + + return ( + c10d.is_nccl_available() and + pg.name() == c10d.Backend.NCCL + ) + + def _broadcast_object_list( object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None ): @@ -65,7 +80,7 @@ def _broadcast_object_list( c10d._warn_not_in_group("broadcast_object_list") return - is_nccl_backend = c10d._check_for_nccl_backend(group) + is_nccl_backend = check_for_nccl_backend(group) current_device = None if device is not None: @@ -113,7 +128,7 @@ def _broadcast_object_list( if my_rank != src: for i, obj_size in enumerate(object_sizes_tensor): - obj_view = object_tensor[offset : offset + obj_size] + obj_view = object_tensor[offset: offset + obj_size] obj_view = obj_view.type(torch.uint8) if obj_view.device != torch.device("cpu"): obj_view = obj_view.cpu() @@ -131,6 +146,126 @@ def _broadcast_object_list( object_list[i] = unpickle_object +def check_device(group): + is_nccl_backend = check_for_nccl_backend(group) + current_device = None + + current_device = torch.device("cpu") + if is_nccl_backend: + current_device = torch.device("cuda", torch.cuda.current_device()) + return current_device, is_nccl_backend + + +def filling_ops_queue(obj, comm_op, comm_rank, ops_queue, group): + if isinstance(obj, torch.Tensor): + op_to_add = dist.P2POp(comm_op, obj, comm_rank, group) + ops_queue.append(op_to_add) + else: + for tensor_to_comm in obj: + op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank, group) + ops_queue.append(op_to_add) + + +def _batch_send_tensor(dst, tensor_list: List[torch.Tensor], group): + ops = [] + filling_ops_queue(tensor_list, dist.isend, dst, ops, group) + if len(ops) > 0: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + torch.cuda.synchronize() + + +def create_recv_buffer(tensor_metadata, current_device): + buffer_recv = [] + for (_, tensor_shape, tensor_dtype, tensor_requires_grad) in tensor_metadata: + tensor_recv = torch.empty( + tensor_shape, requires_grad=tensor_requires_grad, device=current_device, dtype=tensor_dtype) + buffer_recv.append(tensor_recv) + return buffer_recv + + +def _batch_recv_tensor(src, tensor_metadata, current_device, group): + buffer_recv = create_recv_buffer(tensor_metadata, current_device) + ops = [] + + filling_ops_queue(buffer_recv, dist.irecv, src, ops, group) + + if len(ops) > 0: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + torch.cuda.synchronize() + return buffer_recv + + +def _send_object_with_serialization(object: Any, src: int, dst: int, group: ProcessGroup, current_device, is_nccl_backend): + if Version(torch.__version__) >= Version("1.13.0"): + object_tensor, object_size_tensor = c10d._object_to_tensor( + object, device=current_device) + else: + object_tensor, object_size_tensor = c10d._object_to_tensor(object) + + if is_nccl_backend: + object_size_tensor = object_size_tensor.to(current_device) + object_tensor = object_tensor.to(current_device) + + c10d.send(object_size_tensor, dst=dst, group=group) + c10d.send(object_tensor, dst=dst, group=group) + + +def _recv_object_with_serialization(src: int, dst: int, group: ProcessGroup, current_device, is_nccl_backend): + object_size_tensor = torch.empty(1, dtype=torch.long) + if is_nccl_backend: + object_size_tensor = object_size_tensor.to(current_device) + + c10d.recv(object_size_tensor, src=src, group=group) + + object_tensor = torch.empty(object_size_tensor.item(), dtype=torch.uint8) + if is_nccl_backend: + object_tensor = object_tensor.to(current_device) + + c10d.recv(object_tensor, src=src, group=group) + + object_tensor = object_tensor.type(torch.uint8) + if object_tensor.device != torch.device("cpu"): + object_tensor = object_tensor.cpu() + + unpickle_object = _cuda_safe_tensor_to_object( + object_tensor, object_size_tensor.item()) + + if ( + isinstance(unpickle_object, torch.Tensor) + and unpickle_object.device.index != torch.cuda.current_device() + ): + unpickle_object = unpickle_object.cuda() + + return unpickle_object + + +def _send_dict_of_tensor(object: Any, src: int, dst: int, group: ProcessGroup, current_device, is_nccl_backend): + metadata = [] + metadata.append('__metadata__transfer__') + for k, v in object.items(): + metadata.append((k, v.shape, v.dtype, v.requires_grad)) + + _send_object_with_serialization(metadata, src, dst, group, current_device, is_nccl_backend) + + _batch_send_tensor(dst, list(object.values()), group) + + +def _recv_dict_of_tensor(metadata, src: int, dst: int, group: ProcessGroup, current_device, is_nccl_backend): + recved_batch = _batch_recv_tensor(src, metadata, current_device, group) + object = { + k: v + for k, v in zip( + [m[0] for m in metadata], + recved_batch, + ) + } + return object + + def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None: """send anything to dst rank @@ -141,8 +276,21 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None: Returns: None """ - # then broadcast safely - _broadcast_object_list([object], src, group) + if c10d._rank_not_in_group(group): + c10d._warn_not_in_group("_send_object") + return + + current_device, is_nccl_backend = check_device(group) + + if type(object) is dict: + is_dict_of_tensor = all([type(k) is str and type( + v) is torch.Tensor for k, v in object.items()]) + + if is_dict_of_tensor: + _send_dict_of_tensor(object, src, dst, group, current_device, is_nccl_backend) + return + + _send_object_with_serialization(object, src, dst, group, current_device, is_nccl_backend) def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: @@ -154,10 +302,19 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: Returns: Any: Object received from src. """ - object_list = [None] - _broadcast_object_list(object_list, src, group) + if c10d._rank_not_in_group(group): + c10d._warn_not_in_group("_recv_object") + return - return object_list[0] + current_device, is_nccl_backend = check_device(group) + + object = _recv_object_with_serialization(src, dst, group, current_device, is_nccl_backend) + if type(object) is list and len(object) >= 1 and object[0] == '__metadata__transfer__': + object.pop(0) + object = _recv_dict_of_tensor(object, src, dst, group, current_device, is_nccl_backend) + return object + else: + return object def _p2p_comm( From 5da7d6a7e9f00b9e7ac6e04973ab11b8cf749d20 Mon Sep 17 00:00:00 2001 From: Elsa Date: Mon, 6 Nov 2023 01:03:40 +0800 Subject: [PATCH 02/14] Cannot bidirectonal send p2p --- colossalai/pipeline/p2p.py | 150 ++++++++++++++++++++++--------------- 1 file changed, 88 insertions(+), 62 deletions(-) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 8d8d72d0ea4e..0449e0f6a052 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -166,16 +166,6 @@ def filling_ops_queue(obj, comm_op, comm_rank, ops_queue, group): ops_queue.append(op_to_add) -def _batch_send_tensor(dst, tensor_list: List[torch.Tensor], group): - ops = [] - filling_ops_queue(tensor_list, dist.isend, dst, ops, group) - if len(ops) > 0: - reqs = dist.batch_isend_irecv(ops) - for req in reqs: - req.wait() - torch.cuda.synchronize() - - def create_recv_buffer(tensor_metadata, current_device): buffer_recv = [] for (_, tensor_shape, tensor_dtype, tensor_requires_grad) in tensor_metadata: @@ -185,11 +175,19 @@ def create_recv_buffer(tensor_metadata, current_device): return buffer_recv -def _batch_recv_tensor(src, tensor_metadata, current_device, group): - buffer_recv = create_recv_buffer(tensor_metadata, current_device) +def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, recv_src, group, current_device): + buffer_recv = None + if recv_tensor_metadata is not None: + buffer_recv = create_recv_buffer(recv_tensor_metadata, current_device) + ops = [] - filling_ops_queue(buffer_recv, dist.irecv, src, ops, group) + if send_dst is not None: + filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, group) + + if recv_src is not None: + assert buffer_recv is not None + filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, group) if len(ops) > 0: reqs = dist.batch_isend_irecv(ops) @@ -199,7 +197,7 @@ def _batch_recv_tensor(src, tensor_metadata, current_device, group): return buffer_recv -def _send_object_with_serialization(object: Any, src: int, dst: int, group: ProcessGroup, current_device, is_nccl_backend): +def _send_object_with_serialization(object: Any, dst: int, group: ProcessGroup, current_device, is_nccl_backend): if Version(torch.__version__) >= Version("1.13.0"): object_tensor, object_size_tensor = c10d._object_to_tensor( object, device=current_device) @@ -214,7 +212,7 @@ def _send_object_with_serialization(object: Any, src: int, dst: int, group: Proc c10d.send(object_tensor, dst=dst, group=group) -def _recv_object_with_serialization(src: int, dst: int, group: ProcessGroup, current_device, is_nccl_backend): +def _recv_object_with_serialization(src: int, group: ProcessGroup, current_device, is_nccl_backend): object_size_tensor = torch.empty(1, dtype=torch.long) if is_nccl_backend: object_size_tensor = object_size_tensor.to(current_device) @@ -231,8 +229,7 @@ def _recv_object_with_serialization(src: int, dst: int, group: ProcessGroup, cur if object_tensor.device != torch.device("cpu"): object_tensor = object_tensor.cpu() - unpickle_object = _cuda_safe_tensor_to_object( - object_tensor, object_size_tensor.item()) + unpickle_object = _cuda_safe_tensor_to_object(object_tensor, object_size_tensor.item()) if ( isinstance(unpickle_object, torch.Tensor) @@ -243,27 +240,82 @@ def _recv_object_with_serialization(src: int, dst: int, group: ProcessGroup, cur return unpickle_object -def _send_dict_of_tensor(object: Any, src: int, dst: int, group: ProcessGroup, current_device, is_nccl_backend): - metadata = [] - metadata.append('__metadata__transfer__') - for k, v in object.items(): - metadata.append((k, v.shape, v.dtype, v.requires_grad)) +def _check_if_fast_send_available(object): + if type(object) is list: + is_list_of_tensor = all([type(v) is torch.Tensor for v in object]) + return is_list_of_tensor + elif type(object) is dict: + is_dict_of_tensor = all([type(k) is str and type( + v) is torch.Tensor for k, v in object.items()]) - _send_object_with_serialization(metadata, src, dst, group, current_device, is_nccl_backend) + return is_dict_of_tensor + return False - _batch_send_tensor(dst, list(object.values()), group) +def _send_recv( + object, + send_dst: Optional[int], + recv_src: Optional[int], + group: ProcessGroup +) -> Any: + if c10d._rank_not_in_group(group): + c10d._warn_not_in_group("_send_recv") + return -def _recv_dict_of_tensor(metadata, src: int, dst: int, group: ProcessGroup, current_device, is_nccl_backend): - recved_batch = _batch_recv_tensor(src, metadata, current_device, group) - object = { - k: v - for k, v in zip( - [m[0] for m in metadata], - recved_batch, - ) - } - return object + current_device, is_nccl_backend = check_device(group) + + assert (send_dst is not None) or (recv_src is not None) + + can_fast_send = False + if send_dst is not None: + can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend + if not can_fast_send: + send_metadata = ['__serialization__', object] + _send_object_with_serialization(send_metadata, send_dst, group, current_device, is_nccl_backend) + else: + send_metadata = [] + send_metadata.append('__tensor__') + if type(object) is list: + for v in object: + send_metadata.append((None, v.shape, v.dtype, v.requires_grad)) + elif type(object) is dict: + for k, v in object.items(): + send_metadata.append((k, v.shape, v.dtype, v.requires_grad)) + _send_object_with_serialization(send_metadata, send_dst, group, current_device, is_nccl_backend) + + recv_metadata = None + if recv_src is not None: + recv_metadata = _recv_object_with_serialization(recv_src, group, current_device, is_nccl_backend) + assert type(recv_metadata) is list and len(recv_metadata) > 1 + if recv_metadata[0] == '__serialization__': + return recv_metadata[1] + else: + recv_metadata.pop(0) + + # The following code is for fast send and recv + if not can_fast_send and send_dst is not None: + return + + send_tensor_list = None + if type(object) is list: + send_tensor_list = object + elif type(object) is dict: + send_tensor_list = list(object.values()) + + recv_buffer = _batch_send_recv_tensor(send_tensor_list, recv_metadata, send_dst, recv_src, group, current_device) + + if recv_metadata is not None: + assert recv_buffer is not None + if recv_metadata[0][0] is None: + return recv_buffer + else: + return { + k: v + for k, v in zip( + [m[0] for m in recv_metadata], + recv_buffer, + ) + } def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None: @@ -276,21 +328,7 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None: Returns: None """ - if c10d._rank_not_in_group(group): - c10d._warn_not_in_group("_send_object") - return - - current_device, is_nccl_backend = check_device(group) - - if type(object) is dict: - is_dict_of_tensor = all([type(k) is str and type( - v) is torch.Tensor for k, v in object.items()]) - - if is_dict_of_tensor: - _send_dict_of_tensor(object, src, dst, group, current_device, is_nccl_backend) - return - - _send_object_with_serialization(object, src, dst, group, current_device, is_nccl_backend) + _send_recv(object, dst, None, group) def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: @@ -302,19 +340,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: Returns: Any: Object received from src. """ - if c10d._rank_not_in_group(group): - c10d._warn_not_in_group("_recv_object") - return - - current_device, is_nccl_backend = check_device(group) - - object = _recv_object_with_serialization(src, dst, group, current_device, is_nccl_backend) - if type(object) is list and len(object) >= 1 and object[0] == '__metadata__transfer__': - object.pop(0) - object = _recv_dict_of_tensor(object, src, dst, group, current_device, is_nccl_backend) - return object - else: - return object + return _send_recv(None, None, src, group) def _p2p_comm( From 9bd0f8181ca24c52869b5e4e118afa7583c62293 Mon Sep 17 00:00:00 2001 From: Elsa Date: Mon, 6 Nov 2023 10:38:45 +0800 Subject: [PATCH 03/14] Refactor tensor creation and serialization in P2P communication --- colossalai/pipeline/p2p.py | 113 +++++++++++++++++++++++++++++++------ 1 file changed, 95 insertions(+), 18 deletions(-) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 0449e0f6a052..72d84d5452e9 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -167,12 +167,19 @@ def filling_ops_queue(obj, comm_op, comm_rank, ops_queue, group): def create_recv_buffer(tensor_metadata, current_device): - buffer_recv = [] - for (_, tensor_shape, tensor_dtype, tensor_requires_grad) in tensor_metadata: + if tensor_metadata[0] == 0: + tensor_shape, tensor_dtype, tensor_requires_grad = tensor_metadata[1] tensor_recv = torch.empty( tensor_shape, requires_grad=tensor_requires_grad, device=current_device, dtype=tensor_dtype) - buffer_recv.append(tensor_recv) - return buffer_recv + return tensor_recv + elif tensor_metadata[0] == 1 or tensor_metadata[0] == 2: + buffer_recv = [] + for tensor_data in tensor_metadata[1:]: + tensor_shape, tensor_dtype, tensor_requires_grad = tensor_data[-3:] + tensor_recv = torch.empty( + tensor_shape, requires_grad=tensor_requires_grad, device=current_device, dtype=tensor_dtype) + buffer_recv.append(tensor_recv) + return buffer_recv def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, recv_src, group, current_device): @@ -240,8 +247,73 @@ def _recv_object_with_serialization(src: int, group: ProcessGroup, current_devic return unpickle_object +def _send_recv_serialization_object(object: Any, send_dst: Optional[int], recv_src: Optional[int], group: ProcessGroup, current_device, is_nccl_backend): + ops = [] + send_object_tensor = None + if object is not None and send_dst is not None: + if Version(torch.__version__) >= Version("1.13.0"): + send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device) + else: + send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object) + + if is_nccl_backend: + send_object_size_tensor = send_object_size_tensor.to(current_device) + send_object_tensor = send_object_tensor.to(current_device) + + filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, group) + + recv_object_size_tensor = None + if recv_src is not None: + recv_object_size_tensor = torch.empty(1, dtype=torch.long) + if is_nccl_backend: + recv_object_size_tensor = recv_object_size_tensor.to(current_device) + filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, group) + + if len(ops) > 0: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + torch.cuda.synchronize() + + ops = [] + + if send_dst is not None and send_object_tensor is not None: + filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, group) + + recv_object_tensor = None + if recv_src is not None and recv_object_size_tensor is not None: + recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8) + if is_nccl_backend: + recv_object_tensor = recv_object_tensor.to(current_device) + filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, group) + + if len(ops) > 0: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + torch.cuda.synchronize() + + if recv_object_tensor is not None and recv_object_size_tensor is not None: + recv_object_tensor = recv_object_tensor.type(torch.uint8) + if recv_object_tensor.device != torch.device("cpu"): + recv_object_tensor = recv_object_tensor.cpu() + + unpickle_object = _cuda_safe_tensor_to_object( + recv_object_tensor, recv_object_size_tensor.item()) + + if ( + isinstance(unpickle_object, torch.Tensor) + and unpickle_object.device.index != torch.cuda.current_device() + ): + unpickle_object = unpickle_object.cuda() + + return unpickle_object + + def _check_if_fast_send_available(object): - if type(object) is list: + if type(object) is torch.Tensor: + return True + elif type(object) is list: is_list_of_tensor = all([type(v) is torch.Tensor for v in object]) return is_list_of_tensor elif type(object) is dict: @@ -267,37 +339,40 @@ def _send_recv( assert (send_dst is not None) or (recv_src is not None) can_fast_send = False + send_metadata = None if send_dst is not None: can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend if not can_fast_send: send_metadata = ['__serialization__', object] - _send_object_with_serialization(send_metadata, send_dst, group, current_device, is_nccl_backend) else: send_metadata = [] send_metadata.append('__tensor__') - if type(object) is list: + if type(object) is torch.Tensor: + send_metadata.append(0) + send_metadata.append((object.shape, object.dtype, object.requires_grad)) + elif type(object) is list: + send_metadata.append(1) for v in object: - send_metadata.append((None, v.shape, v.dtype, v.requires_grad)) + send_metadata.append((v.shape, v.dtype, v.requires_grad)) elif type(object) is dict: + send_metadata.append(2) for k, v in object.items(): send_metadata.append((k, v.shape, v.dtype, v.requires_grad)) - _send_object_with_serialization(send_metadata, send_dst, group, current_device, is_nccl_backend) - recv_metadata = None - if recv_src is not None: - recv_metadata = _recv_object_with_serialization(recv_src, group, current_device, is_nccl_backend) - assert type(recv_metadata) is list and len(recv_metadata) > 1 + recv_metadata = _send_recv_serialization_object(send_metadata, send_dst, recv_src, group, current_device, is_nccl_backend) + if recv_metadata is not None: + assert type(recv_metadata) is list and len(recv_metadata) >= 2 if recv_metadata[0] == '__serialization__': return recv_metadata[1] else: recv_metadata.pop(0) - - # The following code is for fast send and recv if not can_fast_send and send_dst is not None: return send_tensor_list = None - if type(object) is list: + if type(object) is torch.Tensor: + send_tensor_list = object + elif type(object) is list: send_tensor_list = object elif type(object) is dict: send_tensor_list = list(object.values()) @@ -306,13 +381,15 @@ def _send_recv( if recv_metadata is not None: assert recv_buffer is not None - if recv_metadata[0][0] is None: + if recv_metadata[0] == 0: + return recv_buffer + elif recv_metadata[0] == 1: return recv_buffer else: return { k: v for k, v in zip( - [m[0] for m in recv_metadata], + [m[0] for m in recv_metadata[1:]], recv_buffer, ) } From c33e5ad5d66fbc93b45a27aea1c2567cf31c93f5 Mon Sep 17 00:00:00 2001 From: Elsa Date: Mon, 6 Nov 2023 14:46:40 +0800 Subject: [PATCH 04/14] Fix llama forward args in flash attention --- colossalai/shardformer/modeling/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 4b6c8342534a..0f911be4870f 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -413,6 +413,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4." From f0f302aaf374c618f76dab5300647128fd8e4adc Mon Sep 17 00:00:00 2001 From: Elsa Date: Mon, 6 Nov 2023 14:52:53 +0800 Subject: [PATCH 05/14] Add flop estimate from megatron --- examples/language/llama2/benchmark.py | 6 +++++- examples/language/llama2/performance_evaluator.py | 14 ++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py index ce13ebbf617d..b38ddbb4a33e 100644 --- a/examples/language/llama2/benchmark.py +++ b/examples/language/llama2/benchmark.py @@ -183,7 +183,11 @@ def empty_init(): model_numel = get_model_numel(model) coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") performance_evaluator = PerformanceEvaluator( - model_numel, args.grad_checkpoint, args.ignore_steps, dp_world_size=dp_size + model_numel, + model.config.num_hidden_layers, + model.config.hidden_size, + model.config.vocab_size, + args.grad_checkpoint, args.ignore_steps, dp_world_size=dp_size ) optimizer = HybridAdam(model.parameters()) diff --git a/examples/language/llama2/performance_evaluator.py b/examples/language/llama2/performance_evaluator.py index a57c1e0e9ae3..05e71edf15a5 100644 --- a/examples/language/llama2/performance_evaluator.py +++ b/examples/language/llama2/performance_evaluator.py @@ -58,6 +58,9 @@ class PerformanceEvaluator: def __init__( self, model_numel: int, + num_layers: int, + hidden_size: int, + vocab_size: int, enable_grad_checkpoint: bool = False, ignore_steps: int = 0, dp_world_size: Optional[int] = None, @@ -65,12 +68,16 @@ def __init__( self.model_numel = model_numel self.enable_grad_checkpoint = enable_grad_checkpoint self.ignore_steps = ignore_steps + self.num_layers = num_layers + self.hidden_size = hidden_size + self.vocab_size = vocab_size self.coordinator = DistCoordinator() self.dp_world_size = dp_world_size or self.coordinator.world_size self.disable: bool = False self.timer = Timer() self.num_samples: int = 0 + self.flop_megatron = 0 self.flop: int = 0 def on_step_start(self, step: int) -> None: @@ -89,17 +96,20 @@ def on_step_end(self, input_ids: Tensor, **kwargs) -> None: batch_size, seq_len = input_ids.shape self.num_samples += batch_size + checkpoint_activations_factor = (3 + int(self.enable_grad_checkpoint)) + self.flop_megatron += (24 * checkpoint_activations_factor * batch_size * seq_len * self.num_layers * (self.hidden_size**2)) * (1. + (seq_len / (6. * self.hidden_size)) + (self.vocab_size / (16. * self.num_layers * self.hidden_size))) self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint)) def on_fit_end(self) -> None: avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size) avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12) mp_world_size = self.coordinator.world_size // self.dp_world_size + avg_tflops_per_gpu_megatron = self.flop_megatron / 1e12 / (avg_duration + 1e-12) / mp_world_size avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size self.coordinator.print_on_master( - f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, " + f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop_megatron: {self.flop_megatron}, flop: {self.flop}, avg_duration: {avg_duration}, " f"avg_throughput: {avg_throughput}" ) self.coordinator.print_on_master( - f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}" + f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU by Megatron: {avg_tflops_per_gpu_megatron:.2f}, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}" ) From cc1e5af15aa058d9979e90e8334871b515705dd8 Mon Sep 17 00:00:00 2001 From: Elsa Date: Mon, 6 Nov 2023 14:54:15 +0800 Subject: [PATCH 06/14] Support loading weight not in weight_map when strict=False in hybrid_parallel --- .../hybrid_parallel_checkpoint_io.py | 29 +++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 779ff42d75a1..13e2400b412b 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -1,4 +1,5 @@ import copy +from functools import reduce import logging import os from pathlib import Path @@ -313,9 +314,13 @@ def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, s # Keep a record of loaded files so that file will not be repeatedly loaded. loaded_file = set() + missing_keys = [] + missing_file_keys = [] + def _load(name: str): if name not in weight_map: - raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!") + missing_file_keys.append(name) + return filename = weight_map[name] # If this param/buffer has been loaded before, directly return. @@ -324,7 +329,6 @@ def _load(name: str): file_path = os.path.join(ckpt_root_path, filename) state_dict = load_shard_state_dict(Path(file_path), use_safetensors) - missing_keys = [] load_state_dict_into_model( model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True @@ -357,6 +361,27 @@ def _load(name: str): if self.verbose and self.coordinator.is_master(): logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + if len(missing_keys) == 0: + raise RuntimeError( + "No weigth is loaded into the model. Please check the checkpoint file and the model structure." + ) + + remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys)) + remain_keys = remain_keys.union(set(missing_file_keys)) + if len(remain_keys) > 0: + if strict: + error_msgs = "Missing key(s) in state_dict: {}. ".format( + ", ".join('"{}"'.format(k) for k in missing_keys) + ) + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + self.__class__.__name__, "\n\t".join(error_msgs) + ) + ) + else: + if self.coordinator.is_master(): + logging.info(f"The following keys are not loaded from checkpoint: {remain_keys}") + def save_sharded_optimizer( self, optimizer: OptimizerWrapper, From 43975e3508c1d0637ec177e61fc2fb9e9783d944 Mon Sep 17 00:00:00 2001 From: Elsa Date: Mon, 6 Nov 2023 21:09:35 +0800 Subject: [PATCH 07/14] Use send_forward_recv_backward, etc in 1f1b --- colossalai/pipeline/p2p.py | 147 ++++++++++++-------- colossalai/pipeline/schedule/one_f_one_b.py | 55 ++++++-- 2 files changed, 132 insertions(+), 70 deletions(-) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 72d84d5452e9..b1c910130cd3 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -182,7 +182,7 @@ def create_recv_buffer(tensor_metadata, current_device): return buffer_recv -def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, recv_src, group, current_device): +def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device): buffer_recv = None if recv_tensor_metadata is not None: buffer_recv = create_recv_buffer(recv_tensor_metadata, current_device) @@ -190,11 +190,11 @@ def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, re ops = [] if send_dst is not None: - filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, group) + filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group) if recv_src is not None: assert buffer_recv is not None - filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, group) + filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group) if len(ops) > 0: reqs = dist.batch_isend_irecv(ops) @@ -204,50 +204,12 @@ def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, re return buffer_recv -def _send_object_with_serialization(object: Any, dst: int, group: ProcessGroup, current_device, is_nccl_backend): - if Version(torch.__version__) >= Version("1.13.0"): - object_tensor, object_size_tensor = c10d._object_to_tensor( - object, device=current_device) - else: - object_tensor, object_size_tensor = c10d._object_to_tensor(object) - - if is_nccl_backend: - object_size_tensor = object_size_tensor.to(current_device) - object_tensor = object_tensor.to(current_device) - - c10d.send(object_size_tensor, dst=dst, group=group) - c10d.send(object_tensor, dst=dst, group=group) - - -def _recv_object_with_serialization(src: int, group: ProcessGroup, current_device, is_nccl_backend): - object_size_tensor = torch.empty(1, dtype=torch.long) - if is_nccl_backend: - object_size_tensor = object_size_tensor.to(current_device) - - c10d.recv(object_size_tensor, src=src, group=group) - - object_tensor = torch.empty(object_size_tensor.item(), dtype=torch.uint8) - if is_nccl_backend: - object_tensor = object_tensor.to(current_device) - - c10d.recv(object_tensor, src=src, group=group) - - object_tensor = object_tensor.type(torch.uint8) - if object_tensor.device != torch.device("cpu"): - object_tensor = object_tensor.cpu() - - unpickle_object = _cuda_safe_tensor_to_object(object_tensor, object_size_tensor.item()) - - if ( - isinstance(unpickle_object, torch.Tensor) - and unpickle_object.device.index != torch.cuda.current_device() - ): - unpickle_object = unpickle_object.cuda() - - return unpickle_object - - -def _send_recv_serialization_object(object: Any, send_dst: Optional[int], recv_src: Optional[int], group: ProcessGroup, current_device, is_nccl_backend): +def _send_recv_serialization_object( + object: Any, + send_dst: Optional[int], recv_src: Optional[int], + send_group: Optional[ProcessGroup], recv_group: Optional[ProcessGroup], + current_device, + is_nccl_backend): ops = [] send_object_tensor = None if object is not None and send_dst is not None: @@ -260,14 +222,14 @@ def _send_recv_serialization_object(object: Any, send_dst: Optional[int], recv_s send_object_size_tensor = send_object_size_tensor.to(current_device) send_object_tensor = send_object_tensor.to(current_device) - filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, group) + filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group) recv_object_size_tensor = None if recv_src is not None: recv_object_size_tensor = torch.empty(1, dtype=torch.long) if is_nccl_backend: recv_object_size_tensor = recv_object_size_tensor.to(current_device) - filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, group) + filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group) if len(ops) > 0: reqs = dist.batch_isend_irecv(ops) @@ -278,14 +240,14 @@ def _send_recv_serialization_object(object: Any, send_dst: Optional[int], recv_s ops = [] if send_dst is not None and send_object_tensor is not None: - filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, group) + filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group) recv_object_tensor = None if recv_src is not None and recv_object_size_tensor is not None: recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8) if is_nccl_backend: recv_object_tensor = recv_object_tensor.to(current_device) - filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, group) + filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group) if len(ops) > 0: reqs = dist.batch_isend_irecv(ops) @@ -324,17 +286,24 @@ def _check_if_fast_send_available(object): return False -def _send_recv( +def _communicate( object, send_dst: Optional[int], recv_src: Optional[int], - group: ProcessGroup + send_group: Optional[ProcessGroup] = None, + recv_group: Optional[ProcessGroup] = None, ) -> Any: - if c10d._rank_not_in_group(group): - c10d._warn_not_in_group("_send_recv") + if c10d._rank_not_in_group(send_group) or c10d._rank_not_in_group(recv_group): + c10d._warn_not_in_group("_communicate") return - current_device, is_nccl_backend = check_device(group) + current_send_device, is_send_nccl_backend = check_device(send_group) + current_recv_device, is_recv_nccl_backend = check_device(recv_group) + + is_nccl_backend = is_send_nccl_backend and is_recv_nccl_backend + + assert current_send_device == current_recv_device + current_device = current_send_device assert (send_dst is not None) or (recv_src is not None) @@ -359,7 +328,7 @@ def _send_recv( for k, v in object.items(): send_metadata.append((k, v.shape, v.dtype, v.requires_grad)) - recv_metadata = _send_recv_serialization_object(send_metadata, send_dst, recv_src, group, current_device, is_nccl_backend) + recv_metadata = _send_recv_serialization_object(send_metadata, send_dst, recv_src, send_group, recv_group, current_device, is_nccl_backend) if recv_metadata is not None: assert type(recv_metadata) is list and len(recv_metadata) >= 2 if recv_metadata[0] == '__serialization__': @@ -377,7 +346,7 @@ def _send_recv( elif type(object) is dict: send_tensor_list = list(object.values()) - recv_buffer = _batch_send_recv_tensor(send_tensor_list, recv_metadata, send_dst, recv_src, group, current_device) + recv_buffer = _batch_send_recv_tensor(send_tensor_list, recv_metadata, send_dst, recv_src, send_group, recv_group, current_device) if recv_metadata is not None: assert recv_buffer is not None @@ -405,7 +374,7 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None: Returns: None """ - _send_recv(object, dst, None, group) + _communicate(object, send_dst=dst, recv_src=None, send_group=group) def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: @@ -417,7 +386,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any: Returns: Any: Object received from src. """ - return _send_recv(None, None, src, group) + return _communicate(None, send_dst=None, recv_src=src, recv_group=group) def _p2p_comm( @@ -562,6 +531,64 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None: cur_rank = self.stage_manager.get_rank() _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) + def send_forward_recv_backward(self, input_object: Any, next_rank: int = None) -> Any: + """Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline + + Args: + input_object (Any): Object to be sent. + next_rank (int, optional): The rank of the sender and recipient of the tensor + """ + if next_rank is None: + next_rank = self.stage_manager.get_next_rank() + + cur_rank = self.stage_manager.get_rank() + group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank) + return _communicate( + input_object, next_rank, next_rank, + send_group=group, recv_group=group, + ) + + def send_backward_recv_forward(self, input_object: Any, prev_rank: int = None) -> Any: + """Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline + + Args: + input_object (Any): Object to be sent. + prev_rank (int, optional): The rank of the sender and recipient of the tensor + """ + if prev_rank is None: + prev_rank = self.stage_manager.get_prev_rank() + + cur_rank = self.stage_manager.get_rank() + group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank) + return _communicate( + input_object, prev_rank, prev_rank, + send_group=group, recv_group=group, + ) + + def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any: + """Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline. + + Args: + input_object (Any): Object to be sent. + prev_rank (int, optional): The rank of the sender of the tensor + next_rank (int, optional): The rank of the recipient of the tensor + """ + if prev_rank is None: + prev_rank = self.stage_manager.get_prev_rank() + if next_rank is None: + next_rank = self.stage_manager.get_next_rank() + + cur_rank = self.stage_manager.get_rank() + recv_group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank) + send_group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank) + return _communicate( + input_object, + send_dst=next_rank, + recv_src=prev_rank, + send_group=send_group, + recv_group=recv_group, + ) + def p2p_communicate( self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16 ) -> None: diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 4eaf135fd5db..0735628bcb9d 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -127,6 +127,17 @@ def send_forward(self, output_object: Any, next_rank: int = None) -> None: if not self.stage_manager.is_last_stage(): self.comm.send_forward(output_object, next_rank) + def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any: + """Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline. + For 1F1B. + + Args: + output_object (Any): Object to be sent. + next_rank (int, optional): The rank of the recipient of the tensor. + """ + if not self.stage_manager.is_last_stage(): + return self.comm.send_forward_recv_backward(output_object, next_rank) + def send_backward(self, input_object: Any, prev_rank: int = None) -> None: """Sends the gradient tensor to the previous stage in pipeline. For 1F1B. @@ -138,6 +149,33 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None: if not self.stage_manager.is_first_stage(): self.comm.send_backward(input_object, prev_rank) + def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any: + """Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline. + For 1F1B. + + Args: + output_object (Any): Object to be sent. + prev_rank (int, optional): The rank of the recipient of the tensor. + """ + if not self.stage_manager.is_first_stage(): + return self.comm.send_backward_recv_forward(output_object, prev_rank) + + def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any: + """Sends the input tensor to the next stage and copy the input tensor from the previous stage in pipeline. + For 1F1B. + + Args: + input_object (Any): Object to be sent. + prev_rank (int, optional): The previous rank of the recipient of the tensor. + next_rank (int, optional): The next rank of the recipient of the tensor. + """ + if self.stage_manager.is_first_stage(): + return self.comm.send_forward(input_object, next_rank) + elif self.stage_manager.is_last_stage(): + return self.comm.recv_forward(prev_rank) + else: + return self.comm.send_forward_recv_forward(input_object, prev_rank, next_rank) + def forward_step( self, model: Module, @@ -287,15 +325,12 @@ def forward_backward_step( output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) if forward_only: - self.send_forward(output_obj) - - if not last_iteration: - input_obj = self.recv_forward() - + if last_iteration: + self.send_forward(output_obj) + else: + input_obj = self.send_forward_recv_forward(output_obj) else: - # TODO adjust here - self.send_forward(output_obj) - output_obj_grad = self.recv_backward() + output_obj_grad = self.send_forward_recv_backward(output_obj) # Add input_obj and output_obj to end of list. input_objs.append(input_obj) @@ -308,10 +343,10 @@ def forward_backward_step( input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) if last_iteration: + self.send_backward(input_obj_grad) input_obj = None else: - input_obj = self.recv_forward() - self.send_backward(input_obj_grad) + input_obj = self.send_backward_recv_forward(input_obj_grad) # Run cooldown backward passes. if not forward_only: From 9427324949d61add25495bdeffc233a9f30b5c8e Mon Sep 17 00:00:00 2001 From: Elsa Date: Tue, 7 Nov 2023 20:21:04 +0800 Subject: [PATCH 08/14] Use dataclass for metdata Remove torch.cuda.synchronize() as suggested --- colossalai/pipeline/p2p.py | 82 +++++++++++++++++++++++--------------- 1 file changed, 49 insertions(+), 33 deletions(-) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index b1c910130cd3..f6b1124778bb 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -4,10 +4,13 @@ import io import pickle import re -from typing import Any, List, Optional, Union, Dict +from typing import Any, List, Optional, Union +from collections import namedtuple import torch import torch.distributed as dist +from dataclasses import dataclass +from enum import Enum from packaging.version import Version from torch.distributed import ProcessGroup from torch.distributed import distributed_c10d as c10d @@ -156,6 +159,22 @@ def check_device(group): return current_device, is_nccl_backend +TensorMetadata = namedtuple('TensorMetadata', ['key', 'shape', 'dtype', 'requires_grad']) + + +class P2PDataType(Enum): + serialization = 0 + tensor = 1 + list = 2 + dict = 3 + + +@dataclass +class P2PMetadata: + data_type: P2PDataType + content: Union[List[TensorMetadata], TensorMetadata, Any] + + def filling_ops_queue(obj, comm_op, comm_rank, ops_queue, group): if isinstance(obj, torch.Tensor): op_to_add = dist.P2POp(comm_op, obj, comm_rank, group) @@ -166,20 +185,19 @@ def filling_ops_queue(obj, comm_op, comm_rank, ops_queue, group): ops_queue.append(op_to_add) -def create_recv_buffer(tensor_metadata, current_device): - if tensor_metadata[0] == 0: - tensor_shape, tensor_dtype, tensor_requires_grad = tensor_metadata[1] - tensor_recv = torch.empty( - tensor_shape, requires_grad=tensor_requires_grad, device=current_device, dtype=tensor_dtype) +def create_recv_buffer(p2p_metadata: P2PMetadata, current_device): + if p2p_metadata.data_type == P2PDataType.tensor: + metadata = p2p_metadata.content + tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype) return tensor_recv - elif tensor_metadata[0] == 1 or tensor_metadata[0] == 2: + elif p2p_metadata.data_type in (P2PDataType.list, P2PDataType.dict): buffer_recv = [] - for tensor_data in tensor_metadata[1:]: - tensor_shape, tensor_dtype, tensor_requires_grad = tensor_data[-3:] - tensor_recv = torch.empty( - tensor_shape, requires_grad=tensor_requires_grad, device=current_device, dtype=tensor_dtype) + for metadata in p2p_metadata.content: + tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype) buffer_recv.append(tensor_recv) return buffer_recv + else: + raise ValueError(f"Unknown data_type: {p2p_metadata.data_type}") def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device): @@ -200,7 +218,6 @@ def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, re reqs = dist.batch_isend_irecv(ops) for req in reqs: req.wait() - torch.cuda.synchronize() return buffer_recv @@ -235,7 +252,6 @@ def _send_recv_serialization_object( reqs = dist.batch_isend_irecv(ops) for req in reqs: req.wait() - torch.cuda.synchronize() ops = [] @@ -253,7 +269,6 @@ def _send_recv_serialization_object( reqs = dist.batch_isend_irecv(ops) for req in reqs: req.wait() - torch.cuda.synchronize() if recv_object_tensor is not None and recv_object_size_tensor is not None: recv_object_tensor = recv_object_tensor.type(torch.uint8) @@ -312,29 +327,30 @@ def _communicate( if send_dst is not None: can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend if not can_fast_send: - send_metadata = ['__serialization__', object] + send_metadata = P2PMetadata(P2PDataType.serialization, object) else: - send_metadata = [] - send_metadata.append('__tensor__') if type(object) is torch.Tensor: - send_metadata.append(0) - send_metadata.append((object.shape, object.dtype, object.requires_grad)) + data_type = P2PDataType.tensor + content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad) elif type(object) is list: - send_metadata.append(1) + data_type = P2PDataType.list + content = [] for v in object: - send_metadata.append((v.shape, v.dtype, v.requires_grad)) + content.append(TensorMetadata(None, v.shape, v.dtype, v.requires_grad)) elif type(object) is dict: - send_metadata.append(2) + data_type = P2PDataType.dict + content = [] for k, v in object.items(): - send_metadata.append((k, v.shape, v.dtype, v.requires_grad)) + content.append(TensorMetadata(k, v.shape, v.dtype, v.requires_grad)) + else: + raise ValueError('Cannot send object of type {}'.format(type(object))) + send_metadata = P2PMetadata(data_type, content) recv_metadata = _send_recv_serialization_object(send_metadata, send_dst, recv_src, send_group, recv_group, current_device, is_nccl_backend) if recv_metadata is not None: - assert type(recv_metadata) is list and len(recv_metadata) >= 2 - if recv_metadata[0] == '__serialization__': - return recv_metadata[1] - else: - recv_metadata.pop(0) + assert type(recv_metadata) is P2PMetadata + if recv_metadata.data_type == P2PDataType.serialization: + return recv_metadata.content if not can_fast_send and send_dst is not None: return @@ -350,18 +366,18 @@ def _communicate( if recv_metadata is not None: assert recv_buffer is not None - if recv_metadata[0] == 0: - return recv_buffer - elif recv_metadata[0] == 1: + if recv_metadata.data_type in [P2PDataType.tensor, P2PDataType.list]: return recv_buffer - else: + elif recv_metadata.data_type == P2PDataType.dict: return { k: v for k, v in zip( - [m[0] for m in recv_metadata[1:]], + [m.key for m in recv_metadata.content], recv_buffer, ) } + else: + raise ValueError('Unknown data type {}'.format(recv_metadata.data_type)) def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None: From c6a1fadd1e0149b755813098a765a4f9a895851a Mon Sep 17 00:00:00 2001 From: Elsa Date: Wed, 8 Nov 2023 14:41:32 +0800 Subject: [PATCH 09/14] Add comment about the torch.cuda.synchronize for potential error --- colossalai/pipeline/p2p.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index f6b1124778bb..035b809b5211 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -218,6 +218,13 @@ def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, re reqs = dist.batch_isend_irecv(ops) for req in reqs: req.wait() + + # Remove synchronization according to Pytorch' documentation + # However, the Megatron-LM does synchronization here + # https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112 + # In case there is potential error, uncomment the following `torch.cuda.synchronize()` + # torch.cuda.synchronize() + return buffer_recv @@ -253,6 +260,9 @@ def _send_recv_serialization_object( for req in reqs: req.wait() + # See the comment in `_batch_send_recv_tensor` + # torch.cuda.synchronize() + ops = [] if send_dst is not None and send_object_tensor is not None: @@ -270,6 +280,9 @@ def _send_recv_serialization_object( for req in reqs: req.wait() + # See the comment in `_batch_send_recv_tensor` + # torch.cuda.synchronize() + if recv_object_tensor is not None and recv_object_size_tensor is not None: recv_object_tensor = recv_object_tensor.type(torch.uint8) if recv_object_tensor.device != torch.device("cpu"): From 2bc04fff3e977fd376953a66058b30d68b478912 Mon Sep 17 00:00:00 2001 From: Elsa Date: Wed, 8 Nov 2023 14:42:44 +0800 Subject: [PATCH 10/14] Typo --- colossalai/pipeline/p2p.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 035b809b5211..af8c75379c40 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -219,7 +219,7 @@ def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, re for req in reqs: req.wait() - # Remove synchronization according to Pytorch' documentation + # Remove synchronization according to Pytorch's documentation # However, the Megatron-LM does synchronization here # https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112 # In case there is potential error, uncomment the following `torch.cuda.synchronize()` From 8f7dee9f1c6e8c675937e51c2e1ec74bddc10cf8 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 13 Nov 2023 12:33:25 +0800 Subject: [PATCH 11/14] Update hybrid_parallel_checkpoint_io.py --- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 13e2400b412b..b7900bc0f217 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -363,7 +363,7 @@ def _load(name: str): if len(missing_keys) == 0: raise RuntimeError( - "No weigth is loaded into the model. Please check the checkpoint file and the model structure." + "No weigth is loaded into the model. Please check the checkpoint files and the model structure." ) remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys)) From 28bd4f907b91312cfe9cd5123a2004313dad0fab Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 15 Nov 2023 10:40:34 +0800 Subject: [PATCH 12/14] Update p2p.py --- colossalai/pipeline/p2p.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index af8c75379c40..14c5c94b88f4 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -219,6 +219,8 @@ def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, re for req in reqs: req.wait() + torch.cuda.synchronize() + # Remove synchronization according to Pytorch's documentation # However, the Megatron-LM does synchronization here # https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112 @@ -260,6 +262,8 @@ def _send_recv_serialization_object( for req in reqs: req.wait() + torch.cuda.synchronize() + # See the comment in `_batch_send_recv_tensor` # torch.cuda.synchronize() @@ -280,6 +284,8 @@ def _send_recv_serialization_object( for req in reqs: req.wait() + torch.cuda.synchronize() + # See the comment in `_batch_send_recv_tensor` # torch.cuda.synchronize() From 1264fef294f709ff267784cf1c69dee9e611e1fe Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 15 Nov 2023 17:01:43 +0800 Subject: [PATCH 13/14] Update one_f_one_b.py --- colossalai/pipeline/schedule/one_f_one_b.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 0735628bcb9d..1f3b80857d6e 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -325,12 +325,14 @@ def forward_backward_step( output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs) if forward_only: - if last_iteration: - self.send_forward(output_obj) - else: - input_obj = self.send_forward_recv_forward(output_obj) + self.send_forward(output_obj) + + if not last_iteration: + input_obj = self.recv_forward() else: - output_obj_grad = self.send_forward_recv_backward(output_obj) + # TODO adjust here + self.send_forward(output_obj) + output_obj_grad = self.recv_backward() # Add input_obj and output_obj to end of list. input_objs.append(input_obj) @@ -343,10 +345,10 @@ def forward_backward_step( input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad) if last_iteration: - self.send_backward(input_obj_grad) input_obj = None else: - input_obj = self.send_backward_recv_forward(input_obj_grad) + input_obj = self.recv_forward() + self.send_backward(input_obj_grad) # Run cooldown backward passes. if not forward_only: From b180eb1676826e42503a989afebd04c651e41f42 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 15 Nov 2023 18:01:24 +0800 Subject: [PATCH 14/14] Update p2p.py --- colossalai/pipeline/p2p.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 14c5c94b88f4..6e49fa36bb83 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -177,10 +177,12 @@ class P2PMetadata: def filling_ops_queue(obj, comm_op, comm_rank, ops_queue, group): if isinstance(obj, torch.Tensor): + obj = obj.contiguous() op_to_add = dist.P2POp(comm_op, obj, comm_rank, group) ops_queue.append(op_to_add) else: for tensor_to_comm in obj: + tensor_to_comm = tensor_to_comm.contiguous() op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank, group) ops_queue.append(op_to_add)