Skip to content

hardcoded torch.float64 isn't supported on Metal (device="mps") #9224

@RoyLeviLangware

Description

@RoyLeviLangware

Describe the bug

Running transformer_flux on macOS device="mps" fails with error:
Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead

Reproduction

Run on Mac:

from diffusers import FluxPipeline


def print_hi(name):
    print(f'Hi, {name}')
    torch.set_default_device("mps")
    pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(device="mps")
    # pipe.enable_model_cpu_offload()  # save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power

    prompt = "A cat holding a sign that says hello world"
    image = pipe(
        prompt,
        guidance_scale=0.0,
        num_inference_steps=16,
        max_sequence_length=256,
        generator=torch.Generator("mps").manual_seed(0)
    ).images[0]
    image.save("flux-schnell.png")


if __name__ == '__main__':
    print_hi('PyCharm')

Logs

Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead

System Info

diffusers==0.30.0
torch==2.4.0

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions