Hi, I want to compute the second-order Hessian matrix for my custom nn.Module, which aims to replace the nn.Conv2D by CP decomposition, and I get this warning:
/data/miniforge3/envs/fairmae/lib/python3.9/site-packages/backpack/custom_module/graph_utils.py:86: UserWarning: Encountered node that may break second-order extensions: op=get_attr, target=V.1. If you encounter this problem, please open an issue at https://github.com/f-dangel/backpack/issues.
The architecture of my module is defined below:
class CPConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, rank, stride=1, padding=0, bias=True):
super(CPConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = (kernel_size, kernel_size) if isinstance(
kernel_size, int) else kernel_size
self.rank = rank
self.stride = stride
self.padding = padding
self.bias = bias
# U: (rank,)
# V: [(out_channels, rank), (in_channels, rank), (kernel_size, rank), (kernel_size, rank)]
self.U = nn.Parameter(torch.randn(rank))
self.V = nn.ParameterList([
nn.Parameter(torch.randn(out_channels, rank)),
nn.Parameter(torch.randn(in_channels, rank)),
nn.Parameter(torch.randn(kernel_size[0], rank)),
nn.Parameter(torch.randn(kernel_size[1], rank))
])
if bias:
self.b = nn.Parameter(torch.randn(out_channels))
else:
self.register_parameter('b', None)
def forward(self, x):
W = torch.einsum('r,or,ir,kr,lr->oikl',
self.U, self.V[0], self.V[1], self.V[2], self.V[3])
return F.conv2d(x, W, self.b, self.stride, self.padding)
Could you tell me how to compute the diag_h for the self.U? Thank you!
Hi, I want to compute the second-order Hessian matrix for my custom nn.Module, which aims to replace the nn.Conv2D by CP decomposition, and I get this warning:
/data/miniforge3/envs/fairmae/lib/python3.9/site-packages/backpack/custom_module/graph_utils.py:86: UserWarning: Encountered node that may break second-order extensions: op=get_attr, target=V.1. If you encounter this problem, please open an issue at https://github.com/f-dangel/backpack/issues.The architecture of my module is defined below:
Could you tell me how to compute the diag_h for the self.U? Thank you!