Skip to content

Use consistent random number generation across hardware #1514

@antoche

Description

@antoche

Is your feature request related to a problem? Please describe.

pytorch.randn is not consistent across hardware devices (See pytorch/pytorch#84234).

diffusers calls torch.randn on the device computation is run on (typically 'cuda'). As a result, results produced with the exact same parameters will differ across machines.

Describe the solution you'd like

Until the issue is resolved in pytorch itself, diffusers should use a deterministic RNG so results can be consistent across hardware.

One possible workaround is to keep using torch.rng while enforcing generation to happen on the cpu, which currently seems consistent no matter the hardware.

Here is an example solution:

def randn(size, generator=None, device=None, **kwargs):
    """
    Wrapper around torch.randn providing proper reproducibility.

    Generation is done on the given generator's device, then moved to the
    given ``device``.

    Args:
        size: tensor size
        generator (torch.Generator): RNG generator
        device (torch.device): Target device for the resulting tensor
    """
    # FIXME: generator RNG device is ignored and needs to be passed to torch.randn (torch issue #62451)
    rng_device = generator.device if generator is not None else device
    image = torch.randn(size, generator=generator, device=rng_device, **kwargs)
    image = image.to(device=device)
    return image


def randn_like(tensor, generator=None, **kwargs):
    return randn(tensor.shape, layout=tensor.layout, generator=generator, device=tensor.device, **kwargs)

Calling these functions instead of the torch ones, with a generator whose device is cpu, gives deterministic results and still allows for the rest of the computations to run on cuda.

This would also simplify and speed up all the tests, which can simply use cpu-bound generators and leave device to be cuda even for those relying on RNG.

Describe alternatives you've considered

It's also possisble to switch to numpy's RNG, which is deterministic. The above solution is more torch-native.

Metadata

Metadata

Labels

staleIssues that haven't received updates

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions