diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index 4262d76173e4..b4b7b0e794d1 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -188,7 +188,7 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV remove_strategy_list = [] for strategy in self.strategies_vector: shard_axis_list = [] - last_axis = len(self.device_mesh.mesh_shape) - 1 + last_axis = len(self.device_mesh.shape) - 1 for op_data, sharding_spec in strategy.sharding_specs.items(): if op_data.data is not None and isinstance(op_data.data, torch.Tensor): for dim, shard_axes in sharding_spec.dim_partition_dict.items(): diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py index 1ce5a08f2d6b..aa1581b99e0f 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -984,7 +984,7 @@ def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1): def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] device_mesh_is_1d = True - if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape: + if len(self.device_mesh.shape) == 2 and 1 not in self.device_mesh.shape: device_mesh_is_1d = False if device_mesh_is_1d: @@ -992,10 +992,10 @@ def collate_strategies(self) -> List[ShardingStrategy]: # Sb = Sb x Sb # can be None as it is only for 1D device mesh # only for 1D device mesh - if len(self.device_mesh.mesh_shape) == 1: + if len(self.device_mesh.shape) == 1: mesh_dim = 0 else: - mesh_dim = self.device_mesh.mesh_shape.index(1) + mesh_dim = self.device_mesh.shape.index(1) strategy_list.append(self.split_one_batch_dim(mesh_dim)) else: # for 2D device mesh diff --git a/colossalai/auto_parallel/tensor_shard/utils/misc.py b/colossalai/auto_parallel/tensor_shard/utils/misc.py index 9e402dab7578..475e95fc4326 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/misc.py +++ b/colossalai/auto_parallel/tensor_shard/utils/misc.py @@ -46,8 +46,8 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens # make sure all dims are covered in sharding spec sharding_len = len(sharding_spec.sharding_sequence) tensor_num_dim = tensor.dim() - num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0] - num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1] + num_devices_in_col = sharding_spec.device_mesh.shape[0] + num_devices_in_row = sharding_spec.device_mesh.shape[1] assert sharding_len == tensor_num_dim, \ f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).' diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 3dada00cd9b5..485577b9650c 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -10,7 +10,7 @@ import torch.nn as nn from torch.optim import Optimizer -from colossalai.tensor.d_tensor.d_tensor import DTensor +from colossalai.tensor.d_tensor import is_distributed_tensor SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" @@ -99,7 +99,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) for key, weight in state_dict.items(): ret_block = None ret_block_size = 0 - if type(weight) != DTensor: + if not is_distributed_tensor(weight): weight_size = calculate_tensor_size(weight) # If this weight is going to tip up over the maximal size, we split. @@ -146,7 +146,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> continue # If the states are stored as DTensors, mark isDTensor as true. - if type(state_tensor) == DTensor: + if is_distributed_tensor(state_tensor): isDTensor = True state_size += calculate_tensor_size(state_tensor) diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py index 2a5f747fbc23..3e96310e1890 100644 --- a/colossalai/device/device_mesh.py +++ b/colossalai/device/device_mesh.py @@ -3,11 +3,19 @@ with some changes. """ import operator +from dataclasses import dataclass from functools import reduce -from typing import List, Tuple +from typing import Dict, List, Union import torch import torch.distributed as dist +from torch.distributed import ProcessGroup + + +@dataclass +class ProcessGroupContainer: + process_group: ProcessGroup + ranks: List[int] # modified from alpa LogicalDeviceMesh(https://github.com/alpa-projects/alpa/blob/main/alpa/shard_parallel/auto_sharding.py) @@ -27,9 +35,11 @@ class DeviceMesh: during initializing the DeviceMesh instance if the init_process_group set to True. Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group. (default: False) - need_flatten(bool, optional): initialize flatten_device_mesh during initializing the DeviceMesh instance if the need_flatten set to True. + device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda') """ + _DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"} + def __init__(self, physical_mesh_id: torch.Tensor, mesh_shape: torch.Size = None, @@ -37,160 +47,442 @@ def __init__(self, mesh_alpha: List[float] = None, mesh_beta: List[float] = None, init_process_group: bool = False, - need_flatten: bool = True): - self.physical_mesh_id = physical_mesh_id + device: str = 'cuda'): + # ============================ + # Physical & Logical Mesh IDs + # ============================ + self._physical_mesh_id = physical_mesh_id + assert physical_mesh_id.dim() == 1, "physical_mesh_id should be a 1D tensor." + + # logical mesh ids can be obtained via two ways + # 1. provide physical mesh id and provide mesh shape + # 2. directly supply the logical mesh id + assert mesh_shape is None or logical_mesh_id is None, \ + "Only one of mesh_shape and logical_mesh_id can be specified." \ + "Logical mesh IDs are obtained from either mesh_shape + phyiscal_mesh_id or directly from the user-supplied logical_mesh_id" + if logical_mesh_id is None: - self.mesh_shape = mesh_shape - self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape) + self._mesh_shape = mesh_shape + self._logical_mesh_id = self._physical_mesh_id.reshape(self._mesh_shape) else: self._logical_mesh_id = logical_mesh_id - self.mesh_shape = self._logical_mesh_id.shape + self._mesh_shape = self._logical_mesh_id.shape + + # ensure two things: + # 1. logical and physical mesh IDs should contain the same elements + # 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed + assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \ + "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id." + assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \ + "Found duplicate IDs in the phyiscal_mesh_id and this is not allowed, please check your physical_mesh_id again." + assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \ + "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again." - # map global rank into logical rank - self.convert_map = {} - self._global_rank_to_logical_rank_map(self._logical_mesh_id, []) + # =============================================== # coefficient for alpha-beta communication model + # alpha is latency and beta is bandwidth + # =============================================== + # if the values are not provided, we assume they are 1 for simplicity if mesh_alpha is None: - mesh_alpha = [1] * len(self.mesh_shape) + mesh_alpha = [1] * len(self._mesh_shape) if mesh_beta is None: - mesh_beta = [1] * len(self.mesh_shape) + mesh_beta = [1] * len(self._mesh_shape) + self.mesh_alpha = tuple(mesh_alpha) self.mesh_beta = tuple(mesh_beta) - self.init_process_group = init_process_group - self.need_flatten = need_flatten - if self.init_process_group: - self.process_groups_dict = self.create_process_groups_for_logical_mesh() - if self.need_flatten and self._logical_mesh_id.dim() > 1: - self.flatten_device_mesh = self.flatten() - # Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten()) - # self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha, - # self.mesh_beta) + + # ensure the alpha and beta have the same shape + assert len(self.mesh_alpha) == len(self.mesh_beta), \ + "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again." + + # ========================= + # Device for Process Group + # ========================= + self._device = device + self._dist_backend = self._DIST_BACKEND[device] + + # ========================= + # Process Group Management + # ========================= + # the _global_to_local_rank_mapping is structured as follows + # { + # : [ , , , ...] + # } + self._global_to_local_rank_mapping = dict() + self._init_global_to_logical_rank_mapping(mapping=self._global_to_local_rank_mapping, + tensor=self.logical_mesh_id) + + # create process group + self._process_group_dict = {} + self._ranks_in_the_process_group = {} + self._global_rank_of_current_process = None + self._is_initialized = False + + # attribute used to inidicate whether this objectd + # is created using DeviceMesh.from_process_group + # this attribute can be used to do some check in methods + # such get_process_group as no global rank information + # is known if created with from_process_group + self._is_init_from_process_group = False + + # initialize process group if specified + self._init_ranks_in_the_same_group() + self._init_process_group = init_process_group + if init_process_group: + self.init_logical_process_group() @property - def shape(self): - return self.mesh_shape + def shape(self) -> torch.Size: + """ + Return the shape of the logical mesh. + """ + return self._mesh_shape @property - def num_devices(self): - return reduce(operator.mul, self.physical_mesh_id.shape, 1) + def num_devices(self) -> int: + """ + Return the number of devices contained in the device mesh. + """ + return reduce(operator.mul, self._physical_mesh_id.shape, 1) @property - def logical_mesh_id(self): + def logical_mesh_id(self) -> torch.Tensor: + """ + Return the logical mesh id. + """ return self._logical_mesh_id - def __deepcopy__(self, memo): + @property + def is_initialized(self) -> bool: + """ + Return whether the process group is initialized. + """ + return self._is_initialized + + @staticmethod + def from_process_group(process_group: Union[ProcessGroup, List[ProcessGroup]]) -> "DeviceMesh": + """ + Create a DeviceMesh instance from the current process group. Please note that the DeviceMesh object created with this method + will not have information about the physical mesh id, and thus will not be able to query for other ranks and perform alpha-beta communication. + + Args: + process_group (Union[ProcessGroup, List[ProcessGroup]]): the process group or a list of process groups for the device mesh. + If the input is a ProcessGroup object, a 1D DeviceMesh object will be created. If the input is a list of ProcessGroup objects, + the ProcessGroup at the ith index will correspond to the process group in the ith axis of the device mesh. + + Returns: + DeviceMesh: the device mesh instance. + """ + + def _get_device_by_backend(process_group): + """ + Get the device type given a process group's backend. + """ + backend = dist.get_backend(process_group) + for _device, _backend in DeviceMesh._DIST_BACKEND.items(): + if _backend == backend: + return _device + return None + + if isinstance(process_group, ProcessGroup): + process_group = [process_group] + + # get mesh shape + mesh_shape = [dist.get_world_size(pg) for pg in process_group] + + # get device + device_list = [_get_device_by_backend(pg) for pg in process_group] + + # make sure all devices are the same + assert all([device == device_list[0] for device in device_list]), \ + "All devices should be the same, please check your input process groups are created with the same distributed backend." + + # create a fake physical mesh id + # as we only get the process group associated with the current process, + # we cannot get the global ranks for all processes in the mesh + # therefore, we only use this fake physical mesh id to create the device mesh + # and will remove this fake physical mesh id later + fake_physical_mesh_id = torch.arange(reduce(operator.mul, mesh_shape, 1)) + + # create the device mesh + device_mesh = DeviceMesh(physical_mesh_id=fake_physical_mesh_id, mesh_shape=mesh_shape, device=device_list[0]) + + # hack the device attribute + device_mesh._physical_mesh_id = None + device_mesh._logical_mesh_id = None + device_mesh._global_rank_of_current_process = dist.get_rank() + device_mesh._is_initialized = False + device_mesh._process_group_dict = { + device_mesh._global_rank_of_current_process: {axis: pg for axis, pg in enumerate(process_group)} + } + + return device_mesh + + def get_process_group(self, axis: int, global_rank: int = None) -> ProcessGroup: + """ + Return the process group on the specified axis. + + Args: + axis (int): the axis of the process group. + global_rank (int, optional): the global rank of the process group. If not specified, the current process is used. (default: None) + """ + if global_rank is None: + global_rank = self._global_rank_of_current_process + elif self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) + return self._process_group_dict[global_rank][axis] + + def get_process_group_for_all_axes(self, global_rank: int = None) -> Dict[int, ProcessGroup]: + """ + Return the process groups for all axes. + + Args: + global_rank (int, optional): the global rank of the process + """ + if global_rank is None: + global_rank = self._global_rank_of_current_process + elif self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) + return self._process_group_dict[global_rank] + + def get_ranks_in_process_group(self, axis: int, global_rank: int = None) -> List[int]: + """ + Return the ranks in the process group on the specified axis. + + Args: + axis (int): the axis of the process group. + global_rank (int, optional): the global rank of the process + """ + if global_rank is None: + global_rank = self._global_rank_of_current_process + elif self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) + return self._ranks_in_the_process_group[global_rank][axis] + + def __deepcopy__(self, memo) -> "DeviceMesh": cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k != 'process_groups_dict': + if k != '_process_group_dict': setattr(result, k, __import__("copy").deepcopy(v, memo)) else: + # process group cannot be copied + # thus, we share them directly setattr(result, k, v) - return result - def flatten(self): - """ - Flatten the logical mesh into an effective 1d logical mesh, + def _init_global_to_logical_rank_mapping(self, + mapping: Dict, + tensor: torch.Tensor, + index_list: List[int] = []) -> Dict[int, List[int]]: """ - flatten_mesh_shape_size = len(self.mesh_shape) - flatten_mesh_shape = [self.num_devices] - return DeviceMesh(self.physical_mesh_id, - tuple(flatten_mesh_shape), - mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), - mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1), - init_process_group=self.init_process_group, - need_flatten=False) + Build a global rank to local rank mapping for each process group in different axis in the logical device mesh. - def _global_rank_to_logical_rank_map(self, tensor, index_list): - ''' - This method is a helper function to build convert_map recursively. - ''' + Args: + mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh. + tensor (torch.Tensor): the tensor that contains the logical mesh ids. + index_list (List[int]) + + Returns: + mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh. + The value is a list of integers and each integer represents the local rank in the indexed axis. + """ for index, inner_tensor in enumerate(tensor): + # index means the local rank in the current axis + # inner_tensor refers to the processes with the same local rank + if inner_tensor.numel() == 1: - self.convert_map[int(inner_tensor)] = index_list + [index] + # if the inner_tensor only has one element, it means that + # it already reaches the last axis + # we append its local_rank in the last axis to the index_list + # and assign to the mapping + # the value of the mapping is the the local rank at the indexed axis of the device mesh + mapping[int(inner_tensor)] = index_list + [index] else: - self._global_rank_to_logical_rank_map(inner_tensor, index_list + [index]) + # we recursively go into the function until we reach the last axis + # meanwhile, we should add the local rank in the current axis in the index_list + self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index]) - def create_process_groups_for_logical_mesh(self): + def init_logical_process_group(self): ''' This method is used to initialize the logical process groups which will be used in communications among logical device mesh. Note: if init_process_group set to False, you have to call this method manually. Otherwise, the communication related function, such as ShapeConsistencyManager.apply will raise errors. ''' - process_groups_dict = {} - check_duplicate_list = [] - global_rank_flatten_list = self.physical_mesh_id.view(-1).tolist() + # sanity check + assert dist.is_initialized, "The torch.distributed should be initialized before calling init_logical_process_group" + assert not self._is_initialized, "The logical process group has been initialized, do not call init_logical_process_group twice" + + # update the global rank of the current process + self._global_rank_of_current_process = dist.get_rank() + duplicate_check_list = [] + + # flatten the global ranks to 1D list + global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist() + for global_rank in global_rank_flatten_list: - process_groups = self.global_rank_to_process_groups_with_global_rank(global_rank) - for axis, process_group in process_groups.items(): - if axis not in process_groups_dict: - process_groups_dict[axis] = [] - if process_group not in check_duplicate_list: - check_duplicate_list.append(process_group) - process_group_handler = dist.new_group(process_group) - process_groups_dict[axis].append((process_group, process_group_handler)) + # find the other ranks which are in the same process group as global_rank + ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank) - return process_groups_dict + for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items(): + # skip duplicated process group creation + if ranks_in_same_group in duplicate_check_list: + continue - def global_rank_to_logical_rank(self, rank): - return self.convert_map[rank] + # create the process group + pg_handler = dist.new_group(ranks=ranks_in_same_group, backend=self._dist_backend) - def global_rank_to_process_groups_with_logical_rank(self, rank): - ''' - Give a global rank and return all logical process groups of this rank. - for example: - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) - mesh_shape = (4, 4) - # [[0, 1, 2, 3], - # [4, 5, 6, 7], - # [8, 9, 10,11], - # [12,13,14,15]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - print(device_mesh.global_rank_to_process_groups_with_logical_rank(0)) - output: - # key is axis name - # value is a list of logical ranks in same axis with rank 0 - {0: [[0, 0], [1, 0], [2, 0], [3, 0]], 1: [[0, 0], [0, 1], [0, 2], [0, 3]]} - ''' - process_groups = {} - for d in range(self.logical_mesh_id.dim()): - for replacer in range(self.logical_mesh_id.shape[d]): - if d not in process_groups: - process_groups[d] = [] - process_group_member = self.convert_map[rank].copy() - process_group_member[d] = replacer - process_groups[d].append(process_group_member) - return process_groups - - def global_rank_to_process_groups_with_global_rank(self, rank): + # keep this process group in the process_groups_dict + for rank in ranks_in_same_group: + if rank not in self._process_group_dict: + self._process_group_dict[rank] = dict() + self._process_group_dict[rank][axis] = pg_handler + + # update the init flag + # we only allow init for once + self._is_initialized = True + + def _init_ranks_in_the_same_group(self): + """ + This method is used to initialize the ranks_in_the_same_group dictionary. + """ + # flatten the global ranks to 1D list + global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist() + + for global_rank in global_rank_flatten_list: + # find the other ranks which are in the same process group as global_rank + ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank) + + for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items(): + # create dict for each rank + if global_rank not in self._process_group_dict: + self._ranks_in_the_process_group[global_rank] = dict() + + # keep this process group in the process_groups_dict + self._ranks_in_the_process_group[global_rank][axis] = ranks_in_same_group + + def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[int], int]: + """ + Return the local rank of the given global rank in the logical device mesh. + + Args: + rank (int): the global rank in the logical device mesh. + axis (int): the axis of the logical device mesh. + """ + if self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) + + local_ranks = self._global_to_local_rank_mapping[rank] + if axis: + return local_ranks[axis] + else: + return local_ranks + + def _collate_global_ranks_in_same_process_group(self, global_rank): ''' - Give a global rank and return all process groups of this rank. - for example: - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) - mesh_shape = (4, 4) - # [[0, 1, 2, 3], - # [4, 5, 6, 7], - # [8, 9, 10,11], - # [12,13,14,15]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - print(device_mesh.global_rank_to_process_groups_with_global_rank(0)) - output: - # key is axis name - # value is a list of global ranks in same axis with rank 0 - {0: [0, 4, 8, 12], 1: [0, 1, 2, 3]} + Give a global rank and return all global ranks involved in its associated process group in each axis. + + Example: + + ```python + sphysical_mesh_id = torch.arange(0, 16) + mesh_shape = (4, 4) + + # logical mesh will look like + # [[0, 1, 2, 3], + # [4, 5, 6, 7], + # [8, 9, 10,11], + # [12,13,14,15]] + + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + print(device_mesh.collate_global_ranks_in_same_process_group(0)) + + # key is axis name + # value is a list of global ranks in same axis with rank 0 + # output will look like + # { + 0: [0, 4, 8, 12], + 1: [0, 1, 2, 3] + # } ''' - logical_process_groups = self.global_rank_to_process_groups_with_logical_rank(rank) - process_groups = {} - for dim, logical_ranks in logical_process_groups.items(): - process_groups[dim] = [] - for logical_rank in logical_ranks: - for g_rank, l_rank in self.convert_map.items(): - if l_rank == logical_rank: - process_groups[dim].append(g_rank) - return process_groups + # We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping + # for self._global_to_local_rank_mapping + # the key is the global rank + # the value is the list of local ranks corresponding to the global rank with respect of different axes + # we can see the list of local ranks as the process coordinates for simplicity + # the key and value are all unique, therefore, + # we can also to use the coordinates to find the global rank + + # ========================================================================= + # Step 1 + # find all the process_coordinates for processes in the same process group + # as the given global rank + # ========================================================================= + + # each + processes_in_the_same_process_group = {} + + for dim in range(self.logical_mesh_id.dim()): + # iterate over the dimension size so that we can include all processes + # in the same process group in the given axis + # the _local_rank refers to the local rank of the current process + for _local_rank in range(self.logical_mesh_id.shape[dim]): + + # if this dimension is not initailized yet, + # initialize it with an empty array + if dim not in processes_in_the_same_process_group: + processes_in_the_same_process_group[dim] = [] + + # get the local rank corresponding to the global rank + process_coordinates = self._global_to_local_rank_mapping[global_rank].copy() + + # replace the local rank in the given dimension with the + # lcoal rank of the current process iterated + process_coordinates[dim] = _local_rank + processes_in_the_same_process_group[dim].append(process_coordinates) + + # ================================================================= + # Step 2 + # Use local rank combination to find its corresponding global rank + # ================================================================= + # the key of the dict is the axis + # the value is the list of global ranks which are in the same process group as the given global rank + global_pg_ranks = {} + for dim, coordinates_of_all_processes in processes_in_the_same_process_group.items(): + global_pg_ranks[dim] = [] + for process_coordinates in coordinates_of_all_processes: + # find the global rank by local rank combination + for _global_rank, _process_coordinates in self._global_to_local_rank_mapping.items(): + if process_coordinates == _process_coordinates: + global_pg_ranks[dim].append(_global_rank) + return global_pg_ranks + + def flatten(self): + """ + Flatten the logical mesh into an effective 1d logical mesh, + """ + if self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) + + flatten_mesh_shape_size = len(self._mesh_shape) + flatten_mesh_shape = [self.num_devices] + return DeviceMesh(self._physical_mesh_id, + tuple(flatten_mesh_shape), + mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), + mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1), + init_process_group=self._init_process_group) def all_gather_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] @@ -211,39 +503,4 @@ def all_to_all_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] penalty_factor = num_devices / 2.0 return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * - (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001) - - -class FlattenDeviceMesh(DeviceMesh): - - def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None): - super().__init__(physical_mesh_id, - mesh_shape, - mesh_alpha, - mesh_beta, - init_process_group=False, - need_flatten=False) - # Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars - self.mesh_alpha = max(self.mesh_alpha) - self.mesh_beta = min(self.mesh_beta) - # Different from original process_groups_dict, rank_list is not stored - self.process_number_dict = self.create_process_numbers_for_logical_mesh() - - def create_process_numbers_for_logical_mesh(self): - ''' - Build 1d DeviceMesh in column-major(0) and row-major(1) - for example: - mesh_shape = (2,4) - # [[0, 1, 2, 3], - # [4, 5, 6, 7]] - # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} - ''' - num_devices = reduce(operator.mul, self.mesh_shape, 1) - process_numbers_dict = {} - process_numbers_dict[0] = torch.arange(num_devices).reshape(self.mesh_shape).transpose(1, 0).flatten().tolist() - process_numbers_dict[1] = torch.arange(num_devices).reshape(self.mesh_shape).flatten().tolist() - return process_numbers_dict - - def mix_gather_cost(self, num_bytes): - num_devices = reduce(operator.mul, self.mesh_shape, 1) - return (self.mesh_alpha + self.mesh_beta * (num_devices - 1) / num_devices * num_bytes + 0.1) + (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001) \ No newline at end of file diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index 76f550dc4392..8b911407307c 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -1,5 +1,5 @@ from types import MethodType -from typing import Callable, Optional, Union +from typing import Callable, Dict, Optional, Union import torch import torch.distributed as dist @@ -8,8 +8,9 @@ from torch.utils._pytree import tree_map from colossalai._analyzer._subclasses import MetaTensor -from colossalai.tensor.d_tensor.d_tensor import DTensor -from colossalai.tensor.d_tensor.layout import Layout +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.d_tensor import distribute_tensor +from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html _NORMAL_FACTORY = [ @@ -172,7 +173,7 @@ def materialize(self) -> torch.Tensor: self.clean() return _convert_cls(self, target) - def distribute(self, layout: Layout) -> torch.Tensor: + def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor: """Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout. Args: @@ -183,7 +184,7 @@ def distribute(self, layout: Layout) -> torch.Tensor: """ target = self._materialize_data() self.clean() - local_tensor = DTensor(target, layout).local_tensor + local_tensor = distribute_tensor(target, device_mesh, sharding_spec) return _convert_cls(self, local_tensor) def clean(self) -> None: @@ -536,7 +537,10 @@ def apply_fn(name: str, p: LazyTensor): return _apply_to_lazy_module(module, apply_fn, verbose) @staticmethod - def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module: + def distribute(module: nn.Module, + device_mesh: DeviceMesh, + sharding_spec_dict: Dict[str, ShardingSpec], + verbose: bool = False) -> nn.Module: """Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. Args: @@ -546,7 +550,7 @@ def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> n """ def apply_fn(name: str, p: LazyTensor): - p.distribute(layout_dict[name]) + p.distribute(device_mesh, sharding_spec_dict[name]) return _apply_to_lazy_module(module, apply_fn, verbose) diff --git a/colossalai/nn/layer/base_layer.py b/colossalai/nn/layer/base_layer.py index 5234b6b1a1b5..4a06bdcb7629 100644 --- a/colossalai/nn/layer/base_layer.py +++ b/colossalai/nn/layer/base_layer.py @@ -10,6 +10,7 @@ class ParallelLayer(nn.Module): + global_state_dict: bool = True def __init__(self): diff --git a/colossalai/nn/layer/parallel_1d/_operation.py b/colossalai/nn/layer/parallel_1d/_operation.py index 394334558275..300baf9c12ba 100644 --- a/colossalai/nn/layer/parallel_1d/_operation.py +++ b/colossalai/nn/layer/parallel_1d/_operation.py @@ -1,5 +1,6 @@ import torch import torch.distributed as dist + from colossalai.core import global_context as gpc try: diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md new file mode 100644 index 000000000000..fca401562be6 --- /dev/null +++ b/colossalai/shardformer/README.md @@ -0,0 +1,392 @@ +# โšก๏ธ ShardFormer + +## ๐Ÿ“š Table of Contents + +- [โšก๏ธ ShardFormer](#๏ธ-shardformer) + - [๐Ÿ“š Table of Contents](#-table-of-contents) + - [๐Ÿ”— Introduction](#-introduction) + - [๐Ÿ”จ Usage](#-usage) + - [Quick Start](#quick-start) + - [Write your own policy](#write-your-own-policy) + - [๐Ÿ—บ Roadmap](#-roadmap) + - [๐Ÿ’ก API Design](#-api-design) + - [Distributed Modules](#distributed-modules) + - [Shard Config](#shard-config) + - [Policy](#policy) + - [Model Sharder](#model-sharder) + - [User-facing API](#user-facing-api) + - [โŒจ๏ธ Development Notes](#๏ธ-development-notes) + - [Add New Policy to Shardformer](#add-new-policy-to-shardformer) + - [Write Your Unit Testing](#write-your-unit-testing) + - [๐Ÿ“Š Benchmarking](#-benchmarking) + - [System Performance](#system-performance) + - [Convergence](#convergence) + + +## ๐Ÿ”— Introduction + +**Shardformer** is a module that automatically parallelizes the mainstream models in libraries such as HuggingFace and TIMM. This module aims to make parallelization hassle-free for users who are not from the system background. + +## ๐Ÿ”จ Usage + +### Quick Start + +The sample API usage is given below: + +``` python +from colossalai.shardformer import ShardConfig, Shard +from transformers import BertForMaskedLM + +# launch colossalai +colossalai.launch_from_torch() + +# create model +config = BertConfig.from_pretrained('bert-base-uncased') +model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config) + +# create huggingface model as normal +shard_config = ShardConfig() +shard_former = ShardFormer(shard_config=shard_config) +sharded_model = shard_former.optimize(model).to('cuda') + +# do everything like normal +... +``` + +### Write your own policy + +If you have a custom model, you can also use Shardformer to parallelize it by writing your own sharding policy. More information about the sharding policy can be found in [API Design](#-api-design). + +```python +from colossalai.shardformer import Policy + +class MyPolicy(Policy): + # implement your own policy + ... + +# init model and shard former +... + +# use customized policy to shard model +my_policy = MyPolicy() +shard_former.optimize(model, my_policy) + + + +``` +## ๐Ÿ—บ Roadmap + +We will follow this roadmap to develop Shardformer: + +- [x] API Design +- [x] API Implementation +- [x] Unit Testing +- [ ] Policy Implementation + - [ ] Hugging Face + - [ ] NLP + - [x] BERT + - [x] T5 + - [x] LlaMa + - [x] GPT2 + - [x] OPT + - [x] BLOOM + - [ ] GLM + - [ ] RoBERTa + - [ ] ALBERT + - [ ] ERNIE + - [ ] GPT Neo + - [ ] GPT-J + - [ ] CV + - [x] ViT + - [ ] BEiT + - [ ] SwinTransformer + - [ ] SwinTransformer V2 + - [ ] Audio + - [ ] Whisper + - [ ] Multi-modal + - [ ] To be added + +## ๐Ÿ’ก API Design + +We will discuss the major components of `ShardFormer` below to help you better understand how things work. +This section serves as the design doc for Shardformer and the function signature might differ from the actual implementation. +Please refer to the code for more details. + +

+ +
+

+ + + +### Distributed Modules + +`ShardFormer` replaces the original PyTorch module with a distributed module. +The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters and defines a new `forward` function to execute distributed computation. +Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module. + +```python +class ParallelModule(torch.nn.Module): + + @abstractmethod + def from_native_module(module: torch.nn.Module, process_group: Union[ProcessGroup, Tuple[ProcessGroup]]) -> ParallelModule + """ + Convert a native module to a parallelized + + Examples: + + ```python + # replace module + my_linear = Linear1D_Col.from_native_module(my_linear, process_group) + ``` + """ +``` + +### Shard Config + +`ShardConfig` is a simple data class to tell `ShardFormer` how sharding will be performed. + +```python +@dataclass +class ShardConfig: + tensor_parallel_process_group: ProcessGroup = None + enable_fused_normalization: bool = False + ... + + # Some possible future config fields + tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode + inference_only: bool # only inject inference-suitable sharding policy + use_flash_attention: bool # whether to use flash attention to speed up attention +``` + +### Policy + +The `Policy` class describes how to handle the model sharding. +It is merely a description, the actual sharding will be performed by `ModelSharder`. +We abstract the policy into four stages: + +1. Preprocessing: call `Policy.preprocess` to do some prior work before sharding, for example, resizing the embedding +2. Providing `ModulePolicyDescription`: call `Policy.module_policy` to get a bunch of `ModulePolicyDescription` to tell `ModelSharder` how the submodules's attributes, child parameters, and deeper submodules will be substituted. +3. Postprocessing: call `Policy.postprocess` to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model. + +``` python +@dataclass +class ModulePolicyDescription: + r""" + Describe how the attributes and parameters will be transformed in a policy. + + Args: + attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding + param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function must receive only one arguments: module. + sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription + object which specifies the module to be replaced and the target module used to replacement. + method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement + """ + attribute_replacement: Dict[str, Any] = None + param_replacement: List[Callable] = None + sub_module_replacement: List[SubModuleReplacementDescription] = None + method_replacement: Dict[str, Callable] = None + +@dataclass +class SubModuleReplacementDescription: + r""" + Describe how a submodule will be replaced + + Args: + suffix (str): used to get the submodule object + target_module (ParallelModule): specifies the module class used to replace to submodule + kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method. + ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception + """ + suffix: str + target_module: ParallelModule + kwargs: Dict[str, Any] = None + ignore_if_not_exist: bool = False + + +class Policy(ABC): + + def __init__(self) + self.model = None + + def set_model(self, model: nn.Module) -> None: + """ + Set model as an attribute of the Policy object so that we can access the model's attributes. + """ + self.model = model + + @abstractmethod + def preprocess(self) -> nn.Module: + """ + Perform some preprocessing on the model, such as resizing the embedding size + """ + ... + + @abstractmethod + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + """ + Return the dict for the modify policy, the key is the original layer class and the value is the + argument for the modify layer + """ + ... + + @abstractmethods + def postprocess(self) -> nn.Module: + """ + Perform some postprocessing on the model, such as binding the embedding with the weight of the classifier head + """ + ... +``` + + +### Model Sharder + +`ModelSharder` is the class in charge of sharding the model based on the given policy. + +```python +class ModelSharder: + + def __init__(self, model: torch.nn.Module, shard_config: ShardConfig, Policy: ShardPolicy = None): + #TODO: input is a cls or a obj + ... + + def shard(self) -> None: + """ + Shard model with parallelelism with the help of pre-processing, replace_model_class, replace_module, and post-processing. + """ + ... + + def replace_module(self) -> None: + """ + Replace the layer according to the policy. Call Policy.module_policy() to get the module. Call _replace_module recursively. + """ + ... +``` + +### User-facing API + +We only expose a limited number of APIs to the user to keep their user experience simple and clean. + +```python +class ShardFormer: + """ + Parallelize model based on the given config and policy + + Example: + + shard_former = ShardFormer(shard_config=shard_config) + shard_former.init_distributed() + model = shard_former.optimize(model, policy=policy) + dataloader = shard_former.shard_dataset(dataset) + + """ + + def __init__(self, shard_config: ShardConfig): + """ + Do two things: + 1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp + 2. serve as a store for shard config + """ + self.shard_config = shard_config + self.pg_manager = None + + def init_distributed(self) -> colossalai.cluster.ProcessGroupManager: + """ + Initialize the distributed process group according to the + """ + pg_manager = ... + self.pg_manager = pg_manager + return pg_manager + + def shard_model(self, model: torch.nn.Module๏ผŒpolicy: Policy) -> torch.nn.Module: + """ + Shard model for TP and PP + """ + ... + + def shard_dataset(self, dataset: Dataset) -> Dataloader: + """ + Shard dataset for DP + """ + ... +``` + +## โŒจ๏ธ Development Notes + +### Add New Policy to Shardformer + +This section serves as the guideline for writing new policies and register them into `shardformer`. + +- Step 1. Write your own model policy + +You can create a new file in the `colossalai/shardformer/policies` folder and name the file with the model name. You can implement your policy in this file. You should not import the any model zoo library at the header section of the file because we do not want to import the library when we do not use the policy. Libraries such as `transformers` should be imported only in the function body when needed. + +Please follow the following protocols when writing your policy: + +- You have to make a clear decision what you want to replace exactly in the original PyTorch module + - Use `ModulePolicyDescription.attribute_replacement` to replace the module attributes + - Use `ModulePolicyDescription.param_replacement` to replace the module parameters + - Use `ModulePolicyDescription.sub_module_replacement` to replace the submodules completely. The target module should implement the `from_native_module` for the . + - Use `ModulePolicyDescription.method_replacement` to replace the module methods. **These replacement methods should be put in the `shardformer/modeling/.py`**. +- You can implement the `ParallelModule` for primitive modules in the `shardformer/layer/.py` file. Primitive modules refer to modules which are not composed of other modules. For example, the `torch.nn.Linear` module is a primitive module while modules such as `BertEncoder` module in the `transformers` library is a composite module. Primitive modules do not nested inner `nn.Module` members. For composite modules, you should consider using `ModulePolicyDescription` to implement your replacement. +- `ParallelModule` is meant to be used in two ways: `ParallelModule.from_native_module` to convert native PyTorch module to the `ParallelModule` and `ParallelModule(...)` to instantiate the module directly just like a normal PyTorch module. `ParallelModule` should be only implemented for modules whose weights are sharded. If you want to make your module compatible with the `ModulePolicyDescription.sub_module_replacement` and there is no weight sharding in your module, you can just implement the `from_native_module` method without inheriting the `ParallelModule` like `colossalai/shardformer/layer/normalization.py`. +- **Do not import any file in the `colossalai/shardformer/policies` and `colossalai/shardformer/modeling` to avoid unwanted import error**. For example, a file in these folders accidentally imports `transformers` library at the top of the file, then the user will have to install `transformers` library even if they do not use this file. Any file in the `modeling` folder should be only imported by the policy file. A policy implementation should be only imported dynamically via the autopolicy or manually via the `ShardFormer` module. +- Try to keep your import statement on third-party libraries such as `transformers` within the function body instead of the header section of the file. This is because we do not want to import the library when we do not use the policy. + + +- Step 2. Register your policy to the autopolicy + +Next, you need to register your policy in the `colossalai/shardformer/policies/autopolicy.py` file. + +For example, if we register the policy for the BERT model, we just add a key-value in the `_POLICY_LIST` dictionary. The key if the `qualname` of the model object (you can get it by model.__class__.__qualname__). The value is a `PolicyLocation` object, which contains the file name and the class name of the policy. We do not import the policy directly because the policy file may contain libraries (such as `transformers`) which we do not want to import when we do not use the policy. + +```python +_POLICY_LIST = { + # BERT + "transformers.models.bert.modeling_bert.BertModel": + PolicyLocation(file_name="bert", class_name="BertModelPolicy"), +} +``` + +### Write Your Unit Testing + +This section serves as the guideline for testing the `shardformer` module. + +- Step 1. Add your model to the model zoo in the test kits. + +Add your model to the `tests/kit/model_zoo` file. This allows you to define test-related components for this model. You can take `tests/kit/model_zoo/transformers/llama.py` as an example for reference. + +- Step 2. Write your unit testing for the model + +Next, implement your unit test in the `tests/test_shardformer` folder. Please refer to other similar tests for style consistency. + + +- Step 3. Execute your test + +When you run tests locally, you should run tests for both your newly-added test file and the whole `shardformer` module tests. + +```bash +# test for your own test file +pytest tests/test_shardformer/test_model/.py + +# test for the whole shardformer module +pytest tests/test_shardformer +``` + +## ๐Ÿ“Š Benchmarking + +### System Performance + +To be added. + +### Convergence + +To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/shardformer_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results. + +| accuracy | f1 | loss | GPU number | model shard | +| :------: | :-----: | :-----: | :--------: | :---------: | +| 0.82594 | 0.87441 | 0.09913 | 4 | True | +| 0.81884 | 0.87299 | 0.10120 | 2 | True | +| 0.81855 | 0.87124 | 0.10357 | 1 | False | + +Overall, the results demonstrate that using shardformers during model training does not affect the convergence. diff --git a/colossalai/shardformer/__init__.py b/colossalai/shardformer/__init__.py new file mode 100644 index 000000000000..77c2af8d18f7 --- /dev/null +++ b/colossalai/shardformer/__init__.py @@ -0,0 +1 @@ +from .shard import ShardConfig, ShardFormer diff --git a/colossalai/shardformer/_utils.py b/colossalai/shardformer/_utils.py new file mode 100644 index 000000000000..4ad877e72357 --- /dev/null +++ b/colossalai/shardformer/_utils.py @@ -0,0 +1,80 @@ +import re + + +def get_obj_list_element(obj, a): + r""" + Get the element of the list in the object + """ + re_pattern = r'\[\d+\]' + prog = re.compile(re_pattern) + result = prog.search(a) + if result: + matched_brackets = result.group() + matched_index = matched_brackets.replace('[', '') + matched_index = matched_index.replace(']', '') + a_ = a.replace(matched_brackets, '') + container_obj = getattr(obj, a_) + obj = container_obj[int(matched_index)] + else: + obj = getattr(obj, a) + return obj + + +def hasattr_(obj, attr: str): + r""" + Check whether the object has the multi sublevel attr + + Args: + obj (object): The object to check + attr (str): The multi level attr to check + """ + attrs = attr.split('.') + for a in attrs: + try: + obj = get_obj_list_element(obj, a) + except AttributeError: + return False + return True + + +def setattr_(obj, attr: str, value, ignore: bool = False): + r""" + Set the object's multi sublevel attr to value, if ignore, ignore when it doesn't exist + + Args: + obj (object): The object to set + attr (str): The multi level attr to set + value (Any): The value to set + ignore (bool): Whether to ignore when the attr doesn't exist + """ + + attrs = attr.split('.') + for a in attrs[:-1]: + try: + obj = get_obj_list_element(obj, a) + except AttributeError: + if ignore: + return + raise AttributeError(f"Object {obj.__class__.__name__} has no attribute {attr}") + setattr(obj, attrs[-1], value) + + +def getattr_(obj, attr: str, ignore: bool = False): + r""" + Get the object's multi sublevel attr + + Args: + obj (object): The object to set + attr (str): The multi level attr to set + ignore (bool): Whether to ignore when the attr doesn't exist + """ + + attrs = attr.split('.') + for a in attrs: + try: + obj = get_obj_list_element(obj, a) + except AttributeError: + if ignore: + return None + raise AttributeError(f"Object {obj.__class__.__name__} has no attribute {attr}") + return obj diff --git a/colossalai/shardformer/examples/data.py b/colossalai/shardformer/examples/data.py new file mode 100644 index 000000000000..6296d4be4eb0 --- /dev/null +++ b/colossalai/shardformer/examples/data.py @@ -0,0 +1,146 @@ +import datasets +from torch.utils.data import DataLoader +from transformers import AutoTokenizer, PreTrainedTokenizer + +from colossalai.booster.plugin.dp_plugin_base import DPPluginBase + + +class GLUEDataBuilder: + + task_text_field_map = { + "cola": ["sentence"], + "sst2": ["sentence"], + "mrpc": ["sentence1", "sentence2"], + "qqp": ["question1", "question2"], + "stsb": ["sentence1", "sentence2"], + "mnli": ["premise", "hypothesis"], + "qnli": ["question", "sentence"], + "rte": ["sentence1", "sentence2"], + "wnli": ["sentence1", "sentence2"], + "ax": ["premise", "hypothesis"], + } + + glue_task_num_labels = { + "cola": 2, + "sst2": 2, + "mrpc": 2, + "qqp": 2, + "stsb": 1, + "mnli": 3, + "qnli": 2, + "rte": 2, + "wnli": 2, + "ax": 3, + } + + loader_columns = [ + "datasets_idx", + "input_ids", + "token_type_ids", + "attention_mask", + "start_positions", + "end_positions", + "labels", + ] + + def __init__( + self, + model_name_or_path: str, + plugin: DPPluginBase = None, + task_name: str = "mrpc", + max_seq_length: int = 128, + train_batch_size: int = 32, + eval_batch_size: int = 32, + **kwargs, + ): + super().__init__() + self.model_name_or_path = model_name_or_path + self.task_name = task_name + self.max_seq_length = max_seq_length + self.train_batch_size = train_batch_size + self.eval_batch_size = eval_batch_size + self.plugin = plugin + + self.text_fields = self.task_text_field_map[task_name] + self.num_labels = self.glue_task_num_labels[task_name] + self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) + self.setup() + + def setup(self): + self.dataset = datasets.load_dataset("glue", self.task_name) + + for split in self.dataset.keys(): + self.dataset[split] = self.dataset[split].map( + self.convert_to_features, + batched=True, + remove_columns=["label"], + ) + self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns] + self.dataset[split].set_format(type="torch", columns=self.columns) + + self.eval_splits = [x for x in self.dataset.keys() if "validation" in x] + + def prepare_data(self): + datasets.load_dataset("glue", self.task_name) + AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) + + def train_dataloader(self): + if self.plugin == None: + return self.native_prepare_dataloader(self.dataset["train"], + batch_size=self.train_batch_size, + shuffle=True, + drop_last=True) + return self.plugin.prepare_dataloader(self.dataset["train"], + batch_size=self.train_batch_size, + shuffle=True, + drop_last=True) + + def val_dataloader(self): + if self.plugin == None: + return self.native_prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) + if len(self.eval_splits) == 1: + return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) + elif len(self.eval_splits) > 1: + return [ + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + for x in self.eval_splits + ] + + def test_dataloader(self): + if self.plugin == None: + return self.native_prepare_dataloader(self.dataset["test"], batch_size=self.train_batch_size) + if len(self.eval_splits) == 1: + return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size) + elif len(self.eval_splits) > 1: + return [ + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + for x in self.eval_splits + ] + + def convert_to_features(self, example_batch): + + # Either encode single sentence or sentence pairs + if len(self.text_fields) > 1: + texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])) + else: + texts_or_text_pairs = example_batch[self.text_fields[0]] + + # Tokenize the text/text pairs + features = self.tokenizer.batch_encode_plus(texts_or_text_pairs, + max_length=self.max_seq_length, + padding='max_length', + truncation=True) + + # Rename label to labels to make it easier to pass to model forward + features["labels"] = example_batch["label"] + + return features + + def native_prepare_dataloader(self, dataset, batch_size, shuffle=False, drop_last=False, pin_memory=False): + + return DataLoader(dataset, + batch_size=batch_size, + sampler=None, + shuffle=shuffle, + drop_last=drop_last, + pin_memory=pin_memory) diff --git a/colossalai/shardformer/examples/shardformer_benchmark.py b/colossalai/shardformer/examples/shardformer_benchmark.py new file mode 100644 index 000000000000..de82305b2547 --- /dev/null +++ b/colossalai/shardformer/examples/shardformer_benchmark.py @@ -0,0 +1,154 @@ +import argparse +import math +from typing import Any, List, Union + +import evaluate +import torch +import torch.distributed as dist +from data import GLUEDataBuilder +from torch import nn +from torch.optim import Adam, AdamW, Optimizer +from torch.utils._pytree import tree_map +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import BertConfig, BertForSequenceClassification, get_linear_schedule_with_warmup + +import colossalai +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam +from colossalai.shardformer import ShardConfig, ShardFormer + + +def to_device(x: Any, device: torch.device) -> Any: + + def _to(t: Any): + if isinstance(t, torch.Tensor): + return t.to(device) + return t + + return tree_map(_to, x) + + +def train(args): + colossalai.launch_from_torch(config={}, seed=42) + coordinator = DistCoordinator() + + # prepare for data and dataset + data_builder = GLUEDataBuilder(model_name_or_path=args.pretrain, + task_name=args.task, + train_batch_size=args.batch_size, + eval_batch_size=args.batch_size) + train_dataloader = data_builder.train_dataloader() + test_dataloader = data_builder.test_dataloader() + + if args.model == "bert": + cfg = BertConfig.from_pretrained(args.pretrain, num_labels=data_builder.num_labels) + model = BertForSequenceClassification.from_pretrained(args.pretrain, config=cfg) + + model.to(torch.cuda.current_device()) + + # if multiple GPUs, shard the model + if dist.get_world_size() > 1: + shard_config = ShardConfig(enable_fused_normalization=args.fused_layernorm) + shard_former = ShardFormer(shard_config=shard_config) + model = shard_former.optimize(model) + + optim = Adam(model.parameters(), lr=args.lr) + num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps + max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch) + lr_scheduler = get_linear_schedule_with_warmup( + optim, + num_warmup_steps=math.ceil(max_steps * args.warmup_fraction), + num_training_steps=max_steps, + ) + fit(model, optim, lr_scheduler, train_dataloader, args.max_epochs, args.accumulation_steps, args.batch_size, + coordinator) + results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, + coordinator) + if coordinator.is_master(): + print(results) + if args.target_f1 is not None and 'f1' in results: + assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' + + +def fit(model: nn.Module, optimizer: Optimizer, scheduler, train_dataloader, max_epochs, accumulation_steps, batch_size, + coordinator): + step_bar = tqdm(range(len(train_dataloader) // accumulation_steps * max_epochs), + desc=f'steps', + disable=not coordinator.is_master()) + total_loss = 0 + for epoch in range(max_epochs): + model.train() + for batch_id, batch in enumerate(train_dataloader): + batch = to_device(batch, torch.cuda.current_device()) + outputs = model(**batch) + loss = outputs.loss + loss = loss / accumulation_steps + loss.backward() + total_loss += loss.item() + if (batch_id + 1) % accumulation_steps == 0: + optimizer.step() + scheduler.step() + optimizer.zero_grad() + step_bar.set_postfix({ + 'epoch': epoch, + 'loss': total_loss / batch_size, + 'lr': scheduler.get_last_lr()[0] + }) + total_loss = 0 + step_bar.update() + + +# evaluate +@torch.no_grad() +def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, + task_name: str, eval_splits: List[str], coordinator: DistCoordinator): + metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size) + model.eval() + + def evaluate_subset(dataloader: DataLoader): + accum_loss = torch.zeros(1, device=torch.cuda.current_device()) + for batch in dataloader: + batch = to_device(batch, torch.cuda.current_device()) + outputs = model(**batch) + val_loss, logits = outputs[:2] + accum_loss.add_(val_loss) + + if num_labels > 1: + preds = torch.argmax(logits, axis=1) + elif num_labels == 1: + preds = logits.squeeze() + + labels = batch["labels"] + metric.add_batch(predictions=preds, references=labels) + + results = metric.compute() + if coordinator.is_master(): + results['loss'] = accum_loss.item() / (len(dataloader) * dataloader.batch_size) + return results + + if isinstance(test_dataloader, DataLoader): + return evaluate_subset(test_dataloader) + else: + assert len(test_dataloader) == len(eval_splits) + final_results = {} + for split, sub_loader in zip(eval_splits, test_dataloader): + results = evaluate_subset(sub_loader) + final_results.update({f'{k}_{split}': v for k, v in results.items()}) + return final_results + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run") + parser.add_argument('--model', type=str, default="bert") + parser.add_argument('--pretrain', type=str, default="bert-base-uncased") + parser.add_argument('--max_epochs', type=int, default=1) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--lr', type=float, default=2.4e-5) + parser.add_argument('--fused_layernorm', type=bool, default=False) + parser.add_argument('--accumulation_steps', type=int, default=8) + parser.add_argument('--warmup_fraction', type=float, default=0.03) + parser.add_argument('--target_f1', type=float, default=None) + args = parser.parse_args() + train(args) diff --git a/colossalai/shardformer/examples/shardformer_benchmark.sh b/colossalai/shardformer/examples/shardformer_benchmark.sh new file mode 100644 index 000000000000..f42b19a32d35 --- /dev/null +++ b/colossalai/shardformer/examples/shardformer_benchmark.sh @@ -0,0 +1,9 @@ +torchrun --standalone --nproc_per_node=4 shardformer_benchmark.py \ + --model "bert" \ + --pretrain "bert-base-uncased" \ + --max_epochs 1 \ + --batch_size 2 \ + --lr 2.4e-5 \ + --fused_layernorm False \ + --accumulation_steps 8 \ + --warmup_fraction 0.03 diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py new file mode 100644 index 000000000000..7fad4948dfd0 --- /dev/null +++ b/colossalai/shardformer/layer/__init__.py @@ -0,0 +1,12 @@ +from .dropout import DropoutForParallelInput, DropoutForReplicatedInput +from .embedding import Embedding1D, VocabParallelEmbedding1D +from .linear import Linear1D_Col, Linear1D_Row +from .loss import cross_entropy_1d +from .normalization import FusedLayerNorm, FusedRMSNorm +from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row + +__all__ = [ + "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', + 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", + 'FusedLayerNorm', 'FusedRMSNorm' +] diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py new file mode 100644 index 000000000000..c025daaeccc7 --- /dev/null +++ b/colossalai/shardformer/layer/_operation.py @@ -0,0 +1,290 @@ +import torch +import torch.distributed as dist +import torch.nn.functional as F + +try: + import fused_mix_prec_layer_norm_cuda +except: + fused_mix_prec_layer_norm_cuda = None + + +class FusedLayerNormAffineFunction1D(torch.autograd.Function): + r"""Layernorm + + Args: + input: input matrix. + weight: weight matrix. + bias: bias matrix. + normalized_shape: input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps: a value added to the denominator for numerical stability + """ + + @staticmethod + def forward(ctx, input, weight, bias, normalized_shape, eps): + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, + bias_, ctx.eps) + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight_, bias_, mean, invvar = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + grad_input, grad_weight, grad_bias \ + = fused_mix_prec_layer_norm_cuda.backward_affine( + grad_output.contiguous(), mean, invvar, + input_, ctx.normalized_shape, + weight_, bias_, ctx.eps) + + return grad_input, grad_weight, grad_bias, None, None + + +class MatmulWithAsyncCommunication(torch.autograd.Function): + """ + Linear layer execution with asynchronous communication in backprop. + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_allreduce = async_grad_allreduce + + output = torch.matmul(input_, weight) + + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + use_bias = ctx.use_bias + + total_input = input + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + if ctx.async_grad_allreduce: + # Asynchronous all-reduce + handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # all-reduce scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = total_input.t().matmul(grad_output) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_allreduce: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None + + +class LinearWithAsyncCommunication(torch.autograd.Function): + """ + Linear layer execution with asynchronous communication in backprop. + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_allreduce = async_grad_allreduce + + if bias is not None: + output = F.linear(input_, weight, bias) + else: + output = F.linear(input_, weight) + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + use_bias = ctx.use_bias + + total_input = input + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + if ctx.async_grad_allreduce: + # Asynchronous all-reduce + handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # all-reduce scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_allreduce: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None + + +class _SplitForwardGatherBackward(torch.autograd.Function): + """ + Split the input and keep only the corresponding chuck to the rank. + + Args: + input_ (`torch.Tensor`): input matrix. + dim (int): the dimension to perform split and gather + process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication + + """ + + @staticmethod + def forward(ctx, input_, dim, process_group): + ctx.process_group = process_group + ctx.dim = dim + return _split(input_, dim, process_group) + + @staticmethod + def backward(ctx, grad_output): + return _gather(grad_output, ctx.dim, ctx.process_group), None, None + + +class _ReduceForward(torch.autograd.Function): + """ + All-reduce the input from the model parallel region. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + """ + + @staticmethod + def forward(ctx, input_, process_group): + return _reduce(input_, process_group) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +class _ReduceBackward(torch.autograd.Function): + """ + All-reduce the input from the model parallel region. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + """ + + @staticmethod + def forward(ctx, input_, process_group): + ctx.process_group = process_group + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _reduce(grad_output, ctx.process_group), None + + +def _reduce(input_, process_group): + # skip if only one rank involved + if dist.get_world_size(process_group) == 1: + return input_ + else: + dist.all_reduce(input_, group=process_group) + return input_ + + +def _split(input_, dim=-1, process_group=None): + # skip if only one rank involved + world_size = dist.get_world_size(process_group) + if world_size == 1: + return input_ + + # Split along last dimension. + dim_size = input_.size(dim) + assert dim_size % world_size == 0, \ + f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \ + f'cannot split tensor evenly' + + tensor_list = torch.split(input_, dim_size // world_size, dim=dim) + rank = dist.get_rank(process_group) + output = tensor_list[rank].contiguous() + + return output + + +def _gather(input_, dim=-1, process_group=None): + # skip if only one rank involved + world_size = dist.get_world_size(process_group) + if world_size == 1: + return input_ + + # all gather + rank = dist.get_rank(process_group) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=process_group) + + # concat + output = torch.cat(tensor_list, dim=dim).contiguous() + + return output + + +class _GatherForwardSplitBackward(torch.autograd.Function): + """Gather the input from model parallel region and concatenate. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def forward(ctx, input_, dim, process_group): + ctx.process_group = process_group + ctx.dim = dim + return _gather(input_, dim, process_group) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.dim, ctx.process_group), None, None + + +def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): + return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) + + +def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): + return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) + + +def gather_forward_split_backward(input_, dim, process_group): + return _GatherForwardSplitBackward.apply(input_, dim, process_group) + + +def split_forward_gather_backward(input_, dim, process_group): + return _SplitForwardGatherBackward.apply(input_, dim, process_group) + + +def reduce_forward(input_, process_group): + return _ReduceForward.apply(input_, process_group) + + +def reduce_backward(input_, process_group): + return _ReduceBackward.apply(input_, process_group) \ No newline at end of file diff --git a/colossalai/shardformer/layer/dropout.py b/colossalai/shardformer/layer/dropout.py new file mode 100644 index 000000000000..2625fe97889a --- /dev/null +++ b/colossalai/shardformer/layer/dropout.py @@ -0,0 +1,83 @@ +from typing import List, Union + +import torch +import torch.nn as nn +from torch.distributed import ProcessGroup + +from .parallel_module import ParallelModule +from .utils import create_randomizer_with_offset + +__all__ = ['DropoutForParallelInput', 'DropoutForReplicatedInput'] + + +class DropoutForParallelInput(ParallelModule, nn.Dropout): + """ + The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with + randomness on different ranks of the given process group. This can avoid the same dropout mask is generated + and applied on the same position of different ranks, leading to poor convergence performance. + + Args: + p (float): probability of an element to be zeroed. Defaults to 0.5. + inplace (bool): If set to True, will do this operation in-place. Defaults to False. + process_group (ProcessGroup): the process group to be used for generating randomness. Defaults to None. + """ + + def __init__(self, p: float = 0.5, inplace: bool = False, process_group: ProcessGroup = None): + # init with nn.Dropout + super(nn.Dropout, self).__init__(p=p, inplace=inplace) + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=process_group) + + @staticmethod + def from_native_module(module: nn.Dropout, + process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "DropoutForParallelInput": + """ + Create a DropoutForParallelInput layer from a native dropout layer. + """ + p = module.p + inplace = module.inplace + return DropoutForParallelInput(p=p, inplace=inplace, process_group=process_group) + + def forward(self, input): + with self.randomizer.fork_rng(): + input = super().forward(input) + return input + + +class DropoutForReplicatedInput(ParallelModule, nn.Dropout): + """ + The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with + randomness on different ranks of the given process group. This can avoid the same dropout mask is generated + and applied on the same position of different ranks, leading to poor convergence performance. + + Args: + p (float): probability of an element to be zeroed. Defaults to 0.5. + inplace (bool): If set to True, will do this operation in-place. Defaults to False. + process_group (ProcessGroup): the process group to be used for generating randomness. Defaults to None. + """ + + def __init__(self, p: float = 0.5, inplace: bool = False, process_group: ProcessGroup = None): + # init with nn.Dropout + super(nn.Dropout, self).__init__(p=p, inplace=inplace) + + # offset the seed with randomizer index only + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=process_group, offset_by_rank=False) + + @staticmethod + def from_native_module( + module: nn.Dropout, + process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "DropoutForReplicatedInput": + """ + Create a Dropout1D layer from a native dropout layer. + """ + p = module.p + inplace = module.inplace + return DropoutForReplicatedInput(p=p, inplace=inplace, process_group=process_group) + + def forward(self, input): + with self.randomizer.fork_rng(): + input = super().forward(input) + return input diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py new file mode 100644 index 000000000000..db39a457b7fd --- /dev/null +++ b/colossalai/shardformer/layer/embedding.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import Callable, List, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + +from colossalai.nn import init as init +from colossalai.nn.layer.utils import divide +from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param + +from ._operation import gather_forward_split_backward, reduce_forward +from .parallel_module import ParallelModule +from .utils import create_randomizer_with_offset + +__all__ = ['Embedding1D', 'VocabParallelEmbedding1D'] + + +class Embedding1D(ParallelModule): + r"""Embedding for 1D parallelism. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed โ€œpadโ€, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about ``initializer`` please refer to + `init `_ + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + gather_output: bool = True, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.process_group = process_group + + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + self.gather_output = gather_output + + # Parameters. + factory_kwargs = {'device': device, 'dtype': dtype} + weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs) + sharded_weight = shard_colwise(weight, process_group) + self.weight = sharded_tensor_to_param(sharded_weight) + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer) + + @staticmethod + def from_native_module(module: nn.Embedding, + process_group: Union[ProcessGroup, List[ProcessGroup]] = None, + *args, + **kwargs) -> "Embedding1D": + r""" + Build a 1D parallelized Embedding from a native nn.Embedding module. + """ + # get the attributes + num_embedding = module.num_embeddings + embedding_dim = module.embedding_dim + padding_idx = module.padding_idx + max_norm = module.max_norm + norm_type = module.norm_type + scale_grad_by_freq = module.scale_grad_by_freq + sparse = module.sparse + dtype = module.weight.dtype + device = module.weight.device + + # sparse is not support yet + if sparse: + raise NotImplementedError("The Embedding1D module does not support sparse embedding yet.") + + embedding = Embedding1D(num_embeddings=num_embedding, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + process_group=process_group, + dtype=dtype, + device=device, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + *args, + **kwargs) + + # copy the weight + with torch.no_grad(): + sharded_weight = shard_colwise(module.weight.data, process_group) + embedding.weight.copy_(sharded_weight) + + return embedding + + def reset_parameters(self, weight_initializer) -> None: + fan_in, fan_out = self.num_embeddings, self.embedding_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input_: Tensor) -> Tensor: + output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + if self.gather_output: + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + return output + else: + return output_parallel + + +class VocabParallelEmbedding1D(ParallelModule): + r"""Embedding parallelized in the vocabulary dimension. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed โ€œpadโ€, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + self.process_group = process_group + + tensor_parallel_size = dist.get_world_size(group=process_group) + tensor_parallel_rank = dist.get_rank(group=process_group) + + self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) + self.num_embeddings = self.num_embeddings_per_partition + self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + + # parameter + factory_kwargs = {'device': device, 'dtype': dtype} + weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs) + sharded_weight = shard_rowwise(weight, process_group) + self.weight = sharded_tensor_to_param(sharded_weight) + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + self.reset_parameters(weight_initializer) + + @staticmethod + def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + r""" + Convert a native pytorch embedding module to a parallel module. + """ + # get the origin attributes + num_embeddings = module.num_embeddings + embedding_dim = module.embedding_dim + padding_idx = module.padding_idx + device = module.weight.device + + # ensure only one process group is used + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + # create the parallel module + vocab_embedding_1d = VocabParallelEmbedding1D(num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + device=device, + process_group=process_group, + *args, + **kwargs) + with torch.no_grad(): + # shard and slice the weight along the vocabulary(num_embeddings) dimension + # the shape of the weight is (num_embeddings, embedding_dim) + shard_weight = shard_rowwise(module.weight.data, process_group) + vocab_embedding_1d.weight.data.copy_(shard_weight) + + return vocab_embedding_1d + + def reset_parameters(self, weight_initializer) -> None: + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.num_embeddings, self.embedding_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None and \ + self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + with torch.no_grad(): + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + + def forward(self, input_: Tensor) -> Tensor: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, + **self.embed_kwargs) + + # Mask the output embedding. + output_parallel[input_mask, :] = 0. + # Reduce across all the model parallel GPUs. + output = reduce_forward(output_parallel, self.process_group) + return output diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py new file mode 100644 index 000000000000..26ba5883c64f --- /dev/null +++ b/colossalai/shardformer/layer/linear.py @@ -0,0 +1,336 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math +from typing import Callable, List, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + +from colossalai.nn import init as init +from colossalai.nn.layer.utils import divide +from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param + +from ._operation import ( + gather_forward_split_backward, + linear_with_async_comm, + reduce_forward, + split_forward_gather_backward, +) +from .parallel_module import ParallelModule +from .utils import create_randomizer_with_offset + +__all__ = ['Linear1D_Col', 'Linear1D_Row'] + + +class Linear1D_Col(ParallelModule): + r"""Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + self.device = device + self.process_group = process_group + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + # Parameters. + factory_kwargs = {'device': device, 'dtype': dtype} + + weight = torch.empty(self.out_features, self.in_features, **factory_kwargs) + sharded_weight = shard_rowwise(weight, self.process_group) + self.weight = sharded_tensor_to_param(sharded_weight) + + if bias: + bias = torch.empty(self.out_features, **factory_kwargs) + sharded_bias = shard_colwise(bias, self.process_group) + self.bias = sharded_tensor_to_param(sharded_bias) + else: + self.bias = None + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # init weights + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + linear_1d = Linear1D_Col(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) + + with torch.no_grad(): + # the weigh to the linear layer is a transpose + # thus shard on row is equal to shard on column + sharded_weight = shard_rowwise(module.weight.data, process_group) + linear_1d.weight.data.copy_(sharded_weight) + if bias: + sharded_bias = shard_colwise(module.bias.data, process_group) + linear_1d.bias.copy_(sharded_bias) + return linear_1d + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + + # Set up backprop all-reduce. + input_parallel = input_ + + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + else: + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output + + +class Linear1D_Row(ParallelModule): + r""" Linear layer with row parallelism + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1): + super().__init__() + + self.stream_chunk_num = stream_chunk_num + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.parallel_input = parallel_input + self.skip_bias_add = skip_bias_add + self.process_group = process_group + self.num_partitions = dist.get_world_size(self.process_group) + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': device, 'dtype': dtype} + weight = torch.empty(self.out_features, self.in_features, **factory_kwargs) + sharded_weight = shard_colwise(weight, self.process_group) + self.weight = sharded_tensor_to_param(sharded_weight) + + if self.stream_chunk_num > 1: + # TODO() work for inference only + self.chunk_weight() + if bias: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + self.bias = None + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + linear_1d = Linear1D_Row(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) + + # TODO: copy the sharded weights + with torch.no_grad(): + # the weigh to the linear layer is a transpose + # thus shard on col is equal to shard on row + sharded_weight = shard_colwise(module.weight.data, process_group) + linear_1d.weight.data.copy_(sharded_weight) + + if bias: + linear_1d.bias.copy_(module.bias.data) + + return linear_1d + + def chunk_weight(self): + self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) + + @torch.no_grad() + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + if self.process_group is None: + src_rank = 0 + else: + src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) + + origin_device = self.bias.device + bias = self.bias.cuda() + dist.broadcast(bias, src=src_rank, group=self.process_group) + bias = bias.to(origin_device) + self.bias.copy_(bias) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + input_ = input_ + else: + assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions) + input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) + + if self.stream_chunk_num > 1: + if self.training: + raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") + with torch.no_grad(): + output_parallel_list = [None for i in range(self.stream_chunk_num)] + handle_list = [] + for i in range(self.stream_chunk_num): + output_parallel_list[i] = F.linear(input_, self.weight_list[i]) + handle = torch.distributed.all_reduce(output_parallel_list[i], + group=self.process_group, + async_op=True) + handle_list.append(handle) + # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) + for handle in handle_list: + handle.wait() + output = torch.cat(output_parallel_list, dim=-1) + else: + output_parallel = F.linear(input_, self.weight) + output = reduce_forward(output_parallel, self.process_group) + + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py new file mode 100644 index 000000000000..38a5395a0f57 --- /dev/null +++ b/colossalai/shardformer/layer/loss.py @@ -0,0 +1,109 @@ +import torch +import torch.distributed as dist +from torch.autograd import Function +from torch.distributed import ProcessGroup + +__all__ = ['DistCrossEntropy', 'cross_entropy_1d'] + + +class DistCrossEntropy(Function): + r""" + Overwrite the forward and backward function to calculate the cross entropy loss before gather + + Args: + Function (:class:`torch.autograd.Function`): default + """ + + @staticmethod + def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int, process_group: ProcessGroup): + r""" + Calculate the cross entropy loss before gather, the origin loss function is as follows: + loss = -log(exp(x[class])/sum(exp(x[i])) + and can be rewrite as: + loss = log(sum(exp(x[i])) - x[class] + + To avoid the `nan` of log(sum(exp(x[i]))), we minus the max of x[i] + + Args: + vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is + [batch_size, seq_len, vocab_size] + labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is + [batch_size, seq_len] + + Returns: + :class:`torch.Tensor`: The cross entropy loss + """ + # get the max + logits_max = torch.max(vocab_logits, dim=-1)[0] + dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group) + + # minus the max to avoid the result of sum of exp is too large and the log is nan + vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) + + # mask the target in the local device + partition_vocab_size = vocab_logits.size()[-1] + rank = dist.get_rank(group=process_group) + world_size = dist.get_world_size(group=process_group) + global_vocab_size = partition_vocab_size * world_size + + # [down, up) => false, other device and -100 => true + delta = (global_vocab_size + world_size - 1) // world_size + down_shreshold = rank * delta + up_shreshold = down_shreshold + delta + mask = (target < down_shreshold) | (target >= up_shreshold) + masked_target = target.clone() - down_shreshold + masked_target[mask] = 0 + + # reshape the logist and target + # reshape the vocab_logits to [bath_size * seq_len, vocab_size] + # reshape the labels to [bath_size * seq_len] + logits_2d = vocab_logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + + # extract the x[class] and set the x[other device] to zero + pred_logits_1d = logits_2d[torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device), + masked_target_1d] + pred_logits_1d = pred_logits_1d.clone().contiguous() + pred_logits = pred_logits_1d.view_as(target) + pred_logits[mask] = 0.0 + + # allreduce the get all x(i,y) + dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group) + exp_logits = vocab_logits + torch.exp(vocab_logits, out=exp_logits) + sum_exp_logits = torch.sum(exp_logits, dim=-1) + dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group) + + # calculate the loss + # loss = log(sum(exp(x[i]))) - x[class] + loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits) + loss = torch.sum(loss).div_(torch.sum(loss != 0.0)) + + # caculate the softmax + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, mask, masked_target_1d) + + return loss + + @staticmethod + def backward(ctx, grad_output): + # retrieve the saved tensors + exp_logits, mask, masked_target_1d = ctx.saved_tensors + + # use exp logits as the input grad + grad_logits = exp_logits + partion_vocab_size = grad_logits.shape[-1] + grad_logits_2d = grad_logits.view(-1, partion_vocab_size) + + update = 1.0 - mask.view(-1).float() + grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update + + grad_logits.mul_(grad_output.unsqueeze(dim=-1)) + return grad_logits, None, None + + +def cross_entropy_1d(vocab_logits: torch.Tensor, + labels: torch.Tensor, + ignore_index: int = -100, + process_group: ProcessGroup = None) -> torch.Tensor: + return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group) diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py new file mode 100644 index 000000000000..b27307154a76 --- /dev/null +++ b/colossalai/shardformer/layer/normalization.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +import torch.nn as nn + +__all__ = ['FusedLayerNorm', 'FusedRMSNorm'] + +FAST_LAYERNORM_SUPPORTED_SIZE = [ + 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 24576, + 25600, 30720, 32768, 40960, 49152, 65536 +] + + +class FusedLayerNorm(): + r""" + This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. + """ + + def __init__(self) -> None: + raise NotImplementedError( + 'FusedLayerNorm is not implemented as a physical class. ' + 'It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex.' + ) + + @staticmethod + def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: + r""" + Convert a native pytorch layer norm module to colossalai layer norm module + """ + # check if apex is installed + try: + import apex + except ImportError: + raise ImportError( + 'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel') + + # get the attributes of the module + normalized_shape = module.normalized_shape + eps = module.eps + elementwise_affine = module.elementwise_affine + dtype = module.weight.dtype + device = module.weight.device + + # pick the suitable layernorm implementation + use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE + + if use_fast_ln: + try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm as ApexFusedLayerNorm + except ImportError: + # fall back to the normal fused layernorm is not built + from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm + else: + from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm + + layernorm = ApexFusedLayerNorm(normalized_shape, eps=eps, + elementwise_affine=elementwise_affine).to(dtype).to(device) + + with torch.no_grad(): + # copy weight and bias + layernorm.weight.copy_(module.weight) + layernorm.bias.copy_(module.bias) + return layernorm + + +class FusedRMSNorm(): + """ + This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface. + """ + + def __init__(self) -> None: + raise NotImplementedError( + 'FusedRMSNorm is not implemented as a physical class. ' + 'It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex.' + ) + + @staticmethod + def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: + try: + from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm + except ImportError: + raise ImportError( + 'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel' + ) + + # to check if it is huggingface LlamaRMSNorm + if module.__class__.__name__ == "LlamaRMSNorm": + normalized_shape = module.weight.shape[0] + eps = module.variance_epsilon + elementwise_affine = True + else: + # get the attributes of the module + normalized_shape = module.normalized_shape + eps = module.eps + elementwise_affine = module.elementwise_affine + + rmsnorm = ApexFusedRMSNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine) + + with torch.no_grad(): + # copy weight and bias + rmsnorm.weight.copy_(module.weight) + + return rmsnorm diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py new file mode 100644 index 000000000000..bda147b121ab --- /dev/null +++ b/colossalai/shardformer/layer/parallel_module.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import itertools +from abc import ABC, abstractmethod +from typing import List, Union + +import torch +import torch.nn as nn +from torch.distributed import ProcessGroup +from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module + +from colossalai.tensor.d_tensor import ( + distribute_tensor, + distribute_tensor_with_customization, + get_device_mesh, + get_sharding_spec, + is_customized_distributed_tensor, + is_distributed_tensor, + sharded_tensor_to_param, + to_global, + to_global_for_customized_distributed_tensor, +) + +__all__ = ['ParallelModule'] + + +class ParallelModule(nn.Module, ABC): + + @abstractmethod + def from_native_module(module: nn.Module, + process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "ParallelModule": + """ + Convert a native PyTorch module to a parallelized module. + + Args: + module (nn.Module): the module to be converted. + process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication. + If this is a list, the process group at the ith index of the list will correspond to the process group + in the ith axis of the device mesh. Defaults to None, which means the global process group. + """ + pass + + def _save_to_state_dict(self, destination, prefix, keep_vars): + r"""Saves module state to `destination` dictionary, containing a state + of the module, but not its descendants. This is called on every + submodule in :meth:`~torch.nn.Module.state_dict`. + + In rare cases, subclasses can achieve class-specific behavior by + overriding this method with custom logic. + + Args: + destination (dict): a dict where state will be stored + prefix (str): the prefix for parameters and buffers used in this + module + """ + for name, param in self._parameters.items(): + if param is not None: + param_ = param if keep_vars else param.detach() + if is_distributed_tensor(param_): + destination[prefix + name] = to_global(param_) + elif is_customized_distributed_tensor(param_): + destination[prefix + name] = to_global_for_customized_distributed_tensor(param_) + else: + destination[prefix + name] = param_ + + for name, buf in self._buffers.items(): + if buf is not None and name not in self._non_persistent_buffers_set: + destination[prefix + name] = buf if keep_vars else buf.detach() + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state: + destination[extra_state_key] = self.get_extra_state() + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + r"""Copies parameters and buffers from :attr:`state_dict` into only + this module, but not its descendants. This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + """ + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in local_state.items(): + key = prefix + name + + if key in state_dict: + input_param = state_dict[key] + if not torch.overrides.is_tensor_like(input_param): + error_msgs.append('While copying the parameter named "{}", ' + 'expected torch.Tensor or Tensor-like object from checkpoint but ' + 'received {}'.format(key, type(input_param))) + continue + + if is_distributed_tensor(param): + # shard the input param + device_mesh = get_device_mesh(param) + sharding_spec = get_sharding_spec(param) + sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec) + input_param = sharded_tensor_to_param(sharded_tensor) + elif is_customized_distributed_tensor(param): + input_param = distribute_tensor_with_customization(input_param, param.shard_fn, param.gather_fn) + + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = torch.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + + if not is_param_lazy and input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' + 'the shape in current model is {}.'.format(key, input_param.shape, param.shape)) + continue + + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + error_msgs.append('While copying the parameter named "{}", ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}, ' + 'an exception occurred : {}.'.format(key, param.size(), input_param.size(), + ex.args)) + elif strict: + missing_keys.append(key) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix):] + input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child + if input_name not in self._modules and input_name not in local_state: + unexpected_keys.append(key) diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py new file mode 100644 index 000000000000..9d51670c65dd --- /dev/null +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -0,0 +1,473 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math +from typing import Callable, List, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + +from colossalai.nn import init as init +from colossalai.nn.layer.utils import divide +from colossalai.tensor.d_tensor.api import ( + customized_distributed_tensor_to_param, + distribute_tensor_with_customization, + shard_rowwise, + sharded_tensor_to_param, +) + +from ._operation import ( + gather_forward_split_backward, + matmul_with_async_comm, + reduce_backward, + reduce_forward, + split_forward_gather_backward, +) +from .parallel_module import ParallelModule +from .utils import create_randomizer_with_offset + +__all__ = ['FusedLinear1D_Col', 'FusedLinear1D_Row'] + +# ==================================== +# For GPT Only +# ==================================== + + +def split_fused_qkv_in_gpt2_style(qkv: torch.Tensor, + n_fused: int, + process_group: ProcessGroup, + is_transposed: bool = False): + """ + The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2]. + + Args: + qkv (torch.Tensor): The fused qkv tensor. + n_fused (int): The number items fused together, defaults to 3 (query, key and value). + process_group (ProcessGroup): The process group for distributed communication. + is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features). + """ + # get the number of slice for the fused qkv + rank = dist.get_rank(group=process_group) + world_size = dist.get_world_size(group=process_group) + order = torch.arange(world_size * n_fused) + + # split the fused qkv + # from + # [Q, K, V] + # to + # [Q1, Q2, K1, K2, V1, V2] + if is_transposed: + weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1) + else: + weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=0) + + # rearrange the slice into the final order + # from + # [Q1, Q2, K1, K2, V1, V2] + # to + # [Q1, K1, V1], [Q2, K2, V2] + weight_chunks_of_current_rank = [weight_chunks[i] for i in order[rank::world_size]] + + if is_transposed: + weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=-1) + else: + weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=0) + return weight_of_current_rank + + +def gather_fused_qkv_in_gpt2_style(qkv: torch.Tensor, + n_fused: int, + process_group: ProcessGroup, + is_transposed: bool = False): + """ + The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2]. + + Args: + qkv (torch.Tensor): The fused qkv tensor. + n_fused (int): The number items fused together, defaults to 3 (query, key and value). + process_group (ProcessGroup): The process group for distributed communication. + is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features). + """ + world_size = dist.get_world_size(group=process_group) + + # gather the tensors + # from + # [Q1, K1, V1], [Q2, K2, V2] + # to + # [Q1, K1, V1, Q2, K2, V2] + origin_device = qkv.device + qkv = qkv.cuda() + gather_list = [torch.zeros_like(qkv) for _ in range(world_size)] + dist.all_gather(gather_list, qkv, group=process_group) + + if is_transposed: + gather_weight = torch.cat(gather_list, dim=-1) + else: + gather_weight = torch.cat(gather_list, dim=0) + gather_weight = gather_weight.to(origin_device) + qkv = qkv.to(origin_device) + + # rearrange the tensor slices + # from + # [Q1, K1, V1, Q2, K2, V2] + # to + # [Q1, Q2, K1, K2, V1, V2] + if is_transposed: + weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1) + else: + weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=0) + + reordered_chunk_list = [] + for i in range(n_fused): + reordered_chunk_list.extend(weight_chunks[i::n_fused]) + + if is_transposed: + reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1) + else: + reordered_gather_weight = torch.cat(reordered_chunk_list, dim=0) + return reordered_gather_weight + + +class GPT2FusedLinearConv1D_Col(ParallelModule): + r"""Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + n_fused (int): The number items fused, defaults to 3 (QKV). + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + async_communication: bool = False, + gather_output: bool = False, + skip_bias_add: bool = False, + n_fused: int = 3, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + self.device = device + self.n_fused = n_fused + self.process_group = process_group + self.async_communication = async_communication + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': device, 'dtype': dtype} + weight = torch.empty(self.in_features, self.out_features, **factory_kwargs) + + def shard_fn(tensor): + return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True) + + def gather_fn(tensor): + return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, True) + + with torch.no_grad(): + sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn) + self.weight = customized_distributed_tensor_to_param(sharded_weight) + + if bias: + bias = torch.empty(self.out_features, **factory_kwargs) + + with torch.no_grad(): + sharded_bias = distribute_tensor_with_customization(bias, shard_fn, gather_fn) + self.bias = customized_distributed_tensor_to_param(sharded_bias) + else: + self.bias = None + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # init weights + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, + *args, **kwargs) -> ParallelModule: + r""" + Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer. + + Args: + module (`nn.Linear`): The module to be converted. + process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication. + n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight. + """ + # get the attributes + in_features = module.weight.shape[0] + out_features = module.weight.shape[1] + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + linear_1d = GPT2FusedLinearConv1D_Col(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) + + # TODO: copy the sharded weights + with torch.no_grad(): + sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data, + n_fused=n_fused, + process_group=process_group, + is_transposed=True) + linear_1d.weight.data.copy_(sharded_weight.data) + + if bias: + sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data, + n_fused=n_fused, + process_group=process_group, + is_transposed=True) + linear_1d.bias.data.copy_(sharded_bias.data) + + return linear_1d + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert input_.shape[-1] == self.weight.shape[0], \ + 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. + input_parallel = reduce_backward(input_, self.process_group) + # input_parallel = input_ + + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + + output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group, + self.async_communication) + + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + else: + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output + + +class GPT2FusedLinearConv1D_Row(ParallelModule): + r""" Linear layer with row parallelism. + This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1): + super().__init__() + + self.stream_chunk_num = stream_chunk_num + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.parallel_input = parallel_input + self.skip_bias_add = skip_bias_add + self.process_group = process_group + self.num_partitions = dist.get_world_size(self.process_group) + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + # Divide the weight matrix along the last dimension. + self.input_size_per_partition = divide(in_features, self.num_partitions) + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': device, 'dtype': dtype} + weight = torch.empty(self.in_features, self.out_features, **factory_kwargs) + sharded_weight = shard_rowwise(weight, self.process_group) + self.weight = sharded_tensor_to_param(sharded_weight) + + if self.stream_chunk_num > 1: + # TODO() work for inference only + self.chunk_weight() + if bias: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + self.bias = None + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + # init weights + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + # get the attributes + in_features = module.weight.shape[0] + out_features = module.weight.shape[1] + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + linear_1d = GPT2FusedLinearConv1D_Row(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) + + # TODO: copy the sharded weights + with torch.no_grad(): + # the weigh to the linear layer is a transpose + # thus shard on col is equal to shard on row + sharded_weight = shard_rowwise(module.weight.data, process_group) + linear_1d.weight.data.copy_(sharded_weight.data) + + if bias: + linear_1d.bias.copy_(module.bias.data) + + return linear_1d + + def chunk_weight(self): + self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + if self.process_group is None: + src_rank = 0 + else: + src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) + + origin_device = self.bias.device + self.bias = self.bias.cuda() + dist.broadcast(self.bias, src=src_rank, group=self.process_group) + self.bias = self.bias.to(origin_device) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert input_.shape[-1] == self.weight.shape[0], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + input_ = input_ + else: + assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[0], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions) + input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) + + if self.stream_chunk_num > 1: + if self.training: + raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") + with torch.no_grad(): + output_parallel_list = [None for i in range(self.stream_chunk_num)] + handle_list = [] + for i in range(self.stream_chunk_num): + output_parallel_list[i] = torch.matmul(input_, self.weight_list[i]) + handle = torch.distributed.all_reduce(output_parallel_list[i], + group=self.process_group, + async_op=True) + handle_list.append(handle) + # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) + for handle in handle_list: + handle.wait() + output = torch.cat(output_parallel_list, dim=-1) + else: + output_parallel = torch.matmul(input_, self.weight) + output = reduce_forward(output_parallel, self.process_group) + + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py new file mode 100644 index 000000000000..f2ac6563c46f --- /dev/null +++ b/colossalai/shardformer/layer/utils.py @@ -0,0 +1,202 @@ +from contextlib import contextmanager + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import _get_global_rank + + +class Randomizer: + """ + Randomizer enables the program to be executed under a different seed within the context. + + Example: + + ```python + randomizer = Randomizer(seed=1024) + + with randomizer.fork(): + # do something here with seed 1024 + do_something() + ``` + + Args: + seed (int): The random seed to set. + enable_cpu (bool): fork the CPU RNG state as well. + with_index (bool): whether to use the index of the randomizer. + """ + + _INDEX = 0 + + def __init__(self, seed: int): + # TODO: remove colossalai.context.random + + self.seed = seed + + # Handle CUDA rng state + # 1. get the current rng state + # 2. set the seed and store the rng state + # 3. recover the original rng state + cuda_original_rng_state = torch.cuda.get_rng_state() + torch.cuda.manual_seed(seed) + self.cuda_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(cuda_original_rng_state) + + # to the same for cpu rng state + cpu_original_rng_state = torch.get_rng_state() + torch.manual_seed(seed) + self.cpu_rng_state = torch.get_rng_state() + torch.set_rng_state(cpu_original_rng_state) + + def _set_cuda_rng_state(self, rng_state): + torch.cuda.set_rng_state(rng_state) + + def _get_cuda_rng_state(self): + current_state = torch.cuda.get_rng_state() + return current_state + + def _set_cpu_rng_state(self, rng_state): + torch.set_rng_state(rng_state) + + def _get_cpu_rng_state(self): + current_state = torch.get_rng_state() + return current_state + + @contextmanager + def fork_rng(self, enable_cpu: bool = False): + """ + This is a context manager to change the dropout state and recover the original state. + + Usage: + :: + >>> with _seed_manager.dropout_mode(): + >>> input = super().forward(input) + """ + try: + current_cuda_rng_state = self._get_cuda_rng_state() + self._set_cuda_rng_state(self.cuda_rng_state) + + if enable_cpu: + current_cpu_rng_state = self._get_cpu_rng_state() + self._set_cpu_rng_state(self.cpu_rng_state) + yield + finally: + self.cuda_rng_state = self._get_cuda_rng_state() + self._set_cuda_rng_state(current_cuda_rng_state) + + if enable_cpu: + self.cpu_rng_state = self._get_cpu_rng_state() + self._set_cpu_rng_state(current_cpu_rng_state) + + @staticmethod + def index(): + """ + Return the index of the randomizer. The index is useful when the user wants + to introduce some randomness in the program. + + Note: + The index will increment by one each time this method is called. + + Example: + + ```python + # assume we need a randomizer to init the weight of different layers + # we can use the index of the randomizer to do so that + # each layer has its own randomizer with a different seed + base_seed = torch.random.initial_seed() + seed = base_seed + Randomizer.index() + randomizer = Randomizer(seed) + + with randomizer.fork(): + init_weights() + ``` + + """ + idx = Randomizer._INDEX + return idx + + @staticmethod + def increment_index(): + """ + Increment the index of the randomizer by one. + """ + Randomizer._INDEX += 1 + + @staticmethod + def is_randomizer_index_synchronized(process_group: ProcessGroup = None): + """ + Return whether the randomizer index is synchronized across processes. + """ + index = Randomizer.index() + if dist.is_initialized(): + # convert the index to tensor + index_tensor = torch.tensor(index, dtype=torch.int32).cuda() + + # all gather the index + gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))] + dist.all_gather(gathered_index, index_tensor, process_group) + + # make sure all the gathered index are the same + for i in range(1, dist.get_world_size(process_group)): + if gathered_index[i] != gathered_index[0]: + return False + + return True + + @staticmethod + def synchronize_index(process_group: ProcessGroup = None): + """ + All gather the index and pick the largest value. + """ + index = Randomizer.index() + + if dist.is_initialized(): + # convert the index to tensor + index_tensor = torch.tensor(index, dtype=torch.int32).cuda() + + # all gather the index + gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))] + dist.all_gather(gathered_index, index_tensor, process_group) + + # pick the largest index + for i in range(1, dist.get_world_size(process_group)): + if gathered_index[i] > index_tensor: + index_tensor = gathered_index[i] + + # set the index + Randomizer._INDEX = index_tensor.item() + + +def create_randomizer_with_offset(seed: int, + process_group: ProcessGroup = None, + offset_by_rank: bool = True, + offset_by_index: bool = True): + """ + Create a randomizer with an offset. The offset is equal to the rank of the process and the index of the randomizer. + + Args: + seed (int): The base random seed to set. + process_group (ProcessGroup): the process group to get the rank from. + offset_by_rank (bool): whether to offset by the rank of the process, i.e., the rank of the process will be added to the seed. Default: True. + offset_by_index (bool): whether to offset by the index of the randomizer, i.e., the index of the randomizer will be added to the seed. Default: True. + + Returns: + Randomizer: the randomizer with offset. + """ + base_seed = seed + + if offset_by_rank and dist.is_initialized(): + rank = dist.get_rank(process_group) + base_seed += rank + + if offset_by_index: + # check if the randomizer index is synchronized + is_synchronized = Randomizer.is_randomizer_index_synchronized(process_group) + assert is_synchronized, ("We detect that the randomizer index is not synchronized across processes." + "This is not allowed when we want to create a randomizer with offset by index." + "Please call Randomizer.synchronize_index() first.") + + base_seed += Randomizer.index() + Randomizer.increment_index() + + return Randomizer(seed=base_seed) diff --git a/colossalai/shardformer/modeling/__init__.py b/colossalai/shardformer/modeling/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py new file mode 100644 index 000000000000..a3d774ff2abb --- /dev/null +++ b/colossalai/shardformer/modeling/bloom.py @@ -0,0 +1,69 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + + +def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor: + + def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, + dtype: torch.dtype) -> torch.Tensor: + """ + Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it + relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value + `softmax(l+a) = softmax(l)`. Based on + https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 + TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. + + Args: + Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) + attention_mask (`torch.Tensor`): + Token-wise attention mask, this should be of shape (batch_size, max_seq_len). + num_heads (`int`, *required*): + number of heads + dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): + dtype of the output tensor + """ + import math + + if dist.is_initialized(): + world_size = dist.get_world_size(process_group) + num_heads = num_heads * world_size + + batch_size, seq_length = attention_mask.shape + closest_power_of_2 = 2**math.floor(math.log2(num_heads)) + base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32) + powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != num_heads: + extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange(1, + 1 + 2 * num_remaining_heads, + 2, + device=attention_mask.device, + dtype=torch.int32) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + + # Note: alibi will added to the attention bias that will be applied to the query, key product of attention + # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) + # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) + # => the query_length dimension will then be broadcasted correctly + # This is more or less identical to T5's relative position bias: + # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 + arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] + alibi = slopes[..., None] * arange_tensor + if dist.is_initialized(): + num_heads_per_rank = int(num_heads / dist.get_world_size(process_group)) + offset = dist.get_rank(process_group) * num_heads_per_rank + alibi = alibi.view(batch_size, num_heads, 1, seq_length) + alibi = alibi[:, offset:num_heads_per_rank + offset, :, :] + return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype) + else: + return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) + + return build_bloom_alibi_tensor diff --git a/colossalai/shardformer/policies/__init__.py b/colossalai/shardformer/policies/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py new file mode 100644 index 000000000000..8051433e8d71 --- /dev/null +++ b/colossalai/shardformer/policies/autopolicy.py @@ -0,0 +1,137 @@ +import importlib +from dataclasses import dataclass + +import torch.nn as nn + +from .basepolicy import Policy + +__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"] + + +@dataclass +class PolicyLocation: + """ + PolicyLocation describes the location of a policy class. + + Args: + file_name (str): The file name of the policy under colossalai.shardformer.policies + class_name (str): The class name of the policy class + """ + file_name: str + class_name: str + + +# we don't want to import all policies here +# as each policy file imports its own model zoo library +# we will allow the user to only import the policy file needed +_POLICY_LIST = { + # BERT + "transformers.models.bert.modeling_bert.BertModel": + PolicyLocation(file_name="bert", class_name="BertModelPolicy"), + "transformers.models.bert.modeling_bert.BertForPreTraining": + PolicyLocation(file_name="bert", class_name="BertForPretrainingPolicy"), + "transformers.models.bert.modeling_bert.BertLMHeadModel": + PolicyLocation(file_name="bert", class_name="BertLMHeadModelPolicy"), + "transformers.models.bert.modeling_bert.BertForMaskedLM": + PolicyLocation(file_name="bert", class_name="BertForMaskedLMPolicy"), + "transformers.models.bert.modeling_bert.BertForSequenceClassification": + PolicyLocation(file_name="bert", class_name="BertForSequenceClassificationPolicy"), + "transformers.models.bert.modeling_bert.BertForTokenClassification": + PolicyLocation(file_name="bert", class_name="BertForTokenClassificationPolicy"), + "transformers.models.bert.modeling_bert.BertForNextSentencePrediction": + PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"), + "transformers.models.bert.modeling_bert.BertForMultipleChoice": + PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"), + + # LLaMA + "transformers.models.llama.modeling_llama.LlamaModel": + PolicyLocation(file_name="llama", class_name="LlamaPolicy"), + "transformers.models.llama.modeling_llama.LlamaForCausalLM": + PolicyLocation(file_name="llama", class_name="LlamaForCausalLMPolicy"), + "transformers.models.llama.modeling_llama.LlamaForSequenceClassification": + PolicyLocation(file_name="llama", class_name="LlamaForSequenceClassificationPolicy"), + + # T5 + "transformers.models.t5.modeling_t5.T5Model": + PolicyLocation(file_name="t5", class_name="T5ModelPolicy"), + "transformers.models.t5.modeling_t5.T5ForConditionalGeneration": + PolicyLocation(file_name="t5", class_name="T5ForConditionalGenerationPolicy"), + "transformers.models.t5.modeling_t5.T5EncoderModel": + PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"), + + # GPT2 + "transformers.models.gpt2.modeling_gpt2.GPT2Model": + PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"), + "transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel": + PolicyLocation(file_name="gpt2", class_name="GPT2LMHeadModelPolicy"), + "transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel": + PolicyLocation(file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy"), + "transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification": + PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"), + "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": + PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"), + + # OPT + "transformers.models.opt.modeling_opt.OPTModel": + PolicyLocation(file_name="opt", class_name="OPTModelPolicy"), + "transformers.models.opt.modeling_opt.OPTForCausalLM": + PolicyLocation(file_name="opt", class_name="OPTForCausalLMPolicy"), + "transformers.models.opt.modeling_opt.OPTForSequenceClassification": + PolicyLocation(file_name="opt", class_name="OPTForSequenceClassificationPolicy"), + "transformers.models.opt.modeling_opt.OPTForQuestionAnswering": + PolicyLocation(file_name="opt", class_name="OPTForQuestionAnsweringPolicy"), + + # Bloom + "transformers.models.bloom.modeling_bloom.BloomModel": + PolicyLocation(file_name="bloom", class_name="BloomModelPolicy"), + "transformers.models.bloom.modeling_bloom.BloomForCausalLM": + PolicyLocation(file_name="bloom", class_name="BloomForCausalLMPolicy"), + "transformers.models.bloom.modeling_bloom.BloomForSequenceClassification": + PolicyLocation(file_name="bloom", class_name="BloomForSequenceClassificationPolicy"), + "transformers.models.bloom.modeling_bloom.BloomForTokenClassification": + PolicyLocation(file_name="bloom", class_name="BloomForTokenClassificationPolicy"), + "transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering": + PolicyLocation(file_name="bloom", class_name="BloomForQuestionAnsweringPolicy"), +} + + +def import_policy(policy_location: PolicyLocation) -> Policy: + """ + Dynamically import a Policy class based on the policy location. + """ + module_name = f"colossalai.shardformer.policies.{policy_location.file_name}" + module = importlib.import_module(module_name) + return getattr(module, policy_location.class_name) + + +def _fullname(obj): + """ + Return the full name of an object, including the module name. + """ + klass = obj.__class__ + module = klass.__module__ + if module == 'builtins': + return klass.__qualname__ # avoid outputs like 'builtins.str' + return module + '.' + klass.__qualname__ + + +def get_autopolicy(model: nn.Module) -> Policy: + r""" + Return the auto policy for the model + + Args: + model (:class:`nn.Module`): The model to get the auto policy + + Return: + :class:`Policy`: The auto policy for the model + """ + full_name = _fullname(model) + policy_location = _POLICY_LIST.get(full_name, None) + + if policy_location is None: + raise NotImplementedError( + f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}" + ) + else: + policy = import_policy(policy_location) + return policy() diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py new file mode 100644 index 000000000000..85e6d509c81b --- /dev/null +++ b/colossalai/shardformer/policies/basepolicy.py @@ -0,0 +1,153 @@ +# part of code modified from https://github.com/tunib-ai/parallelformers + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Type, Union + +import torch.nn as nn + +from ..shard.shard_config import ShardConfig + +__all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"] + + +class ParallelModule(): + + def __init__(self): + pass + + +@dataclass +class SubModuleReplacementDescription: + r""" + Describe how a submodule will be replaced + + Args: + suffix (str): used to get the submodule object + target_module (ParallelModule): specifies the module class used to replace to submodule + kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method. + ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception + """ + suffix: str + target_module: ParallelModule + kwargs: Dict[str, Any] = None + ignore_if_not_exist: bool = False + + +@dataclass +class ModulePolicyDescription: + r""" + Describe how the attributes and parameters will be transformed in a policy. + + Args: + attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding + param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function + must receive only one arguments: module. One example is + + ```python + def example_replace_weight(module: torch.nn.Module): + weight = module.weight + new_weight = shard_rowwise(weight, process_group) + module.weight = torch.nn.Parameter(new_weight) + ``` + sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription + object which specifies the module to be replaced and the target module used to replacement. + method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement + """ + attribute_replacement: Dict[str, Any] = None + param_replacement: List[Callable] = None + sub_module_replacement: List[SubModuleReplacementDescription] = None + method_replacement: Dict[str, Callable] = None + + +class Policy(ABC): + r""" + The base class for all the policies. For each different model, it should have a different policy class, + like BertPolicy for Bert Model or OPTPolicy for OPT model. + + Shardformer has provided many built-in sharding policies for the mainstream models. You can use the + built-in policies by setting `policy = None`, which is already the default arguemnt for `Shardformer.optimize`. + If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify. + """ + + def __init__(self) -> None: + self.shard_config = None + self.model = None + self.shard_config = None + + def set_model(self, model: nn.Module) -> None: + r""" + Set model as an attribute of the Policy object so that we can access the model's attributes. + + Args: + model (:class:`nn.Module`): The model to be perform + """ + self.model = model + + def set_shard_config(self, shard_config: ShardConfig) -> None: + r""" + Set shard config as an attribute of the Policy object. + + Args: + shard_config (:class:`ShardConfig`): The shard config to be perform + """ + self.shard_config = shard_config + self.config_sanity_check() + + @abstractmethod + def config_sanity_check(self): + """ + Check if the shard config is valid for the model. Raise an exception if the config is invalid. + This method is made abstractmethod with no default implementation because we want to the policy writer + to take note of the feature supported by his/her model and policy. + """ + pass + + @abstractmethod + def preprocess(self) -> nn.Module: + r""" + Perform some preprocessing of the model, like reshaping the embedding layer. + """ + pass + + @abstractmethod + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + r""" + This method returns the module policy, which is a dictionary. The key is the module name or the module object, + and the value is the ModulePolicyDescription object. The ModulePolicyDescription object describes how the module + will be transformed. + """ + pass + + @abstractmethod + def postprocess(self) -> nn.Module: + r""" + Perform some postprocessing of the model, like binding the weight of embedding layer with + the classifier layer + """ + pass + + def append_or_create_submodule_replacement( + self, description: Union[SubModuleReplacementDescription, + List[SubModuleReplacementDescription]], policy: Dict[Union[str, nn.Module], + ModulePolicyDescription], + target_key: Union[str, nn.Module]) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + r""" + Append or create a new submodule replacement description to the policy for the given key. + + Args: + submodule_replace_desc (Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]]): the submodule replacement description to be appended + policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated + target_key (Union[str, nn.Module]): the key of the policy to be updated + """ + # convert to list + if isinstance(description, SubModuleReplacementDescription): + description = [description] + + # append or create a new description + if target_key in policy: + policy[target_key].sub_module_replacement.extend(description) + else: + policy[target_key] = ModulePolicyDescription(sub_module_replacement=description) + + return policy diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py new file mode 100644 index 000000000000..9c2736cc64d3 --- /dev/null +++ b/colossalai/shardformer/policies/bert.py @@ -0,0 +1,293 @@ +import torch.nn as nn + +import colossalai.shardformer.layer as col_nn + +from .._utils import getattr_, setattr_ +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = [ + 'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy', + 'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy', + 'BertForMultipleChoicePolicy' +] + + +class BertPolicy(Policy): + + def config_sanity_check(self): + pass + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + # TODO: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer + + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[BertLayer] = ModulePolicyDescription(attribute_replacement={ + "attention.self.all_head_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "crossattention.self.all_head_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attention.self.num_attention_heads": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "crossattention.self.num_attention_heads": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.self.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.self.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.self.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.self.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ]) + + policy[BertEmbeddings] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ) + ]) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + # Handle bert layer + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="attention.output.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="output.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=BertLayer) + + # handle embedding layer + self.append_or_create_submodule_replacement( + description=[SubModuleReplacementDescription( + suffix="LayerNorm", + target_module=col_nn.FusedLayerNorm, + )], + policy=policy, + target_key=BertEmbeddings) + return policy + + def add_lm_head_policy(self, base_policy): + from transformers.models.bert.modeling_bert import BertLMPredictionHead + + # optimize for tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}), + policy=base_policy, + target_key=BertLMPredictionHead) + + # optimize with fused normalization + if self.shard_config.enable_fused_normalization: + # Handle bert lm prediction head + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="transform.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + policy=base_policy, + target_key=BertLMPredictionHead) + return base_policy + + def postprocess(self): + return self.model + + +# BertModel +class BertModelPolicy(BertPolicy): + + def __init__(self) -> None: + super().__init__() + + +# BertForPreTraining +class BertForPretrainingPolicy(BertPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + module_policy = super().module_policy() + module_policy = self.add_lm_head_policy(module_policy) + return module_policy + + def postprocess(self): + binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + setattr_(self.model, v, param) + return self.model + + +# BertLMHeadModel +class BertLMHeadModelPolicy(BertPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + module_policy = super().module_policy() + module_policy = self.add_lm_head_policy(module_policy) + return module_policy + + def postprocess(self): + binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + setattr_(self.model, v, param) + return self.model + + +# BertForMaskedLM +class BertForMaskedLMPolicy(BertPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + module_policy = super().module_policy() + module_policy = self.add_lm_head_policy(module_policy) + return module_policy + + def postprocess(self): + binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + setattr_(self.model, v, param) + return self.model + + +# BertForSequenceClassification +class BertForSequenceClassificationPolicy(BertPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.bert.modeling_bert import BertForSequenceClassification + + module_policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + BertForSequenceClassification: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ]) + } + module_policy.update(addon_module) + return module_policy + + +# BertForTokenClassification +class BertForTokenClassificationPolicy(BertPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.bert.modeling_bert import BertForTokenClassification + + module_policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + BertForTokenClassification: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ]) + } + module_policy.update(addon_module) + return module_policy + + +# BertForNextSentencePrediction +class BertForNextSentencePredictionPolicy(BertPolicy): + + def __init__(self) -> None: + super().__init__() + + +# BertForMultipleChoice +class BertForMultipleChoicePolicy(BertPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.bert.modeling_bert import BertForMultipleChoice + + module_policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + BertForMultipleChoice: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ]) + } + module_policy.update(addon_module) + return module_policy diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py new file mode 100644 index 000000000000..a0b5340f72bc --- /dev/null +++ b/colossalai/shardformer/policies/bloom.py @@ -0,0 +1,185 @@ +import torch.nn as nn + +import colossalai.shardformer.layer as col_nn + +from .._utils import getattr_, setattr_ +from ..modeling.bloom import build_bloom_alibi_tensor_fn +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + + +class BloomPolicy(Policy): + + def config_sanity_check(self): + pass + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel + + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={ + "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", + target_module=col_nn.Linear1D_Row, + ), + ]) + + policy[BloomModel] = ModulePolicyDescription( + attribute_replacement={ + "num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, + }, + method_replacement={ + "build_alibi_tensor": build_bloom_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group) + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ) + ]) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + # handle bloom model + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="ln_f", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="word_embeddings_layernorm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=BloomModel) + + # handle bloom block + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=col_nn.FusedLayerNorm, + ) + ], + policy=policy, + target_key=BloomBlock) + + return policy + + def postprocess(self): + return self.model + + +class BloomModelPolicy(BloomPolicy): + pass + + +class BloomForCausalLMPolicy(BloomPolicy): + + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomForCausalLM + policy = super().module_policy() + + # handle tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)), + policy=policy, + target_key=BloomForCausalLM) + + return policy + + def postprocess(self): + binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"} + + for k, v in binding_map.items(): + param = getattr_(self.model, k) + + if not isinstance(param, nn.Parameter): + param = nn.Parameter(param) + + # tie weights + setattr_(self.model, v, param) + return self.model + + +class BloomForSequenceClassificationPolicy(BloomPolicy): + + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification + policy = super().module_policy() + + # handle tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)), + policy=policy, + target_key=BloomForSequenceClassification) + + return policy + + +class BloomForTokenClassificationPolicy(BloomPolicy): + + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomForTokenClassification + policy = super().module_policy() + + # handle tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription(suffix="classifier", + target_module=col_nn.Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ], + policy=policy, + target_key=BloomForTokenClassification) + + return policy + + +class BloomForQuestionAnsweringPolicy(BloomPolicy): + # No head sharding as the output features is only 2 + pass diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py new file mode 100644 index 000000000000..549cdbf87a80 --- /dev/null +++ b/colossalai/shardformer/policies/gpt2.py @@ -0,0 +1,193 @@ +import torch.nn as nn + +import colossalai.shardformer.layer as col_nn + +from .._utils import getattr_, setattr_ +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = [ + 'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy', + 'GPT2ForTokenClassificationPolicy', 'GPT2ForSequenceClassificationPolicy' +] + + +class GPT2Policy(Policy): + + def config_sanity_check(self): + pass + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model + + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[GPT2Model] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wte", + target_module=col_nn.VocabParallelEmbedding1D, + ), + ]) + policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={ + "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.c_attn", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 3, + }, + ), + SubModuleReplacementDescription( + suffix="attn.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.c_fc", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 1, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + ), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ]) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="ln_f", + target_module=col_nn.FusedLayerNorm, + ), + policy=policy, + target_key=GPT2Model) + + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="ln_1", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="ln_2", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription(suffix="ln_cross_attn", + target_module=col_nn.FusedLayerNorm, + ignore_if_not_exist=True) + ], + policy=policy, + target_key=GPT2Block) + return policy + + def postprocess(self): + return self.model + + +# GPT2Model +class GPT2ModelPolicy(GPT2Policy): + + def __init__(self) -> None: + super().__init__() + + +# GPT2LMHeadModel +class GPT2LMHeadModelPolicy(GPT2Policy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel + + module_policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + GPT2LMHeadModel: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) + ]) + } + module_policy.update(addon_module) + return module_policy + + def postprocess(self): + binding_map = {"transformer.wte.weight": "lm_head.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + setattr_(self.model, v, param) + return self.model + + +# GPT22DoubleHeadsModel +class GPT2DoubleHeadsModelPolicy(GPT2Policy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel + + module_policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + GPT2DoubleHeadsModel: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) + ]) + } + module_policy.update(addon_module) + return module_policy + + def postprocess(self): + binding_map = {"transformer.wte.weight": "lm_head.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + setattr_(self.model, v, param) + return self.model + + +# GPT2ForTokenClassification +class GPT2ForTokenClassificationPolicy(GPT2Policy): + + def __init__(self) -> None: + super().__init__() + + +# GPT2ForSequenceClassification +class GPT2ForSequenceClassificationPolicy(GPT2Policy): + + def __init__(self) -> None: + super().__init__() diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py new file mode 100644 index 000000000000..157785bdcf13 --- /dev/null +++ b/colossalai/shardformer/policies/llama.py @@ -0,0 +1,145 @@ +from typing import Dict, Union + +import torch.nn as nn + +from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D + +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy'] + + +class LlamaPolicy(Policy): + + def config_sanity_check(self): + pass + + def preprocess(self): + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel + + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[LlamaDecoderLayer] = ModulePolicyDescription( + attribute_replacement={ + "self_attn.hidden_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=Linear1D_Row, + ) + ], + ) + + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ), + policy=policy, + target_key=LlamaModel) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=FusedRMSNorm, + ) + ], + policy=policy, + target_key=LlamaDecoderLayer) + + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=LlamaModel) + + return policy + + def postprocess(self): + return self.model + + +class LlamaForCausalLMPolicy(LlamaPolicy): + + def module_policy(self): + from transformers import LlamaForCausalLM + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + LlamaForCausalLM: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) + ]) + } + policy.update(new_item) + return policy + + +class LlamaForSequenceClassificationPolicy(LlamaPolicy): + + def module_policy(self): + from transformers import LlamaForSequenceClassification + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for sequence classification + new_item = { + LlamaForSequenceClassification: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) + ]) + } + policy.update(new_item) + return policy diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py new file mode 100644 index 000000000000..b87db53f45f1 --- /dev/null +++ b/colossalai/shardformer/policies/opt.py @@ -0,0 +1,140 @@ +from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D + +from .._utils import getattr_, setattr_ +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = [ + 'OPTPolicy', 'OPTModelPolicy', 'OPTForCausalLMPolicy', 'OPTForSequenceClassificationPolicy', + 'OPTForQuestionAnsweringPolicy' +] + + +class OPTPolicy(Policy): + + def config_sanity_check(self): + pass + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer + + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ) + ]) + policy[OPTDecoderLayer] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=Linear1D_Row, + ) + ]) + + policy[OPTAttention] = ModulePolicyDescription(attribute_replacement={ + "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="out_proj", + target_module=Linear1D_Row, + ), + ]) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), + policy=policy, + target_key=OPTDecoder) + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription(suffix="self_attn_layer_norm", + target_module=FusedLayerNorm, + ignore_if_not_exist=True), + SubModuleReplacementDescription(suffix="final_layer_norm", + target_module=FusedLayerNorm, + ignore_if_not_exist=True) + ], + policy=policy, + target_key=OPTDecoderLayer) + + return policy + + def postprocess(self): + return self.model + + +class OPTModelPolicy(OPTPolicy): + + def __init__(self) -> None: + super().__init__() + + +class OPTForCausalLMPolicy(OPTPolicy): + + def module_policy(self): + from transformers.models.opt.modeling_opt import OPTForCausalLM + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), + policy=policy, + target_key=OPTForCausalLM) + return policy + + def postprocess(self): + binding_map = { + 'model.decoder.embed_tokens': 'lm_head', + } + + for k, v in binding_map.items(): + src_mod = getattr_(self.model, k) + dst_mod = getattr_(self.model, v) + dst_mod.weight = src_mod.weight + + return self.model + + +class OPTForSequenceClassificationPolicy(OPTPolicy): + + def __init__(self) -> None: + super().__init__() + + +class OPTForQuestionAnsweringPolicy(OPTPolicy): + + def __init__(self) -> None: + super().__init__() diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py new file mode 100644 index 000000000000..cde59ab77042 --- /dev/null +++ b/colossalai/shardformer/policies/t5.py @@ -0,0 +1,249 @@ +from colossalai.shardformer.layer import ( + DropoutForParallelInput, + Embedding1D, + FusedRMSNorm, + Linear1D_Col, + Linear1D_Row, + VocabParallelEmbedding1D, +) +from colossalai.shardformer.policies.basepolicy import ModulePolicyDescription + +from .._utils import getattr_, setattr_ +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"] + + +class T5BasePolicy(Policy): + + def config_sanity_check(self): + pass + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.t5.modeling_t5 import ( + T5Attention, + T5DenseActDense, + T5DenseGatedActDense, + T5LayerCrossAttention, + T5LayerFF, + T5LayerSelfAttention, + T5Stack, + ) + + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=Embedding1D, + ) + ]) + policy[T5LayerSelfAttention] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ]) + policy[T5LayerCrossAttention] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ) + ]) + policy[T5Attention] = ModulePolicyDescription(attribute_replacement={ + "d_model": + self.model.config.d_model // self.shard_config.tensor_parallel_size, + "n_heads": + self.model.config.num_heads // self.shard_config.tensor_parallel_size, + "inner_dim": + self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="o", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="relative_attention_bias", + target_module=Embedding1D, + kwargs=dict(gather_output=False), + ignore_if_not_exist=True) + ]) + policy[T5LayerFF] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ]) + policy[T5DenseGatedActDense] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wi_0", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="wi_1", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ) + ]) + policy[T5DenseActDense] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wi", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="wo", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ) + ]) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="layer_norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=T5LayerFF) + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="layer_norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=T5LayerFF) + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="layer_norm", target_module=FusedRMSNorm), + policy=policy, + target_key=T5LayerSelfAttention) + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="layer_norm", target_module=FusedRMSNorm), + policy=policy, + target_key=T5LayerCrossAttention) + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=FusedRMSNorm), + policy=policy, + target_key=T5Stack) + return policy + + def postprocess(self): + binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]] + + for k, v in binding_map: + mod = getattr_(self.model, k) + setattr_(self.model, v, mod) + return self.model + + +class T5ModelPolicy(T5BasePolicy): + + def module_policy(self): + from transformers import T5Model + base_policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="shared", + target_module=VocabParallelEmbedding1D, + ), + policy=base_policy, + target_key=T5Model) + return base_policy + + +class T5ForConditionalGenerationPolicy(T5BasePolicy): + + def module_policy(self): + from transformers import T5ForConditionalGeneration + + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="shared", + target_module=VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription(suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)) + ], + policy=policy, + target_key=T5ForConditionalGeneration) + return policy + + def postprocess(self): + super().postprocess() + + binding_map = {"shared": "lm_head"} + + for k, v in binding_map.items(): + src_mod = getattr_(self.model, k) + dst_mod = getattr_(self.model, v) + dst_mod.weight = src_mod.weight + + return self.model + + +class T5EncoderPolicy(T5BasePolicy): + + def module_policy(self): + from transformers import T5EncoderModel + + base_policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="shared", + target_module=VocabParallelEmbedding1D, + ), + policy=base_policy, + target_key=T5EncoderModel) + return base_policy + + def postprocess(self): + binding_map = [ + ["shared", "encoder.embed_tokens"], + ] + + for k, v in binding_map: + mod = getattr_(self.model, k) + setattr_(self.model, v, mod) + return self.model diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py new file mode 100644 index 000000000000..eaebe2eee0ba --- /dev/null +++ b/colossalai/shardformer/policies/vit.py @@ -0,0 +1,110 @@ +from typing import Dict, Union + +import torch.nn as nn + +from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row + +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ['ViTPolicy'] + + +class ViTPolicy(Policy): + + def config_sanity_check(self): + pass + + def preprocess(self): + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer + + base_policy = { + ViTEmbeddings: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForReplicatedInput, + ) + ]), + ViTLayer: + ModulePolicyDescription(attribute_replacement={ + "attention.attention.num_attention_heads": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "attention.attention.all_head_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=DropoutForParallelInput, + ), + ]), + } + + # optimization configuration + if self.shard_config.enable_fused_normalization: + base_policy[ViTAttention].sub_module_replacement.extend([ + SubModuleReplacementDescription( + suffix="layernorm_before", + target_module=FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layernorm_after", + target_module=FusedLayerNorm, + ) + ]) + base_policy[ViTModel].sub_module_replacement.append( + SubModuleReplacementDescription( + suffix="layernorm", + target_module=FusedLayerNorm, + )) + + return base_policy + + def new_model_class(self): + return None + + def postprocess(self): + return self.model diff --git a/colossalai/shardformer/shard/__init__.py b/colossalai/shardformer/shard/__init__.py new file mode 100644 index 000000000000..7abdd45ec7c5 --- /dev/null +++ b/colossalai/shardformer/shard/__init__.py @@ -0,0 +1,5 @@ +from .shard_config import ShardConfig +from .sharder import ModelSharder +from .shardformer import ShardFormer + +__all__ = ['ShardConfig', 'ModelSharder', 'ShardFormer'] diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py new file mode 100644 index 000000000000..83c08d275df3 --- /dev/null +++ b/colossalai/shardformer/shard/shard_config.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass + +import torch.distributed as dist +from torch.distributed import ProcessGroup + +__all__ = ['ShardConfig'] + + +@dataclass +class ShardConfig: + r""" + The config for sharding the huggingface model + + Args: + tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. + enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. + enable_fused_normalization (bool): Whether to use fused layernorm, default is False. + enable_all_optimization (bool): Whether to turn on all optimization, default is False. + """ + tensor_parallel_process_group: ProcessGroup = None + enable_tensor_parallelism: bool = True + enable_fused_normalization: bool = False + enable_all_optimization: bool = False + + # TODO: add support for tensor parallel + # pipeline_parallel_size: int + # data_parallel_size: int + # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] + # inference_only: bool = True + # gather_output: bool = True + + @property + def tensor_parallel_size(self): + return self._tensor_parallel_size + + def __post_init__(self): + if not self.enable_tensor_parallelism: + self._tensor_parallel_size = 1 + else: + # get the parallel size + self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) + + # turn on all optimization if all_optimization is set to True + if self.enable_all_optimization: + self._turn_on_all_optimization() + + def _turn_on_all_optimization(self): + """ + Turn on all optimization. + """ + # you can add all the optimization flag here + self.enable_fused_normalization = True diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py new file mode 100644 index 000000000000..2867a0a4fd77 --- /dev/null +++ b/colossalai/shardformer/shard/sharder.py @@ -0,0 +1,174 @@ +from typing import Any, Callable, Dict, List, Union + +import torch.nn as nn + +from .._utils import getattr_, setattr_ +from ..policies.autopolicy import get_autopolicy +from ..policies.basepolicy import Policy, SubModuleReplacementDescription +from .shard_config import ShardConfig + +__all__ = ['ModelSharder', 'shard_model'] + + +class ModelSharder(object): + r""" + Shard the original huggingface model according to the policy + + Args: + policy (:class:`Policy`): The policy to shard the model + model (:class:`torch.Module`): The model to shard + shard_config: The setting of distributed model + """ + + def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None: + self.model = model + self.policy = get_autopolicy(self.model) if policy is None else policy + self.shard_config = shard_config + + def shard(self) -> None: + r""" + Shard the model according to the policy + """ + self.policy.set_model(self.model) + self.policy.set_shard_config(self.shard_config) + self._preprocess() + self._replace_module() + self._postprocess() + + def _preprocess(self) -> None: + self.model = self.policy.preprocess() + + def _postprocess(self) -> None: + self.model = self.policy.postprocess() + + def _replace_module(self,) -> None: + r""" + Replace the module according to the policy, and replace the module one by one + + Args: + model (:class:`torch.nn.Module`): The model to shard + """ + module_descriptions = self.policy.module_policy() + for layer_cls, module_description in module_descriptions.items(): + attr_replacement = module_description.attribute_replacement + param_replacement = module_description.param_replacement + sub_module_replacement = module_description.sub_module_replacement + method_replacement = module_description.method_replacement + self._recursive_replace_layer(self.model, layer_cls, attr_replacement, param_replacement, + method_replacement, sub_module_replacement) + + def _recursive_replace_layer( + self, + module: nn.Module, + origin_cls: Union[str, nn.Module], + attr_replacement: Dict[str, Any], + param_replacement: List[Callable], + method_replacement: Dict[str, Callable], + sub_module_replacement: List[Callable], + ) -> None: + r""" + Reverse the replace layer operation + + Args: + layer (torch.nn.Module): The object of layer to shard + origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name. + attr_replacement (Dict): The attribute dict to modify + param_replacement (List[Callable]): The function list to get parameter shard information in polic + sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy + """ + if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \ + (module.__class__ == origin_cls): + if attr_replacement is not None: + self._replace_attr(module, attr_replacement) + + if param_replacement is not None: + self._replace_param(module, param_replacement) + + if method_replacement is not None: + self._replace_method(module, method_replacement) + + if sub_module_replacement is not None: + self._replace_sub_module(module, sub_module_replacement) + + for name, child in module.named_children(): + self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement, method_replacement, + sub_module_replacement) + + def _replace_attr( + self, + module: nn.Module, + attr_replacement: Dict[str, Any], + ) -> None: + r""" + Replace the attribute of the layer + + Args: + layer (:class:`torch.nn.Module`): The object of layer to shard + attr_replacement (Dict): The attribute dict to modify + """ + for k, v in attr_replacement.items(): + setattr_(module, k, v, ignore=True) + + def _replace_param( + self, + module: nn.Module, + param_replacement: List[Callable], + ) -> None: + r""" + Replace the parameter of the layer + + Args: + layer (:class:`torch.nn.Module`): The object of layer to shard + param_replacement (List[Callable]): The function list to get parameter shard information in policy + """ + for param_func in param_replacement: + param_func(module) + + def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Callable]): + for method_name, new_method in method_replacement.items(): + # bind the new method to the module + setattr(module, method_name, new_method.__get__(module, module.__class__)) + + def _replace_sub_module( + self, + org_layer: nn.Module, + sub_module_replacement: List[SubModuleReplacementDescription], + ) -> None: + r""" + Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict + + Args: + org_layer (torch.nn.Module): The origin layer object to shard + sub_module_replacement (List[SubModuleReplacementDescription]): The sub module replacement description list + + """ + for description in sub_module_replacement: + suffix = description.suffix + target_module = description.target_module + kwargs = {} if description.kwargs is None else description.kwargs + + assert target_module is not None, 'target_module should not be None' + + # TODO: support different parallel mode + native_sub_module = getattr_(org_layer, suffix, ignore=True) + + assert not isinstance(native_sub_module, target_module), \ + f"The module with suffix {suffix} has been replaced, please check the policy" + + # if it is None and we are allowed to ignore this module + # just skip + if description.ignore_if_not_exist and native_sub_module is None: + continue + + try: + replace_layer = target_module.from_native_module(native_sub_module, + self.shard_config.tensor_parallel_process_group, + **kwargs) + except Exception as e: + raise RuntimeError( + f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}" + f" with {target_module.__qualname__} with the exception: {e}. " + "Please check your model configuration or sharding policy, you can set up an issue for us to help you as well." + ) + + setattr_(org_layer, suffix, replace_layer) diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py new file mode 100644 index 000000000000..3fce12463414 --- /dev/null +++ b/colossalai/shardformer/shard/shardformer.py @@ -0,0 +1,46 @@ +import torch.nn as nn + +from colossalai.cluster import DistCoordinator + +from ..policies.basepolicy import Policy +from .shard_config import ShardConfig +from .sharder import ModelSharder + + +class ShardFormer: + """ + Parallelize model based on the given config and policy + + Example: + + ```python + from colossalai.shardformer import ShardFormer, ShardConfig + from transformers import BertForMaskedLM + import colossalai + import torch + + colossalai.launch_from_torch(config={}) + + org_model = BertForMaskedLM.from_pretrained('bert-base-uncased') + shard_config = ShardConfig() + shard_former = ShardFormer(shard_config=shard_config) + model = shard_former.optimize(org_model) + ``` + """ + + def __init__(self, shard_config: ShardConfig): + self.coordinator = DistCoordinator() + self.shard_config = shard_config + + def optimize(self, model: nn.Module, policy: Policy = None): + r""" + This method will optimize the model based on the given policy. + + Args: + model (`torch.nn.Model`): the origin huggingface model + shard_config (`ShardConfig`): the config for distribute information + policy (`Policy`): the custom policy for sharding + """ + sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy) + sharder.shard() + return model diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py index af38d2a502c2..204f81343199 100644 --- a/colossalai/tensor/comm_spec.py +++ b/colossalai/tensor/comm_spec.py @@ -16,69 +16,66 @@ def _all_gather(tensor, comm_spec): ''' Implement all gather operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - tensor_list = [ - torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) - for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]) - ] - # without this contiguous operation, the all gather may get some unexpected results. - tensor = tensor.contiguous() - dist.all_gather(tensor_list, tensor, group=process_group) - output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() - return output + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + + tensor_list = [ + torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) + for _ in range(comm_spec.device_mesh.shape[comm_spec.logical_process_axis]) + ] + # without this contiguous operation, the all gather may get some unexpected results. + tensor = tensor.contiguous() + dist.all_gather(tensor_list, tensor, group=process_group) + output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() + return output def _split(tensor, comm_spec): ''' Implement shard operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, _ in process_groups_list: - if dist.get_rank() in rank_list: - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - start = length * rank_list.index(dist.get_rank()) - output = torch.narrow(tensor, dim, start, length).contiguous() - return output + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group) + start = length * dist.get_rank(process_group) + output = torch.narrow(tensor, dim, start, length).contiguous() + return output def _all_to_all(tensor, comm_spec): ''' Implement all to all operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - new_shape = list(tensor.shape) - new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) - new_shape = torch.Size(new_shape) - output_tensor_list = [ - torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) - ] - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - input_tensor_list = [ - torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) - ] - group = process_group - dist.all_to_all(output_tensor_list, input_tensor_list, group) - output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() - return output + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + world_size = dist.get_world_size(process_group) + + new_shape = list(tensor.shape) + new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size + new_shape = torch.Size(new_shape) + output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // world_size + input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)] + group = process_group + dist.all_to_all(output_tensor_list, input_tensor_list, group) + output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() + return output def _all_reduce(tensor, comm_spec, async_op=False): ''' Implement all reduce operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - if not tensor.is_contiguous(): - tensor = tensor.contiguous() - dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) - return tensor + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) + return tensor def _mix_gather(tensor, comm_spec): @@ -128,7 +125,7 @@ def _mix_gather(tensor, comm_spec): process_group = "[0, 1, 2, 3, 4, 5, 6, 7]" tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)] ''' - total_slices = comm_spec.device_mesh.mesh_shape[0] + total_slices = comm_spec.device_mesh.shape[0] tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)] leading_group_dim = comm_spec.logical_process_axes[0] assert len(comm_spec.device_mesh.process_groups_dict) == 1 @@ -149,7 +146,7 @@ def _mix_gather(tensor, comm_spec): if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]: output = torch.cat(tuple(tensor_list), comm_spec.gather_dim[0]).contiguous() else: - mesh_shape = comm_spec.device_meshes.mesh_shape + mesh_shape = comm_spec.device_meshes.shape cat_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]] tmp_tensor_shape = list(tensor.shape) tmp_tensor_shape[comm_spec.gather_dim[0]] *= cat_slice[0] @@ -181,9 +178,9 @@ def _mix_split(tensor, comm_spec): # [4, 5, 6, 7]] # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} ''' - mesh_shape = comm_spec.device_meshes.mesh_shape + mesh_shape = comm_spec.device_meshes.shape dim = comm_spec.gather_dim - total_slices = comm_spec.device_mesh.mesh_shape[0] + total_slices = comm_spec.device_mesh.shape[0] # Get global rank rank = dist.get_rank() @@ -414,7 +411,7 @@ def __init__(self, self.forward_only = forward_only if isinstance(self.logical_process_axis, list): if not mix_gather: - self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh + self.device_mesh = self.sharding_spec.device_mesh.flatten() self.logical_process_axis = 0 else: self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes diff --git a/colossalai/tensor/d_tensor/RAEDME.md b/colossalai/tensor/d_tensor/RAEDME.md new file mode 100644 index 000000000000..3d862dddbf20 --- /dev/null +++ b/colossalai/tensor/d_tensor/RAEDME.md @@ -0,0 +1,103 @@ +# ๐Ÿ”ข Distributed Tensor + +## ๐Ÿ“š Table of Contents + +- [๐Ÿ”ข Distributed Tensor](#-distributed-tensor) + - [๐Ÿ“š Table of Contents](#-table-of-contents) + - [๐Ÿ”— Introduction](#-introduction) + - [๐Ÿ“ Design](#-design) + - [๐Ÿ”จ Usage](#-usage) + - [๐ŸŽˆ Progress Log](#-progress-log) + +## ๐Ÿ”— Introduction + +Distributed tensor is a type of tensor that is distributed across multiple devices. It is a wrapper of PyTorch tensor, and it is used to support distributed training. +It can represent the device topology and tensor placement over the devices in the topology. It also provides a set of APIs to manipulate the distributed tensor. + +## ๐Ÿ“ Design + +Our implementation is inspired by the work [Alpa](https://arxiv.org/abs/2201.12023), which unifies data parallelism and tensor parallelism as intra-op parallelism. It uses notations `S` to represent the sharded dimension and `R` to represent the replicated dimension. For example, given a 2D matrix, `[S, R]` represents the tensor is sharded over the first dimension. + +Each sharded dimension will have a subscript to represent its placement over the devices. Assuming we have 4 GPUs and the GPUs are arranged in a 2 x 2 manner. Let's say we have a 2D matrix like below: + + +```text + [1, 2, 3, 4 ] +A = [4, 5, 6, 7 ] + [8, 9, 10, 11] + [12, 13, 14, 15] +``` + +`[S0, R]` would mean that the first dimension is sharded over the rows in the device topology. + +```text +| --------------------โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”-| +| | | +| [1, 2, 3, 4 ] | [1, 2, 3, 4 ] | +| [4, 5, 6, 7 ] | [4, 5, 6, 7 ] | +| | | +| --------------------โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”----- +| | | +| [8, 9, 10, 11] | [8, 9, 10, 11] | +| [12, 13, 14, 15] | [12, 13, 14, 15] | +| | | +| --------------------โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”----- +``` + +`[S01, R]` would mean that the first dimension is sharded over both the row and column in the device topology. + +```text +| --------------------โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”-| +| | | +| [1, 2, 3, 4 ] | [4, 5, 6, 7 ] | +| | | +| --------------------โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”----- +| | | +| [8, 9, 10, 11] | [12, 13, 14, 15] | +| | | +| --------------------โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”โ€”----- +``` + +## ๐Ÿ”จ Usage + +A sample API usage is given below. + +```python +import torch + +import colossalai +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.d_tensor import DTensor, ShardingSpec + +colossalai.launch_from_torch(config={}) + +# define your device mesh +# assume you have 4 GPUs +physical_mesh_id = torch.arange(0, 4) +mesh_shape = (2, 2) +device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + +# define a tensor +a = torch.rand(16, 32).cuda() + +# create sharding spec for the tensor +# assume the sharding spec is [S0, R] +dim_partition_dict = {0: [0]} +sharding_spec = ShardingSpec(a.dim(), dim_partition_dict) + +# create a distributed tensor +d_tensor = DTensor(a, device_mesh, sharding_spec) +print(d_tensor) + +global_tensor = d_tensor.to_global() +print(global_tensor) +``` + + +## ๐ŸŽˆ Progress Log + +- [x] Support layout conversion +- [x] Support sharding on 2D device mesh +- [ ] Support sharding on 3D device mesh +- [ ] Support sharding 4D device mesh +- [ ] Support sharding info saving and offline tensor merge (we can save tensor as dtensor and gather the tensors back to the global tensor based on the sharding info in a single process in CPU, useful for distributed training checkpoint load and save.) diff --git a/colossalai/tensor/d_tensor/__init__.py b/colossalai/tensor/d_tensor/__init__.py index e69de29bb2d1..3ae38a12555b 100644 --- a/colossalai/tensor/d_tensor/__init__.py +++ b/colossalai/tensor/d_tensor/__init__.py @@ -0,0 +1,28 @@ +from .api import ( + compute_global_numel, + customized_distributed_tensor_to_param, + distribute_tensor, + distribute_tensor_with_customization, + get_device_mesh, + get_global_shape, + get_layout, + get_sharding_spec, + is_customized_distributed_tensor, + is_distributed_tensor, + is_sharded, + redistribute, + shard_colwise, + shard_rowwise, + sharded_tensor_to_param, + to_global, + to_global_for_customized_distributed_tensor, +) +from .layout import Layout +from .sharding_spec import ShardingSpec + +__all__ = [ + 'is_distributed_tensor', 'distribute_tensor', 'to_global', 'is_sharded', 'shard_rowwise', 'shard_colwise', + 'sharded_tensor_to_param', 'compute_global_numel', 'get_sharding_spec', 'get_global_shape', 'get_device_mesh', + 'redistribute', 'get_layout', 'is_customized_distributed_tensor', 'distribute_tensor_with_customization', + 'to_global_for_customized_distributed_tensor', 'customized_distributed_tensor_to_param', 'Layout', 'ShardingSpec' +] diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py new file mode 100644 index 000000000000..95a44e09e16a --- /dev/null +++ b/colossalai/tensor/d_tensor/api.py @@ -0,0 +1,434 @@ +import copy +import operator +from functools import reduce +from typing import Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from colossalai.device.device_mesh import DeviceMesh + +from .layout import Layout +from .layout_converter import LayoutConverter +from .sharding_spec import ShardingSpec + +layout_converter = LayoutConverter() + + +def is_distributed_tensor(tensor: torch.Tensor) -> bool: + """ + Check whether the given tensor is a distributed tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + bool: Whether the given tensor is a distributed tensor. + """ + return hasattr(tensor, "dist_layout") + + +def is_sharded(dtensor: torch.Tensor) -> bool: + """ + Check if a tensor is sharded. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + bool: True if the tensor is sharded, False otherwise. + """ + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + return list(dtensor.shape) == list(dtensor.dist_layout.global_shape) + + +def _hijack_detach_and_clone(dtensor: torch.Tensor) -> torch.Tensor: + """ + Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied. + + Args: + tensor (torch.Tensor): The tensor to be hijacked. + + Returns: + torch.Tensor: The hijacked tensor. + """ + dtensor._old_detach = dtensor.detach + dtensor._old_clone = dtensor.clone + + def new_detach(self): + t_ = self._old_detach() + t_.dist_layout = copy.deepcopy(self.dist_layout) + return t_ + + def new_clone(self, *args, **kwargs): + t_ = self._old_clone(*args, **kwargs) + t_.dist_layout = copy.deepcopy(self.dist_layout) + return t_ + + # bind the new methods to the tensor + dtensor.detach = new_detach.__get__(dtensor) + dtensor.clone = new_clone.__get__(dtensor) + return dtensor + + +def _construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec: + ''' + Construct the default sharding specification for the tensor. + + Args: + tensor (`torch.Tensor`): the tensor to be sharded. + + Returns: + A `ShardingSpec` object without any sharding specified. + ''' + return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={}) + + +def _apply_layout(tensor, layout): + ''' + Apply the layout to the local tensor during initializing process. + ''' + # layout converter requires a source and target laytout + # we construct the source layer for an unsharded tensor + # and use self.dist_layer as the targer layout for the sharded tensor + source_spec = _construct_default_sharding_spec(tensor) + source_layout = Layout(device_mesh=layout.device_mesh, sharding_spec=source_spec, global_shape=tensor.shape) + sharded_tensor = layout_converter.apply(tensor=tensor, source_layout=source_layout, target_layout=layout) + return sharded_tensor + + +def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor: + """ + Convert the given tensor to a distributed tensor. + + Args: + tensor (torch.Tensor): The tensor to be converted. + device_mesh (DeviceMesh): The device mesh for abstraction of the compute devices. + sharding_spec (ShardingSpec): The sharding specification which describes how the tensor will be sharded. + + Returns: + torch.Tensor: The distributed tensor. + """ + assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.' + dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=tensor.shape) + + # shard tensor + sharded_tensor = _apply_layout(tensor, dist_layout) + + # hack some tensor methods + _hijack_detach_and_clone(sharded_tensor) + + return sharded_tensor + + +def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None: + ''' + Convert the layout of the tensor from source_spec to target_spec. + This will update the `local_tensor` and `dist_layout` in place. + + Args: + dtensor (torch.Tensor): the distributed tensor to be converted. + device_mesh (DeviceMesh): the device mesh for abstraction of the compute devices. + target_layout (Layout): the target layout specification. + ''' + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + global_shape = get_global_shape(dtensor) + target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape) + resharded_tensor = layout_converter.apply(tensor=dtensor, + source_layout=dtensor.dist_layout, + target_layout=target_layout) + return resharded_tensor + + +def to_global(dtensor: torch.Tensor) -> torch.Tensor: + """ + Convert a distributed tensor to the global tensor with the given layout. + This function returns a native `torch.Tensor` object. + + Args: + dtensor (torch.Tensor): the distributed tensor to be converted. + + Returns: + torch.Tensor: the global tensor. + """ + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + layout_converter = LayoutConverter() + + global_sharding_spec = ShardingSpec(dtensor.dim(), {}) + device_mesh = get_device_mesh(dtensor) + global_shape = get_global_shape(dtensor) + global_layout = Layout(device_mesh=device_mesh, sharding_spec=global_sharding_spec, global_shape=global_shape) + + global_tensor = layout_converter.apply(dtensor, dtensor.dist_layout, global_layout) + return global_tensor + + +def shard_rowwise( + tensor: torch.Tensor, + group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None, +) -> torch.Tensor: + """ + Shard the first dim of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be sharded. + group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor. + If None, the tensor will be sharded with respect to the global process group. + Defaults to None. + inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False. + + Returns: + torch.Tensor: The sharded tensor. + """ + # if the group_or_device_mesh is None, we shard the tensor with respect to the global process group + if group_or_device_mesh is None: + group_or_device_mesh = dist.GroupMember.WORLD + + if isinstance(group_or_device_mesh, ProcessGroup): + device_mesh = DeviceMesh.from_process_group(group_or_device_mesh) + else: + assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.' + device_mesh = group_or_device_mesh + + sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]}) + + return distribute_tensor(tensor, device_mesh, sharding_spec) + + +def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> torch.Tensor: + """ + Shard the first dim of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be sharded. + group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor. + If None, the tensor will be sharded with respect to the global process group. + Defaults to None. + inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False. + + Returns: + torch.Tensor: The sharded tensor. + """ + # if the group_or_device_mesh is None, we shard the tensor with respect to the global process group + if group_or_device_mesh is None: + group_or_device_mesh = dist.GroupMember.WORLD + + if isinstance(group_or_device_mesh, ProcessGroup): + device_mesh = DeviceMesh.from_process_group(group_or_device_mesh) + else: + assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.' + device_mesh = group_or_device_mesh + sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]}) + + return distribute_tensor(tensor, device_mesh, sharding_spec) + + +def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True): + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + param = torch.nn.Parameter(dtensor, requires_grad=requires_grad) + + # make it distributed as well + param.dist_layout = dtensor.dist_layout + _hijack_detach_and_clone(param) + + return param + + +def compute_global_numel(dtensor: torch.Tensor) -> int: + """ + Compute the global number of elements in the distributed tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + int: The global number of elements in the distributed tensor. + """ + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + numel = reduce(operator.mul, dtensor.dist_layout.global_shape) + return numel + + +def get_layout(dtensor: torch.Tensor) -> Layout: + """ + Get the layout of the distributed tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + Layout: The layout of the distributed tensor. + + """ + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + return dtensor.dist_layout + + +def get_global_shape(dtensor: torch.Tensor) -> torch.Size: + """ + Get the global shape of the distributed tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + torch.Size: The global shape of the distributed tensor. + """ + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + return dtensor.dist_layout.global_shape + + +def get_device_mesh(dtensor: torch.Tensor) -> DeviceMesh: + """ + Get the device mesh of the distributed tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + DeviceMesh: The device mesh of the distributed tensor. + """ + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + return dtensor.dist_layout.device_mesh + + +def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec: + """ + Get the sharding spec of the distributed tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + ShardingSpec: The sharding spec of the distributed tensor. + """ + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + return dtensor.dist_layout.sharding_spec + + +# ====================================================== +# Some sharding does not obey the SPMD style +# e.g. Fused QKV layer in GPT2 +# we support customize sharding with the following APIs +# ====================================================== +def is_customized_distributed_tensor(tensor: torch.Tensor): + """ + Check whether the given tensor is a customized distributed tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + bool: Whether the given tensor is a customized distributed tensor. + """ + return hasattr(tensor, 'shard_fn') and hasattr(tensor, 'gather_fn') + + +def _hijack_detach_and_clone_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor: + """ + Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied. + + Args: + tensor (torch.Tensor): The tensor to be hijacked. + + Returns: + torch.Tensor: The hijacked tensor. + """ + dtensor._old_detach = dtensor.detach + dtensor._old_clone = dtensor.clone + + def new_detach(self): + t_ = self._old_detach() + t_.shard_fn = self.shard_fn + t_.gather_fn = self.gather_fn + return t_ + + def new_clone(self, *args, **kwargs): + t_ = self._old_clone(*args, **kwargs) + t_.shard_fn = self.shard_fn + t_.gather_fn = self.gather_fn + return t_ + + # bind the new methods to the tensor + dtensor.detach = new_detach.__get__(dtensor) + dtensor.clone = new_clone.__get__(dtensor) + return dtensor + + +def distribute_tensor_with_customization(tensor: torch.Tensor, shard_fn, gather_fn: callable): + """ + Distribute the given tensor with the given shard_fn and gather_fn. + + Example: + + ```python + # define shard and gather functions + def shard_fn(tensor): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + return tensor.chunk(world_size, dim=0)[rank] + + def gather_fn(tensor): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + shard_list = [torch.zeros_like(tensor) for _ in range(world_size)] + torch.distributed.all_gather(shard_list, tensor) + return torch.cat(shard_list, dim=0) + + # create a distributed tensor + tensor = torch.rand(4, 4) + dtensor = distribute_tensor_with_customization(tensor, shard_fn, gather_fn) + ``` + + Args: + tensor (torch.Tensor): The tensor to be distributed. + shard_fn (callable): The function to shard the tensor. + gather_fn (callable): The function to gather the tensor. + + Returns: + torch.Tensor: The distributed tensor. + """ + assert callable(shard_fn), 'The shard_fn must be callable.' + assert callable(gather_fn), 'The gather_fn must be callable.' + assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.' + + sharded_tensor = shard_fn(tensor) + + # set the shard_fn and gather_fn as attributes of the distributed tensor + sharded_tensor.shard_fn = shard_fn + sharded_tensor.gather_fn = gather_fn + + # set the shard_fn and gather_fn as attributes of the distributed tensor + _hijack_detach_and_clone_for_customized_distributed_tensor(sharded_tensor) + + return sharded_tensor + + +def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor: + """ + Gather the given tensor to the global tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + torch.Tensor: The global tensor. + """ + assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.' + return dtensor.gather_fn(dtensor) + + +def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True): + """ + Convert the given customized distributed tensor to a parameter. + """ + assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.' + + param = torch.nn.Parameter(dtensor, requires_grad=requires_grad) + + # make it distributed as well + param.shard_fn = dtensor.shard_fn + param.gather_fn = dtensor.gather_fn + _hijack_detach_and_clone_for_customized_distributed_tensor(param) + return param diff --git a/colossalai/tensor/d_tensor/comm_spec.py b/colossalai/tensor/d_tensor/comm_spec.py index 159125fa16db..79b2e3ef936a 100644 --- a/colossalai/tensor/d_tensor/comm_spec.py +++ b/colossalai/tensor/d_tensor/comm_spec.py @@ -24,12 +24,12 @@ class CommSpec: ''' Communication spec is used to record the communication action. It converts the communication spec to real action which will be used in runtime. It contains comm_pattern to determine the - communication method, process_groups_dict to determine the process groups, gather_dim and shard_dim + communication method, process_group_dict to determine the process groups, gather_dim and shard_dim to determine the buffer shape, and logical_process_axis Argument: - comm_pattern(CollectiveCommPattern): describe the communication method used in this spec. - process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec. + comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec. + process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec. gather_dim(int, Optional): The gather_dim of the tensor will be gathered. shard_dim(int, Optional): The shard_dim of the tensor will be sharded. logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. @@ -37,7 +37,7 @@ class CommSpec: def __init__(self, comm_pattern: CollectiveCommPattern, - process_groups_dict: Dict, + process_group_dict: Dict, gather_dim: int = None, shard_dim: int = None, logical_process_axis: int = None): @@ -45,7 +45,7 @@ def __init__(self, self.gather_dim = gather_dim self.shard_dim = shard_dim self.logical_process_axis = logical_process_axis - self.process_groups_dict = process_groups_dict + self.process_group_dict = process_group_dict def __repr__(self): res_list = ["CommSpec:("] @@ -92,68 +92,56 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec): ''' Implement all gather operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - tensor_list = [ - torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) - ] - # without this contiguous operation, the all gather may get some unexpected results. - tensor = tensor.contiguous() - dist.all_gather(tensor_list, tensor, group=process_group) - output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() - return output + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + world_size = dist.get_world_size(process_group) + tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] + # without this contiguous operation, the all gather may get some unexpected results. + tensor = tensor.contiguous() + dist.all_gather(tensor_list, tensor, group=process_group) + output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() + return output def _split(tensor: torch.Tensor, comm_spec: CommSpec): ''' Implement shard operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, _ in process_groups_list: - if dist.get_rank() in rank_list: - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - start = length * rank_list.index(dist.get_rank()) - output = torch.narrow(tensor, dim, start, length).contiguous() - return output + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group) + start = length * dist.get_rank(process_group) + output = torch.narrow(tensor, dim, start, length).contiguous() + return output def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec): ''' Implement all to all operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - new_shape = list(tensor.shape) - new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) - new_shape = torch.Size(new_shape) - output_tensor_list = [ - torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) - ] - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - input_tensor_list = [ - torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) - ] - group = process_group - dist.all_to_all(output_tensor_list, input_tensor_list, group) - output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() - return output + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + world_size = dist.get_world_size(process_group) + new_shape = list(tensor.shape) + new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size + new_shape = torch.Size(new_shape) + output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // world_size + input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)] + group = process_group + dist.all_to_all(output_tensor_list, input_tensor_list, group) + output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() + return output def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False): ''' Implement all reduce operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - if not tensor.is_contiguous(): - tensor = tensor.contiguous() - dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) - return tensor + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) + return tensor class _ReduceGrad(torch.autograd.Function): @@ -269,7 +257,7 @@ def symbolic(graph, input_): def forward(ctx, input_, comm_spec): output = _all_to_all(input_, comm_spec) comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern, - process_groups_dict=comm_spec.process_groups_dict, + process_group_dict=comm_spec.process_group_dict, gather_dim=comm_spec.shard_dim, shard_dim=comm_spec.gather_dim, logical_process_axis=comm_spec.logical_process_axis) diff --git a/colossalai/tensor/d_tensor/layout.py b/colossalai/tensor/d_tensor/layout.py index ee7ef74a99ae..a35b2f43e44b 100644 --- a/colossalai/tensor/d_tensor/layout.py +++ b/colossalai/tensor/d_tensor/layout.py @@ -1,12 +1,11 @@ import operator -from dataclasses import dataclass from functools import reduce import torch from colossalai.device.device_mesh import DeviceMesh -from .misc import DuplicatedShardingDimensionError, LayoutException, ShardingNotDivisibleError +from .misc import DuplicatedShardingDimensionError, ShardingNotDivisibleError from .sharding_spec import ShardingSpec @@ -15,26 +14,23 @@ class Layout: Attributes: device_mesh: the device mesh to store the tensor distributed. - device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'. sharding_spec: the sharding specification to describe how the tensor is sharded. - entire_shape: the entire shape of the global tensor. + global_shape: the entire shape of the global tensor. """ - def __init__(self, device_mesh: DeviceMesh, device_type: torch.device, sharding_spec: ShardingSpec, - entire_shape: torch.Size): + def __init__(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size): self.device_mesh = device_mesh - self.device_type = device_type self.sharding_spec = sharding_spec - self.entire_shape = entire_shape + self.global_shape = global_shape self._sanity_check() def __hash__(self) -> int: return hash(f'{self.sharding_spec}') def get_sharded_shape_per_device(self): - sharded_shape = list(self.entire_shape) + sharded_shape = list(self.global_shape) for dim, shard_list in self.sharding_spec.dim_partition_dict.items(): - mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list] + mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list] shard_partitions = reduce(operator.mul, mesh_list, 1) assert sharded_shape[ dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.' @@ -45,22 +41,23 @@ def _sanity_check(self): sharding_spec = self.sharding_spec # make sure all axes in logical device mesh only be used once - dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim())) - for dim, shard_list in sharding_spec.dim_partition_dict.items(): - for element in shard_list: - if element in dim_check_list: - dim_check_list.remove(element) - else: - raise DuplicatedShardingDimensionError( - f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.") + if self.device_mesh.logical_mesh_id is not None: + dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim())) + for dim, shard_list in sharding_spec.dim_partition_dict.items(): + for element in shard_list: + if element in dim_check_list: + dim_check_list.remove(element) + else: + raise DuplicatedShardingDimensionError( + f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.") # make sure that the sharding for a dimension is divisible by the number of devices for dim, shard_list in sharding_spec.dim_partition_dict.items(): - tensor_dim_size = self.entire_shape[dim] + tensor_dim_size = self.global_shape[dim] num_devices = 1 for element in shard_list: - num_devices *= self.device_mesh.mesh_shape[element] + num_devices *= self.device_mesh.shape[element] if tensor_dim_size % num_devices != 0: raise ShardingNotDivisibleError( diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index cf02aac309f4..528ed7901c4f 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -3,10 +3,8 @@ from dataclasses import dataclass from typing import Dict, List, Tuple -import numpy as np import torch -from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from colossalai.context.singleton_meta import SingletonMeta from colossalai.tensor.d_tensor.comm_spec import * from colossalai.tensor.d_tensor.layout import Layout @@ -28,18 +26,6 @@ class LayoutConverterOptions: pass -def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor: - layout_converter = LayoutConverter() - global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {}) - global_layout = Layout(device_mesh=layout.device_mesh, - device_type=layout.device_type, - sharding_spec=global_sharding_spec, - entire_shape=layout.entire_shape) - with torch.no_grad(): - global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout) - return global_tensor - - def set_layout_converting_options(options: LayoutConverterOptions): """ Configure the shape consistency manager via function call. @@ -49,6 +35,9 @@ def set_layout_converting_options(options: LayoutConverterOptions): class LayoutConverter(metaclass=SingletonMeta): + """ + LayoutConverter is a singleton class which converts the layout of a distributed tensor. + """ def __init__(self): self._options = None @@ -91,15 +80,14 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_dict = {0: [0], 1: [1]} # [S0,S1,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec, - entire_shape=entire_shape) + global_shape=global_shape) rst_dict = layout_converter.all_gather_transform_layouts(layout) for layout, comm_spec in rst_dict.items(): @@ -112,7 +100,12 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co valid_spec_dict = {} comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD source_spec = source_layout.sharding_spec - process_groups_dict = source_layout.device_mesh.process_groups_dict + + # the key of the dict is the axis + # the value is the process group + current_rank = source_layout.device_mesh._global_rank_of_current_process + process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] + for target_pair in source_spec.dim_partition_dict.items(): shard_list = all_gather_simulator(target_pair) index = target_pair[0] @@ -130,7 +123,7 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co logical_process_axis = target_pair[1][-1] comm_spec = CommSpec( comm_pattern, - process_groups_dict=process_groups_dict, + process_group_dict=process_group_dict, gather_dim=gather_dim, # shard_dim will be used during backward shard_dim=gather_dim, @@ -141,8 +134,7 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) new_layout = Layout(device_mesh=source_layout.device_mesh, sharding_spec=new_sharding_spec, - device_type=source_layout.device_type, - entire_shape=source_layout.entire_shape) + global_shape=source_layout.global_shape) valid_spec_dict[new_layout] = comm_spec except LayoutException: @@ -167,15 +159,14 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_dict = {0: [0], 1: [1]} # [S0,S1,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec, - entire_shape=entire_shape) + global_shape=global_shape) rst_dict = layout_converter.all_to_all_transform_layout(layout) for layout, comm_spec in rst_dict.items(): @@ -188,7 +179,12 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com ''' valid_spec_dict = {} comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD - process_groups_dict = source_layout.device_mesh.process_groups_dict + + # the key of the dict is the axis + # the value is the process group + current_rank = source_layout.device_mesh._global_rank_of_current_process + process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] + source_spec = source_layout.sharding_spec tensor_dims = source_spec.dims for f_index in range(tensor_dims - 1): @@ -229,7 +225,7 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com shard_dim = f_index logical_process_axis = b_target_pair[1][-1] comm_spec = CommSpec(comm_pattern, - process_groups_dict, + process_group_dict=process_group_dict, gather_dim=gather_dim, shard_dim=shard_dim, logical_process_axis=logical_process_axis) @@ -252,8 +248,7 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) new_layout = Layout(device_mesh=source_layout.device_mesh, sharding_spec=new_sharding_spec, - device_type=source_layout.device_type, - entire_shape=source_layout.entire_shape) + global_shape=source_layout.global_shape) valid_spec_dict[new_layout] = comm_spec except LayoutException: pass @@ -278,16 +273,15 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_dict = {0: [0]} # [S0,R,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec, - entire_shape=entire_shape) + global_shape=global_shape) rst_dict = layout_converter.shard_transform_layout(layout) for layout, comm_spec in rst_dict.items(): @@ -301,10 +295,14 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec valid_spec_dict = {} comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD source_spec = source_layout.sharding_spec - process_groups_dict = source_layout.device_mesh.process_groups_dict + + # the key of the dict is the axis + # the value is the process group + current_rank = source_layout.device_mesh._global_rank_of_current_process + process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] # legal sharding dims means the mesh_id is still available to use. - legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.mesh_shape))] + legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))] for dim, shard_list in source_spec.dim_partition_dict.items(): for element in shard_list: legal_sharding_dims.remove(element) @@ -329,7 +327,7 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec shard_dim = index logical_process_axis = shard_list[-1] comm_spec = CommSpec(comm_pattern, - process_groups_dict, + process_group_dict=process_group_dict, gather_dim=shard_dim, shard_dim=shard_dim, logical_process_axis=logical_process_axis) @@ -340,8 +338,7 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec dim_partition_dict=new_dim_partition_dict) new_layout = Layout(device_mesh=source_layout.device_mesh, sharding_spec=new_sharding_spec, - device_type=source_layout.device_type, - entire_shape=source_layout.entire_shape) + global_shape=source_layout.global_shape) valid_spec_dict[new_layout] = comm_spec except LayoutException: pass @@ -399,7 +396,7 @@ def layout_converting(self, source_layout: Layout, # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_source = {1: [0, 1]} dim_partition_target = {0: [0, 1]} @@ -407,16 +404,14 @@ def layout_converting(self, source_layout: Layout, # [R,S01,R] sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + global_shape=global_shape) # [S01,R,R] sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + global_shape=global_shape) transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path]) @@ -505,21 +500,19 @@ def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layo # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) # [S0,R,R] sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + global_shape=global_shape) # [R,S0,R] sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + global_shape=global_shape) if rank in (0, 1): sharded_tensor_0 = torch.zeros(2, 1) @@ -553,4 +546,5 @@ def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layo _, comm_action_sequence = self.layout_converting(source_layout, target_layout) for comm_spec in comm_action_sequence: tensor = comm_spec.covert_spec_to_action(tensor) + tensor.dist_layout = target_layout return tensor diff --git a/colossalai/tensor/d_tensor/utils.py b/colossalai/tensor/d_tensor/utils.py index 644bb6306b42..fc22b990d879 100644 --- a/colossalai/tensor/d_tensor/utils.py +++ b/colossalai/tensor/d_tensor/utils.py @@ -29,7 +29,7 @@ def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = Fals # the comm size for all gather is the size of the gathered tensor gather_dim = comm_spec.gather_dim all_gather_axis = layout.sharding_spec.dim_partition_dict[gather_dim][-1] - all_gather_size = device_mesh.mesh_shape[all_gather_axis] + all_gather_size = device_mesh.shape[all_gather_axis] comm_size_for_all_gather = comm_size * all_gather_size forward_communication_cost = device_mesh.all_gather_cost(comm_size_for_all_gather, logical_process_axis) # give a tiny cost to shard diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index 5bec552d69d5..99d782c3f6e8 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -285,7 +285,7 @@ def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict): comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD # legal sharding dims means the mesh_id is still available to use. - legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.mesh_shape))] + legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.shape))] for dim, shard_list in source_spec.dim_partition_dict.items(): for element in shard_list: legal_sharding_dims.remove(element) @@ -435,7 +435,7 @@ def gather_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, """ input_shape = compute_shape(comm_spec.sharding_spec) input_numel = np.prod(input_shape) - output_numel = input_numel * comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis] + output_numel = input_numel * comm_spec.device_mesh.shape[comm_spec.logical_process_axis] peak_numel = max(peak_numel, alloc_numel + output_numel * 2) alloc_numel += output_numel if discard_input: @@ -461,7 +461,7 @@ def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, p # generate a new tensor input_shape = compute_shape(comm_spec.sharding_spec) input_numel = np.prod(input_shape) - output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis] + output_numel = input_numel // comm_spec.device_mesh.shape[comm_spec.logical_process_axis] alloc_numel += output_numel peak_numel = max(peak_numel, alloc_numel) if discard_input: diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py index 406ad49097b5..e594fd297dc4 100644 --- a/colossalai/tensor/sharding_spec.py +++ b/colossalai/tensor/sharding_spec.py @@ -195,7 +195,7 @@ def __init__(self, def __repr__(self): res_list = ["DistSpec:"] res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence)) - res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.mesh_shape}") + res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.shape}") return ' '.join(res_list) def _sanity_check(self): @@ -222,7 +222,7 @@ def _sanity_check(self): num_devices = 1 for element in shard_list: - num_devices *= self.device_mesh.mesh_shape[element] + num_devices *= self.device_mesh.shape[element] if tensor_dim_size % num_devices != 0: raise ShardingNotDivisibleError( @@ -288,7 +288,7 @@ def get_sharded_shape_per_device(self): sharded_shape = list(self.entire_shape) for dim, shard_list in self.dim_partition_dict.items(): - mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list] + mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list] shard_partitions = reduce(operator.mul, mesh_list, 1) assert sharded_shape[ dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.' diff --git a/colossalai/testing/__init__.py b/colossalai/testing/__init__.py index 9d0475ed064c..0db33361c6a0 100644 --- a/colossalai/testing/__init__.py +++ b/colossalai/testing/__init__.py @@ -3,6 +3,7 @@ assert_close_loose, assert_equal, assert_equal_in_group, + assert_hf_output_close, assert_not_equal, check_state_dict_equal, ) @@ -20,5 +21,5 @@ __all__ = [ 'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize', 'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus', 'free_port', 'spawn', - 'clear_cache_before_run', 'run_on_environment_flag', 'check_state_dict_equal' + 'clear_cache_before_run', 'run_on_environment_flag', 'check_state_dict_equal', 'assert_hf_output_close' ] diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index faf61638d8bb..5cbfb936b144 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -1,4 +1,4 @@ -from typing import OrderedDict +from typing import Any, List, OrderedDict import torch import torch.distributed as dist @@ -52,3 +52,52 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool assert torch.equal(v, d2[k]) else: assert v == d2[k] + + +def assert_hf_output_close(out1: Any, + out2: Any, + ignore_keys: List[str] = None, + track_name: str = "", + atol=1e-5, + rtol=1e-5): + """ + Check if two outputs from huggingface are equal. + + Args: + out1 (Any): the first output + out2 (Any): the second output + ignore_keys (List[str]): the keys to ignore when comparing two dicts + track_name (str): the name of the value compared, used to track the path + """ + if isinstance(out1, dict) and isinstance(out2, dict): + # if two values are dict + # we recursively check the keys + assert set(out1.keys()) == set(out2.keys()) + for k in out1.keys(): + if ignore_keys is not None and k in ignore_keys: + continue + assert_hf_output_close(out1[k], + out2[k], + track_name=f"{track_name}.{k}", + ignore_keys=ignore_keys, + atol=atol, + rtol=rtol) + elif isinstance(out1, (list, tuple)) and isinstance(out2, (list, tuple)): + # if two values are list + # we recursively check the elements + assert len(out1) == len(out2) + for i in range(len(out1)): + assert_hf_output_close(out1[i], + out2[i], + track_name=f"{track_name}.{i}", + ignore_keys=ignore_keys, + atol=atol, + rtol=rtol) + elif isinstance(out1, Tensor) and isinstance(out2, Tensor): + if out1.shape != out2.shape: + raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}") + assert torch.allclose( + out1, out2, atol=atol, rtol=rtol + ), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, \nmean error: {torch.abs(out1 - out2).mean()}" + else: + assert out1 == out2, f"{track_name}: value mismatch.\nout1: {out1}\nout2: {out2}" diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 6895113bc637..50121a9283f2 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -15,3 +15,4 @@ einops triton==2.0.0.dev20221202 git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 +SentencePiece diff --git a/tests/kit/model_zoo/registry.py b/tests/kit/model_zoo/registry.py index 6cc4c8ef370d..1e7ef3b62736 100644 --- a/tests/kit/model_zoo/registry.py +++ b/tests/kit/model_zoo/registry.py @@ -28,27 +28,35 @@ def register(self, model_fn: Callable, data_gen_fn: Callable, output_transform_fn: Callable, + loss_fn: Callable = None, model_attribute: ModelAttribute = None): """ Register a model and data generation function. Examples: - >>> # Register - >>> model_zoo = ModelZooRegistry() - >>> model_zoo.register('resnet18', resnet18, resnet18_data_gen) - >>> # Run the model - >>> data = resnet18_data_gen() # do not input any argument - >>> model = resnet18() # do not input any argument - >>> out = model(**data) + + ```python + # normal forward workflow + model = resnet18() + data = resnet18_data_gen() + output = model(**data) + transformed_output = output_transform_fn(output) + loss = loss_fn(transformed_output) + + # Register + model_zoo = ModelZooRegistry() + model_zoo.register('resnet18', resnet18, resnet18_data_gen, output_transform_fn, loss_fn) + ``` Args: name (str): Name of the model. - model_fn (callable): A function that returns a model. **It must not contain any arguments.** - output_transform_fn (callable): A function that transforms the output of the model into Dict. - data_gen_fn (callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.** + model_fn (Callable): A function that returns a model. **It must not contain any arguments.** + data_gen_fn (Callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.** + output_transform_fn (Callable): A function that transforms the output of the model into Dict. + loss_fn (Callable): a function to compute the loss from the given output. Defaults to None model_attribute (ModelAttribute): Attributes of the model. Defaults to None. """ - self[name] = (model_fn, data_gen_fn, output_transform_fn, model_attribute) + self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute) def get_sub_registry(self, keyword: str): """ @@ -62,6 +70,8 @@ def get_sub_registry(self, keyword: str): for k, v in self.items(): if keyword in k: new_dict[k] = v + + assert len(new_dict) > 0, f'No model found with keyword {keyword}' return new_dict diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index f56ff7ad84eb..4aa01abe13ee 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -1,5 +1,7 @@ from .albert import * from .bert import * +from .bloom import * from .gpt import * +from .llama import * from .opt import * from .t5 import * diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py index 99135704da70..d2d3de7b7bee 100644 --- a/tests/kit/model_zoo/transformers/bert.py +++ b/tests/kit/model_zoo/transformers/bert.py @@ -6,83 +6,147 @@ # =============================== # Register single-sentence BERT # =============================== -BATCH_SIZE = 2 -SEQ_LENGTH = 16 -def data_gen_fn(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) +# define data gen function +def data_gen(): + # Generated from following code snippet + # + # from transformers import BertTokenizer + # input = 'Hello, my dog is cute' + # tokenized_input = tokenizer(input, return_tensors='pt') + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + # token_type_ids = tokenized_input['token_type_ids'] + input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]], dtype=torch.int64) + token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data['labels'] = data['input_ids'].clone() + return data + + +def data_gen_for_pretraining(): + # pretraining data gen + # `next_sentence_label` is the label for next sentence prediction, 0 or 1 + data = data_gen_for_lm() + data['next_sentence_label'] = torch.tensor([1], dtype=torch.int64) + return data + + +def data_gen_for_sequence_classification(): + # sequence classification data gen + # `labels` is the label for sequence classification, 0 or 1 + data = data_gen() + data['labels'] = torch.tensor([1], dtype=torch.int64) + return data + + +def data_gen_for_token_classification(): + # token classification data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen() + data['labels'] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) + return data + + +def data_gen_for_mcq(): + # multiple choice question data gen + # Generated from following code snippet + # + # tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") + # prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + # choice0 = "It is eaten with a fork and a knife." + # choice1 = "It is eaten while held in the hand." + # data = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True) + # data = {k: v.unsqueeze(0) for k, v in encoding.items()} + # data['labels'] = torch.tensor([0], dtype=torch.int64) + input_ids = torch.tensor([[[ + 101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591, + 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102 + ], + [ + 101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, + 4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096, + 2218, 1999, 1996, 2192, 1012, 102, 0 + ]]]) + token_type_ids = torch.tensor( + [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]]) + attention_mask = torch.tensor( + [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]]) + labels = torch.tensor([0], dtype=torch.int64) + + return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels) + + +# define output transform function output_transform_fn = lambda x: x -config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256) +# define loss funciton +loss_fn_for_bert_model = lambda x: x.pooler_output.mean() +loss_fn = lambda x: x.loss + +config = transformers.BertConfig(hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=256, + hidden_dropout_prob=0, + attention_probs_dropout_prob=0) # register the BERT variants model_zoo.register(name='transformers_bert', model_fn=lambda: transformers.BertModel(config), - data_gen_fn=data_gen_fn, + data_gen_fn=data_gen, output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_bert_model, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_bert_for_pretraining', model_fn=lambda: transformers.BertForPreTraining(config), - data_gen_fn=data_gen_fn, + data_gen_fn=data_gen_for_pretraining, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_bert_lm_head_model', model_fn=lambda: transformers.BertLMHeadModel(config), - data_gen_fn=data_gen_fn, + data_gen_fn=data_gen_for_lm, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_bert_for_masked_lm', model_fn=lambda: transformers.BertForMaskedLM(config), - data_gen_fn=data_gen_fn, + data_gen_fn=data_gen_for_lm, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_bert_for_sequence_classification', model_fn=lambda: transformers.BertForSequenceClassification(config), - data_gen_fn=data_gen_fn, + data_gen_fn=data_gen_for_sequence_classification, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_bert_for_token_classification', model_fn=lambda: transformers.BertForTokenClassification(config), - data_gen_fn=data_gen_fn, + data_gen_fn=data_gen_for_token_classification, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) - - -# =============================== -# Register multi-sentence BERT -# =============================== -def data_gen_for_next_sentence(): - tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") - prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." - next_sentence = "The sky is blue due to the shorter wavelength of blue light." - encoding = tokenizer(prompt, next_sentence, return_tensors="pt") - return encoding - - -def data_gen_for_mcq(): - tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") - prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." - choice0 = "It is eaten with a fork and a knife." - choice1 = "It is eaten while held in the hand." - encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True) - encoding = {k: v.unsqueeze(0) for k, v in encoding.items()} - return encoding - - -# register the following models model_zoo.register(name='transformers_bert_for_next_sentence', model_fn=lambda: transformers.BertForNextSentencePrediction(config), - data_gen_fn=data_gen_for_next_sentence, + data_gen_fn=data_gen_for_sequence_classification, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_bert_for_mcq', model_fn=lambda: transformers.BertForMultipleChoice(config), data_gen_fn=data_gen_for_mcq, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/bloom.py b/tests/kit/model_zoo/transformers/bloom.py new file mode 100644 index 000000000000..71146c0b9819 --- /dev/null +++ b/tests/kit/model_zoo/transformers/bloom.py @@ -0,0 +1,107 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register Bloom +# =============================== + + +def data_gen(): + # Generated from following code snippet + # + # from transformers import BloomTokenizer + # input = 'Hello, my dog is cute' + # tokenized_input = tokenizer(input, return_tensors='pt') + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data['labels'] = data['input_ids'].clone() + return data + + +def data_gen_for_token_classification(): + # token classification data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen() + data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64) + return data + + +def data_gen_for_sequence_classification(): + # sequence classification data gen + data = data_gen() + data['labels'] = torch.tensor([0], dtype=torch.int64) + return data + + +def data_gen_for_question_answering(): + # obtained with the following code + # + # from transformers import AutoTokenizer + # tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") + # question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + # inputs = tokenizer(question, text, return_tensors="pt") + + input_ids = torch.tensor( + [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss function +loss_fn_for_bloom_model = lambda x: x.last_hidden_state.mean() +loss_fn_for_causal_lm = lambda x: x.loss +loss_fn_for_classification = lambda x: x.logits.mean() +loss_fn_for_question_answering = lambda x: x.end_logits.mean() + +config = transformers.BloomConfig(n_layer=1, + n_head=4, + vocab_size=250880, + hidden_dropout=0, + attention_dropout=0, + hidden_size=64) + +# register the following models +model_zoo.register(name='transformers_bloom', + model_fn=lambda: transformers.BloomModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_bloom_model, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bloom_for_causal_lm', + model_fn=lambda: transformers.BloomForCausalLM(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_causal_lm, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bloom_for_sequence_classification', + model_fn=lambda: transformers.BloomForSequenceClassification(config), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_classification, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bloom_for_token_classification', + model_fn=lambda: transformers.BloomForTokenClassification(config), + data_gen_fn=data_gen_for_token_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_classification, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bloom_for_question_answering', + model_fn=lambda: transformers.BloomForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_question_answering, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index 5ed4fbe70dc9..b9e0310780af 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -6,52 +6,89 @@ # =============================== # Register single-sentence GPT # =============================== -BATCH_SIZE = 1 # it can only be 1 as GPT cannot handle batch sizes > 1 if no padding token is defined. -SEQ_LENGTH = 16 def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + # Generated from following code snippet + # + # from transformers import GPT2Tokenizer + # input = 'Hello, my dog is cute' + # tokenized_input = tokenizer(input, return_tensors='pt') + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) -def seq_classification_data_gen(): - # batch sizes should be 1 if no padding token is defined. - input_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64) - token_type_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64) - return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data['labels'] = data['input_ids'].clone() + return data +def data_gen_for_token_classification(): + # token classification data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen() + data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64) + return data + + +def data_gen_for_sequence_classification(): + # sequence classification data gen + data = data_gen() + data['labels'] = torch.tensor([0], dtype=torch.int64) + return data + + +# define output transform function output_transform_fn = lambda x: x -config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4) +# define loss function +loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean() +loss_fn = lambda x: x.loss + +config = transformers.GPT2Config(n_layer=2, + n_head=4, + vocab_size=50258, + attn_pdrop=0, + embd_pdrop=0, + resid_pdrop=0, + summary_first_dropout=0, + hidden_dropout=0, + problem_type="single_label_classification") # register the following models model_zoo.register(name='transformers_gpt', model_fn=lambda: transformers.GPT2Model(config), data_gen_fn=data_gen, output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_gpt2_model, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_gpt_lm', model_fn=lambda: transformers.GPT2LMHeadModel(config), - data_gen_fn=data_gen, + data_gen_fn=data_gen_for_lm, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_gpt_double_heads', model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), - data_gen_fn=data_gen, + data_gen_fn=data_gen_for_lm, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_gpt_for_token_classification', model_fn=lambda: transformers.GPT2ForTokenClassification(config), - data_gen_fn=data_gen, + data_gen_fn=data_gen_for_token_classification, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_gpt_for_sequence_classification', model_fn=lambda: transformers.GPT2ForSequenceClassification(config), - data_gen_fn=seq_classification_data_gen, + data_gen_fn=data_gen_for_sequence_classification, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py new file mode 100644 index 000000000000..705bbc7364ba --- /dev/null +++ b/tests/kit/model_zoo/transformers/llama.py @@ -0,0 +1,76 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +try: + from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel + HAS_LLAMA = True +except ImportError: + HAS_LLAMA = False + +if HAS_LLAMA: + # =============================== + # Register LLaMA + # =============================== + + def data_gen(): + # the input ids are corresponding to the sentence + # 'Hello, my dog is cute' + # + # the code is give below: + # ----------------------------------- + # from transformers import LlamaTokenizerFast + # tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") + # input = 'Hello, my dog is cute' + # tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') + # ----------------------------------- + + input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long() + attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long() + return dict(input_ids=input_ids, attention_mask=attention_mask) + + # label is needed for casual lm + def data_gen_for_casual_lm(): + data = data_gen() + labels = data['input_ids'].clone() + data['labels'] = labels + return data + + # transform the output to a dict + output_transform_fn = lambda x: x + + # function to get the loss + loss_fn = lambda output: output.last_hidden_state.mean() + loss_fn_for_casual_lm = lambda output: output.loss + loss_fn_for_seq_classification = lambda output: output.logits.mean() + + config = LlamaConfig(num_hidden_layers=4, + hidden_size=128, + intermediate_size=256, + num_attention_heads=4, + max_position_embeddings=128, + num_labels=16) + + # register the following models + # transformers.LlamaModel, + # transformers.LlamaForCausalLM, + # transformers.LlamaForSequenceClassification, + model_zoo.register(name='transformers_llama', + model_fn=lambda: transformers.LlamaModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True)) + model_zoo.register(name='transformers_llama_for_casual_lm', + model_fn=lambda: transformers.LlamaForCausalLM(config), + data_gen_fn=data_gen_for_casual_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_casual_lm, + model_attribute=ModelAttribute(has_control_flow=True)) + model_zoo.register(name='transformers_llama_for_sequence_classification', + model_fn=lambda: transformers.LlamaForSequenceClassification(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_seq_classification, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/opt.py b/tests/kit/model_zoo/transformers/opt.py index d9c4a0b3c23c..4463ae12b901 100644 --- a/tests/kit/model_zoo/transformers/opt.py +++ b/tests/kit/model_zoo/transformers/opt.py @@ -11,14 +11,47 @@ def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long() + attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long() return dict(input_ids=input_ids, attention_mask=attention_mask) -output_transform_fn = lambda x: x +def data_gen_for_causal_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + labels = data['input_ids'].clone() + data['labels'] = labels + return data + + +def data_gen_for_sequence_classification(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + labels = data['input_ids'].clone() + data['labels'] = torch.tensor([1]) + return data + + +def data_gen_for_question_answering(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data['start_positions'] = torch.tensor([0]) + data['end_positions'] = torch.tensor([1]) + return data + -config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4) +output_transform_fn = lambda x: x +loss_fn_for_opt_model = lambda x: x.last_hidden_state.mean() +loss_fn_for_lm = lambda x: x.loss +config = transformers.OPTConfig( + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + dropout=0, +) # register the following models # transformers.OPTModel, @@ -27,9 +60,23 @@ def data_gen(): model_fn=lambda: transformers.OPTModel(config), data_gen_fn=data_gen, output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_opt_model, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_opt_for_causal_lm', model_fn=lambda: transformers.OPTForCausalLM(config), - data_gen_fn=data_gen, + data_gen_fn=data_gen_for_causal_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_lm, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_opt_for_question_answering', + model_fn=lambda: transformers.OPTForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_lm, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_opt_for_sequence_classification', + model_fn=lambda: transformers.OPTForSequenceClassification(config), + data_gen_fn=data_gen_for_sequence_classification, output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_lm, model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py index b81bcad90db8..689db2c40abb 100644 --- a/tests/kit/model_zoo/transformers/t5.py +++ b/tests/kit/model_zoo/transformers/t5.py @@ -6,24 +6,50 @@ # =============================== # Register single-sentence T5 # =============================== -BATCH_SIZE = 2 -SEQ_LENGTH = 16 - - -def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - return dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids) +# define data gen function def data_gen_for_encoder_only(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + # Generated from following code snippet + # + # from transformers import T5Config, T5Tokenizer + # config = T5Config(decoder_start_token_id=0) + # tokenizer = T5Tokenizer.from_pretrained("t5-small") + # input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids + input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1]]).long() return dict(input_ids=input_ids) +def data_gen_for_conditional_generation(): + # labels is generated with the following code + # + # labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids + data = data_gen_for_encoder_only() + labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1]]).long() + data['labels'] = labels + return data + + +def data_gen_for_t5_model(): + # decoder_inputs_ids is obtained with the following code + # + # decoder_input_ids = model._shift_right(input_ids) + data = data_gen_for_encoder_only() + decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5]]).long() + data['decoder_input_ids'] = decoder_input_ids + return data + + +# output transform function output_transform_fn = lambda x: x -config = transformers.T5Config(d_model=128, num_layers=2) +# define loss funciton +loss_fn_for_t5_model = lambda x: x.last_hidden_state.mean() +loss_fn_for_encoder_only = lambda x: x.last_hidden_state.mean() +loss_fn_for_conditional_generation = lambda x: x.loss + +# define model config +config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0) # register the following models # transformers.T5Model, @@ -31,16 +57,19 @@ def data_gen_for_encoder_only(): # transformers.T5EncoderModel, model_zoo.register(name='transformers_t5', model_fn=lambda: transformers.T5Model(config), - data_gen_fn=data_gen, + data_gen_fn=data_gen_for_t5_model, output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_t5_model, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_t5_for_conditional_generation', model_fn=lambda: transformers.T5ForConditionalGeneration(config), - data_gen_fn=data_gen, + data_gen_fn=data_gen_for_conditional_generation, output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_conditional_generation, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_t5_encoder_model', model_fn=lambda: transformers.T5EncoderModel(config), data_gen_fn=data_gen_for_encoder_only, output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_encoder_only, model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py index fc9d8455ed5c..f0cf2a5fcbca 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py @@ -58,13 +58,4 @@ def test_evoformer_block(model, shape, max_memory): if __name__ == "__main__": - run_test( - rank=0, - data=get_data(LATENTS_SHAPE), - max_memory=None, - model=UNet2DModel, - print_code=False, - print_mem=True, - print_est_mem=False, - print_progress=False, - ) + test_evoformer_block() diff --git a/tests/test_booster/test_mixed_precision/test_fp16_torch.py b/tests/test_booster/test_mixed_precision/test_fp16_torch.py index 963387da262b..26ce00e94869 100644 --- a/tests/test_booster/test_mixed_precision/test_fp16_torch.py +++ b/tests/test_booster/test_mixed_precision/test_fp16_torch.py @@ -11,7 +11,7 @@ def run_torch_amp(rank, world_size, port): # init dist env colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') sub_model_zoo = model_zoo.get_sub_registry('timm') - for name, (model_fn, data_gen_fn, output_transform_fn, _) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in sub_model_zoo.items(): # dlrm_interactionarch has not parameters, so skip if name == 'dlrm_interactionarch': continue diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index d606d6d89bd4..d29c92926066 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -71,7 +71,7 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True): passed_models = [] failed_info = {} # (model_name, error) pair - for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): # These models lead to CUDA error if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp', 'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext'): diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index f70f27be2aa7..eedd8c59a3a8 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -61,7 +61,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS skipped_models = [] - for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): # FIXME(ver217): fix these models if name in ignore_models: skipped_models.append(name) diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index fbe44e5ce6fb..1484273973ae 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -40,7 +40,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): def check_torch_ddp_plugin(): - for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): if name == 'dlrm_interactionarch': continue run_fn(model_fn, data_gen_fn, output_transform_fn) diff --git a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py index 44767f051fdd..cbd5d57800db 100644 --- a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py @@ -42,7 +42,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): def check_torch_fsdp_plugin(): - for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): if any(element in name for element in [ 'diffusers', 'deepfm_sparsearch', 'dlrm_interactionarch', 'torchvision_googlenet', 'torchvision_inception_v3' diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 14d69cab2176..602cf468c944 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -22,7 +22,7 @@ @parameterize('use_safetensors', [False, True]) def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool): from transformers import BertForSequenceClassification - (model_fn, data_gen_fn, output_transform_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) bert_model = model_fn() with shared_tempdir() as tempdir: @@ -53,7 +53,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b @parameterize('shard', [True, False]) @parameterize('model_name', ['transformers_gpt']) def exam_state_dict(placement_policy, shard: bool, model_name: str): - (model_fn, data_gen_fn, output_transform_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() plugin = GeminiPlugin(placement_policy=placement_policy) booster = Booster(plugin=plugin) diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py index 789ce8ab35b8..1f8db99c9236 100644 --- a/tests/test_device/test_device_mesh.py +++ b/tests/test_device/test_device_mesh.py @@ -1,22 +1,89 @@ +import pytest import torch +import torch.distributed as dist +import colossalai from colossalai.device.device_mesh import DeviceMesh +from colossalai.testing import rerun_if_address_is_in_use, spawn def test_device_mesh(): - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + physical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7], # [8, 9, 10,11], # [12,13,14,15]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - assert device_mesh.convert_map[5] == [1, 1] - assert device_mesh.convert_map[11] == [2, 3] - assert device_mesh.global_rank_to_process_groups_with_logical_rank(0)[0] == [[0, 0], [1, 0], [2, 0], [3, 0]] - assert device_mesh.global_rank_to_process_groups_with_logical_rank(2)[1] == [[0, 0], [0, 1], [0, 2], [0, 3]] - assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3] + assert device_mesh.global_rank_to_local_rank(5) == [1, 1] + assert device_mesh.global_rank_to_local_rank(11) == [2, 3] + assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3] + + +def check_1d_device_mesh(): + # check for 1D device mesh + process_group = dist.GroupMember.WORLD + device_mesh = DeviceMesh.from_process_group(process_group) + + # checks + assert device_mesh.shape == [4] + assert len(device_mesh.get_process_group_for_all_axes().keys()) == 1, 'Expected 1 axis for the process group dict' + assert device_mesh.get_process_group(axis=0) == process_group, 'Expected world process group' + assert device_mesh.is_initialized + assert device_mesh.num_devices == 4 + assert device_mesh.is_initialized + assert device_mesh.logical_mesh_id is None + assert device_mesh._is_init_from_process_group + + +def check_2d_device_mesh(): + # create process group for 2D device mesh + first_row_ranks = [0, 1] + second_row_ranks = [2, 3] + first_col_ranks = [0, 2] + second_col_ranks = [1, 3] + + first_row_pg = dist.new_group(first_row_ranks, backend='nccl') + second_row_pg = dist.new_group(second_row_ranks, backend='nccl') + first_col_pg = dist.new_group(first_col_ranks, backend='nccl') + second_col_pg = dist.new_group(second_col_ranks, backend='nccl') + + # check for + current_rank = dist.get_rank() + + if current_rank in first_row_ranks: + row_pg = first_row_pg + else: + row_pg = second_row_pg + + if current_rank in first_col_ranks: + col_pg = first_col_pg + else: + col_pg = second_col_pg + + device_mesh = DeviceMesh.from_process_group([col_pg, row_pg]) + + # checks + assert device_mesh.shape == [2, 2] + assert len(device_mesh.get_process_group_for_all_axes().keys()) == 2, 'Expected 2 axes for the process group dict' + assert device_mesh.get_process_group(axis=0) == col_pg, 'Expected column process group' + assert device_mesh.get_process_group(axis=1) == row_pg, 'Expected row process group' + assert device_mesh.num_devices == 4 + assert device_mesh.is_initialized + assert device_mesh.logical_mesh_id is None + assert device_mesh._is_init_from_process_group + + +def check_init_from_process_group(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_device_mesh_from_process_group(): + spawn(check_init_from_process_group, 4) if __name__ == '__main__': test_device_mesh() + test_device_mesh_from_process_group() \ No newline at end of file diff --git a/tests/test_device/test_init_logical_pg.py b/tests/test_device/test_init_logical_pg.py index 2b7060c4846a..7c6339eff67e 100644 --- a/tests/test_device/test_init_logical_pg.py +++ b/tests/test_device/test_init_logical_pg.py @@ -20,16 +20,12 @@ def check_layer(rank, world_size, port): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - logical_pg_dict = {0: [[0, 2], [1, 3]], 1: [[0, 1], [2, 3]]} - logical_process_groups = device_mesh.process_groups_dict - - for mesh_dim, pgs in logical_pg_dict.items(): - for index, pg in enumerate(pgs): - if rank in pg: - tensor = torch.ones(4).cuda() - group = logical_process_groups[mesh_dim][index][1] - dist.all_reduce(tensor, op=ReduceOp.SUM, group=group) - assert tensor.equal(tensor_to_check) + + for axis in range(len(mesh_shape)): + tensor = torch.ones(4).cuda() + pg = device_mesh.get_process_group(axis=axis) + dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg) + assert tensor.equal(tensor_to_check) gpc.destroy() diff --git a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py index 7a4bf131ae36..58c8132e1490 100644 --- a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py +++ b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py @@ -1,3 +1,5 @@ +from typing import List + import torch from numpy import isin from torch.fx import GraphModule @@ -7,19 +9,23 @@ from colossalai._analyzer.fx import symbolic_trace -def trace_model_and_compare_output(model, data_gen): +def trace_model_and_compare_output(model, data_gen, ignore_data: List[str] = None): # must turn on eval mode to ensure the output is consistent model.eval() + inputs = data_gen() + + if ignore_data is not None: + # drop the ignore_data key + inputs = {k: v for k, v in inputs.items() if k not in ignore_data} + try: - kwargs = data_gen() - meta_args = {k: v.to('meta') for k, v in kwargs.items()} + meta_args = {k: v.to('meta') for k, v in inputs.items()} gm = symbolic_trace(model, meta_args=meta_args) except Exception as e: raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") # run forward - inputs = data_gen() non_fx_out = model(**inputs) fx_out = gm(**inputs) diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py index f4d681221191..a1470400ad82 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py @@ -15,7 +15,7 @@ def test_albert(): sub_registry = model_zoo.get_sub_registry('transformers_albert') - for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() trace_model_and_compare_output(model, data_gen_fn) diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py index a833bb30c056..632ad366ccc4 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py @@ -12,9 +12,9 @@ def test_bert(): sub_registry = model_zoo.get_sub_registry('transformers_bert') - for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() - trace_model_and_compare_output(model, data_gen_fn) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'next_sentence_label']) if __name__ == '__main__': diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py index 0cbea82e083a..ac87a7fcb13b 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py @@ -47,7 +47,7 @@ def test_diffusers(): sub_model_zoo = model_zoo.get_sub_registry('diffusers') - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items(): data = data_gen_fn() trace_and_compare(model_fn, data, output_transform_fn) torch.cuda.synchronize() @@ -60,7 +60,7 @@ def test_torch_diffusers(): sub_model_zoo = model_zoo.get_sub_registry('diffusers') - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items(): data = data_gen_fn() model = model_fn() output = model(**data) diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py index 67107469d8bb..31bcb7028e25 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py @@ -12,7 +12,7 @@ def test_gpt(): sub_registry = model_zoo.get_sub_registry('transformers_gpt') - for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() # TODO: support the following models @@ -21,7 +21,7 @@ def test_gpt(): if model.__class__.__name__ in ['GPT2DoubleHeadsModel']: continue - trace_model_and_compare_output(model, data_gen_fn) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels']) if __name__ == '__main__': diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py index 369545b03de1..c68b89e82fbe 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -11,10 +11,9 @@ @clear_cache_before_run() def test_opt(): sub_registry = model_zoo.get_sub_registry('transformers_opt') - - for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() - trace_model_and_compare_output(model, data_gen_fn) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'start_positions', 'end_positions']) if __name__ == '__main__': diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py index 811cf3b21430..45e06bc2bbb0 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py @@ -12,9 +12,14 @@ def test_t5(): sub_registry = model_zoo.get_sub_registry('transformers_t5') - for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): + if name == "transformers_t5_for_conditional_generation": + # cannot trace for loss function yet + # so we use a data gen which does not produce labels + data_gen_fn = sub_registry.get('transformers_t5')[1] + model = model_fn() - trace_model_and_compare_output(model, data_gen_fn) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels']) if __name__ == '__main__': diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py index 11302e8f36b0..98433b8f7c3b 100644 --- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -56,7 +56,7 @@ def test_timm_models(): sub_model_zoo = model_zoo.get_sub_registry('timm') - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items(): data = data_gen_fn() if attribute is not None and attribute.has_control_flow: meta_args = {k: v.to('meta') for k, v in data.items()} diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py index eafcaca10b1d..2b7def5bef85 100644 --- a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py @@ -16,7 +16,7 @@ def test_torchaudio_models(): sub_model_zoo = model_zoo.get_sub_registry('torchaudio') - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items(): model = model_fn() trace_and_compare(model, data_gen_fn, diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py index df02568c0049..f969c8e6c3da 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py @@ -53,7 +53,7 @@ def test_torchrec_deepfm_models(): deepfm_models = model_zoo.get_sub_registry('deepfm') torch.backends.cudnn.deterministic = True - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in deepfm_models.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in deepfm_models.items(): data = data_gen_fn() if attribute is not None and attribute.has_control_flow: meta_args = {k: v.to('meta') for k, v in data.items()} diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py index 9776452be9c8..94fb24f33376 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py @@ -53,7 +53,7 @@ def test_torchrec_dlrm_models(): torch.backends.cudnn.deterministic = True dlrm_models = model_zoo.get_sub_registry('dlrm') - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in dlrm_models.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in dlrm_models.items(): data = data_gen_fn() # dlrm_interactionarch is not supported diff --git a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py index bd259475ae5a..74cb753e2937 100644 --- a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py +++ b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py @@ -10,7 +10,7 @@ def test_torchvision_models(): torch.backends.cudnn.deterministic = True tv_sub_registry = model_zoo.get_sub_registry('torchvision') - for name, (model_fn, data_gen_fn, output_transform_fn, model_attribute) in tv_sub_registry.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, model_attribute) in tv_sub_registry.items(): data = data_gen_fn() if model_attribute is not None and model_attribute.has_stochastic_depth_prob: diff --git a/tests/test_lazy/lazy_init_utils.py b/tests/test_lazy/lazy_init_utils.py index 85bfd0e27801..73c3c5422d8a 100644 --- a/tests/test_lazy/lazy_init_utils.py +++ b/tests/test_lazy/lazy_init_utils.py @@ -6,8 +6,10 @@ import torch from packaging import version +from colossalai.device.device_mesh import DeviceMesh from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor -from colossalai.tensor.d_tensor.layout_converter import to_global +from colossalai.tensor.d_tensor import to_global +from colossalai.tensor.d_tensor.layout import Layout from tests.kit.model_zoo.registry import ModelAttribute SUPPORT_LAZY = version.parse(torch.__version__) >= version.parse('1.12.0') @@ -60,7 +62,7 @@ def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn: def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, check_forward: bool = False) -> None: - model_fn, data_gen_fn, output_transform_fn, model_attr = entry + model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry _MyTensor._pre_op_fn = lambda *args: set_seed(seed) LazyTensor._pre_op_fn = lambda *args: set_seed(seed) ctx = LazyInitContext(tensor_cls=_MyTensor) @@ -81,7 +83,8 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, print(f'{model.__class__.__name__} pass') -def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None: +def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh, + sharding_spec_dict: dict) -> None: state = model.state_dict() distributed_state = distributed_model.state_dict() @@ -91,6 +94,8 @@ def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn. assert n1 == n2 t1 = t1.cuda() t2 = t2.cuda() - if n2 in layout_dict: - t2 = to_global(t2, layout_dict[n2]) + if n2 in sharding_spec_dict: + layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_dict[n2], global_shape=t1.shape) + t2.dist_layout = layout + t2 = to_global(t2) assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}' diff --git a/tests/test_lazy/test_distribute.py b/tests/test_lazy/test_distribute.py index d515b175a9ea..622d9deb601d 100644 --- a/tests/test_lazy/test_distribute.py +++ b/tests/test_lazy/test_distribute.py @@ -26,23 +26,19 @@ def find_shard_dim(shape: torch.Size) -> Optional[int]: return dim -def make_layout(device_mesh: DeviceMesh, original_tensor: torch.Tensor) -> Layout: +def make_sharding_spec(original_tensor: torch.Tensor) -> Layout: shard_dim = find_shard_dim(original_tensor.shape) dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {} target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict) - layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=target_sharding_spec, - entire_shape=original_tensor.shape) - return layout + return target_sharding_spec def _get_current_name(prefix: str, name: str) -> str: return f'{prefix}.{name}'.lstrip('.') -def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict: - layout_dict = {} +def generate_sharding_spec_dict(model: nn.Module) -> dict: + sharding_spec_dict = {} @torch.no_grad() def generate_recursively(module: nn.Module, prefix: str = ''): @@ -53,17 +49,17 @@ def generate_recursively(module: nn.Module, prefix: str = ''): # initialize tensors directly attached to the current module for name, param in module.named_parameters(recurse=False): if isinstance(param, LazyTensor): - layout = make_layout(device_mesh, param) - layout_dict[_get_current_name(prefix, name)] = layout + sharding_spec = make_sharding_spec(param) + sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec for name, buf in module.named_buffers(recurse=False): if isinstance(buf, LazyTensor): - layout = make_layout(device_mesh, buf) - layout_dict[_get_current_name(prefix, name)] = layout + sharding_spec = make_sharding_spec(buf) + sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec generate_recursively(model) - return layout_dict + return sharding_spec_dict @parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm']) @@ -75,19 +71,19 @@ def run_dist_lazy_init(subset, seed: int = 42): for name, entry in sub_model_zoo.items(): # TODO(ver217): lazy init does not support weight norm, skip these models - if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'): + if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base') or name.startswith('transformers_llama'): continue print_rank_0(name) - model_fn, data_gen_fn, output_transform_fn, model_attr = entry + model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry ctx = LazyInitContext(tensor_cls=_MyTensor) with ctx: model = model_fn() ctx = LazyInitContext() with ctx: deferred_model = model_fn() - layout_dict = generate_layout_dict(deferred_model, device_mesh) - ctx.distribute(deferred_model, layout_dict, verbose=True) - assert_dist_model_equal(model, deferred_model, layout_dict) + sharding_spec_dict = generate_sharding_spec_dict(deferred_model) + ctx.distribute(deferred_model, device_mesh, sharding_spec_dict, verbose=True) + assert_dist_model_equal(model, deferred_model, device_mesh, sharding_spec_dict) def run_dist(rank, world_size, port) -> None: diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py index f828b23a94c4..4b7aeed73a69 100644 --- a/tests/test_lazy/test_models.py +++ b/tests/test_lazy/test_models.py @@ -10,7 +10,7 @@ def test_torchvision_models_lazy_init(subset): sub_model_zoo = model_zoo.get_sub_registry(subset) for name, entry in sub_model_zoo.items(): # TODO(ver217): lazy init does not support weight norm, skip these models - if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'): + if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base') or name.startswith('transformers_llama'): continue check_lazy_init(entry, verbose=True) diff --git a/tests/test_shardformer/__init__.py b/tests/test_shardformer/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/test_shardformer/test_layer/test_dist_crossentropy.py b/tests/test_shardformer/test_layer/test_dist_crossentropy.py new file mode 100644 index 000000000000..72e6e5cf26ed --- /dev/null +++ b/tests/test_shardformer/test_layer/test_dist_crossentropy.py @@ -0,0 +1,42 @@ +import pytest +import torch +import torch.nn.functional as F + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer import cross_entropy_1d +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) + + +def check_dist_crossentropy(rank, world_size, port, ignore_index): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl') + + # prepare data + pred = torch.randn(2, 4, 8, requires_grad=True) + labels = torch.randint(8, (2, 4)) + # set some label to -100 to test the ignore index + labels[0, -1] = ignore_index + + org_pred = pred.view(-1, 8) + org_labels = labels.view(-1) + org_loss = F.cross_entropy(org_pred, org_labels) + + dist_pred = pred.chunk(world_size, -1)[rank] + dist_loss = cross_entropy_1d(dist_pred.to('cuda'), labels.to('cuda'), ignore_index=ignore_index) + + assert torch.allclose(org_loss, dist_loss, + atol=1e-5), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}" + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_dist_crossentropy(): + ignore_index = -100 + spawn(check_dist_crossentropy, 2, ignore_index=ignore_index) + + +if __name__ == '__main__': + test_dist_crossentropy() diff --git a/tests/test_shardformer/test_layer/test_dropout.py b/tests/test_shardformer/test_layer/test_dropout.py new file mode 100644 index 000000000000..332e377110a4 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_dropout.py @@ -0,0 +1,70 @@ +import torch +import torch.distributed as dist +import torch.nn as nn + +import colossalai +from colossalai.shardformer.layer import DropoutForParallelInput, DropoutForReplicatedInput +from colossalai.testing import assert_equal, assert_not_equal, rerun_if_address_is_in_use, spawn + + +def check_dropout_parallel_input(): + dropout = nn.Dropout().cuda() + dropout_1d = DropoutForParallelInput.from_native_module(dropout, process_group=None) + + # check computation correctness + x = torch.rand(4, 128).cuda() + + # we set seed so that dropout will generate the same mask + torch.cuda.manual_seed(1024) + out = dropout(x) + + # we set seed to simulate the same scenario + # but expect the dropout mask to be different + # due to the internal randomness control + torch.cuda.manual_seed(1024) + out_1d = dropout_1d(x) + + # ensure out is the same across all ranks + world_size = dist.get_world_size() + out_all = [torch.empty_like(out) for _ in range(world_size)] + dist.all_gather(out_all, out) + + for i in range(world_size): + assert_equal(out_all[i], out_all[0]) + + # ensure out_1d is different across ranks + out_1d_all = [torch.zeros_like(out_1d) for _ in range(world_size)] + dist.all_gather(out_1d_all, out_1d) + for i in range(1, world_size): + assert_not_equal(out_1d_all[i], out_1d_all[0]) + + +def check_dropout_replicated_input(): + dropout = nn.Dropout().cuda() + dropout_replica = DropoutForReplicatedInput.from_native_module(dropout, process_group=None) + + # check computation correctness + x = torch.rand(4, 128).cuda() + out_1d = dropout_replica(x) + + # ensure out_1d is different across ranks + world_size = dist.get_world_size() + out_1d_all = [torch.zeros_like(out_1d) for _ in range(world_size)] + dist.all_gather(out_1d_all, out_1d) + for i in range(1, world_size): + assert_equal(out_1d_all[i], out_1d_all[0]) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_dropout_parallel_input() + check_dropout_replicated_input() + + +@rerun_if_address_is_in_use() +def test_dropout(): + spawn(run_dist, nprocs=2) + + +if __name__ == '__main__': + test_dropout() diff --git a/tests/test_shardformer/test_layer/test_embedding.py b/tests/test_shardformer/test_layer/test_embedding.py new file mode 100644 index 000000000000..8a6aa42a42f2 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_embedding.py @@ -0,0 +1,47 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer import Embedding1D +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_embedding_1d(): + embedding = nn.Embedding(32, 128).cuda() + embedding_1d = Embedding1D.from_native_module(embedding, process_group=None) + + assert embedding_1d.weight.shape == torch.Size([32, 64]) + + # ensure state dict is reversibly loadable + embedding.load_state_dict(embedding_1d.state_dict()) + embedding_1d.load_state_dict(embedding.state_dict()) + + # check computation correctness + x = torch.randint(low=0, high=32, size=(4, 32)).cuda() + out = embedding(x) + gather_out = embedding_1d(x) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(embedding.weight.grad, 2, dim=1)[rank] + assert_close(target_grad, embedding_1d.weight.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_embedding_1d() + + +@rerun_if_address_is_in_use() +def test_embedding_1d(): + spawn(run_dist, nprocs=2) + + +if __name__ == '__main__': + test_embedding_1d() diff --git a/tests/test_shardformer/test_layer/test_layernorm.py b/tests/test_shardformer/test_layer/test_layernorm.py new file mode 100644 index 000000000000..080fae034956 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_layernorm.py @@ -0,0 +1,44 @@ +import torch +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer import FusedLayerNorm +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_layernorm(): + norm = nn.LayerNorm(128, 0.00001).cuda() + norm1d = FusedLayerNorm.from_native_module(norm, process_group=None) + + assert norm1d.weight.shape == torch.Size([128]) + + # ensure state dict is reversibly loadable + norm.load_state_dict(norm1d.state_dict()) + norm1d.load_state_dict(norm.state_dict()) + + # check computation correctness + x = torch.rand(4, 128).cuda() + out = norm(x) + gather_out = norm1d(x) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + assert_close(norm.weight.grad, norm1d.weight.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_layernorm() + + +@rerun_if_address_is_in_use() +def test_layernorm(): + spawn(run_dist, nprocs=2) + + +if __name__ == '__main__': + test_layernorm_1d() \ No newline at end of file diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py new file mode 100644 index 000000000000..da3bdc1d78d3 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -0,0 +1,131 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row +from colossalai.tensor.d_tensor import is_distributed_tensor +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_linear_1d_col(): + linear = nn.Linear(32, 128).cuda() + linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True) + + # ensure that the parameters are distributed + assert is_distributed_tensor(linear_col.weight) + assert is_distributed_tensor(linear_col.bias) + + # ensure the shape is correct + assert linear_col.weight.shape == torch.Size([64, 32]) + assert linear_col.bias.shape == torch.Size([64]) + + # ensure state dict is reversibly loadable + linear.load_state_dict(linear_col.state_dict()) + linear_col.load_state_dict(linear.state_dict()) + + # check computation correctness + x = torch.rand(4, 32).cuda() + x_for_unshard = x.expand_as(x.clone()) + x_for_unshard.requires_grad_(True) + x_for_shard = x.expand_as(x.clone()) + x_for_shard.requires_grad_(True) + + out = linear(x_for_unshard) + gather_out = linear_col(x_for_shard) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank] + assert_close(target_grad, linear_col.weight.grad) + + # check the input gradients + assert x_for_shard.grad is not None + assert x_for_unshard.grad is not None + assert_close(x_for_unshard.grad, x_for_shard.grad) + + +def check_linear_1d_row(): + linear = nn.Linear(32, 128).cuda() + linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False) + + assert linear_row.weight.shape == torch.Size([128, 16]) + assert linear_row.bias.shape == torch.Size([128]) + + # check computation correctness + x = torch.rand(4, 32).cuda() + x_for_unshard = x.expand_as(x.clone()) + x_for_unshard.requires_grad_(True) + x_for_shard = x.expand_as(x.clone()) + x_for_shard.requires_grad_(True) + + # run forward + out = linear(x_for_unshard) + gather_out = linear_row(x_for_shard) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank] + assert_close(target_grad, linear_row.weight.grad) + + # check the input gradients + assert x_for_shard.grad is not None + assert x_for_unshard.grad is not None + assert_close(x_for_unshard.grad, x_for_shard.grad) + + +def check_linear_col_plus_row(): + linear_1 = nn.Linear(32, 128).cuda() + linear_2 = nn.Linear(128, 32).cuda() + linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False) + linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True) + + # check computation correctness + x = torch.rand(4, 32).cuda() + x_for_unshard = x.expand_as(x.clone()) + x_for_unshard.requires_grad_(True) + x_for_shard = x.expand_as(x.clone()) + x_for_shard.requires_grad_(True) + + # run forward + unshard_out = linear_2(linear_1(x_for_unshard)) + shard_out = linear_row(linear_col(x_for_shard)) + assert_close(unshard_out, shard_out) + + # check backward correctness + unshard_out.sum().backward() + shard_out.sum().backward() + + rank = dist.get_rank() + target_1_grad = torch.chunk(linear_1.weight.grad, 2, dim=0)[rank] + assert_close(target_1_grad, linear_col.weight.grad) + + # check the input gradients + assert x_for_shard.grad is not None + assert x_for_unshard.grad is not None + assert_close(x_for_unshard.grad, x_for_shard.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_linear_1d_col() + check_linear_1d_row() + check_linear_col_plus_row() + + +@rerun_if_address_is_in_use() +def test_linear(): + spawn(run_dist, nprocs=2) + + +if __name__ == '__main__': + test_linear() diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py new file mode 100644 index 000000000000..681c4f6dd9f1 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py @@ -0,0 +1,120 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row +from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +# This code is copied from https://github.com/huggingface/transformers +class Conv1D(nn.Module): + """ + 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). + + Basically works like a linear layer but the weights are transposed. + + Args: + nf (`int`): The number of output features. + nx (`int`): The number of input features. + """ + + def __init__(self, nf, nx): + super().__init__() + self.nf = nf + self.weight = nn.Parameter(torch.empty(nx, nf)) + self.bias = nn.Parameter(torch.zeros(nf)) + nn.init.normal_(self.weight, std=0.02) + + def forward(self, x): + size_out = x.size()[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(size_out) + return x + + +def rearrange(tensor: torch.Tensor, dim: int): + tensor = tensor.clone() + world_size = 2 + order = torch.arange(world_size * 3) + new_order = [] + for i in range(world_size): + new_order.append(order[i::world_size]) + new_order = torch.cat(new_order) + + tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim) + rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order] + rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim) + return rearanged_tensor + + +def check_linear_conv_1d_col(): + linear = Conv1D(192, 48).cuda() + linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear, + process_group=None, + gather_output=True, + n_fused=3) + + assert linear.weight.shape == torch.Size([48, 192]) + assert linear.bias.shape == torch.Size([192]) + assert linear_conv_col.weight.shape == torch.Size([48, 96]) + assert linear_conv_col.bias.shape == torch.Size([96]) + + # ensure weights are reversibly loadable + linear_conv_col.load_state_dict(linear.state_dict()) + linear.load_state_dict(linear_conv_col.state_dict()) + + # check computation correctness + x = torch.rand(4, 48).cuda() + out = linear(x) + gather_out = linear_conv_col(x) + assert_close(rearrange(out, 1), gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True) + assert_close(target_grad, linear_conv_col.weight.grad) + + +def check_linear_conv_1d_row(): + linear = Conv1D(192, 48).cuda() + linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False) + + assert linear.weight.shape == torch.Size([48, 192]) + assert linear_row.weight.shape == torch.Size([24, 192]) + assert linear_row.bias.shape == torch.Size([192]) + + # check computation correctness + x = torch.rand(4, 48).cuda() + out = linear(x) + gather_out = linear_row(x) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank] + assert_close(target_grad, linear_row.weight.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # test for linear conv + check_linear_conv_1d_col() + check_linear_conv_1d_row() + + +@rerun_if_address_is_in_use() +def test_linearconv(): + spawn(run_dist, nprocs=2) + + +if __name__ == '__main__': + test_linearconv() diff --git a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py new file mode 100644 index 000000000000..8991d9b304f5 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py @@ -0,0 +1,49 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer import VocabParallelEmbedding1D +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +def check_vocab_embedding_1d(): + embedding = nn.Embedding(128, 32).to('cuda') + dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding, process_group=None) + + assert dist_embedding_1d.weight.shape == torch.Size([64, 32]) + assert dist_embedding_1d.num_embeddings == 64 + assert dist_embedding_1d.embedding_dim == 32 + + # ensure state dict is reversibly loadable + embedding.load_state_dict(dist_embedding_1d.state_dict()) + dist_embedding_1d.load_state_dict(embedding.state_dict()) + + # check embedding correctness + x = torch.randint(0, 128, (4, 32)).to('cuda') + org_out = embedding(x) + dist_out = dist_embedding_1d(x) + assert_close(org_out, dist_out) + + # check backward correctness + org_out.sum().backward() + dist_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(embedding.weight.grad, 2, dim=0)[rank] + assert_close(target_grad, dist_embedding_1d.weight.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_vocab_embedding_1d() + + +@rerun_if_address_is_in_use() +def test_vocab_embedding(): + spawn(run_dist, nprocs=2) + + +if __name__ == '__main__': + test_vocab_embedding() diff --git a/tests/test_shardformer/test_model/__init__.py b/tests/test_shardformer/test_model/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py new file mode 100644 index 000000000000..d83d9ecd39e0 --- /dev/null +++ b/tests/test_shardformer/test_model/_utils.py @@ -0,0 +1,35 @@ +import copy + +from colossalai.shardformer import ShardConfig, ShardFormer + + +def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True): + # create new model + org_model = model_fn().cuda() + + # shard model + shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism) + model_copy = copy.deepcopy(org_model) + shard_former = ShardFormer(shard_config=shard_config) + sharded_model = shard_former.optimize(model_copy).cuda() + return org_model, sharded_model + + +def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # prepare input + data = data_gen_fn() + data = {k: v.cuda() for k, v in data.items()} + + # switch to train mode + original_model.train() + sharded_model.train() + # run forward + org_output = original_model(**data) + org_output = output_transform_fn(org_output) + org_loss = loss_fn(org_output) + + shard_output = sharded_model(**data) + shard_output = output_transform_fn(shard_output) + shard_loss = loss_fn(shard_output) + return org_output, org_loss, shard_output, shard_loss diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py new file mode 100644 index 000000000000..1afedb7079ea --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -0,0 +1,95 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output) + + # do backward + org_loss.backward() + shard_loss.backward() + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + + # check grad + + if org_model.__class__.__name__ == 'BertModel': + bert = org_model + sharded_bert = sharded_model + else: + bert = org_model.bert + sharded_bert = sharded_model.bert + + # compare self attention grad + org_grad = bert.encoder.layer[0].attention.self.query.weight.grad + shard_grad = sharded_bert.encoder.layer[0].attention.self.query.weight.grad + shard_weight = sharded_bert.encoder.layer[0].attention.self.query.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + # compare embedding grad + org_grad = bert.embeddings.word_embeddings.weight.grad + shard_grad = sharded_bert.embeddings.word_embeddings.weight.grad + shard_weight = sharded_bert.embeddings.word_embeddings.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_bert_test(enable_fused_normalization, enable_tensor_parallelism): + sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + + torch.cuda.empty_cache() + + +def check_bert(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_bert_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bert(): + spawn(check_bert, 2) + + +if __name__ == "__main__": + test_bert() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py new file mode 100644 index 000000000000..a3389652269c --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -0,0 +1,94 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) + + # do backward + org_loss.backward() + shard_loss.backward() + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + + # unwrap model + if org_model.__class__.__name__ == 'BloomModel': + bloom = org_model + sharded_bloom = sharded_model + else: + bloom = org_model.transformer + sharded_bloom = sharded_model.transformer + + # check attention grad + org_grad = bloom.h[0].self_attention.query_key_value.weight.grad + shard_grad = sharded_bloom.h[0].self_attention.query_key_value.weight.grad + shard_weight = sharded_bloom.h[0].self_attention.query_key_value.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + # check embedding weights + org_grad = bloom.word_embeddings.weight.grad + shard_grad = sharded_bloom.word_embeddings.weight.grad + shard_weight = sharded_bloom.word_embeddings.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism): + sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + torch.cuda.empty_cache() + + +def check_bloom(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_bloom_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bloom(): + spawn(check_bloom, 2) + + +if __name__ == "__main__": + test_bloom() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py new file mode 100644 index 000000000000..ee7737687d99 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -0,0 +1,94 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) + + # do backward + org_loss.backward() + shard_loss.backward() + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to origin model loss\n{org_loss}\n{shard_loss}" + + # unwrap model + if org_model.__class__.__name__ == 'GPT2Model': + org_model = org_model + sharded_model = sharded_model + else: + org_model = org_model.transformer + sharded_model = sharded_model.transformer + + # check mlp grad + org_grad = org_model.h[0].mlp.c_fc.weight.grad + shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad + shard_weight = sharded_model.h[0].mlp.c_fc.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=1) + else: + all_shard_grad = shard_grad + assert torch.allclose( + org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}" + + # check embedding weights + org_grad = org_model.wte.weight.grad + shard_grad = sharded_model.wte.weight.grad + shard_weight = sharded_model.wte.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + assert torch.allclose( + org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}" + + +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism): + sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + torch.cuda.empty_cache() + + +def check_gpt2(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_gpt2_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_gpt2(): + spawn(check_gpt2, 2) + + +if __name__ == "__main__": + test_gpt2() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py new file mode 100644 index 000000000000..74b5fdd18af8 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -0,0 +1,97 @@ +import os + +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + + # forward check + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4) + + # run backward + org_loss.backward() + shard_loss.backward() + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + + # unwrap model + if hasattr(org_model, 'model'): + llama_model = org_model.model + shard_llama_model = sharded_model.model + else: + llama_model = org_model + shard_llama_model = sharded_model + + # check attention grad + org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad + shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad + shard_weight = shard_llama_model.layers[0].self_attn.q_proj.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + + # check embedding grad + org_grad = llama_model.embed_tokens.weight.grad + shard_grad = shard_llama_model.embed_tokens.weight.grad + shard_weight = shard_llama_model.embed_tokens.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + + +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism): + sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + torch.cuda.empty_cache() + + +def check_llama(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_gpt2_llama() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(): + spawn(check_llama, 4) + + +if __name__ == "__main__": + test_llama() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py new file mode 100644 index 000000000000..25bccb13b1a8 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -0,0 +1,96 @@ +import copy +import os + +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4) + + # run backward + org_loss.backward() + shard_loss.backward() + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + + # unwrap model + if hasattr(org_model, 'model'): + opt_model = org_model.model + shard_opt_model = sharded_model.model + else: + opt_model = org_model + shard_opt_model = sharded_model + + # check attention grad + org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad + shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad + shard_weight = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + # check embedding grad + org_grad = opt_model.decoder.embed_tokens.weight.grad + shard_grad = shard_opt_model.decoder.embed_tokens.weight.grad + shard_weight = shard_opt_model.decoder.embed_tokens.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_t5_test(enable_fused_normalization, enable_tensor_parallelism): + sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + torch.cuda.empty_cache() + + +def check_OPTModel(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_t5_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_OPTModel(): + spawn(check_OPTModel, 4) + + +if __name__ == '__main__': + test_OPTModel() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py new file mode 100644 index 000000000000..0762dc09e5af --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -0,0 +1,107 @@ +import os + +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + # the value "past_key_values" is sharded, so we ignore + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) + + # do backward + org_loss.backward() + shard_loss.backward() + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + + # check attention grad + org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad + shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad + shard_weight = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + + # check self attention embed + org_grad = org_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.grad + shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.grad + shard_weight = sharded_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=1) + else: + all_shard_grad = shard_grad + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + # check token embedding grad + org_grad = org_model.shared.weight.grad + + # check weights are tied + if hasattr(org_model, 'lm_head'): + assert org_model.shared.weight.data.data_ptr() == org_model.lm_head.weight.data.data_ptr() + assert sharded_model.shared.weight.data.data_ptr() == sharded_model.lm_head.weight.data.data_ptr() + + shard_grad = sharded_model.shared.weight.grad + shard_weight = sharded_model.shared.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_t5_test(enable_fused_normalization, enable_tensor_parallelism): + sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + torch.cuda.empty_cache() + + +def check_t5(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_t5_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_t5(): + spawn(check_t5, 2) + + +if __name__ == "__main__": + test_t5() diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py new file mode 100644 index 000000000000..af1605b6b659 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -0,0 +1,56 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output) + + # do backward + org_loss.backward() + shard_loss.backward() + + # check grad + org_grad = org_model.encoder.layer[0].attention.attention.query.weight.grad + shard_grad = sharded_model.encoder.layer[0].attention.attention.query.weight.grad + + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + +def check_vit(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(world_size, model_fn) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + + torch.cuda.empty_cache() + + +@pytest.mark.dist +@pytest.mark.skip +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_vit(): + spawn(check_vit, 4) + + +if __name__ == "__main__": + test_vit() diff --git a/tests/test_shardformer/test_with_torch_ddp.py b/tests/test_shardformer/test_with_torch_ddp.py new file mode 100644 index 000000000000..9f8a5db6c94f --- /dev/null +++ b/tests/test_shardformer/test_with_torch_ddp.py @@ -0,0 +1,77 @@ +import pytest +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai +from colossalai.cluster import DistCoordinator +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo + + +def check_shardformer_with_ddp(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') + + # create shardformer + # ranks: [0, 1, 2, 3] + # tp ranks = [0, 1], [2, 3] + # dp ranks = [0, 2], [1, 3] + dp_process_group_1 = dist.new_group([0, 2]) + dp_process_group_2 = dist.new_group([1, 3]) + tp_process_group_1 = dist.new_group([0, 1]) + tp_process_group_2 = dist.new_group([2, 3]) + + coordinator = DistCoordinator() + + if coordinator.rank in [0, 1]: + tp_process_group = tp_process_group_1 + else: + tp_process_group = tp_process_group_2 + + if coordinator.rank in [0, 2]: + dp_process_group = dp_process_group_1 + else: + dp_process_group = dp_process_group_2 + + shard_config = ShardConfig(tensor_parallel_process_group=tp_process_group, enable_fused_normalization=True) + shardformer = ShardFormer(shard_config=shard_config) + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + # create and shard model + model = model_fn().cuda() + sharded_model = shardformer.optimize(model) + + # add ddp + sharded_ddp_model = DDP(sharded_model, process_group=dp_process_group) + + # prepare input + data = data_gen_fn() + data = {k: v.cuda() for k, v in data.items()} + + # switch to train mode + sharded_ddp_model.train() + + # run forward + output = sharded_ddp_model(**data) + loss = loss_fn(output) + + # backward + loss.backward() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_gpt2(): + spawn(check_shardformer_with_ddp, 4) + + +if __name__ == "__main__": + test_gpt2() + test_gpt2() diff --git a/tests/test_tensor/test_dtensor/test_comm_spec.py b/tests/test_tensor/test_dtensor/test_comm_spec.py index d1f5b9299397..95fcd2aaf8f3 100644 --- a/tests/test_tensor/test_dtensor/test_comm_spec.py +++ b/tests/test_tensor/test_dtensor/test_comm_spec.py @@ -1,14 +1,11 @@ import pytest import torch -import torch.distributed as dist -from torch.distributed import ReduceOp from colossalai.core import global_context as gpc from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec -from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -125,23 +122,6 @@ def check_all_reduce_bwd(process_groups_dict, rank): assert tensor_to_comm.equal(tensor_to_check) -def check_all_reduce_in_flatten_device_mesh(process_groups_dict, rank): - # tensor to comm - tensor_to_comm = torch.ones(2, 2).cuda() * rank - - # reduce through logical process axis 0 at flatten device mesh - # tensor to check - # tensor([[6., 6.], - # [6., 6.]]) - tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda() - - # CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1]) - comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, process_groups_dict, logical_process_axis=0) - tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) - - assert tensor_to_comm.equal(tensor_to_check) - - def check_comm(rank, world_size, port): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') @@ -153,24 +133,22 @@ def check_comm(rank, world_size, port): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - process_groups_dict = device_mesh.process_groups_dict + + process_group_dict = device_mesh._process_group_dict[rank] # test all gather - check_all_gather(process_groups_dict, rank) + check_all_gather(process_group_dict, rank) # test shard - check_shard(process_groups_dict, rank) + check_shard(process_group_dict, rank) # test all to all - check_all_to_all(process_groups_dict, rank) + check_all_to_all(process_group_dict, rank) # test all reduce - check_all_reduce_fwd(process_groups_dict, rank) - check_all_reduce_bwd(process_groups_dict, rank) + check_all_reduce_fwd(process_group_dict, rank) + check_all_reduce_bwd(process_group_dict, rank) - flatten_process_groups_dict = device_mesh.flatten_device_mesh.process_groups_dict - # test all reduce in 1D flatten device mesh - check_all_reduce_in_flatten_device_mesh(flatten_process_groups_dict, rank) gpc.destroy() diff --git a/tests/test_tensor/test_dtensor/test_dtensor.py b/tests/test_tensor/test_dtensor/test_dtensor.py index 3ca369acbf87..5a1aef79f332 100644 --- a/tests/test_tensor/test_dtensor/test_dtensor.py +++ b/tests/test_tensor/test_dtensor/test_dtensor.py @@ -3,9 +3,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor -from colossalai.tensor.d_tensor.layout import Layout -from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec +from colossalai.tensor.d_tensor import ShardingSpec, distribute_tensor, get_global_shape, redistribute, to_global from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -31,22 +29,18 @@ def check_dtensor(rank, world_size, port): device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]}) - layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=target_sharding_spec, - entire_shape=original_tensor.shape) - d_tensor = DTensor(original_tensor, layout) + d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec) - assert d_tensor.entire_shape == original_tensor.shape - assert d_tensor.data_type == original_tensor.dtype + assert get_global_shape(d_tensor) == original_tensor.shape + assert d_tensor.dtype == original_tensor.dtype if rank in (0, 1): - assert d_tensor.to_local().equal(original_tensor.narrow(0, 0, 2)) + assert d_tensor.equal(original_tensor.narrow(0, 0, 2)) elif rank in (2, 3): - assert d_tensor.to_local().equal(original_tensor.narrow(0, 2, 2)) + assert d_tensor.equal(original_tensor.narrow(0, 2, 2)) else: raise ValueError(f'rank {rank} is not in the device mesh') - assert d_tensor.to_global().equal(original_tensor) + assert to_global(d_tensor).equal(original_tensor) output = test_model(d_tensor) if rank in (0, 1): @@ -57,34 +51,29 @@ def check_dtensor(rank, world_size, port): raise ValueError(f'rank {rank} is not in the device mesh') new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]}) - new_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=new_sharding_spec, - entire_shape=original_tensor.shape) - - d_tensor.layout_convert(new_layout) + d_tensor = redistribute(d_tensor, device_mesh, new_sharding_spec) if rank == 0: - assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1)) + assert d_tensor.equal(original_tensor.narrow(0, 0, 1)) elif rank == 1: - assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 1, 1)) + assert d_tensor.equal(original_tensor.narrow(0, 1, 1)) elif rank == 2: - assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 2, 1)) + assert d_tensor.equal(original_tensor.narrow(0, 2, 1)) elif rank == 3: - assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 3, 1)) + assert d_tensor.equal(original_tensor.narrow(0, 3, 1)) else: raise ValueError(f'rank {rank} is not in the device mesh') - dtensor_from_local = distribute_tensor(original_tensor, new_layout) + dtensor_from_local = distribute_tensor(original_tensor, device_mesh, new_sharding_spec) if rank == 0: - assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1)) + assert dtensor_from_local.equal(original_tensor.narrow(0, 0, 1)) elif rank == 1: - assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 1, 1)) + assert dtensor_from_local.equal(original_tensor.narrow(0, 1, 1)) elif rank == 2: - assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 2, 1)) + assert dtensor_from_local.equal(original_tensor.narrow(0, 2, 1)) elif rank == 3: - assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 3, 1)) + assert dtensor_from_local.equal(original_tensor.narrow(0, 3, 1)) else: raise ValueError(f'rank {rank} is not in the device mesh') diff --git a/tests/test_tensor/test_dtensor/test_layout_converter.py b/tests/test_tensor/test_dtensor/test_layout_converter.py index 5c3da5f2b9ff..5388fd901e09 100644 --- a/tests/test_tensor/test_dtensor/test_layout_converter.py +++ b/tests/test_tensor/test_dtensor/test_layout_converter.py @@ -9,12 +9,12 @@ from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.layout_converter import LayoutConverter -from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec +from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec from colossalai.testing import rerun_if_address_is_in_use, spawn -entire_shape = torch.Size((64, 32, 16)) +global_shape = torch.Size((64, 32, 16)) layout_converter = LayoutConverter() -physical_mesh_id = torch.arange(0, 4).reshape(2, 2) +physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -30,10 +30,7 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,S1,R # device_mesh_shape: (2, 2) sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) - layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec, - entire_shape=entire_shape) + layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape) rst_dict = layout_converter.all_gather_transform_layouts(layout) @@ -49,10 +46,7 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,S1,R # device_mesh_shape: (4, 4) sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all) - layout_all2all = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_all2all, - entire_shape=entire_shape) + layout_all2all = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_all2all, global_shape=global_shape) rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all) @@ -71,10 +65,7 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,R,R # device_mesh_shape: (4, 4) sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard) - shard_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_shard, - entire_shape=entire_shape) + shard_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_shard, global_shape=global_shape) rst_dict_shard = layout_converter.shard_transform_layout(shard_layout) @@ -100,19 +91,13 @@ def check_layout_converting(rank, world_size, port): # shard_sequence: R,S01,R # device_mesh_shape: (4, 4) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) - source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape) # DistSpec: # shard_sequence: S01,R,R # device_mesh_shape: (4, 4) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) - target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape) transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) @@ -137,7 +122,7 @@ def check_layout_converting(rank, world_size, port): assert comm_action_sequence[2].shard_dim == 0 assert comm_action_sequence[2].logical_process_axis == 1 - # checkout cached_spec_pairs_transform_path + # checkout chached_spec_pairs_transform_path assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence @@ -159,21 +144,15 @@ def check_layout_converting_apply(rank, world_size, port): # shard_sequence: R,S01,R # device_mesh_shape: (4, 4) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) - source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape) # DistSpec: # shard_sequence: S01,R,R # device_mesh_shape: (4, 4) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) - target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape) - original_tensor = torch.rand(entire_shape).cuda() + original_tensor = torch.rand(global_shape).cuda() # tensor_to_apply: [R, S01, R] tensor_to_apply = original_tensor.narrow(1, rank * 8, 8) diff --git a/tests/test_tensor/test_shape_consistency.py b/tests/test_tensor/test_shape_consistency.py index 6fe9ee292cd0..859eef051256 100644 --- a/tests/test_tensor/test_shape_consistency.py +++ b/tests/test_tensor/test_shape_consistency.py @@ -1,9 +1,10 @@ -from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern import torch -from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec + from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -physical_mesh_id = torch.arange(0, 16).reshape(2, 8) +physical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7], diff --git a/tests/test_tensor/test_sharded_linear.py b/tests/test_tensor/test_sharded_linear.py index d66d4fec14d1..9bd9805e9b8f 100644 --- a/tests/test_tensor/test_sharded_linear.py +++ b/tests/test_tensor/test_sharded_linear.py @@ -26,7 +26,7 @@ def run_dist(rank, world_size, port): # the mesh is in the following topo # [[0, 1], # [2, 3]] - physical_mesh_id = torch.arange(0, 4).reshape(2, 2) + physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) row_id = rank // 2 diff --git a/tests/test_tensor/test_sharding_spec.py b/tests/test_tensor/test_sharding_spec.py index 909c84ef0f0e..5007c4141849 100644 --- a/tests/test_tensor/test_sharding_spec.py +++ b/tests/test_tensor/test_sharding_spec.py @@ -5,7 +5,7 @@ def test_sharding_spec(): - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + physical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7],