From 894e95a4d242490a376a4f499de7045eb4ecfc9f Mon Sep 17 00:00:00 2001 From: hxwang Date: Wed, 12 Jun 2024 07:57:47 +0000 Subject: [PATCH 1/2] [gemini] quick fix on possible async operation --- colossalai/zero/gemini/gemini_hook.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 9e297c2a8a19..3231b08993a5 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -55,6 +55,15 @@ def pre_op(self, params): ) # prefetch + if self._gemini_manager.chunk_manager._prefetch_stream is not None: + # This is when prefetch happens the first time and there is no dist.Work to sync, + # there is possibility that the optimizer haven't finish computation on default stream, + # thus we might prefetch outdated chunks there. + # + # Other than that, self._gemini_manager.wait_chunks will have synced with default stream + # by calling dist.Work.wait(). + self._gemini_manager.chunk_manager._prefetch_stream.wait_stream(torch.cuda.current_stream()) + with get_accelerator().stream(self._gemini_manager.chunk_manager._prefetch_stream): for chunk in chunks_fetch_async: maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True) From b77be86afd1400a48af04a30ba209ad496301fb5 Mon Sep 17 00:00:00 2001 From: hxwang Date: Wed, 12 Jun 2024 08:00:21 +0000 Subject: [PATCH 2/2] [gemini] quick fix on possible async operation --- colossalai/zero/gemini/gemini_hook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 3231b08993a5..bf5faa0fe884 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -61,7 +61,7 @@ def pre_op(self, params): # thus we might prefetch outdated chunks there. # # Other than that, self._gemini_manager.wait_chunks will have synced with default stream - # by calling dist.Work.wait(). + # by calling dist.Work.wait() and this line makes no diff. self._gemini_manager.chunk_manager._prefetch_stream.wait_stream(torch.cuda.current_stream()) with get_accelerator().stream(self._gemini_manager.chunk_manager._prefetch_stream):