From 9bf7ebb74c9d7861a0521670b62d9c9b440de17c Mon Sep 17 00:00:00 2001 From: KakaruHayate <97896816+KakaruHayate@users.noreply.github.com> Date: Mon, 14 Apr 2025 21:23:03 +0800 Subject: [PATCH 1/5] Fix issue about 'find_unused_parameters' when DDP training. --- modules/commons/rotary_embedding_torch.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/modules/commons/rotary_embedding_torch.py b/modules/commons/rotary_embedding_torch.py index 4efcb514..37773598 100644 --- a/modules/commons/rotary_embedding_torch.py +++ b/modules/commons/rotary_embedding_torch.py @@ -306,15 +306,16 @@ def forward( exists(self.cached_freqs) and \ (offset + seq_len) <= self.cached_freqs_seq_len ): - return self.cached_freqs[offset:(offset + seq_len)].detach() + freqs = self.cached_freqs[offset:(offset + seq_len)].detach() + else: + freqs = self.freqs - freqs = self.freqs + freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) + freqs = repeat(freqs, '... n -> ... (n r)', r = 2) - freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) - freqs = repeat(freqs, '... n -> ... (n r)', r = 2) - - if should_cache and offset == 0: - self.cached_freqs[:seq_len] = freqs.detach() - self.cached_freqs_seq_len = seq_len + if should_cache and offset == 0: + self.cached_freqs[:seq_len] = freqs.detach() + self.cached_freqs_seq_len = seq_len + freqs = freqs + 0. * self.freqs.sum() return freqs From 030d32952f25605b601073e5d32c2b52737d2a96 Mon Sep 17 00:00:00 2001 From: KakaruHayate <97896816+KakaruHayate@users.noreply.github.com> Date: Tue, 15 Apr 2025 21:52:04 +0800 Subject: [PATCH 2/5] annotation --- modules/commons/rotary_embedding_torch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/commons/rotary_embedding_torch.py b/modules/commons/rotary_embedding_torch.py index 37773598..90bac8c9 100644 --- a/modules/commons/rotary_embedding_torch.py +++ b/modules/commons/rotary_embedding_torch.py @@ -317,5 +317,6 @@ def forward( self.cached_freqs[:seq_len] = freqs.detach() self.cached_freqs_seq_len = seq_len + # Fix issue about 'find_unused_parameters' when DDP training. freqs = freqs + 0. * self.freqs.sum() return freqs From dbcd80bdb537d7dba1b223b4602bc444cb102781 Mon Sep 17 00:00:00 2001 From: KakaruHayate <97896816+KakaruHayate@users.noreply.github.com> Date: Tue, 15 Apr 2025 22:50:31 +0800 Subject: [PATCH 3/5] slim --- modules/commons/rotary_embedding_torch.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/modules/commons/rotary_embedding_torch.py b/modules/commons/rotary_embedding_torch.py index 90bac8c9..8002e2c5 100644 --- a/modules/commons/rotary_embedding_torch.py +++ b/modules/commons/rotary_embedding_torch.py @@ -306,17 +306,17 @@ def forward( exists(self.cached_freqs) and \ (offset + seq_len) <= self.cached_freqs_seq_len ): - freqs = self.cached_freqs[offset:(offset + seq_len)].detach() - else: - freqs = self.freqs + # Fix issue about 'find_unused_parameters' when DDP training. + freqs = self.cached_freqs[offset:(offset + seq_len)].detach() + 0. * self.freqs.sum() + return freqs - freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) - freqs = repeat(freqs, '... n -> ... (n r)', r = 2) + freqs = self.freqs - if should_cache and offset == 0: - self.cached_freqs[:seq_len] = freqs.detach() - self.cached_freqs_seq_len = seq_len + freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) + freqs = repeat(freqs, '... n -> ... (n r)', r = 2) + + if should_cache and offset == 0: + self.cached_freqs[:seq_len] = freqs.detach() + self.cached_freqs_seq_len = seq_len - # Fix issue about 'find_unused_parameters' when DDP training. - freqs = freqs + 0. * self.freqs.sum() return freqs From 656bae591282dba093bb2097010a7e33086634d9 Mon Sep 17 00:00:00 2001 From: KakaruHayate <97896816+KakaruHayate@users.noreply.github.com> Date: Mon, 14 Apr 2025 21:23:03 +0800 Subject: [PATCH 4/5] Fix issue about 'find_unused_parameters' when DDP training. annotation slim --- modules/commons/rotary_embedding_torch.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/commons/rotary_embedding_torch.py b/modules/commons/rotary_embedding_torch.py index 4efcb514..8002e2c5 100644 --- a/modules/commons/rotary_embedding_torch.py +++ b/modules/commons/rotary_embedding_torch.py @@ -306,7 +306,9 @@ def forward( exists(self.cached_freqs) and \ (offset + seq_len) <= self.cached_freqs_seq_len ): - return self.cached_freqs[offset:(offset + seq_len)].detach() + # Fix issue about 'find_unused_parameters' when DDP training. + freqs = self.cached_freqs[offset:(offset + seq_len)].detach() + 0. * self.freqs.sum() + return freqs freqs = self.freqs From 29886fa9122b20e2fb348e9e03199db409918b87 Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Tue, 15 Apr 2025 22:55:29 +0800 Subject: [PATCH 5/5] Update rotary_embedding_torch.py --- modules/commons/rotary_embedding_torch.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/commons/rotary_embedding_torch.py b/modules/commons/rotary_embedding_torch.py index 8002e2c5..e0ab05f2 100644 --- a/modules/commons/rotary_embedding_torch.py +++ b/modules/commons/rotary_embedding_torch.py @@ -306,8 +306,9 @@ def forward( exists(self.cached_freqs) and \ (offset + seq_len) <= self.cached_freqs_seq_len ): - # Fix issue about 'find_unused_parameters' when DDP training. - freqs = self.cached_freqs[offset:(offset + seq_len)].detach() + 0. * self.freqs.sum() + freqs = self.cached_freqs[offset:(offset + seq_len)].detach() + # Fix issue about 'find_unused_parameters' when DDP training.(#244) + freqs = freqs + 0. * self.freqs.sum() return freqs freqs = self.freqs