-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Description
Describe the bug
I'm following this tutorial for Flax Stable Diffusion (https://huggingface.co/CompVis/stable-diffusion) on Cloud TPU VM (v4-8).
Issue with newer version of diffusers
When making inferences with the newest version of diffusers/transformers (diffusers==0.14.0 and transformers==4.27.4), and loading in bfloat16 model
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="bf16",
dtype=jnp.bfloat16
)
I see this
>>> images = p_generate(prompt_ids.astype(jnp.bfloat16), p_params, rng, num_inference_steps, height, width, guidance_scale)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/api.py", line 2395, in cache_miss
execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params)
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 1021, in xla_pmap_impl_lazy
compiled_fun, fingerprint = parallel_callable(
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/linear_util.py", line 301, in memoized_fun
ans = call(fun, *args)
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 1297, in parallel_callable
pmap_computation = lower_parallel_callable(
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 1461, in lower_parallel_callable
jaxpr, consts, replicas, parts, shards = stage_parallel_callable(
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 1374, in stage_parallel_callable
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2099, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2046, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/linear_util.py", line 165, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/yeandy/.local/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py", line 266, in _generate
latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 1700, in fori_loop
(_, result), _ = scan(_fori_scan_body_fun(body_fun), (lower_, init_val),
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 260, in scan
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 246, in _create_jaxpr
jaxpr, consts, out_tree = _initial_style_jaxpr(
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 60, in _initial_style_jaxpr
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 54, in _initial_style_open_jaxpr
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2029, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2046, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/linear_util.py", line 165, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 1616, in scanned_fun
return (i + 1, body_fun()(i, x)), None
File "/home/yeandy/.local/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py", line 251, in loop_body
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
File "/home/yeandy/.local/lib/python3.8/site-packages/diffusers/schedulers/scheduling_pndm_flax.py", line 273, in step
prev_sample, state = self.step_plms(state, model_output, timestep, sample)
File "/home/yeandy/.local/lib/python3.8/site-packages/diffusers/schedulers/scheduling_pndm_flax.py", line 438, in step_plms
cur_model_output=jax.lax.select_n(
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 981, in select_n
return select_n_p.bind(which, *cases)
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/core.py", line 360, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/core.py", line 363, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1781, in process_primitive
return self.default_process_primitive(primitive, tracers, params)
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1785, in default_process_primitive
out_avals, effects = primitive.abstract_eval(*avals, **params)
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/core.py", line 396, in abstract_eval_
return abstract_eval(*args, **kwargs), no_effects
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/lax/utils.py", line 67, in standard_abstract_eval
dtype_rule(*avals, **kwargs), weak_type=weak_type,
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 3412, in _select_dtype_rule
_check_same_dtypes("select", False, *(c.dtype for c in cases))
File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 4693, in _check_same_dtypes
raise TypeError(msg.format(name, ", ".join(map(str, types))))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: lax.select requires arguments to have the same dtypes, got bfloat16, float32, float32, float32, float32. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).
I think the error comes from this line, where there is dtype mismatch between bfloat16 and float32 data types in the jax.lax.select method.
state = state.replace(
cur_model_output=jax.lax.select_n(
jnp.clip(state.counter, 0, 4),
model_output, # counter 0
(model_output + state.ets[-1]) / 2, # counter 1
(3 * state.ets[-1] - state.ets[-2]) / 2, # counter 2
(23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12, # counter 3
(1 / 24)
* (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]), # counter >= 4
),
)
I tried using a float32 model, and the same Python code works.
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="flax",
dtype=jnp.float32
)
Testing an older version of diffusers works
When loading the version of diffusers before this change, then my code works with both bfloat16 and float32
Install packages
pip install diffusers==0.7.2
pip install transformers==4.24.0
Python code
import time
import numpy as np
import jax
import jax.numpy as jnp
from diffusers import FlaxStableDiffusionPipeline
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="bf16",
dtype=jnp.bfloat16
)
def create_key(seed=0):
return jax.random.PRNGKey(seed)
p_params = replicate(params)
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
prompts = [
"Labrador in the style of Hokusai",
"Painting of a squirrel skating in New York",
"HAL-9000 in the style of Van Gogh",
"Times Square under water, with fish and a dolphin swimming around",
"Ancient Roman fresco showing a man working on his laptop",
"Close-up photograph of young black woman against urban background, high quality, bokeh",
"Armchair in the shape of an avocado",
"Clown astronaut in space, with Earth in the background",
]
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)
p_generate = pmap(pipeline._generate)
print("Sharded prompt ids has shape:", prompt_ids.shape)
images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
Reproduction
Install packages
pip install diffusers==0.14.0
pip install transformers==4.27.4
sudo pip install "jax[tpu]==0.4.6" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install flax
Run python3 test.py, where test.py is the following
import time
import numpy as np
import jax
import jax.numpy as jnp
from diffusers import FlaxStableDiffusionPipeline
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="bf16",
dtype=jnp.bfloat16
)
def create_key(seed=0):
return jax.random.PRNGKey(seed)
p_params = replicate(params)
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
prompts = [
"Labrador in the style of Hokusai",
"Painting of a squirrel skating in New York",
"HAL-9000 in the style of Van Gogh",
"Times Square under water, with fish and a dolphin swimming around",
"Ancient Roman fresco showing a man working on his laptop",
"Close-up photograph of young black woman against urban background, high quality, bokeh",
"Armchair in the shape of an avocado",
"Clown astronaut in space, with Earth in the background",
]
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)
num_inference_steps = 50
height = 512
width = 512
guidance_scale = 7.5
p_generate = pmap(pipeline._generate, static_broadcasted_argnums=[3,4,5,6])
print("Sharded prompt ids has shape:", prompt_ids.shape)
images = p_generate(prompt_ids, p_params, rng, num_inference_steps, height, width, guidance_scale)
images = images.block_until_ready()
Logs
No response
System Info
diffusers==0.14.0
transformers==4.27.4
jax[tpu]==0.4.6
GCP TPU VM, v4-8