Skip to content

[FEATURE]: support multiple (partial) backward passes for zero #5601

@ver217

Description

@ver217

Describe the feature

In some vae training, users may use weight adaptive loss which may compute grad of some parameters twice, like
image

This will trigger backward hook twice.

Based on pytorch's document, we may use post-grad-accumulation hook to solve this problem.

image

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions