|
fast_weights = OrderedDict((name, param - torch.mul(meta_step_size, torch.clamp(grad, 0-clip_value, clip_value))) for |
|
((name, param), grad) in |
|
zip(model.named_parameters(), grads)) |
|
learner = copy.deepcopy(model) |
|
learner.load_state_dict(fast_weights, strict=False) |
|
|
|
output_outer = learner(image_freq) |
|
del fast_weights |
|
loss_outer = loss_fun(output_outer, class_l) |
|
|
|
loss = loss_inner + loss_outer |
Your meta learning implementation is wrong, load_state_dict() does not preserve the computation graph so loss_outer will never propagate gradient back to the original model parameters. be careful :)
CCST/federated/fed_run.py
Lines 112 to 122 in 87c3e44
Your meta learning implementation is wrong,
load_state_dict()does not preserve the computation graph soloss_outerwill never propagate gradient back to the original model parameters. be careful :)