From 92b4b9af292de57cadc3434fb065fd79297b2e7c Mon Sep 17 00:00:00 2001 From: lclgy Date: Wed, 12 Jul 2023 16:59:51 +0800 Subject: [PATCH 1/5] optimize the optimizer step time --- colossalai/zero/low_level/low_level_optim.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 023db122fd33..682cf888e5aa 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -414,13 +414,13 @@ def step(self, closure=None): # update working partition updated by the current rank for group_id in range(self.num_param_groups): master_working_param = self.optim.param_groups[group_id]['params'] - for idx, splited_param in enumerate(master_working_param): - full_master_param = [torch.zeros_like(splited_param).cuda() for _ in range(self._world_size)] - dist.all_gather(full_master_param, splited_param.cuda(), group=self.dp_pg) working_param = real_working_params[group_id][idx] - full_master_param = flatten(full_master_param)[:working_param.numel()].reshape_as(working_param) - working_param.data.copy_(full_master_param) + padding_left = self._local_rank * splited_param.numel() + padding_right = working_param.numel() - (self._local_rank+1) * splited_param.numel() + print(working_param.shape, padding_left, padding_right, self._local_rank, splited_param.numel()) + working_param.data.copy_(torch.nn.functional.pad(splited_param, (padding_left, padding_right)).reshape_as(working_param)) + dist.all_reduce(working_param, group=self.dp_pg) self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id] From b102a0312396bbc07281402a2c9be1421ff58bbf Mon Sep 17 00:00:00 2001 From: lclgy Date: Wed, 12 Jul 2023 17:18:10 +0800 Subject: [PATCH 2/5] fix corner case --- colossalai/zero/low_level/low_level_optim.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 682cf888e5aa..8dfc5431c9ef 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -417,9 +417,9 @@ def step(self, closure=None): for idx, splited_param in enumerate(master_working_param): working_param = real_working_params[group_id][idx] padding_left = self._local_rank * splited_param.numel() - padding_right = working_param.numel() - (self._local_rank+1) * splited_param.numel() - print(working_param.shape, padding_left, padding_right, self._local_rank, splited_param.numel()) - working_param.data.copy_(torch.nn.functional.pad(splited_param, (padding_left, padding_right)).reshape_as(working_param)) + padding_right = splited_param.numel() * (self._world_size - self._local_rank - 1) + with torch.no_grad(): + working_param.data.copy_(torch.nn.functional.pad(splited_param, (padding_left, padding_right))[:working_param.numel()].reshape_as(working_param)) dist.all_reduce(working_param, group=self.dp_pg) self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id] From 949d446d11ac74b2e9ff8d908b8a2663a3704808 Mon Sep 17 00:00:00 2001 From: lclgy Date: Wed, 12 Jul 2023 17:21:14 +0800 Subject: [PATCH 3/5] polish --- colossalai/zero/low_level/low_level_optim.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 8dfc5431c9ef..e8abaacb29e4 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -419,7 +419,10 @@ def step(self, closure=None): padding_left = self._local_rank * splited_param.numel() padding_right = splited_param.numel() * (self._world_size - self._local_rank - 1) with torch.no_grad(): - working_param.data.copy_(torch.nn.functional.pad(splited_param, (padding_left, padding_right))[:working_param.numel()].reshape_as(working_param)) + working_param.data.copy_( + torch.nn.functional.pad( + splited_param, + (padding_left, padding_right))[:working_param.numel()].reshape_as(working_param)) dist.all_reduce(working_param, group=self.dp_pg) self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id] From 79eb0bb03551f3c79437a2cec0fca081992c3275 Mon Sep 17 00:00:00 2001 From: lclgy Date: Wed, 12 Jul 2023 18:04:45 +0800 Subject: [PATCH 4/5] replace all-reduce with all-gather --- colossalai/zero/low_level/low_level_optim.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index e8abaacb29e4..91c09e90b669 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -412,18 +412,17 @@ def step(self, closure=None): release_param_grad(self._master_param_groups_of_current_rank[group_id]) # update working partition updated by the current rank + dtype = real_working_params[0][0].dtype + device = real_working_params[0][0].device for group_id in range(self.num_param_groups): master_working_param = self.optim.param_groups[group_id]['params'] for idx, splited_param in enumerate(master_working_param): working_param = real_working_params[group_id][idx] - padding_left = self._local_rank * splited_param.numel() - padding_right = splited_param.numel() * (self._world_size - self._local_rank - 1) - with torch.no_grad(): - working_param.data.copy_( - torch.nn.functional.pad( - splited_param, - (padding_left, padding_right))[:working_param.numel()].reshape_as(working_param)) - dist.all_reduce(working_param, group=self.dp_pg) + all_splited_param = [ + torch.zeros(splited_param.shape, device=device, dtype=dtype) for _ in range(self._world_size) + ] + dist.all_gather(all_splited_param, splited_param.to(device).to(dtype), group=self.dp_pg) + working_param.data.copy_(flatten(all_splited_param)[:working_param.numel()].reshape_as(working_param)) self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id] From 28acdf3c6d0f5e21ba7b0918de4ab2418b9102dc Mon Sep 17 00:00:00 2001 From: lc_pro Date: Tue, 18 Jul 2023 10:20:34 +0800 Subject: [PATCH 5/5] set comm device to cuda --- colossalai/zero/low_level/low_level_optim.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 91c09e90b669..2b3f50ed4fd4 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -413,15 +413,14 @@ def step(self, closure=None): # update working partition updated by the current rank dtype = real_working_params[0][0].dtype - device = real_working_params[0][0].device for group_id in range(self.num_param_groups): master_working_param = self.optim.param_groups[group_id]['params'] for idx, splited_param in enumerate(master_working_param): working_param = real_working_params[group_id][idx] all_splited_param = [ - torch.zeros(splited_param.shape, device=device, dtype=dtype) for _ in range(self._world_size) + torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size) ] - dist.all_gather(all_splited_param, splited_param.to(device).to(dtype), group=self.dp_pg) + dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg) working_param.data.copy_(flatten(all_splited_param)[:working_param.numel()].reshape_as(working_param)) self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id]