diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index e92d1e1ec77a..973ce7ab422c 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -22,6 +22,7 @@ import tempfile import unittest import warnings +from pathlib import Path import numpy as np import pytest @@ -4985,6 +4986,27 @@ def test_custom_generate_requires_trust_remote_code(self): with self.assertRaises(ValueError): model.generate(**model_inputs, custom_generate="transformers-community/custom_generate_example") + def test_custom_generate_local_directory(self): + """Tests that custom_generate works with local directories containing importable relative modules""" + with tempfile.TemporaryDirectory() as tmp_dir: + custom_generate_dir = Path(tmp_dir) / "custom_generate" + custom_generate_dir.mkdir() + with open(custom_generate_dir / "generate.py", "w") as f: + f.write("from .helper import ret_success\ndef generate(*args, **kwargs):\n return ret_success()\n") + with open(custom_generate_dir / "helper.py", "w") as f: + f.write('def ret_success():\n return "success"\n') + model = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-MistralForCausalLM", device_map="auto" + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + model_inputs = tokenizer("Hello, world!", return_tensors="pt").to(model.device) + value = model.generate( + **model_inputs, + custom_generate=str(tmp_dir), + trust_remote_code=True, + ) + assert value == "success" + @require_torch class TokenHealingTestCase(unittest.TestCase):