diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index ebca0ee0ee57..61c9d1438cdf 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -534,96 +534,96 @@ def save_sharded_optimizer( f"index located at {final_index_file_path}." ) - # def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""): - # """ - # Load sharded optimizer with the given path to index file of checkpoint folder. - - # Args: - # optimizer (OptimizerWrapper): The optimizer to be loaded. - # checkpoint_index_file (str): Path to the index file of checkpointing folder. - # prefix (str): Not used. - # """ - # assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" - - # def _get_param_id_from_optimizer_param( - # param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None - # ): - # if master_to_working_map is not None: - # working_param = master_to_working_map[id(param)] - # else: - # working_param = param - # return optimizer.param_info["param2id"][id(working_param)] - - # # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects. - # # When Zero is used, the mapped parameter objects should be fp32 master parameters. - # # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info. - # id_map = {} - # master_to_working_map = optimizer.get_master_to_working_map() - # for pg in optimizer.optim.param_groups: - # for param in pg["params"]: - # param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) - # id_map[param_id] = param - - # # Read checkpoint index file. - # ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) - # ckpt_root_path = ckpt_index_file.root_path - # weight_map = ckpt_index_file.weight_map - # weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int - - # # Load param_groups - # param_group_path = ckpt_index_file.get_param_group_filename() - # if param_group_path is None: - # raise RuntimeError( - # f"Invalid index file path {checkpoint_index_file} for an optimizer. \ - # Lacking param group file under current directory." - # ) - # saved_groups = torch.load(param_group_path) - - # updated_groups = [] - # for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): - # # obtain updated param group - # new_pg = copy.deepcopy(saved_pg) - # new_pg["params"] = old_pg["params"] # The parameters in the same group shouldn't change. - # updated_groups.append(new_pg) - # optimizer.optim.__dict__.update({"param_groups": updated_groups}) - - # # Load saved states to optimizer. - # # Keep a record of loaded files so that file will not be repeatedly loaded. - # loaded_file = set() - # for pg in optimizer.optim.param_groups: - # for param in pg["params"]: - # if param is None: - # continue - # param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) - # if param_id not in weight_map: - # continue - # filename = weight_map[param_id] - - # # If this param's states has been loaded before, directly return. - # if filename in loaded_file: - # continue - - # file_path = os.path.join(ckpt_root_path, filename) - # state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) - # load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) - # loaded_file.add(filename) - - # # Then shard the loaded optimizer states if using tp/zero. - # for param, state in optimizer.optim.state.items(): - # device = param.device - # if master_to_working_map is not None: - # working_param = master_to_working_map[id(param)] - # else: - # working_param = param - # original_shape = optimizer.param_info["param2shape"][id(working_param)] - # sharded_state = self.shard_from_complete_optimizer_state( - # state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True - # ) - # optimizer.optim.state[param] = sharded_state - - # sharded_optimizer_loading_epilogue(optimizer.optim) - # if self.verbose and self.coordinator.is_master(): - # logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""): + """ + Load sharded optimizer with the given path to index file of checkpoint folder. + + Args: + optimizer (OptimizerWrapper): The optimizer to be loaded. + checkpoint_index_file (str): Path to the index file of checkpointing folder. + prefix (str): Not used. + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" + + def _get_param_id_from_optimizer_param( + param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None + ): + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + return optimizer.param_info["param2id"][id(working_param)] + + # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects. + # When Zero is used, the mapped parameter objects should be fp32 master parameters. + # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info. + id_map = {} + master_to_working_map = optimizer.get_master_to_working_map() + for pg in optimizer.optim.param_groups: + for param in pg["params"]: + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + id_map[param_id] = param + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map + weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int + + # Load param_groups + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError( + f"Invalid index file path {checkpoint_index_file} for an optimizer. \ + Lacking param group file under current directory." + ) + saved_groups = torch.load(param_group_path) + + updated_groups = [] + for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): + # obtain updated param group + new_pg = copy.deepcopy(saved_pg) + new_pg["params"] = old_pg["params"] # The parameters in the same group shouldn't change. + updated_groups.append(new_pg) + optimizer.optim.__dict__.update({"param_groups": updated_groups}) + + # Load saved states to optimizer. + # Keep a record of loaded files so that file will not be repeatedly loaded. + loaded_file = set() + for pg in optimizer.optim.param_groups: + for param in pg["params"]: + if param is None: + continue + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + if param_id not in weight_map: + continue + filename = weight_map[param_id] + + # If this param's states has been loaded before, directly return. + if filename in loaded_file: + continue + + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) + load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) + loaded_file.add(filename) + + # Then shard the loaded optimizer states if using tp/zero. + for param, state in optimizer.optim.state.items(): + device = param.device + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info["param2shape"][id(working_param)] + sharded_state = self.shard_from_complete_optimizer_state( + state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True + ) + optimizer.optim.state[param] = sharded_state + + sharded_optimizer_loading_epilogue(optimizer.optim) + if self.verbose and self.coordinator.is_master(): + logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): """ diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py index 0b3126a92953..2fbc34302cde 100644 --- a/colossalai/shardformer/modeling/mixtral.py +++ b/colossalai/shardformer/modeling/mixtral.py @@ -13,7 +13,7 @@ MoeCausalLMOutputWithPast, load_balancing_loss_func, ) -from transformers.utils import logging +from transformers.utils import is_flash_attn_2_available, logging from colossalai.lazy import LazyInitContext from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven @@ -218,7 +218,7 @@ def mixtral_model_forward( # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage - if self._use_flash_attention_2: + if is_flash_attn_2_available(): # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 12369289f5a7..8fe18f69bcd1 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -65,7 +65,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, working_p = sharded_optimizer.master_to_working_param[id(p2)] grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p)) grad_index = ( - 0 if sharded_optimizer._partition_grads else sharded_optimizer.pid_to_bucket_store[id(p2)].local_rank + 0 + if sharded_optimizer._partition_grads + else sharded_optimizer.pid_to_bucket_store[id(working_p)].local_rank ) grad = grads[grad_index] sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]