Skip to content

local variable 'train_dataset' referenced before assignment #2957

@RissyRan

Description

@RissyRan

Describe the bug

Hi team,

Recently, we are running Diffusion model on JAX, and noticed an UnboundLocalError issue for local variable train_dataset. It could be a simple fix as this forked repo.

Reproduction

Codes to reproduce:

git clone https://github.com/huggingface/diffusers.git && cd diffusers && pip install . && pip install tensorflow clu && pip install -U -r examples/text_to_image/requirements_flax.txt

cd diffusers/examples/text_to_image && TPU_LIBRARY_PATH=/lib/libtpu.so TPU_PREMAPPED_BUFFER_SIZE=4294967296 JAX_PLATFORMS=tpu,cpu python3 train_text_to_image_flax.py --pretrained_model_name_or_path=duongna/stable-diffusion-v1-4-flax --fake_data=1 --resolution=128 --center_crop --random_flip --train_batch_size=4 --mixed_precision=fp16 --max_train_steps=1500 --learning_rate=1e-05 --max_grad_norm=1 --output_dir=sd-pokemon-model

Logs

UnboundLocalError: local variable 'train_dataset' referenced before assignment

System Info

Cloud TPU (Ref)

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingstaleIssues that haven't received updates

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions