Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

client = Client()
response = client.create(
model='my-assistant',
from_='llama3.2',
system="You are mario from Super Mario Bros.",
stream=False
model='my-assistant',
from_='llama3.2',
system='You are mario from Super Mario Bros.',
stream=False,
)
print(response.status)
15 changes: 11 additions & 4 deletions ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,22 @@ def __init__(
)


CONNECTION_ERROR_MESSAGE = 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download'


class Client(BaseClient):
def __init__(self, host: Optional[str] = None, **kwargs) -> None:
super().__init__(httpx.Client, host, **kwargs)

def _request_raw(self, *args, **kwargs):
r = self._client.request(*args, **kwargs)
try:
r = self._client.request(*args, **kwargs)
r.raise_for_status()
return r
except httpx.HTTPStatusError as e:
raise ResponseError(e.response.text, e.response.status_code) from None
return r
except httpx.ConnectError:
raise ConnectionError(CONNECTION_ERROR_MESSAGE) from None

@overload
def _request(
Expand Down Expand Up @@ -613,12 +618,14 @@ def __init__(self, host: Optional[str] = None, **kwargs) -> None:
super().__init__(httpx.AsyncClient, host, **kwargs)

async def _request_raw(self, *args, **kwargs):
r = await self._client.request(*args, **kwargs)
try:
r = await self._client.request(*args, **kwargs)
r.raise_for_status()
return r
except httpx.HTTPStatusError as e:
raise ResponseError(e.response.text, e.response.status_code) from None
return r
except httpx.ConnectError:
raise ConnectionError(CONNECTION_ERROR_MESSAGE) from None

@overload
async def _request(
Expand Down
3 changes: 3 additions & 0 deletions ollama/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,3 +535,6 @@ def __init__(self, error: str, status_code: int = -1):

self.status_code = status_code
'HTTP status code of the response.'

def __str__(self) -> str:
return f'{self.error} (status code: {self.status_code})'
Comment on lines +539 to +540
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

status code does not get printed currently - this exposes it

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it currently output?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No code, just the message from the server:
image

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vs with the change:
image

29 changes: 28 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pytest_httpserver import HTTPServer, URIPattern
from werkzeug.wrappers import Request, Response

from ollama._client import AsyncClient, Client, _copy_tools
from ollama._client import CONNECTION_ERROR_MESSAGE, AsyncClient, Client, _copy_tools

PNG_BASE64 = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'
PNG_BYTES = base64.b64decode(PNG_BASE64)
Expand Down Expand Up @@ -1112,3 +1112,30 @@ def test_tool_validation():
with pytest.raises(ValidationError):
invalid_tool = {'type': 'invalid_type', 'function': {'name': 'test'}}
list(_copy_tools([invalid_tool]))


def test_client_connection_error():
client = Client('http://localhost:1234')

with pytest.raises(ConnectionError, match=CONNECTION_ERROR_MESSAGE):
client.chat('model', messages=[{'role': 'user', 'content': 'prompt'}])
with pytest.raises(ConnectionError, match=CONNECTION_ERROR_MESSAGE):
client.chat('model', messages=[{'role': 'user', 'content': 'prompt'}])
with pytest.raises(ConnectionError, match=CONNECTION_ERROR_MESSAGE):
client.generate('model', 'prompt')
with pytest.raises(ConnectionError, match=CONNECTION_ERROR_MESSAGE):
client.show('model')


@pytest.mark.asyncio
async def test_async_client_connection_error():
client = AsyncClient('http://localhost:1234')
with pytest.raises(ConnectionError) as exc_info:
await client.chat('model', messages=[{'role': 'user', 'content': 'prompt'}])
assert str(exc_info.value) == 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download'
with pytest.raises(ConnectionError) as exc_info:
await client.generate('model', 'prompt')
assert str(exc_info.value) == 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download'
with pytest.raises(ConnectionError) as exc_info:
await client.show('model')
assert str(exc_info.value) == 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download'