From 04a9b487b8704a184f508cc4a875e468e883e2ca Mon Sep 17 00:00:00 2001 From: Lee Penkman Date: Tue, 8 Jul 2025 16:07:04 +1200 Subject: [PATCH 1/2] Switch to Gemma model --- README.md | 5 +- questions/constants.py | 9 ++- .../inference_server/inference_server.py | 1 + questions/link_enricher.py | 63 +++++++++++-------- questions/test_gemma_multimodal.py | 40 ++++++++++++ questions/test_inf_moon.py | 47 -------------- questions/text_generator_inference.py | 8 +-- scripts/example_moodream.py | 63 ++++++++----------- tests/unit/questions/test_link_enricher.py | 4 ++ tests/unit/test_audio_model.py | 5 ++ tests/unit/test_document_processor.py | 4 ++ 11 files changed, 129 insertions(+), 120 deletions(-) create mode 100644 questions/test_gemma_multimodal.py delete mode 100644 questions/test_inf_moon.py diff --git a/README.md b/README.md index a2b6047..4ba0b7a 100644 --- a/README.md +++ b/README.md @@ -113,10 +113,11 @@ make download-punkt # download the punkt dataset for NLTK ### Models -Download models from huggingface. +Download the Gemma model from HuggingFace. The model is small enough for local +testing but still supports multimodal input. ```shell -huggingface-cli download HuggingFaceTB/SmolLM2-1.7B-Instruct --local-dir models/SmolLM-1.7B +huggingface-cli download google/gemma-3n-e4b-it --local-dir models/gemma wget -P models https://huggingface.co/geneing/Kokoro/resolve/f610f07c62f8baa30d4ed731530e490230e4ee83/kokoro-v0_19.pth ``` diff --git a/questions/constants.py b/questions/constants.py index dc8b12c..0f4e865 100644 --- a/questions/constants.py +++ b/questions/constants.py @@ -1,4 +1,7 @@ import os -weights_path_tgz = os.getenv("WEIGHTS_PATH_TGZ", "models/SmolLM-1.7B") -weights_path_tgc = os.getenv("WEIGHTS_PATH_TGC", "models/SmolLM-1.7B") -weights_path_tg = os.getenv("WEIGHTS_PATH", "models/SmolLM-1.7B") +# Default to the multimodal Gemma model which can handle both text generation +# and image description. The environment variables allow overriding the model +# path, but when unset a small Gemma checkpoint from HuggingFace is used. +weights_path_tgz = os.getenv("WEIGHTS_PATH_TGZ", "google/gemma-3n-e4b-it") +weights_path_tgc = os.getenv("WEIGHTS_PATH_TGC", "google/gemma-3n-e4b-it") +weights_path_tg = os.getenv("WEIGHTS_PATH", "google/gemma-3n-e4b-it") diff --git a/questions/inference_server/inference_server.py b/questions/inference_server/inference_server.py index c6a8736..d949450 100644 --- a/questions/inference_server/inference_server.py +++ b/questions/inference_server/inference_server.py @@ -1343,3 +1343,4 @@ def tts_demo(request: Request): # return HTTPException(status_code=500, detail=f"Error generating text: {str(e)}") if __name__ == "__main__": + pass diff --git a/questions/link_enricher.py b/questions/link_enricher.py index e17ed40..e35fc27 100755 --- a/questions/link_enricher.py +++ b/questions/link_enricher.py @@ -15,8 +15,11 @@ setup_logging() logger = logging.getLogger(__name__) from requests_futures.sessions import FuturesSession -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer +try: + import torch +except ModuleNotFoundError: # pragma: no cover - optional dependency + torch = None +from transformers import pipeline from questions.inference_server.model_cache import ModelCache # change into OFA dir @@ -97,42 +100,48 @@ def get_title_from_html(html, long_description=False): LINK_MODEL_CACHE = ModelCache() -def load_moondream_model(): - """Load the Moondream model for image captioning""" - model_id = "vikhyatk/moondream2" - revision = "2024-08-26" - - model = AutoModelForCausalLM.from_pretrained( - model_id, - trust_remote_code=True, - revision=revision +def load_gemma_pipe(): + """Load the Gemma model for image captioning and text generation.""" + model_id = "google/gemma-3n-e4b-it" + device = 0 if torch and torch.cuda.is_available() else -1 + return pipeline( + "image-text-to-text", + model=model_id, + device=device, + torch_dtype=torch.bfloat16 if torch and torch.cuda.is_available() else None, ) - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision - ) - return model, tokenizer def get_caption_for_image_response(response, prompt="Describe this image."): - """Get image caption using Moondream model""" + """Generate a caption for an image using Gemma.""" response.raw.decode_content = True - image_bytes = response.content - - img = Image.open(BytesIO(image_bytes)) - + img = Image.open(BytesIO(response.content)) + with log_time("image captioning"): - model, tokenizer = LINK_MODEL_CACHE.add_or_get("moondream_model", load_moondream_model) - enc_image = model.encode_image(img) - caption = model.answer_question(enc_image, prompt, tokenizer) - + pipe = LINK_MODEL_CACHE.add_or_get("gemma_pipe", load_gemma_pipe) + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant."}], + }, + { + "role": "user", + "content": [ + {"type": "image", "image": img}, + {"type": "text", "text": prompt}, + ], + }, + ] + output = pipe(text=messages, max_new_tokens=100) + caption = output[0]["generated_text"][-1]["content"] + if debug: logger.info(f"Image description: {caption}") - + if any(ocr_tag in caption.lower() for ocr_tag in ocr_tags): with log_time("OCR"): ocr_data = ocr_tess(img) caption += " " + ocr_data - + return caption diff --git a/questions/test_gemma_multimodal.py b/questions/test_gemma_multimodal.py new file mode 100644 index 0000000..4d9c657 --- /dev/null +++ b/questions/test_gemma_multimodal.py @@ -0,0 +1,40 @@ +import os +import importlib.util +import pytest +from PIL import Image +from transformers import pipeline +from questions.logging_config import setup_logging + +# Skip the tests if torch is not available as the transformers pipelines +# require it for model execution. +if importlib.util.find_spec("torch") is None: + pytest.skip("torch is required for Gemma tests", allow_module_level=True) + +setup_logging() + +def test_gemma_image_captioning(): + model_id = "yujiepan/gemma-3n-tiny-random" + pipe = pipeline( + "image-text-to-text", + model=model_id, + device=-1, + ) + + image_path = "static/img/me.jpg" + assert os.path.exists(image_path) + img = Image.open(image_path) + + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + {"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": "Describe this image."}]}, + ] + output = pipe(text=messages, max_new_tokens=5) + assert isinstance(output[0]["generated_text"][-1]["content"], str) + + +def test_gemma_text_generation(): + model_id = "yujiepan/gemma-3n-tiny-random" + text_pipe = pipeline("text-generation", model=model_id, device=-1) + out = text_pipe("Hello", max_new_tokens=5) + assert isinstance(out[0]["generated_text"], str) + diff --git a/questions/test_inf_moon.py b/questions/test_inf_moon.py deleted file mode 100644 index 1314b69..0000000 --- a/questions/test_inf_moon.py +++ /dev/null @@ -1,47 +0,0 @@ -import pytest -from transformers import AutoModelForCausalLM, AutoTokenizer -from PIL import Image -import os -import logging -from questions.logging_config import setup_logging - -setup_logging() -logger = logging.getLogger(__name__) - -def test_moondream(): - model_id = "vikhyatk/moondream2" - revision = "2024-08-26" - - model = AutoModelForCausalLM.from_pretrained( - model_id, - trust_remote_code=True, - revision=revision - ) - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision - ) - - # Test with a sample image - image_path = "static/img/me.jpg" - assert os.path.exists(image_path), f"Test image not found at {image_path}" - - image = Image.open(image_path) - enc_image = model.encode_image(image) - - # Test basic image description - response = model.answer_question(enc_image, "Describe this image.", tokenizer) - logger.info(f"Image description: {response}") - assert isinstance(response, str) - assert len(response) > 0 - - # Test specific question - question = "What colors are present in this image?" - response = model.answer_question(enc_image, question, tokenizer) - logger.info(f"Question: {question}\nAnswer: {response}") - assert isinstance(response, str) - assert len(response) > 0 - - -if __name__ == "__main__": - test_moondream() \ No newline at end of file diff --git a/questions/text_generator_inference.py b/questions/text_generator_inference.py index 1ce6b10..fb2685b 100755 --- a/questions/text_generator_inference.py +++ b/questions/text_generator_inference.py @@ -97,10 +97,10 @@ def load_model(weights_path): if (Path("/" + weights_path) / "config.json").exists(): weights_path = str(Path("/") / weights_path) - # if (Path("/models")).exists(): # prefer to save in ramdisk - # weights_path = "/" + weights_path - - if not (Path(weights_path) / "config.json").exists(): + # Only attempt to download weights if a local directory was specified and no + # config file is present. When using a HuggingFace model id the directory + # will not exist and downloading is handled by `from_pretrained`. + if os.path.isdir(weights_path) and not (Path(weights_path) / "config.json").exists(): download_model(weights_path, weights_path) logger.info(f"loading model from {weights_path}") diff --git a/scripts/example_moodream.py b/scripts/example_moodream.py index 9dec469..d61cd51 100644 --- a/scripts/example_moodream.py +++ b/scripts/example_moodream.py @@ -1,14 +1,6 @@ import os from PIL import Image -from transformers import AutoModelForCausalLM, AutoTokenizer - -def ensure_model_downloaded(): - """Ensure model is downloaded to models directory""" - model_path = "models/moondream2" - if not os.path.exists(model_path): - print(f"Model not found in {model_path}, downloading...") - os.makedirs(model_path, exist_ok=True) - return model_path +from transformers import pipeline def load_image(): """Load the local chrome icon image""" @@ -18,45 +10,42 @@ def load_image(): return Image.open(image_path) def main(): - # Initialize model and tokenizer - model_id = "vikhyatk/moondream2" - revision = "2024-08-26" - model_path = ensure_model_downloaded() - - print("Loading model and tokenizer...") - model = AutoModelForCausalLM.from_pretrained( - model_id, - trust_remote_code=True, - revision=revision, - cache_dir=model_path, - force_download=True - ) - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - cache_dir=model_path, - force_download=True + """Run a simple caption generation using the Gemma model.""" + model_id = "google/gemma-3n-e4b-it" + pipe = pipeline( + "image-text-to-text", + model=model_id, + device=0 if torch.cuda.is_available() else -1, + torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None, ) # Load and process image print("Loading local image...") image = load_image() - print("Encoding image...") - enc_image = model.encode_image(image) - - # Ask questions about the image questions = [ "Describe this image.", "What colors are prominent in this image?", - "Is this an icon or logo? If so, describe its design." + "Is this an icon or logo? If so, describe its design.", ] - + print("\nAsking questions about the image:") for question in questions: - print(f"\nQ: {question}") - answer = model.answer_question(enc_image, question, tokenizer) - print(f"A: {answer}") + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant."}], + }, + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": question}, + ], + }, + ] + answer = pipe(text=messages, max_new_tokens=100)[0]["generated_text"][-1]["content"] + print(f"\nQ: {question}\nA: {answer}") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/unit/questions/test_link_enricher.py b/tests/unit/questions/test_link_enricher.py index b0a023a..c74935e 100644 --- a/tests/unit/questions/test_link_enricher.py +++ b/tests/unit/questions/test_link_enricher.py @@ -1,6 +1,10 @@ import logging +import importlib.util import pytest +if importlib.util.find_spec("transformers") is None: + pytest.skip("transformers not installed", allow_module_level=True) + bs4 = pytest.importorskip("bs4", reason="bs4 required for link enrichment tests") from questions.link_enricher import get_urls, enrich_links from questions.logging_config import setup_logging diff --git a/tests/unit/test_audio_model.py b/tests/unit/test_audio_model.py index 84fc6e5..d8e0ff8 100644 --- a/tests/unit/test_audio_model.py +++ b/tests/unit/test_audio_model.py @@ -1,5 +1,10 @@ import builtins from unittest import mock +import importlib.util +import pytest + +if importlib.util.find_spec("torch") is None: + pytest.skip("torch not installed", allow_module_level=True) import questions.inference_server.inference_server as server diff --git a/tests/unit/test_document_processor.py b/tests/unit/test_document_processor.py index f1be354..35b2321 100644 --- a/tests/unit/test_document_processor.py +++ b/tests/unit/test_document_processor.py @@ -1,5 +1,9 @@ import requests import pytest +import importlib.util + +if importlib.util.find_spec("docx") is None: + pytest.skip("python-docx not installed", allow_module_level=True) from questions.document_processor import convert_to_markdown, convert_documents_async From 77a061d4f6973d0998fa6f833c575c5b122e6e17 Mon Sep 17 00:00:00 2001 From: Lee Penkman Date: Tue, 8 Jul 2025 16:48:15 +1200 Subject: [PATCH 2/2] Add configurable Gemma pipeline --- README.md | 8 +++++++ questions/link_enricher.py | 12 +++++----- questions/test_gemma_multimodal.py | 1 - scripts/example_moodream.py | 4 ---- scripts/run_gemma.py | 35 ++++++++++++++++++++++++++++++ tests/unit/test_gemma_env.py | 12 ++++++++++ 6 files changed, 61 insertions(+), 11 deletions(-) create mode 100755 scripts/run_gemma.py create mode 100644 tests/unit/test_gemma_env.py diff --git a/README.md b/README.md index 4ba0b7a..4d26c4c 100644 --- a/README.md +++ b/README.md @@ -128,6 +128,14 @@ models/tg a general model accessible with model=multilingual models/tgz an instruct model accessible with model=instruct models/tgc a chat model accessible with model=chat +For the Gemma image pipeline you can override the default model or device using +the following environment variables: + +``` +GEMMA_MODEL_ID=myorg/my-gemma-checkpoint +GEMMA_DEVICE=0 # set to -1 for CPU +``` + model=best is configured to figure out which model to use based on the prompt being scored based on perplexity of each model. This needs tuning for the avg and std deviation of the perplexity as each model has different ideas about how confidenti it is. Overtrained models are more confident about all text being in the dataset (tend to generate text verbatim from the dataset). diff --git a/questions/link_enricher.py b/questions/link_enricher.py index e35fc27..c19cf91 100755 --- a/questions/link_enricher.py +++ b/questions/link_enricher.py @@ -102,8 +102,12 @@ def get_title_from_html(html, long_description=False): def load_gemma_pipe(): """Load the Gemma model for image captioning and text generation.""" - model_id = "google/gemma-3n-e4b-it" - device = 0 if torch and torch.cuda.is_available() else -1 + model_id = os.getenv("GEMMA_MODEL_ID", "google/gemma-3n-e4b-it") + device_env = os.getenv("GEMMA_DEVICE") + if device_env is not None: + device = int(device_env) + else: + device = 0 if torch and torch.cuda.is_available() else -1 return pipeline( "image-text-to-text", model=model_id, @@ -119,10 +123,6 @@ def get_caption_for_image_response(response, prompt="Describe this image."): with log_time("image captioning"): pipe = LINK_MODEL_CACHE.add_or_get("gemma_pipe", load_gemma_pipe) messages = [ - { - "role": "system", - "content": [{"type": "text", "text": "You are a helpful assistant."}], - }, { "role": "user", "content": [ diff --git a/questions/test_gemma_multimodal.py b/questions/test_gemma_multimodal.py index 4d9c657..4b7116f 100644 --- a/questions/test_gemma_multimodal.py +++ b/questions/test_gemma_multimodal.py @@ -25,7 +25,6 @@ def test_gemma_image_captioning(): img = Image.open(image_path) messages = [ - {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, {"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": "Describe this image."}]}, ] output = pipe(text=messages, max_new_tokens=5) diff --git a/scripts/example_moodream.py b/scripts/example_moodream.py index d61cd51..0bdb637 100644 --- a/scripts/example_moodream.py +++ b/scripts/example_moodream.py @@ -32,10 +32,6 @@ def main(): print("\nAsking questions about the image:") for question in questions: messages = [ - { - "role": "system", - "content": [{"type": "text", "text": "You are a helpful assistant."}], - }, { "role": "user", "content": [ diff --git a/scripts/run_gemma.py b/scripts/run_gemma.py new file mode 100755 index 0000000..2e94c9a --- /dev/null +++ b/scripts/run_gemma.py @@ -0,0 +1,35 @@ +import argparse +from PIL import Image +from transformers import pipeline +import os + + +def main(): + parser = argparse.ArgumentParser(description="Generate captions using Gemma") + parser.add_argument("image", help="Path to image") + parser.add_argument("prompt", nargs="?", default="Describe this image.", help="Prompt for the model") + args = parser.parse_args() + + model_id = os.getenv("GEMMA_MODEL_ID", "google/gemma-3n-e4b-it") + device_env = os.getenv("GEMMA_DEVICE") + if device_env is not None: + device = int(device_env) + else: + try: + import torch + device = 0 if torch.cuda.is_available() else -1 + except ModuleNotFoundError: + device = -1 + + pipe = pipeline("image-text-to-text", model=model_id, device=device) + + img = Image.open(args.image) + messages = [ + {"role": "user", "content": [{"type": "image", "image": img}, {"type": "text", "text": args.prompt}]}, + ] + output = pipe(text=messages, max_new_tokens=100) + print(output[0]["generated_text"][-1]["content"]) + + +if __name__ == "__main__": + main() diff --git a/tests/unit/test_gemma_env.py b/tests/unit/test_gemma_env.py new file mode 100644 index 0000000..61da69c --- /dev/null +++ b/tests/unit/test_gemma_env.py @@ -0,0 +1,12 @@ +import os +import importlib +import pytest + +pytest.importorskip("transformers") + + +def test_load_gemma_pipe_env(monkeypatch): + monkeypatch.setenv("GEMMA_MODEL_ID", "yujiepan/gemma-3n-tiny-random") + module = importlib.import_module("questions.link_enricher") + pipe = module.load_gemma_pipe() + assert pipe.model.name_or_path == "yujiepan/gemma-3n-tiny-random"