From 272fc1052e35ff645e3f7900ed03705fefbab38d Mon Sep 17 00:00:00 2001 From: Stephan Date: Fri, 8 Mar 2024 23:07:30 +0800 Subject: [PATCH 1/3] Remove unnecessary calls to deepcopy --- colossalai/tensor/d_tensor/sharding_spec.py | 3 +-- colossalai/tensor/sharding_spec.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/colossalai/tensor/d_tensor/sharding_spec.py b/colossalai/tensor/d_tensor/sharding_spec.py index 2ac0ca73e4b8..a1b3cb0fbbf6 100644 --- a/colossalai/tensor/d_tensor/sharding_spec.py +++ b/colossalai/tensor/d_tensor/sharding_spec.py @@ -67,7 +67,6 @@ def build_difference_2d_dict(self): difference_dict = {} for source_spec in source_spec_list: for target_spec in target_spec_list: - spec_pair = (deepcopy(source_spec), deepcopy(target_spec)) source_shard_list = self._convert_str_to_shard_list(source_spec) target_shard_list = self._convert_str_to_shard_list(target_spec) @@ -112,7 +111,7 @@ def build_difference_2d_dict(self): else: difference = NAN - difference_dict[spec_pair] = difference + difference_dict[(source_spec, target_spec)] = difference self.difference_dict = difference_dict diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py index b78ef6d97dd4..714e99498b76 100644 --- a/colossalai/tensor/sharding_spec.py +++ b/colossalai/tensor/sharding_spec.py @@ -71,7 +71,6 @@ def build_difference_2d_dict(self): difference_dict = {} for source_spec in source_spec_list: for target_spec in target_spec_list: - spec_pair = (deepcopy(source_spec), deepcopy(target_spec)) source_shard_list = self._convert_str_to_shard_list(source_spec) target_shard_list = self._convert_str_to_shard_list(target_spec) @@ -116,7 +115,7 @@ def build_difference_2d_dict(self): else: difference = NAN - difference_dict[spec_pair] = difference + difference_dict[(source_spec, target_spec)] = difference self.difference_dict = difference_dict From 324269fd3e4614655a400164a78b154f142ab3f5 Mon Sep 17 00:00:00 2001 From: Stephan Date: Fri, 8 Mar 2024 23:19:02 +0800 Subject: [PATCH 2/3] Build DimSpec's difference dict only once This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough. --- colossalai/tensor/d_tensor/sharding_spec.py | 84 ++++++++++++--------- colossalai/tensor/sharding_spec.py | 84 ++++++++++++--------- 2 files changed, 100 insertions(+), 68 deletions(-) diff --git a/colossalai/tensor/d_tensor/sharding_spec.py b/colossalai/tensor/d_tensor/sharding_spec.py index a1b3cb0fbbf6..9d652b801571 100644 --- a/colossalai/tensor/d_tensor/sharding_spec.py +++ b/colossalai/tensor/d_tensor/sharding_spec.py @@ -1,4 +1,3 @@ -from copy import deepcopy from typing import Dict, List from ..utils import merge_same_dim_mesh_list @@ -23,10 +22,11 @@ class DimSpec: Otherwise, the element in shard_list means the data will be sharded in that dimension. """ + _DIFFERENCE_DICT = None + def __init__(self, shard_list): self.is_replica = len(shard_list) == 0 self.shard_list = shard_list - self.build_difference_2d_dict() def __eq__(self, other): return str(self) == str(other) @@ -39,24 +39,43 @@ def __repr__(self): target += str(dim) return target - def _convert_str_to_shard_list(self, str_spec): + @property + def difference_dict(self): """ - Convert str_spec into shard_list. + Returns the difference dict, and lazily initializes it when needed - Argument: - str_spec(str): dim spec in str type. + Return: + difference_dict(Dict[Tuple[int, int], Union[int, float, str]]): + difference dict """ + if self._DIFFERENCE_DICT is None: + self._DIFFERENCE_DICT = self._build_difference_2d_dict() - if str_spec == "R": - return [] - if str_spec == "S0": - return [0] - if str_spec == "S1": - return [1] - if str_spec == "S01": - return [0, 1] + return self._DIFFERENCE_DICT - def build_difference_2d_dict(self): + def dim_diff(self, other): + """ + The difference between two _DimSpec. + + Argument: + other(_DimSpec): the dim spec to compare with. + + Return: + difference(int): the difference between two _DimSpec. + + Example: + dim_spec = _DimSpec([0]) + other_dim_spec = _DimSpec([0, 1]) + print(dim_spec.difference(other_dim_spec)) + + Output: + 5 + """ + difference = self.difference_dict[(str(self), str(other))] + return difference + + @classmethod + def _build_difference_2d_dict(cls): """ Build a difference mapping for 2D device mesh case. It will be used to compute the difference between DimSpec pairs. @@ -67,8 +86,8 @@ def build_difference_2d_dict(self): difference_dict = {} for source_spec in source_spec_list: for target_spec in target_spec_list: - source_shard_list = self._convert_str_to_shard_list(source_spec) - target_shard_list = self._convert_str_to_shard_list(target_spec) + source_shard_list = cls._convert_str_to_shard_list(source_spec) + target_shard_list = cls._convert_str_to_shard_list(target_spec) # source same as target if source_shard_list == target_shard_list: @@ -113,28 +132,25 @@ def build_difference_2d_dict(self): difference = NAN difference_dict[(source_spec, target_spec)] = difference - self.difference_dict = difference_dict + return difference_dict - def dim_diff(self, other): + @staticmethod + def _convert_str_to_shard_list(str_spec): """ - The difference between two _DimSpec. + Convert str_spec into shard_list. Argument: - other(_DimSpec): the dim spec to compare with. - - Return: - difference(int): the difference between two _DimSpec. - - Example: - dim_spec = _DimSpec([0]) - other_dim_spec = _DimSpec([0, 1]) - print(dim_spec.difference(other_dim_spec)) - - Output: - 5 + str_spec(str): dim spec in str type. """ - difference = self.difference_dict[(str(self), str(other))] - return difference + + if str_spec == "R": + return [] + if str_spec == "S0": + return [0] + if str_spec == "S1": + return [1] + if str_spec == "S01": + return [0, 1] class ShardingSpec: diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py index 714e99498b76..d8e542cbb307 100644 --- a/colossalai/tensor/sharding_spec.py +++ b/colossalai/tensor/sharding_spec.py @@ -1,5 +1,4 @@ import operator -from copy import deepcopy from functools import reduce import torch @@ -27,10 +26,11 @@ class _DimSpec: Otherwise, the element in shard_list means the data will be sharded in that dimension. """ + _DIFFERENCE_DICT = None + def __init__(self, shard_list): self.is_replica = len(shard_list) == 0 self.shard_list = shard_list - self.build_difference_2d_dict() def __eq__(self, other): return str(self) == str(other) @@ -43,24 +43,43 @@ def __repr__(self): target += str(dim) return target - def _convert_str_to_shard_list(self, str_spec): + @property + def difference_dict(self): """ - Convert str_spec into shard_list. + Returns the difference dict, and lazily initializes it when needed - Argument: - str_spec(str): dim spec in str type. + Return: + difference_dict(Dict[Tuple[int, int], Union[int, float, str]]): + difference dict """ + if self._DIFFERENCE_DICT is None: + self._DIFFERENCE_DICT = self._build_difference_2d_dict() - if str_spec == "R": - return [] - if str_spec == "S0": - return [0] - if str_spec == "S1": - return [1] - if str_spec == "S01": - return [0, 1] + return self._DIFFERENCE_DICT - def build_difference_2d_dict(self): + def difference(self, other): + """ + The difference between two _DimSpec. + + Argument: + other(_DimSpec): the dim spec to compare with. + + Return: + difference(int): the difference between two _DimSpec. + + Example: + dim_spec = _DimSpec([0]) + other_dim_spec = _DimSpec([0, 1]) + print(dim_spec.difference(other_dim_spec)) + + Output: + 5 + """ + difference = self.difference_dict[(str(self), str(other))] + return difference + + @classmethod + def _build_difference_2d_dict(cls): """ Build a difference mapping for 2D device mesh case. It will be used to compute the difference between DimSpec pairs. @@ -71,8 +90,8 @@ def build_difference_2d_dict(self): difference_dict = {} for source_spec in source_spec_list: for target_spec in target_spec_list: - source_shard_list = self._convert_str_to_shard_list(source_spec) - target_shard_list = self._convert_str_to_shard_list(target_spec) + source_shard_list = cls._convert_str_to_shard_list(source_spec) + target_shard_list = cls._convert_str_to_shard_list(target_spec) # source same as target if source_shard_list == target_shard_list: @@ -117,28 +136,25 @@ def build_difference_2d_dict(self): difference = NAN difference_dict[(source_spec, target_spec)] = difference - self.difference_dict = difference_dict + return difference_dict - def difference(self, other): + @staticmethod + def _convert_str_to_shard_list(str_spec): """ - The difference between two _DimSpec. + Convert str_spec into shard_list. Argument: - other(_DimSpec): the dim spec to compare with. - - Return: - difference(int): the difference between two _DimSpec. - - Example: - dim_spec = _DimSpec([0]) - other_dim_spec = _DimSpec([0, 1]) - print(dim_spec.difference(other_dim_spec)) - - Output: - 5 + str_spec(str): dim spec in str type. """ - difference = self.difference_dict[(str(self), str(other))] - return difference + + if str_spec == "R": + return [] + if str_spec == "S0": + return [0] + if str_spec == "S1": + return [1] + if str_spec == "S01": + return [0, 1] class ShardingSpecException(Exception): From 462db3405f945dfd9c4079df8b9110ed327bf656 Mon Sep 17 00:00:00 2001 From: Stephan Date: Sat, 13 Jul 2024 19:31:52 +0800 Subject: [PATCH 3/3] Fix documentation of DimSpec's difference method --- colossalai/tensor/d_tensor/sharding_spec.py | 12 ++++++------ colossalai/tensor/sharding_spec.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/colossalai/tensor/d_tensor/sharding_spec.py b/colossalai/tensor/d_tensor/sharding_spec.py index 9d652b801571..307d98bea757 100644 --- a/colossalai/tensor/d_tensor/sharding_spec.py +++ b/colossalai/tensor/d_tensor/sharding_spec.py @@ -55,18 +55,18 @@ def difference_dict(self): def dim_diff(self, other): """ - The difference between two _DimSpec. + The difference between two DimSpec. Argument: - other(_DimSpec): the dim spec to compare with. + other(DimSpec): the dim spec to compare with. Return: - difference(int): the difference between two _DimSpec. + difference(int): the difference between two DimSpec. Example: - dim_spec = _DimSpec([0]) - other_dim_spec = _DimSpec([0, 1]) - print(dim_spec.difference(other_dim_spec)) + dim_spec = DimSpec([0]) + other_dim_spec = DimSpec([0, 1]) + print(dim_spec.dim_diff(other_dim_spec)) Output: 5 diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py index d8e542cbb307..fb42afab75b9 100644 --- a/colossalai/tensor/sharding_spec.py +++ b/colossalai/tensor/sharding_spec.py @@ -82,7 +82,7 @@ def difference(self, other): def _build_difference_2d_dict(cls): """ Build a difference mapping for 2D device mesh case. It will be used to - compute the difference between DimSpec pairs. + compute the difference between _DimSpec pairs. """ source_spec_list = ["R", "S0", "S1", "S01"]