diff --git a/examples/web-search-crawl.py b/examples/web-search-crawl.py new file mode 100644 index 00000000..30f9be4f --- /dev/null +++ b/examples/web-search-crawl.py @@ -0,0 +1,84 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "rich", +# ] +# /// +import os +from typing import Union + +from rich import print + +from ollama import Client, WebCrawlResponse, WebSearchResponse + + +def format_tool_results(results: Union[WebSearchResponse, WebCrawlResponse]): + if isinstance(results, WebSearchResponse): + if not results.success: + error_msg = ', '.join(results.errors) if results.errors else 'Unknown error' + return f'Web search failed: {error_msg}' + + output = [] + for query, search_results in results.results.items(): + output.append(f'Search results for "{query}":') + for i, result in enumerate(search_results, 1): + output.append(f'{i}. {result.title}') + output.append(f' URL: {result.url}') + output.append(f' Content: {result.content}') + output.append('') + + return '\n'.join(output).rstrip() + + elif isinstance(results, WebCrawlResponse): + if not results.success: + error_msg = ', '.join(results.errors) if results.errors else 'Unknown error' + return f'Web crawl failed: {error_msg}' + + output = [] + for url, crawl_results in results.results.items(): + output.append(f'Crawl results for "{url}":') + for i, result in enumerate(crawl_results, 1): + output.append(f'{i}. {result.title}') + output.append(f' URL: {result.url}') + output.append(f' Content: {result.content}') + if result.links: + output.append(f' Links: {", ".join(result.links)}') + output.append('') + + return '\n'.join(output).rstrip() + + +client = Client(headers={'Authorization': (os.getenv('OLLAMA_API_KEY'))}) +available_tools = {'web_search': client.web_search, 'web_crawl': client.web_crawl} + +query = "ollama's new engine" +print('Query: ', query) + +messages = [{'role': 'user', 'content': query}] +while True: + response = client.chat(model='qwen3', messages=messages, tools=[client.web_search, client.web_crawl], think=True) + if response.message.thinking: + print('Thinking: ') + print(response.message.thinking + '\n\n') + if response.message.content: + print('Content: ') + print(response.message.content + '\n') + + messages.append(response.message) + + if response.message.tool_calls: + for tool_call in response.message.tool_calls: + function_to_call = available_tools.get(tool_call.function.name) + if function_to_call: + result: WebSearchResponse | WebCrawlResponse = function_to_call(**tool_call.function.arguments) + print('Result from tool call name: ', tool_call.function.name, 'with arguments: ', tool_call.function.arguments) + print('Result: ', format_tool_results(result)[:200]) + + # caps the result at ~2000 tokens + messages.append({'role': 'tool', 'content': format_tool_results(result)[: 2000 * 4], 'tool_name': tool_call.function.name}) + else: + print(f'Tool {tool_call.function.name} not found') + messages.append({'role': 'tool', 'content': f'Tool {tool_call.function.name} not found', 'tool_name': tool_call.function.name}) + else: + # no more tool calls, we can stop the loop + break diff --git a/ollama/__init__.py b/ollama/__init__.py index afe8ce71..85d8bce7 100644 --- a/ollama/__init__.py +++ b/ollama/__init__.py @@ -15,6 +15,8 @@ ShowResponse, StatusResponse, Tool, + WebCrawlResponse, + WebSearchResponse, ) __all__ = [ @@ -35,6 +37,8 @@ 'ShowResponse', 'StatusResponse', 'Tool', + 'WebCrawlResponse', + 'WebSearchResponse', ] _client = Client() @@ -51,3 +55,5 @@ copy = _client.copy show = _client.show ps = _client.ps +websearch = _client.web_search +webcrawl = _client.web_crawl diff --git a/ollama/_client.py b/ollama/_client.py index 0a85a74a..3cc41b85 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -66,6 +66,10 @@ ShowResponse, StatusResponse, Tool, + WebCrawlRequest, + WebCrawlResponse, + WebSearchRequest, + WebSearchResponse, ) T = TypeVar('T') @@ -622,6 +626,46 @@ def ps(self) -> ProcessResponse: '/api/ps', ) + def web_search(self, queries: Sequence[str], max_results: int = 3) -> WebSearchResponse: + """ + Performs a web search + + Args: + queries: The queries to search for + max_results: The maximum number of results to return. + + Returns: + WebSearchResponse with the search results + """ + return self._request( + WebSearchResponse, + 'POST', + 'https://ollama.com/api/web_search', + json=WebSearchRequest( + queries=queries, + max_results=max_results, + ).model_dump(exclude_none=True), + ) + + def web_crawl(self, urls: Sequence[str]) -> WebCrawlResponse: + """ + Gets the content of web pages for the provided URLs. + + Args: + urls: The URLs to crawl + + Returns: + WebCrawlResponse with the crawl results + """ + return self._request( + WebCrawlResponse, + 'POST', + 'https://ollama.com/api/web_crawl', + json=WebCrawlRequest( + urls=urls, + ).model_dump(exclude_none=True), + ) + class AsyncClient(BaseClient): def __init__(self, host: Optional[str] = None, **kwargs) -> None: @@ -691,6 +735,46 @@ async def inner(): return cls(**(await self._request_raw(*args, **kwargs)).json()) + async def websearch(self, queries: Sequence[str], max_results: int = 3) -> WebSearchResponse: + """ + Performs a web search + + Args: + queries: The queries to search for + max_results: The maximum number of results to return. + + Returns: + WebSearchResponse with the search results + """ + return await self._request( + WebSearchResponse, + 'POST', + 'https://ollama.com/api/web_search', + json=WebSearchRequest( + queries=queries, + max_results=max_results, + ).model_dump(exclude_none=True), + ) + + async def webcrawl(self, urls: Sequence[str]) -> WebCrawlResponse: + """ + Gets the content of web pages for the provided URLs. + + Args: + urls: The URLs to crawl + + Returns: + WebCrawlResponse with the crawl results + """ + return await self._request( + WebCrawlResponse, + 'POST', + 'https://ollama.com/api/web_crawl', + json=WebCrawlRequest( + urls=urls, + ).model_dump(exclude_none=True), + ) + @overload async def generate( self, diff --git a/ollama/_types.py b/ollama/_types.py index 04822875..2fb30a34 100644 --- a/ollama/_types.py +++ b/ollama/_types.py @@ -538,6 +538,40 @@ class Model(SubscriptableBaseModel): models: Sequence[Model] +class WebSearchRequest(SubscriptableBaseModel): + queries: Sequence[str] + max_results: Optional[int] = None + + +class WebSearchResult(SubscriptableBaseModel): + title: str + url: str + content: str + + +class WebCrawlResult(SubscriptableBaseModel): + title: str + url: str + content: str + links: Optional[Sequence[str]] = None + + +class WebSearchResponse(SubscriptableBaseModel): + results: Mapping[str, Sequence[WebSearchResult]] + success: bool + errors: Optional[Sequence[str]] = None + + +class WebCrawlRequest(SubscriptableBaseModel): + urls: Sequence[str] + + +class WebCrawlResponse(SubscriptableBaseModel): + results: Mapping[str, Sequence[WebCrawlResult]] + success: bool + errors: Optional[Sequence[str]] = None + + class RequestError(Exception): """ Common class for request errors.