From df1ad8698d8ddce15d9afa529c81a1ac277cf8c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 20 Jun 2024 22:35:53 +0300 Subject: [PATCH 1/3] Fix multi-gpu case --- .../consistency_training/train_cm_ct_unconditional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/research_projects/consistency_training/train_cm_ct_unconditional.py b/examples/research_projects/consistency_training/train_cm_ct_unconditional.py index b7a1e2a545f8..e9cc94497be4 100644 --- a/examples/research_projects/consistency_training/train_cm_ct_unconditional.py +++ b/examples/research_projects/consistency_training/train_cm_ct_unconditional.py @@ -1195,7 +1195,7 @@ def unwrap_model(model): # Resolve the c parameter for the Pseudo-Huber loss if args.huber_c is None: - args.huber_c = 0.00054 * args.resolution * math.sqrt(unet.config.in_channels) + args.huber_c = 0.00054 * args.resolution * math.sqrt(accelerator.unwrap_model(unet).config.in_channels) # Get current number of discretization steps N according to our discretization curriculum current_discretization_steps = get_discretization_steps( From 4b06b95428a09ec8e4678089fb648edd07d66138 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= <46008593+tolgacangoz@users.noreply.github.com> Date: Wed, 17 Jul 2024 16:02:52 +0300 Subject: [PATCH 2/3] Prefer previously created `unwrap_model()` function For `torch.compile()` generalizability --- .../consistency_training/train_cm_ct_unconditional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/research_projects/consistency_training/train_cm_ct_unconditional.py b/examples/research_projects/consistency_training/train_cm_ct_unconditional.py index e9cc94497be4..6c5ae59761c2 100644 --- a/examples/research_projects/consistency_training/train_cm_ct_unconditional.py +++ b/examples/research_projects/consistency_training/train_cm_ct_unconditional.py @@ -198,7 +198,7 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling= def log_validation(unet, scheduler, args, accelerator, weight_dtype, step, name="teacher"): logger.info("Running validation... ") - unet = accelerator.unwrap_model(unet) + unet = unwrap_model(unet) pipeline = ConsistencyModelPipeline( unet=unet, scheduler=scheduler, @@ -1195,7 +1195,7 @@ def unwrap_model(model): # Resolve the c parameter for the Pseudo-Huber loss if args.huber_c is None: - args.huber_c = 0.00054 * args.resolution * math.sqrt(accelerator.unwrap_model(unet).config.in_channels) + args.huber_c = 0.00054 * args.resolution * math.sqrt(unwrap_model(unet).config.in_channels) # Get current number of discretization steps N according to our discretization curriculum current_discretization_steps = get_discretization_steps( From 62f1ff3a897175640ad1815a662be71c03007fa7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 17 Jul 2024 16:13:58 +0300 Subject: [PATCH 3/3] `chore: update unwrap_model() function to use accelerator.unwrap_model()` --- .../consistency_training/train_cm_ct_unconditional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/research_projects/consistency_training/train_cm_ct_unconditional.py b/examples/research_projects/consistency_training/train_cm_ct_unconditional.py index 6c5ae59761c2..eccc539f230c 100644 --- a/examples/research_projects/consistency_training/train_cm_ct_unconditional.py +++ b/examples/research_projects/consistency_training/train_cm_ct_unconditional.py @@ -198,7 +198,7 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling= def log_validation(unet, scheduler, args, accelerator, weight_dtype, step, name="teacher"): logger.info("Running validation... ") - unet = unwrap_model(unet) + unet = accelerator.unwrap_model(unet) pipeline = ConsistencyModelPipeline( unet=unet, scheduler=scheduler,