From 0126f416cbf004fb75146a7c25d5c463237ed420 Mon Sep 17 00:00:00 2001 From: Ishan Date: Wed, 23 Apr 2025 15:45:06 +0530 Subject: [PATCH] [train_dreambooth_flux] Add LANCZOS as the default interpolation mode for image resizing --- examples/dreambooth/train_dreambooth_flux.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index fcc1737e6588..6fe24634b5e3 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -618,6 +618,15 @@ def parse_args(input_args=None): ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--image_interpolation_mode", + type=str, + default="lanczos", + choices=[ + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + ], + help="The image interpolation method to use for resizing images.", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -737,7 +746,10 @@ def __init__( self.instance_images.extend(itertools.repeat(img, repeats)) self.pixel_values = [] - train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) + if interpolation is None: + raise ValueError(f"Unsupported interpolation mode {interpolation=}.") + train_resize = transforms.Resize(size, interpolation=interpolation) train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose(