From 18b0999f5fbf68fec938ad4515963d8476d92331 Mon Sep 17 00:00:00 2001 From: tongyu <119610311+tongyu0924@users.noreply.github.com> Date: Sun, 27 Apr 2025 15:21:02 +0800 Subject: [PATCH 1/2] Update train_text_to_image_lora.py --- .../text_to_image/train_text_to_image_lora.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index ba733efe6003..3b28e84e85c8 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -418,6 +418,15 @@ def parse_args(): default=4, help=("The dimension of the LoRA update matrices."), ) + parser.add_argument( + "--image_interpolation_mode", + type=str, + default="lanczos", + choices=[ + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + ], + help="The image interpolation method to use for resizing images.", + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -649,10 +658,17 @@ def tokenize_captions(examples, is_train=True): ) return inputs.input_ids - # Preprocessing the datasets. + # Get the specified interpolation method from the args + interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) + + # Raise an error if the interpolation method is invalid + if interpolation is None: + raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.") + + # Data preprocessing transformations train_transforms = transforms.Compose( [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.Resize(args.resolution, interpolation=interpolation), # Use dynamic interpolation method transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), transforms.ToTensor(), From 93cd97ba96ef6c19acbf3bfadcf0493f81356821 Mon Sep 17 00:00:00 2001 From: tongyu0924 Date: Mon, 28 Apr 2025 21:17:02 +0800 Subject: [PATCH 2/2] update_train_text_to_image_lora --- examples/text_to_image/train_text_to_image_lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 3b28e84e85c8..480f4b36df61 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -660,11 +660,11 @@ def tokenize_captions(examples, is_train=True): # Get the specified interpolation method from the args interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) - + # Raise an error if the interpolation method is invalid if interpolation is None: raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.") - + # Data preprocessing transformations train_transforms = transforms.Compose( [