From a15a532eaf9bdd356ed44740050d0be09bcb5758 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Tue, 5 Aug 2025 15:32:26 -0700 Subject: [PATCH 1/3] types: relax type for tools --- ollama/_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ollama/_types.py b/ollama/_types.py index caf1e703..b21aa85b 100644 --- a/ollama/_types.py +++ b/ollama/_types.py @@ -313,7 +313,7 @@ class Function(SubscriptableBaseModel): class Tool(SubscriptableBaseModel): - type: Optional[Literal['function']] = 'function' + type: Optional[str] = None class Function(SubscriptableBaseModel): name: Optional[str] = None From c7ffe5caf13e7d34bec6afb8d5835962190b38f9 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Tue, 5 Aug 2025 15:47:54 -0700 Subject: [PATCH 2/3] types: improve tests and handling --- ollama/_types.py | 4 ++-- ollama/_utils.py | 3 ++- tests/test_client.py | 9 +++++---- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/ollama/_types.py b/ollama/_types.py index b21aa85b..db928e54 100644 --- a/ollama/_types.py +++ b/ollama/_types.py @@ -79,7 +79,7 @@ def __contains__(self, key: str) -> bool: if key in self.model_fields_set: return True - if value := self.model_fields.get(key): + if value := self.__class__.model_fields.get(key): return value.default is not None return False @@ -313,7 +313,7 @@ class Function(SubscriptableBaseModel): class Tool(SubscriptableBaseModel): - type: Optional[str] = None + type: Optional[str] = 'function' class Function(SubscriptableBaseModel): name: Optional[str] = None diff --git a/ollama/_utils.py b/ollama/_utils.py index 653a04c7..15f1cc0c 100644 --- a/ollama/_utils.py +++ b/ollama/_utils.py @@ -79,11 +79,12 @@ def convert_function_to_tool(func: Callable) -> Tool: } tool = Tool( + type='function', function=Tool.Function( name=func.__name__, description=schema.get('description', ''), parameters=Tool.Function.Parameters(**schema), - ) + ), ) return Tool.model_validate(tool) diff --git a/tests/test_client.py b/tests/test_client.py index 1e66184b..15cd2dfa 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1136,10 +1136,11 @@ def func2(y: str) -> int: def test_tool_validation(): - # Raises ValidationError when used as it is a generator - with pytest.raises(ValidationError): - invalid_tool = {'type': 'invalid_type', 'function': {'name': 'test'}} - list(_copy_tools([invalid_tool])) + arbitrary_tool = {'type': 'custom_type', 'function': {'name': 'test'}} + tools = list(_copy_tools([arbitrary_tool])) + assert len(tools) == 1 + assert tools[0].type == 'custom_type' + assert tools[0].function.name == 'test' def test_client_connection_error(): From f8054a0f4cedff0872bceadce11639e631a7b42f Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Tue, 5 Aug 2025 15:57:56 -0700 Subject: [PATCH 3/3] fix formatting --- tests/test_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_client.py b/tests/test_client.py index 15cd2dfa..6917edc2 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,7 +8,7 @@ import pytest from httpx import Response as httpxResponse -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel from pytest_httpserver import HTTPServer, URIPattern from werkzeug.wrappers import Request, Response