Skip to content

Resnet + DiagHessian #344

@antoniobalordi

Description

@antoniobalordi

Hi!

I trained a ResNet18 from torchvision.models on my dataset (CIFAR-10) and now I would like to compute the Exact DiagHessian but I'm encountering the following error:

NotImplementedError: Extension saving to diag_h does not have an extension for Module <class 'backpack.custom_module.branching.SumModule'>

I followed all the instructions as per tutorial so I extended the model in evaluation mode to avoid the batch normalization layers.

If I try with other functions like DiagGNNExact it works but is very slow, while DiagHessian returns me this error.

The problem is that even if I avoid the resnet from torchvision following the tutorial on residual networks it made use of SumModule.

Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions