Skip to content

bfloat16/float32 type mismatch in lax.select in scheduling_pndm_flax.py #3039

@yeandy

Description

@yeandy

Describe the bug

I'm following this tutorial for Flax Stable Diffusion (https://huggingface.co/CompVis/stable-diffusion) on Cloud TPU VM (v4-8).

Issue with newer version of diffusers

When making inferences with the newest version of diffusers/transformers (diffusers==0.14.0 and transformers==4.27.4), and loading in bfloat16 model

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="bf16",
    dtype=jnp.bfloat16
)

I see this

>>> images = p_generate(prompt_ids.astype(jnp.bfloat16), p_params, rng, num_inference_steps, height, width, guidance_scale)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/api.py", line 2395, in cache_miss
    execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params)
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 1021, in xla_pmap_impl_lazy
    compiled_fun, fingerprint = parallel_callable(
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/linear_util.py", line 301, in memoized_fun
    ans = call(fun, *args)
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 1297, in parallel_callable
    pmap_computation = lower_parallel_callable(
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 1461, in lower_parallel_callable
    jaxpr, consts, replicas, parts, shards = stage_parallel_callable(
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 1374, in stage_parallel_callable
    jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2099, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2046, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/linear_util.py", line 165, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/yeandy/.local/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py", line 266, in _generate
    latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 1700, in fori_loop
    (_, result), _ = scan(_fori_scan_body_fun(body_fun), (lower_, init_val),
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 260, in scan
    init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 246, in _create_jaxpr
    jaxpr, consts, out_tree = _initial_style_jaxpr(
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 60, in _initial_style_jaxpr
    jaxpr, consts, out_tree = _initial_style_open_jaxpr(
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 54, in _initial_style_open_jaxpr
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2029, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2046, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/linear_util.py", line 165, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 1616, in scanned_fun
    return (i + 1, body_fun()(i, x)), None
  File "/home/yeandy/.local/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py", line 251, in loop_body
    latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
  File "/home/yeandy/.local/lib/python3.8/site-packages/diffusers/schedulers/scheduling_pndm_flax.py", line 273, in step
    prev_sample, state = self.step_plms(state, model_output, timestep, sample)
  File "/home/yeandy/.local/lib/python3.8/site-packages/diffusers/schedulers/scheduling_pndm_flax.py", line 438, in step_plms
    cur_model_output=jax.lax.select_n(
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 981, in select_n
    return select_n_p.bind(which, *cases)
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/core.py", line 360, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/core.py", line 363, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1781, in process_primitive
    return self.default_process_primitive(primitive, tracers, params)
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1785, in default_process_primitive
    out_avals, effects = primitive.abstract_eval(*avals, **params)
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/core.py", line 396, in abstract_eval_
    return abstract_eval(*args, **kwargs), no_effects
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/lax/utils.py", line 67, in standard_abstract_eval
    dtype_rule(*avals, **kwargs), weak_type=weak_type,
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 3412, in _select_dtype_rule
    _check_same_dtypes("select", False, *(c.dtype for c in cases))
  File "/home/yeandy/.local/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 4693, in _check_same_dtypes
    raise TypeError(msg.format(name, ", ".join(map(str, types))))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: lax.select requires arguments to have the same dtypes, got bfloat16, float32, float32, float32, float32. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).

I think the error comes from this line, where there is dtype mismatch between bfloat16 and float32 data types in the jax.lax.select method.

        state = state.replace(
            cur_model_output=jax.lax.select_n(
                jnp.clip(state.counter, 0, 4),
                model_output,  # counter 0
                (model_output + state.ets[-1]) / 2,  # counter 1
                (3 * state.ets[-1] - state.ets[-2]) / 2,  # counter 2
                (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12,  # counter 3
                (1 / 24)
                * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]),  # counter >= 4
            ),
        )

I tried using a float32 model, and the same Python code works.

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="flax",
    dtype=jnp.float32
)

Testing an older version of diffusers works

When loading the version of diffusers before this change, then my code works with both bfloat16 and float32

Install packages

pip install diffusers==0.7.2
pip install transformers==4.24.0

Python code

import time


import numpy as np
import jax
import jax.numpy as jnp
from diffusers import FlaxStableDiffusionPipeline
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
   "CompVis/stable-diffusion-v1-4",
   revision="bf16",
   dtype=jnp.bfloat16
)

def create_key(seed=0):
 return jax.random.PRNGKey(seed)

p_params = replicate(params)

rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())

prompts = [
   "Labrador in the style of Hokusai",
   "Painting of a squirrel skating in New York",
   "HAL-9000 in the style of Van Gogh",
   "Times Square under water, with fish and a dolphin swimming around",
   "Ancient Roman fresco showing a man working on his laptop",
   "Close-up photograph of young black woman against urban background, high quality, bokeh",
   "Armchair in the shape of an avocado",
   "Clown astronaut in space, with Earth in the background",
]

prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)

p_generate = pmap(pipeline._generate)

print("Sharded prompt ids has shape:", prompt_ids.shape)

images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()

Reproduction

Install packages

pip install diffusers==0.14.0
pip install transformers==4.27.4
sudo pip install "jax[tpu]==0.4.6" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install flax

Run python3 test.py, where test.py is the following

import time

import numpy as np
import jax
import jax.numpy as jnp
from diffusers import FlaxStableDiffusionPipeline
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="bf16",
    dtype=jnp.bfloat16
)

def create_key(seed=0):
 return jax.random.PRNGKey(seed)

p_params = replicate(params)

rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())

prompts = [
   "Labrador in the style of Hokusai",
   "Painting of a squirrel skating in New York",
   "HAL-9000 in the style of Van Gogh",
   "Times Square under water, with fish and a dolphin swimming around",
   "Ancient Roman fresco showing a man working on his laptop",
   "Close-up photograph of young black woman against urban background, high quality, bokeh",
   "Armchair in the shape of an avocado",
   "Clown astronaut in space, with Earth in the background",
]

prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)

num_inference_steps = 50
height = 512
width = 512
guidance_scale = 7.5

p_generate = pmap(pipeline._generate, static_broadcasted_argnums=[3,4,5,6])
print("Sharded prompt ids has shape:", prompt_ids.shape)

images = p_generate(prompt_ids, p_params, rng, num_inference_steps, height, width, guidance_scale)
images = images.block_until_ready()

Logs

No response

System Info

diffusers==0.14.0
transformers==4.27.4
jax[tpu]==0.4.6
GCP TPU VM, v4-8

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions