diff --git a/src/groundlight/images.py b/src/groundlight/images.py index 695968a1..a8d2714f 100644 --- a/src/groundlight/images.py +++ b/src/groundlight/images.py @@ -1,6 +1,6 @@ # pylint: disable=deprecated-module -import imghdr from io import BufferedReader, BytesIO, IOBase +from pathlib import Path from typing import Union from groundlight.optional_imports import Image, np @@ -37,10 +37,11 @@ def bytestream_from_filename(image_filename: str, jpeg_quality: int = DEFAULT_JP Only supports JPEG and PNG files for now. For PNG files, we convert to RGB format used in JPEGs. """ - if imghdr.what(image_filename) == "jpeg": + image_path = Path(image_filename) + if image_path.suffix.lower() in (".jpeg", ".jpg"): buffer = buffer_from_jpeg_file(image_filename) return ByteStreamWrapper(data=buffer) - if imghdr.what(image_filename) == "png": + if image_path.suffix.lower() == ".png": pil_img = Image.open(image_filename) # This chops off the alpha channel which can cause unexpected behavior, but handles minimal transparency well pil_img = pil_img.convert("RGB") @@ -53,7 +54,7 @@ def buffer_from_jpeg_file(image_filename: str) -> BufferedReader: For now, we only support JPEG files, and raise an ValueError otherwise. """ - if imghdr.what(image_filename) == "jpeg": + if Path(image_filename).suffix.lower() in (".jpeg", ".jpg"): # Note this will get fooled by truncated binaries since it only reads the header. # That's okay - the server will catch it. return open(image_filename, "rb") diff --git a/test/integration/test_groundlight.py b/test/integration/test_groundlight.py index 904f16fb..a47b6088 100644 --- a/test/integration/test_groundlight.py +++ b/test/integration/test_groundlight.py @@ -470,11 +470,16 @@ def test_submit_image_query_bad_filename(gl: Groundlight, detector: Detector): def test_submit_image_query_bad_jpeg_file(gl: Groundlight, detector: Detector): - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ApiException) as exc_info: _image_query = gl.submit_image_query( detector=detector.id, image="test/assets/blankfile.jpeg", human_review="NEVER" ) - assert "jpeg" in str(exc_info).lower() + assert "uploaded image is empty or corrupted" in exc_info.value.body.lower() + with pytest.raises(ValueError) as exc_info: + _image_query = gl.submit_image_query( + detector=detector.id, image="test/assets/blankfile.jpeeeg", human_review="NEVER" + ) + assert "we only support jpeg and png" in str(exc_info).lower() @pytest.mark.skipif(MISSING_PIL, reason="Needs pillow") # type: ignore