diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 39fa06a..b9efe8f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -19,19 +19,37 @@ uv run pytest tests -m "not integration" or, as a shortcut, -``` +```bash just test ``` Generally if you are not developing a new provider, you can test most functionality through mocking and the normal test suite. -However to ensure the providers work, we also have integration tests which actually require a credential and connect +However, to ensure the providers work, we also have integration tests which actually require a credential and connect to the provider endpoints. Those can be run with -``` +```bash uv run pytest tests -m integration -# or `just integration` +# or `just integration` +``` + +### Integration testing with Ollama + +To run integration tests against Ollama, you need the model that tests expect available locally. + +First, run ollama and pull the models you want to test. +```bash +ollama serve +# Then in another terminal, pull the model +OLLAMA_MODEL=$(uv run python -c "from src.exchange.providers.ollama import OLLAMA_MODEL; print(OLLAMA_MODEL)") +ollama pull $OLLAMA_MODEL +``` + +Finally, run ollama integration tests. +```bash +uv run pytest tests -m integration -k ollama +# or `just integration -k ollama` ``` ## Pull Requests diff --git a/src/exchange/providers/ollama.py b/src/exchange/providers/ollama.py index 37bf026..7570abf 100644 --- a/src/exchange/providers/ollama.py +++ b/src/exchange/providers/ollama.py @@ -1,99 +1,40 @@ import os -from typing import Any, Dict, List, Tuple, Type +from typing import Type import httpx -from exchange.message import Message -from exchange.providers.base import Provider, Usage -from exchange.providers.retry_with_back_off_decorator import retry_httpx_request -from exchange.providers.utils import ( - messages_to_openai_spec, - openai_response_to_message, - openai_single_message_context_length_exceeded, - raise_for_status, - tools_to_openai_spec, -) -from exchange.tool import Tool +from exchange.providers.openai import OpenAiProvider OLLAMA_HOST = "http://localhost:11434/" +OLLAMA_MODEL = "mistral-nemo" -# -# NOTE: this is experimental, best used with 70B model or larger if you can. -# Example profile config to try: -class OllamaProvider(Provider): +class OllamaProvider(OpenAiProvider): """Provides chat completions for models hosted by Ollama""" - """ + __doc__ += f""" - ollama: - provider: ollama - processor: llama3.1 - accelerator: llama3.1 - moderator: passive - toolkits: - - name: developer - requires: {} - """ +Here's an example profile configuration to try: + + ollama: + provider: ollama + processor: {OLLAMA_MODEL} + accelerator: {OLLAMA_MODEL} + moderator: passive + toolkits: + - name: developer + requires: {{}} +""" def __init__(self, client: httpx.Client) -> None: print("PLEASE NOTE: the ollama provider is experimental, use with care") - super().__init__() - self.client = client + super().__init__(client) @classmethod def from_env(cls: Type["OllamaProvider"]) -> "OllamaProvider": - url = os.environ["OLLAMA_HOST"] + url = os.environ.get("OLLAMA_HOST", OLLAMA_HOST) client = httpx.Client( base_url=url, - headers={"Content-Type": "application/json"}, timeout=httpx.Timeout(60 * 10), ) return cls(client) - - @staticmethod - def get_usage(data: dict) -> Usage: - usage = data.get("usage", {}) - input_tokens = usage.get("prompt_tokens", 0) - output_tokens = usage.get("completion_tokens", 0) - total_tokens = usage.get("total_tokens", input_tokens + output_tokens) - - return Usage( - input_tokens=input_tokens, - output_tokens=output_tokens, - total_tokens=total_tokens, - ) - - def complete( - self, - model: str, - system: str, - messages: List[Message], - tools: Tuple[Tool], - **kwargs: Dict[str, Any], - ) -> Tuple[Message, Usage]: - payload = dict( - messages=[ - {"role": "system", "content": system}, - *messages_to_openai_spec(messages), - ], - model=model, - tools=tools_to_openai_spec(tools) if tools else [], - **kwargs, - ) - payload = {k: v for k, v in payload.items() if v} - response = self._send_request(payload) - - # Check for context_length_exceeded error for single, long input message - if "error" in response.json() and len(messages) == 1: - openai_single_message_context_length_exceeded(response.json()["error"]) - - data = raise_for_status(response).json() - - message = openai_response_to_message(data) - usage = self.get_usage(data) - return message, usage - - @retry_httpx_request() - def _send_request(self, payload: Any) -> httpx.Response: # noqa: ANN401 - return self.client.post("v1/chat/completions", json=payload) diff --git a/tests/providers/test_ollama.py b/tests/providers/test_ollama.py index 6666ea8..7812fe6 100644 --- a/tests/providers/test_ollama.py +++ b/tests/providers/test_ollama.py @@ -1,54 +1,12 @@ -import os -from unittest.mock import patch - import pytest -from exchange import Message, Text -from exchange.providers.ollama import OLLAMA_HOST, OllamaProvider - - -@pytest.fixture -@patch.dict(os.environ, {}) -def ollama_provider(): - os.environ["OLLAMA_HOST"] = OLLAMA_HOST - return OllamaProvider.from_env() - - -@patch("httpx.Client.post") -@patch("time.sleep", return_value=None) -@patch("logging.warning") -@patch("logging.error") -def test_ollama_completion(mock_error, mock_warning, mock_sleep, mock_post, ollama_provider): - mock_response = { - "choices": [{"message": {"role": "assistant", "content": "Hello!"}}], - } - - mock_post.return_value.json.return_value = mock_response - - model = "llama2" - system = "You are a helpful assistant." - messages = [Message.user("Hello")] - tools = () - - reply_message, _ = ollama_provider.complete(model=model, system=system, messages=messages, tools=tools) - - assert reply_message.content == [Text(text="Hello!")] - mock_post.assert_called_once_with( - "v1/chat/completions", - json={ - "messages": [ - {"role": "system", "content": system}, - {"role": "user", "content": "Hello"}, - ], - "model": model, - }, - ) +from exchange import Message +from exchange.providers.ollama import OllamaProvider, OLLAMA_MODEL @pytest.mark.integration def test_ollama_integration(): - os.environ["OLLAMA_HOST"] = OLLAMA_HOST provider = OllamaProvider.from_env() - model = "llama2" # specify a valid model + model = OLLAMA_MODEL system = "You are a helpful assistant." messages = [Message.user("Hello")] diff --git a/tests/test_integration.py b/tests/test_integration.py index 05078de..6aac78b 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,12 +1,15 @@ import pytest from exchange.exchange import Exchange from exchange.message import Message +from exchange.moderators import ContextTruncate from exchange.providers import get_provider +from exchange.providers.ollama import OLLAMA_MODEL from exchange.tool import Tool too_long_chars = "x" * (2**20 + 1) cases = [ + (get_provider("ollama"), OLLAMA_MODEL), (get_provider("openai"), "gpt-4o-mini"), (get_provider("databricks"), "databricks-meta-llama-3-70b-instruct"), (get_provider("bedrock"), "anthropic.claude-3-5-sonnet-20240620-v1:0"), @@ -21,6 +24,7 @@ def test_simple(provider, model): ex = Exchange( provider=provider, model=model, + moderator=ContextTruncate(model), system="You are a helpful assistant.", ) @@ -61,7 +65,7 @@ def read_file(filename: str) -> str: tools=(Tool.from_function(read_file),), ) - ex.add(Message.user(f"Can you read the contents of this file: {temp_file}")) + ex.add(Message.user(f"What are the contents of this file? {temp_file}")) response = ex.reply() @@ -80,6 +84,7 @@ def get_password() -> str: ex = Exchange( provider=provider, model=model, + moderator=ContextTruncate(model), system="You are a helpful assistant. Expect to need to authenticate using get_password.", tools=(Tool.from_function(get_password),), ) diff --git a/tests/test_vision.py b/tests/test_integration_vision.py similarity index 92% rename from tests/test_vision.py rename to tests/test_integration_vision.py index 95be50f..8635886 100644 --- a/tests/test_vision.py +++ b/tests/test_integration_vision.py @@ -2,9 +2,9 @@ from exchange.content import ToolResult, ToolUse from exchange.exchange import Exchange from exchange.message import Message +from exchange.moderators import ContextTruncate from exchange.providers import get_provider - cases = [ (get_provider("openai"), "gpt-4o-mini"), ] @@ -18,6 +18,7 @@ def test_simple(provider, model): ex = Exchange( provider=provider, model=model, + moderator=ContextTruncate(model), system="You are a helpful assistant.", )