From af605b1423e46583cfe2ef7344d54596825e33d7 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 16 Sep 2024 12:44:13 -0700 Subject: [PATCH 1/3] update ptxla training script to fix step time calc. --- .../research_projects/pytorch_xla/README.md | 10 +---- .../pytorch_xla/train_text_to_image_xla.py | 39 +++++++++---------- 2 files changed, 19 insertions(+), 30 deletions(-) diff --git a/examples/research_projects/pytorch_xla/README.md b/examples/research_projects/pytorch_xla/README.md index a6901d5ada9d..f84e1ef9b20e 100644 --- a/examples/research_projects/pytorch_xla/README.md +++ b/examples/research_projects/pytorch_xla/README.md @@ -5,15 +5,7 @@ The `train_text_to_image_xla.py` script shows how to fine-tune stable diffusion It has been tested on v4 and v5p TPU versions. Training code has been tested on multi-host. This script implements Distributed Data Parallel using GSPMD feature in XLA compiler -where we shard the input batches over the TPU devices. - -As of 9-11-2024, these are some expected step times. - -| accelerator | global batch size | step time (seconds) | -| ----------- | ----------------- | --------- | -| v5p-128 | 1024 | 0.245 | -| v5p-256 | 2048 | 0.234 | -| v5p-512 | 4096 | 0.2498 | +where we shard the input batches over the TPU devices. The script works on single and multi-host. ## Create TPU diff --git a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py index 5d9d8c540f11..ffd2b4750ae0 100644 --- a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py +++ b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py @@ -42,7 +42,6 @@ } PORT = 9012 - def save_model_card( args, repo_id: str, @@ -140,33 +139,31 @@ def run_optimizer(self): self.optimizer.step() def start_training(self): - times = [] - last_time = time.time() - step = 0 - while True: - if self.global_step >= self.args.max_train_steps: - xm.mark_step() - break - if step == 4 and PROFILE_DIR is not None: - xm.wait_device_ops() - xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration) + dataloader_exception = False + measure_start_step = 10 + assert measure_start_step < self.args.max_train_steps + total_time = 0 + for step in range(0, self.args.max_train_steps): try: batch = next(self.dataloader) except Exception as e: + dataloader_exception = True print(e) break + if step == measure_start_step and PROFILE_DIR is not None: + xm.wait_device_ops() + xp.trace_detached('localhost:9012', PROFILE_DIR, duration_ms=args.profile_duration) + last_time = time.time() loss = self.step_fn(batch["pixel_values"], batch["input_ids"]) - step_time = time.time() - last_time - if step >= 10: - times.append(step_time) - print(f"step: {step}, step_time: {step_time}") - if step % 5 == 0: - print(f"step: {step}, loss: {loss}") - last_time = time.time() self.global_step += 1 - step += 1 - # print(f"Average step time: {sum(times)/len(times)}") - xm.wait_device_ops() + xm.mark_step() + if not dataloader_exception: + xm.wait_device_ops() + total_time = time.time() - last_time + print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}") + else: + print("dataloader exception happen, skip result") + return def step_fn( self, From 3524cdf8a2b5f2a581ddcfed8a32445c8653627f Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 16 Sep 2024 21:15:40 +0000 Subject: [PATCH 2/3] update step time calculations. --- examples/research_projects/pytorch_xla/README.md | 2 +- .../pytorch_xla/train_text_to_image_xla.py | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/research_projects/pytorch_xla/README.md b/examples/research_projects/pytorch_xla/README.md index f84e1ef9b20e..78a7e4226e81 100644 --- a/examples/research_projects/pytorch_xla/README.md +++ b/examples/research_projects/pytorch_xla/README.md @@ -85,7 +85,7 @@ export PROFILE_DIR=/tmp/ export CACHE_DIR=/tmp/ export DATASET_NAME=lambdalabs/naruto-blip-captions export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to 64 for TPU v5p -export TRAIN_STEPS=50 +export TRAIN_STEPS=200 export OUTPUT_DIR=/tmp/trained-model/ python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=4 --loader_prefetch_size=4 --device_prefetch_size=4' diff --git a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py index ffd2b4750ae0..5065e5c51b7a 100644 --- a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py +++ b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py @@ -142,7 +142,9 @@ def start_training(self): dataloader_exception = False measure_start_step = 10 assert measure_start_step < self.args.max_train_steps - total_time = 0 + losses = [] + start_time = time.time() + print_every = 10 for step in range(0, self.args.max_train_steps): try: batch = next(self.dataloader) @@ -153,14 +155,15 @@ def start_training(self): if step == measure_start_step and PROFILE_DIR is not None: xm.wait_device_ops() xp.trace_detached('localhost:9012', PROFILE_DIR, duration_ms=args.profile_duration) - last_time = time.time() loss = self.step_fn(batch["pixel_values"], batch["input_ids"]) + losses.append(loss) + if step > measure_start_step and step % print_every == 0: + print(f"step: {step}, avg. time: {((time.time() - start_time) / print_every)}, avg loss: {sum(losses) / len(losses)}") + start_time = time.time() self.global_step += 1 xm.mark_step() if not dataloader_exception: xm.wait_device_ops() - total_time = time.time() - last_time - print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}") else: print("dataloader exception happen, skip result") return From 7cb0e3d2884952e8cf27ff39b058ab8233b325ff Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Mon, 16 Sep 2024 21:18:07 +0000 Subject: [PATCH 3/3] make style and make quality. --- .../pytorch_xla/train_text_to_image_xla.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py index 5065e5c51b7a..0320567ae166 100644 --- a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py +++ b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py @@ -42,6 +42,7 @@ } PORT = 9012 + def save_model_card( args, repo_id: str, @@ -152,13 +153,15 @@ def start_training(self): dataloader_exception = True print(e) break - if step == measure_start_step and PROFILE_DIR is not None: + if step == measure_start_step and PROFILE_DIR is not None: xm.wait_device_ops() - xp.trace_detached('localhost:9012', PROFILE_DIR, duration_ms=args.profile_duration) + xp.trace_detached("localhost:9012", PROFILE_DIR, duration_ms=args.profile_duration) loss = self.step_fn(batch["pixel_values"], batch["input_ids"]) losses.append(loss) if step > measure_start_step and step % print_every == 0: - print(f"step: {step}, avg. time: {((time.time() - start_time) / print_every)}, avg loss: {sum(losses) / len(losses)}") + print( + f"step: {step}, avg. time: {((time.time() - start_time) / print_every)}, avg loss: {sum(losses) / len(losses)}" + ) start_time = time.time() self.global_step += 1 xm.mark_step()