Skip to content

Making it work with deepspeed #8

@eliird

Description

@eliird

I know it's been a while since this repo was uploaded but was thinking of getting this to work with deepspeed. Did you by any chance try that? Internally deepspeed calls self.module.parameters() for a lot of things and the segment wrapper does not have the original parameters. I tried writing a wrapper around the original model that contains both the original parameter and the dict in segment like:

class Model(Module):
       ... # model definition and forward function

class OGCWrapper(Model):
      def __init__(self, info_dict):
               x = nn.Linear(10, 10)
               self.info_dict = info_dict # generate using optimal_grad_checkpointing function and dumped in a pickle format
      def forward(self, x):
               return graph_forward(x, **self.info_dict)

but this takes more memory than the model without checkpointing, when using with deepspeed

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions