diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index fb31215db..96f49e0a3 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -18,8 +18,10 @@ def __init__(self, model, schedule="linear", **kwargs): def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): + if attr.device != torch.device("cuda") and torch.cuda.is_available(): attr = attr.to(torch.device("cuda")) + else: + attr = attr.to(torch.device("cpu")) setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): @@ -238,4 +240,4 @@ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unco x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning) - return x_dec \ No newline at end of file + return x_dec diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 78eeb1003..63a6d46c5 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -17,8 +17,10 @@ def __init__(self, model, schedule="linear", **kwargs): def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): + if attr.device != torch.device("cuda") and torch.cuda.is_available(): attr = attr.to(torch.device("cuda")) + else: + attr = attr.to(torch.device("cpu")) setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):