Skip to content
24 changes: 17 additions & 7 deletions src/diffusers/utils/loading_utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
import os
from typing import Union
from typing import Callable, Union

import PIL.Image
import PIL.ImageOps
import requests


def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
def load_image(
image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None
) -> PIL.Image.Image:
"""
Loads `image` to a PIL Image.

Args:
image (`str` or `PIL.Image.Image`):
The image to convert to the PIL Image format.
convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], optional):
A conversion method to apply to the image after loading it.
When set to `None` the image will be converted "RGB".

Returns:
`PIL.Image.Image`:
A PIL Image.
Expand All @@ -24,14 +30,18 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
image = PIL.Image.open(image)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path."
)
elif isinstance(image, PIL.Image.Image):
image = image
else:
raise ValueError(
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
"Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image."
)

image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")

if convert_method is not None:
image = convert_method(image)
else:
image = image.convert("RGB")

return image