Skip to content
8 changes: 6 additions & 2 deletions src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,15 @@ def enable_model_cpu_offload(self, gpu_id=0):
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

model_sequence = [
self.text_encoder,
self.text_encoder.text_model,
self.text_encoder.text_projection,
self.text_encoder_2,
self.projection_model,
self.language_model,
self.unet,
self.vae,
self.vocoder,
self.text_encoder,
]

hook = None
Expand Down Expand Up @@ -909,7 +912,8 @@ def __call__(
encoder_hidden_states=generated_prompt_embeds,
encoder_hidden_states_1=prompt_embeds,
encoder_attention_mask_1=attention_mask,
).sample
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
Expand Down
29 changes: 24 additions & 5 deletions tests/pipelines/audioldm2/test_audioldm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.utils import is_xformers_available, slow, torch_device
from diffusers.utils import is_accelerate_available, is_accelerate_version, is_xformers_available, slow, torch_device
from diffusers.utils.testing_utils import enable_full_determinism

from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
Expand Down Expand Up @@ -477,7 +477,6 @@ def test_to_dtype(self):
# The method component.dtype returns the dtype of the first parameter registered in the model, not the
# dtype of the entire model. In the case of CLAP, the first parameter is a float64 constant (logit scale)
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(model_dtypes["text_encoder"] == torch.float64)

# Without the logit scale parameters, everything is float32
model_dtypes.pop("text_encoder")
Expand All @@ -492,6 +491,26 @@ def test_to_dtype(self):
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))

@unittest.skipIf(
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.17.0"),
reason="CPU offload is only available with CUDA and `accelerate v0.17.0` or higher",
)
def test_model_cpu_offload(self, expected_max_diff=2e-4):
components = self.get_dummy_components()
audioldm_pipe = AudioLDM2Pipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(torch_device)
output_without_offload = audioldm_pipe(**inputs)[0]

audioldm_pipe.enable_model_cpu_offload()
inputs = self.get_dummy_inputs(torch_device)
output_with_offload = audioldm_pipe(**inputs)[0]

max_diff = np.abs(output_with_offload - output_without_offload).max()
self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")


@slow
class AudioLDM2PipelineSlowTests(unittest.TestCase):
Expand All @@ -514,7 +533,7 @@ def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0
return inputs

def test_audioldm2(self):
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("/home/sanchit/convert-audioldm2/hub-audioldm2")
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)

Expand All @@ -532,7 +551,7 @@ def test_audioldm2(self):
assert max_diff < 1e-3

def test_audioldm2_lms(self):
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("/home/sanchit/convert-audioldm2/hub-audioldm2")
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
audioldm_pipe.scheduler = LMSDiscreteScheduler.from_config(audioldm_pipe.scheduler.config)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
Expand All @@ -552,7 +571,7 @@ def test_audioldm2_lms(self):
assert max_diff < 1e-3

def test_audioldm2_large(self):
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("/home/sanchit/convert-audioldm2/hub-audioldm2-large")
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2-large")
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)

Expand Down