From 2b3ce36a8362327e078e151ee5f5a9d210144c3c Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 24 Apr 2024 16:03:54 +0800 Subject: [PATCH] [gemini] fix buffer cast --- colossalai/zero/gemini/gemini_ddp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index c79422171f1b..b25de1d68613 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -840,6 +840,7 @@ def _cast_buffers(self): for buffer in self.module.buffers(): if isinstance(buffer, LazyTensor): buffer.materialize() + for buffer in self.module.buffers(): buffer.data = buffer.to(get_accelerator().get_current_device()) if torch.is_floating_point(buffer): buffer.data = buffer.to(self.mixed_precision)