From 02e87d69d3a9103e14a706713ca5f683b3b2a979 Mon Sep 17 00:00:00 2001 From: hxwang Date: Mon, 17 Jun 2024 07:58:15 +0000 Subject: [PATCH] [zero] fix missing hook removal --- colossalai/zero/low_level/low_level_strategy.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/colossalai/zero/low_level/low_level_strategy.py b/colossalai/zero/low_level/low_level_strategy.py index 1d01494654a3..7f7daaed3fec 100644 --- a/colossalai/zero/low_level/low_level_strategy.py +++ b/colossalai/zero/low_level/low_level_strategy.py @@ -1,4 +1,5 @@ # this code is inspired by the DeepSpeed library and implemented with our own design from scratch +import weakref from abc import ABC, abstractmethod from copy import deepcopy from functools import partial @@ -94,20 +95,27 @@ def __init__( # reduction hook is only used if overlapping communication # or stage 2 is used # if it is stage 1 without overlapping, no hook will be attached + self.grad_handles = [] if self._overlap_communication or self._partition_grad: # we iterate over the working params # on each param, we register a hook to its AccumulateGrad object param_group = self.working_param_group for param in param_group: if param.requires_grad: + self_weak_proxy = weakref.proxy(self) + param_weak_proxy = weakref.proxy(param) - def _grad_handler(grad, param): + def _grad_handler(grad): # if run with no_sync context, would not sync grad when backward - if self.require_grad_sync: - self._add_to_bucket(param) + if self_weak_proxy.require_grad_sync: + self_weak_proxy._add_to_bucket(param_weak_proxy) return grad - param.register_hook(partial(_grad_handler, param=param)) + self.grad_handles.append(param.register_post_accumulate_grad_hook(partial(_grad_handler))) + + def __del__(self): + for handle in self.grad_handles: + handle.remove() def _create_master_param_current_rank(self, param_list): # split each param evenly by world size