Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,7 @@ def unwrap_model(model):

# Resolve the c parameter for the Pseudo-Huber loss
if args.huber_c is None:
args.huber_c = 0.00054 * args.resolution * math.sqrt(unet.config.in_channels)
args.huber_c = 0.00054 * args.resolution * math.sqrt(unwrap_model(unet).config.in_channels)

# Get current number of discretization steps N according to our discretization curriculum
current_discretization_steps = get_discretization_steps(
Expand Down