From 9b206639f4c9e73402ebfaf5fc7d202a90991ded Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Fri, 1 Sep 2023 16:14:29 +0800 Subject: [PATCH 1/7] pytree test --- colossalai/pipeline/schedule/_utils.py | 44 ++++++++++++++++--- colossalai/pipeline/schedule/one_f_one_b.py | 19 ++++++-- colossalai/shardformer/policies/chatglm2.py | 5 +++ tests/test_shardformer/test_model/_utils.py | 11 ++--- .../test_model/test_shard_bert.py | 1 + 5 files changed, 63 insertions(+), 17 deletions(-) diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index 5cd934b76822..302501aed9c4 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -1,9 +1,41 @@ -from typing import Any, List, Optional +from collections import OrderedDict +from typing import Any, List, Optional, Tuple import torch import torch.cuda from torch.nn import Module -from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten +from torch.utils._pytree import SUPPORTED_NODES, LeafSpec, TreeSpec, _is_leaf, tree_flatten, tree_map, tree_unflatten + + +def tree_map_hf(fn: Any, pytree: Any): + flat_args, spec = tree_flatten_hf(pytree) + return tree_unflatten([fn(i) for i in flat_args], spec) + + +# use this flatten function to handle the ModelingOutput Class instance. +def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]: + """Flattens a pytree into a list of values and a TreeSpec that can be used + to reconstruct the pytree. + """ + if _is_leaf(pytree): + return [pytree], LeafSpec() + + if isinstance(pytree, OrderedDict): + node_type = OrderedDict + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(pytree) + + # Recursively flatten the children + result: List[Any] = [] + children_specs: List['TreeSpec'] = [] + for child in child_pytrees: + flat, child_spec = tree_flatten(child) + result += flat + children_specs.append(child_spec) + return result, TreeSpec(node_type, context, children_specs) + else: + result, tree_spec = tree_flatten(pytree) + return result, tree_spec def to_device(x: Any, device: Optional[torch.device] = None) -> Any: @@ -104,7 +136,7 @@ def detach(x: Any) -> Any: return x -def merge_batch(data: List[Any]) -> Any: +def merge_batch(data: List[Any], batch_size_dim=0) -> Any: """Merge micro batches into a batch. Args: @@ -118,15 +150,17 @@ def merge_batch(data: List[Any]) -> Any: flattened_data = [] tree_spec = None for d in data: - elems, tree_spec = tree_flatten(d) + # elems should be an instance of OrderedDict + elems, tree_spec = tree_flatten_hf(d) flattened_data.append(elems) merged_data = [] + for elem_batch in zip(*flattened_data): if isinstance(elem_batch[0], torch.Tensor): if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs merged_data.append(None) else: - merged_data.append(torch.cat(elem_batch, dim=0)) + merged_data.append(torch.cat(elem_batch, dim=batch_size_dim)) else: merged_data.append(list(elem_batch)) return tree_unflatten(merged_data, tree_spec) diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 11b2655a22c9..ec53a67716c4 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -6,12 +6,21 @@ from torch.nn import Module from torch.utils._pytree import tree_map -from colossalai.interface import OptimizerWrapper +from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.utils.cuda import get_current_device -from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device +from ._utils import ( + detach, + get_batch_size, + get_micro_batch, + merge_batch, + model_forward, + retain_grad, + to_device, + tree_map_hf, +) from .base import PipelineSchedule @@ -154,7 +163,7 @@ def forward_step(self, if accum_loss is not None: accum_loss.add_(loss.detach()) if outputs is not None: - outputs.append(tree_map(detach, output_obj)) + outputs.append(tree_map_hf(detach, output_obj)) return loss else: return output_obj @@ -302,5 +311,7 @@ def forward_backward_step(self, self.send_backward(input_obj_grad) if outputs is not None: - outputs = merge_batch(outputs) + if isinstance(model, ModelWrapper): + model = model.unwrap() + outputs = merge_batch(outputs, getattr(model, 'batch_size_dim', 0)) return {'loss': accum_loss, 'outputs': outputs} diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index 5bcbc2acc28e..44898847056a 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -41,6 +41,11 @@ def preprocess(self): new_vocab_size = vocab_size + world_size - vocab_size % world_size self.model.resize_token_embeddings(new_vocab_size) + if self.pipeline_stage_manager is not None: + # the batch_size_dim is bounded to Model + bsz_dim = 1 + setattr(self.model, 'batch_size_dim', bsz_dim) + return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 72bb2b025ba4..f77bf7495808 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -191,15 +191,10 @@ def check_output_hidden_state(org_output: Tensor, org_hidden_state = org_output.last_hidden_state - if stage_manager is None: - sharded_hidden_state = sharded_output.last_hidden_state - if stage_manager and stage_manager.is_last_stage(): - pipeline_output = sharded_output['outputs'] - if isinstance(pipeline_output, List): - sharded_hidden_state = torch.cat([output.last_hidden_state for output in pipeline_output], dim=dim) - else: - sharded_hidden_state = pipeline_output.last_hidden_state + sharded_hidden_state = sharded_output['outputs']['last_hidden_state'] + else: + sharded_hidden_state = sharded_output.last_hidden_state assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \ f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 0855e2248710..c779e417052b 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -179,6 +179,7 @@ def run_bert_3d_test(test_config): sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) clear_layout_converter() From f26b70eb7bdcf6b3ee811cd1d03c11ffeb607c6c Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Sat, 2 Sep 2023 23:38:12 +0800 Subject: [PATCH 2/7] test bert --- tests/test_shardformer/test_model/_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index f77bf7495808..98032a67e481 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -192,6 +192,7 @@ def check_output_hidden_state(org_output: Tensor, org_hidden_state = org_output.last_hidden_state if stage_manager and stage_manager.is_last_stage(): + print(sharded_output['outputs']) sharded_hidden_state = sharded_output['outputs']['last_hidden_state'] else: sharded_hidden_state = sharded_output.last_hidden_state From 6c43e56ce0c8f8f4539b3ce90e3239bf436cab97 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Sun, 3 Sep 2023 01:59:58 +0800 Subject: [PATCH 3/7] test bert --- colossalai/pipeline/schedule/_utils.py | 1 + tests/test_shardformer/test_model/_utils.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index 302501aed9c4..da732f6d9ed9 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -163,4 +163,5 @@ def merge_batch(data: List[Any], batch_size_dim=0) -> Any: merged_data.append(torch.cat(elem_batch, dim=batch_size_dim)) else: merged_data.append(list(elem_batch)) + print("merged_data: ", merged_data) return tree_unflatten(merged_data, tree_spec) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 98032a67e481..69e4feacf9f7 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -192,7 +192,9 @@ def check_output_hidden_state(org_output: Tensor, org_hidden_state = org_output.last_hidden_state if stage_manager and stage_manager.is_last_stage(): - print(sharded_output['outputs']) + if isinstance(sharded_output['outputs'], list): + print(len(sharded_output['outputs'])) + print("last hidden shape", sharded_output['outputs'][0].last_hidden_state.shape) sharded_hidden_state = sharded_output['outputs']['last_hidden_state'] else: sharded_hidden_state = sharded_output.last_hidden_state From 8333642963159aeee2f53a3308d1bd98f185fe6f Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Sun, 3 Sep 2023 09:20:56 +0800 Subject: [PATCH 4/7] test bert --- colossalai/pipeline/schedule/_utils.py | 1 + tests/kit/model_zoo/torchrec/__init__.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index da732f6d9ed9..17068f980021 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -150,6 +150,7 @@ def merge_batch(data: List[Any], batch_size_dim=0) -> Any: flattened_data = [] tree_spec = None for d in data: + print('d: ', d) # elems should be an instance of OrderedDict elems, tree_spec = tree_flatten_hf(d) flattened_data.append(elems) diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py index 43952e6998cf..4a19f2449602 100644 --- a/tests/kit/model_zoo/torchrec/__init__.py +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -1 +1 @@ -from .torchrec import * +#from .torchrec import * From 009bc0ae16d321c4e0bb0156d46d52d54f769c08 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Sun, 3 Sep 2023 22:03:55 +0800 Subject: [PATCH 5/7] revise --- colossalai/pipeline/schedule/_utils.py | 4 +--- tests/kit/model_zoo/torchrec/__init__.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index 17068f980021..f5b1450ecc6f 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -17,10 +17,8 @@ def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]: """Flattens a pytree into a list of values and a TreeSpec that can be used to reconstruct the pytree. """ - if _is_leaf(pytree): - return [pytree], LeafSpec() - if isinstance(pytree, OrderedDict): + print("pytree: ") node_type = OrderedDict flatten_fn = SUPPORTED_NODES[node_type].flatten_fn child_pytrees, context = flatten_fn(pytree) diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py index 4a19f2449602..43952e6998cf 100644 --- a/tests/kit/model_zoo/torchrec/__init__.py +++ b/tests/kit/model_zoo/torchrec/__init__.py @@ -1 +1 @@ -#from .torchrec import * +from .torchrec import * From ae80a23b10b82bfea86b7c295de2adb95cfb4890 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 4 Sep 2023 11:16:38 +0800 Subject: [PATCH 6/7] add register --- colossalai/pipeline/schedule/_utils.py | 28 ++++++++++++++++++--- tests/test_shardformer/test_model/_utils.py | 3 --- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index f5b1450ecc6f..1d35d4ed64a1 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -4,7 +4,28 @@ import torch import torch.cuda from torch.nn import Module -from torch.utils._pytree import SUPPORTED_NODES, LeafSpec, TreeSpec, _is_leaf, tree_flatten, tree_map, tree_unflatten +from torch.utils._pytree import ( + SUPPORTED_NODES, + LeafSpec, + TreeSpec, + _is_leaf, + _register_pytree_node, + tree_flatten, + tree_map, + tree_unflatten, +) + + +# this register are for torch under version 1.13.1, maybe removed in the future +def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Any]: + return list(d.values()), list(d.keys()) + + +def _odict_unflatten(values: List[Any], context: Any) -> 'OrderedDict[Any, Any]': + return OrderedDict((key, value) for key, value in zip(context, values)) + + +_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten) def tree_map_hf(fn: Any, pytree: Any): @@ -14,11 +35,11 @@ def tree_map_hf(fn: Any, pytree: Any): # use this flatten function to handle the ModelingOutput Class instance. def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]: - """Flattens a pytree into a list of values and a TreeSpec that can be used + """Flattens a pytree into a list of values an a TreeSpec that can be used to reconstruct the pytree. """ if isinstance(pytree, OrderedDict): - print("pytree: ") + print("pytree: Ordered dict") node_type = OrderedDict flatten_fn = SUPPORTED_NODES[node_type].flatten_fn child_pytrees, context = flatten_fn(pytree) @@ -162,5 +183,4 @@ def merge_batch(data: List[Any], batch_size_dim=0) -> Any: merged_data.append(torch.cat(elem_batch, dim=batch_size_dim)) else: merged_data.append(list(elem_batch)) - print("merged_data: ", merged_data) return tree_unflatten(merged_data, tree_spec) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 69e4feacf9f7..f77bf7495808 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -192,9 +192,6 @@ def check_output_hidden_state(org_output: Tensor, org_hidden_state = org_output.last_hidden_state if stage_manager and stage_manager.is_last_stage(): - if isinstance(sharded_output['outputs'], list): - print(len(sharded_output['outputs'])) - print("last hidden shape", sharded_output['outputs'][0].last_hidden_state.shape) sharded_hidden_state = sharded_output['outputs']['last_hidden_state'] else: sharded_hidden_state = sharded_output.last_hidden_state From d5651c552a77c96f641faa2d0b99a0d15c691d51 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 4 Sep 2023 14:11:17 +0800 Subject: [PATCH 7/7] add register --- colossalai/pipeline/schedule/_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index 1d35d4ed64a1..583558551b3c 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -39,7 +39,6 @@ def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]: to reconstruct the pytree. """ if isinstance(pytree, OrderedDict): - print("pytree: Ordered dict") node_type = OrderedDict flatten_fn = SUPPORTED_NODES[node_type].flatten_fn child_pytrees, context = flatten_fn(pytree) @@ -48,7 +47,7 @@ def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]: result: List[Any] = [] children_specs: List['TreeSpec'] = [] for child in child_pytrees: - flat, child_spec = tree_flatten(child) + flat, child_spec = tree_flatten_hf(child) result += flat children_specs.append(child_spec) return result, TreeSpec(node_type, context, children_specs) @@ -169,7 +168,6 @@ def merge_batch(data: List[Any], batch_size_dim=0) -> Any: flattened_data = [] tree_spec = None for d in data: - print('d: ', d) # elems should be an instance of OrderedDict elems, tree_spec = tree_flatten_hf(d) flattened_data.append(elems)