From 4b10dee2b2a73242bb3ac89332f634cda2b6f633 Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Thu, 5 Dec 2024 15:40:49 -0800 Subject: [PATCH] Structured outputs support with examples (#354) --- README.md | 3 - examples/README.md | 6 + examples/async-structured-outputs.py | 32 ++++ examples/structured-outputs-image.py | 50 ++++++ examples/structured-outputs.py | 26 +++ ollama/_client.py | 27 ++-- ollama/_types.py | 3 +- tests/test_client.py | 226 ++++++++++++++++++++++++++- 8 files changed, 355 insertions(+), 18 deletions(-) create mode 100644 examples/async-structured-outputs.py create mode 100644 examples/structured-outputs-image.py create mode 100644 examples/structured-outputs.py diff --git a/README.md b/README.md index 454c159..b6ab33b 100644 --- a/README.md +++ b/README.md @@ -37,9 +37,6 @@ See [_types.py](ollama/_types.py) for more information on the response types. Response streaming can be enabled by setting `stream=True`. -> [!NOTE] -> Streaming Tool/Function calling is not yet supported. - ```python from ollama import chat diff --git a/examples/README.md b/examples/README.md index a455c60..b55fd4e 100644 --- a/examples/README.md +++ b/examples/README.md @@ -30,6 +30,12 @@ python3 examples/.py - [multimodal_generate.py](multimodal_generate.py) +### Structured Outputs - Generate structured outputs with a model +- [structured-outputs.py](structured-outputs.py) +- [async-structured-outputs.py](async-structured-outputs.py) +- [structured-outputs-image.py](structured-outputs-image.py) + + ### Ollama List - List all downloaded models and their properties - [list.py](list.py) diff --git a/examples/async-structured-outputs.py b/examples/async-structured-outputs.py new file mode 100644 index 0000000..b2c8dac --- /dev/null +++ b/examples/async-structured-outputs.py @@ -0,0 +1,32 @@ +from pydantic import BaseModel +from ollama import AsyncClient +import asyncio + + +# Define the schema for the response +class FriendInfo(BaseModel): + name: str + age: int + is_available: bool + + +class FriendList(BaseModel): + friends: list[FriendInfo] + + +async def main(): + client = AsyncClient() + response = await client.chat( + model='llama3.1:8b', + messages=[{'role': 'user', 'content': 'I have two friends. The first is Ollama 22 years old busy saving the world, and the second is Alonso 23 years old and wants to hang out. Return a list of friends in JSON format'}], + format=FriendList.model_json_schema(), # Use Pydantic to generate the schema + options={'temperature': 0}, # Make responses more deterministic + ) + + # Use Pydantic to validate the response + friends_response = FriendList.model_validate_json(response.message.content) + print(friends_response) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/examples/structured-outputs-image.py b/examples/structured-outputs-image.py new file mode 100644 index 0000000..73d09cc --- /dev/null +++ b/examples/structured-outputs-image.py @@ -0,0 +1,50 @@ +from pathlib import Path +from pydantic import BaseModel +from typing import List, Optional, Literal +from ollama import chat +from rich import print + + +# Define the schema for image objects +class Object(BaseModel): + name: str + confidence: float + attributes: Optional[dict] = None + + +class ImageDescription(BaseModel): + summary: str + objects: List[Object] + scene: str + colors: List[str] + time_of_day: Literal['Morning', 'Afternoon', 'Evening', 'Night'] + setting: Literal['Indoor', 'Outdoor', 'Unknown'] + text_content: Optional[str] = None + + +# Get path from user input +path = input('Enter the path to your image: ') +path = Path(path) + +# Verify the file exists +if not path.exists(): + raise FileNotFoundError(f'Image not found at: {path}') + +# Set up chat as usual +response = chat( + model='llama3.2-vision', + format=ImageDescription.model_json_schema(), # Pass in the schema for the response + messages=[ + { + 'role': 'user', + 'content': 'Analyze this image and return a detailed JSON description including objects, scene, colors and any text detected. If you cannot determine certain details, leave those fields empty.', + 'images': [path], + }, + ], + options={'temperature': 0}, # Set temperature to 0 for more deterministic output +) + + +# Convert received content to the schema +image_analysis = ImageDescription.model_validate_json(response.message.content) +print(image_analysis) diff --git a/examples/structured-outputs.py b/examples/structured-outputs.py new file mode 100644 index 0000000..cb28ccd --- /dev/null +++ b/examples/structured-outputs.py @@ -0,0 +1,26 @@ +from ollama import chat +from pydantic import BaseModel + + +# Define the schema for the response +class FriendInfo(BaseModel): + name: str + age: int + is_available: bool + + +class FriendList(BaseModel): + friends: list[FriendInfo] + + +# schema = {'type': 'object', 'properties': {'friends': {'type': 'array', 'items': {'type': 'object', 'properties': {'name': {'type': 'string'}, 'age': {'type': 'integer'}, 'is_available': {'type': 'boolean'}}, 'required': ['name', 'age', 'is_available']}}}, 'required': ['friends']} +response = chat( + model='llama3.1:8b', + messages=[{'role': 'user', 'content': 'I have two friends. The first is Ollama 22 years old busy saving the world, and the second is Alonso 23 years old and wants to hang out. Return a list of friends in JSON format'}], + format=FriendList.model_json_schema(), # Use Pydantic to generate the schema or format=schema + options={'temperature': 0}, # Make responses more deterministic +) + +# Use Pydantic to validate the response +friends_response = FriendList.model_validate_json(response.message.content) +print(friends_response) diff --git a/ollama/_client.py b/ollama/_client.py index 9f37a52..2ca4513 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -23,6 +23,8 @@ import sys +from pydantic.json_schema import JsonSchemaValue + from ollama._utils import convert_function_to_tool @@ -186,7 +188,7 @@ def generate( context: Optional[Sequence[int]] = None, stream: Literal[False] = False, raw: bool = False, - format: Optional[Literal['', 'json']] = None, + format: Optional[Union[Literal['json'], JsonSchemaValue]] = None, images: Optional[Sequence[Union[str, bytes]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, @@ -204,7 +206,7 @@ def generate( context: Optional[Sequence[int]] = None, stream: Literal[True] = True, raw: bool = False, - format: Optional[Literal['', 'json']] = None, + format: Optional[Union[Literal['json'], JsonSchemaValue]] = None, images: Optional[Sequence[Union[str, bytes]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, @@ -221,7 +223,7 @@ def generate( context: Optional[Sequence[int]] = None, stream: bool = False, raw: Optional[bool] = None, - format: Optional[Literal['', 'json']] = None, + format: Optional[Union[Literal['json'], JsonSchemaValue]] = None, images: Optional[Sequence[Union[str, bytes]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, @@ -265,7 +267,7 @@ def chat( *, tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, stream: Literal[False] = False, - format: Optional[Literal['', 'json']] = None, + format: Optional[Union[Literal['json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, ) -> ChatResponse: ... @@ -278,7 +280,7 @@ def chat( *, tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, stream: Literal[True] = True, - format: Optional[Literal['', 'json']] = None, + format: Optional[Union[Literal['json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, ) -> Iterator[ChatResponse]: ... @@ -290,7 +292,7 @@ def chat( *, tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, stream: bool = False, - format: Optional[Literal['', 'json']] = None, + format: Optional[Union[Literal['json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, ) -> Union[ChatResponse, Iterator[ChatResponse]]: @@ -327,7 +329,6 @@ def add_two_numbers(a: int, b: int) -> int: Returns `ChatResponse` if `stream` is `False`, otherwise returns a `ChatResponse` generator. """ - return self._request( ChatResponse, 'POST', @@ -689,7 +690,7 @@ async def generate( context: Optional[Sequence[int]] = None, stream: Literal[False] = False, raw: bool = False, - format: Optional[Literal['', 'json']] = None, + format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, images: Optional[Sequence[Union[str, bytes]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, @@ -707,7 +708,7 @@ async def generate( context: Optional[Sequence[int]] = None, stream: Literal[True] = True, raw: bool = False, - format: Optional[Literal['', 'json']] = None, + format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, images: Optional[Sequence[Union[str, bytes]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, @@ -724,7 +725,7 @@ async def generate( context: Optional[Sequence[int]] = None, stream: bool = False, raw: Optional[bool] = None, - format: Optional[Literal['', 'json']] = None, + format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, images: Optional[Sequence[Union[str, bytes]]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, @@ -767,7 +768,7 @@ async def chat( *, tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, stream: Literal[False] = False, - format: Optional[Literal['', 'json']] = None, + format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, ) -> ChatResponse: ... @@ -780,7 +781,7 @@ async def chat( *, tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, stream: Literal[True] = True, - format: Optional[Literal['', 'json']] = None, + format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, ) -> AsyncIterator[ChatResponse]: ... @@ -792,7 +793,7 @@ async def chat( *, tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, stream: bool = False, - format: Optional[Literal['', 'json']] = None, + format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, ) -> Union[ChatResponse, AsyncIterator[ChatResponse]]: diff --git a/ollama/_types.py b/ollama/_types.py index 589c7aa..11a0a59 100644 --- a/ollama/_types.py +++ b/ollama/_types.py @@ -4,6 +4,7 @@ from datetime import datetime from typing import Any, Mapping, Optional, Union, Sequence +from pydantic.json_schema import JsonSchemaValue from typing_extensions import Annotated, Literal from pydantic import ( @@ -150,7 +151,7 @@ class BaseGenerateRequest(BaseStreamableRequest): options: Optional[Union[Mapping[str, Any], Options]] = None 'Options to use for the request.' - format: Optional[Literal['', 'json']] = None + format: Optional[Union[Literal['json'], JsonSchemaValue]] = None 'Format of the response.' keep_alive: Optional[Union[float, str]] = None diff --git a/tests/test_client.py b/tests/test_client.py index fbd01bd..aab2f2e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,7 +1,7 @@ import os import io import json -from pydantic import ValidationError +from pydantic import ValidationError, BaseModel import pytest import tempfile from pathlib import Path @@ -122,6 +122,128 @@ def test_client_chat_images(httpserver: HTTPServer): assert response['message']['content'] == "I don't know." +def test_client_chat_format_json(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/chat', + method='POST', + json={ + 'model': 'dummy', + 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], + 'tools': [], + 'format': 'json', + 'stream': False, + }, + ).respond_with_json( + { + 'model': 'dummy', + 'message': { + 'role': 'assistant', + 'content': '{"answer": "Because of Rayleigh scattering"}', + }, + } + ) + + client = Client(httpserver.url_for('/')) + response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], format='json') + assert response['model'] == 'dummy' + assert response['message']['role'] == 'assistant' + assert response['message']['content'] == '{"answer": "Because of Rayleigh scattering"}' + + +def test_client_chat_format_pydantic(httpserver: HTTPServer): + class ResponseFormat(BaseModel): + answer: str + confidence: float + + httpserver.expect_ordered_request( + '/api/chat', + method='POST', + json={ + 'model': 'dummy', + 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], + 'tools': [], + 'format': {'title': 'ResponseFormat', 'type': 'object', 'properties': {'answer': {'title': 'Answer', 'type': 'string'}, 'confidence': {'title': 'Confidence', 'type': 'number'}}, 'required': ['answer', 'confidence']}, + 'stream': False, + }, + ).respond_with_json( + { + 'model': 'dummy', + 'message': { + 'role': 'assistant', + 'content': '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}', + }, + } + ) + + client = Client(httpserver.url_for('/')) + response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], format=ResponseFormat.model_json_schema()) + assert response['model'] == 'dummy' + assert response['message']['role'] == 'assistant' + assert response['message']['content'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}' + + +@pytest.mark.asyncio +async def test_async_client_chat_format_json(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/chat', + method='POST', + json={ + 'model': 'dummy', + 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], + 'tools': [], + 'format': 'json', + 'stream': False, + }, + ).respond_with_json( + { + 'model': 'dummy', + 'message': { + 'role': 'assistant', + 'content': '{"answer": "Because of Rayleigh scattering"}', + }, + } + ) + + client = AsyncClient(httpserver.url_for('/')) + response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], format='json') + assert response['model'] == 'dummy' + assert response['message']['role'] == 'assistant' + assert response['message']['content'] == '{"answer": "Because of Rayleigh scattering"}' + + +@pytest.mark.asyncio +async def test_async_client_chat_format_pydantic(httpserver: HTTPServer): + class ResponseFormat(BaseModel): + answer: str + confidence: float + + httpserver.expect_ordered_request( + '/api/chat', + method='POST', + json={ + 'model': 'dummy', + 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], + 'tools': [], + 'format': {'title': 'ResponseFormat', 'type': 'object', 'properties': {'answer': {'title': 'Answer', 'type': 'string'}, 'confidence': {'title': 'Confidence', 'type': 'number'}}, 'required': ['answer', 'confidence']}, + 'stream': False, + }, + ).respond_with_json( + { + 'model': 'dummy', + 'message': { + 'role': 'assistant', + 'content': '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}', + }, + } + ) + + client = AsyncClient(httpserver.url_for('/')) + response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], format=ResponseFormat.model_json_schema()) + assert response['model'] == 'dummy' + assert response['message']['role'] == 'assistant' + assert response['message']['content'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}' + + def test_client_generate(httpserver: HTTPServer): httpserver.expect_ordered_request( '/api/generate', @@ -205,6 +327,108 @@ def test_client_generate_images(httpserver: HTTPServer): assert response['response'] == 'Because it is.' +def test_client_generate_format_json(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/generate', + method='POST', + json={ + 'model': 'dummy', + 'prompt': 'Why is the sky blue?', + 'format': 'json', + 'stream': False, + }, + ).respond_with_json( + { + 'model': 'dummy', + 'response': '{"answer": "Because of Rayleigh scattering"}', + } + ) + + client = Client(httpserver.url_for('/')) + response = client.generate('dummy', 'Why is the sky blue?', format='json') + assert response['model'] == 'dummy' + assert response['response'] == '{"answer": "Because of Rayleigh scattering"}' + + +def test_client_generate_format_pydantic(httpserver: HTTPServer): + class ResponseFormat(BaseModel): + answer: str + confidence: float + + httpserver.expect_ordered_request( + '/api/generate', + method='POST', + json={ + 'model': 'dummy', + 'prompt': 'Why is the sky blue?', + 'format': {'title': 'ResponseFormat', 'type': 'object', 'properties': {'answer': {'title': 'Answer', 'type': 'string'}, 'confidence': {'title': 'Confidence', 'type': 'number'}}, 'required': ['answer', 'confidence']}, + 'stream': False, + }, + ).respond_with_json( + { + 'model': 'dummy', + 'response': '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}', + } + ) + + client = Client(httpserver.url_for('/')) + response = client.generate('dummy', 'Why is the sky blue?', format=ResponseFormat.model_json_schema()) + assert response['model'] == 'dummy' + assert response['response'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}' + + +@pytest.mark.asyncio +async def test_async_client_generate_format_json(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/generate', + method='POST', + json={ + 'model': 'dummy', + 'prompt': 'Why is the sky blue?', + 'format': 'json', + 'stream': False, + }, + ).respond_with_json( + { + 'model': 'dummy', + 'response': '{"answer": "Because of Rayleigh scattering"}', + } + ) + + client = AsyncClient(httpserver.url_for('/')) + response = await client.generate('dummy', 'Why is the sky blue?', format='json') + assert response['model'] == 'dummy' + assert response['response'] == '{"answer": "Because of Rayleigh scattering"}' + + +@pytest.mark.asyncio +async def test_async_client_generate_format_pydantic(httpserver: HTTPServer): + class ResponseFormat(BaseModel): + answer: str + confidence: float + + httpserver.expect_ordered_request( + '/api/generate', + method='POST', + json={ + 'model': 'dummy', + 'prompt': 'Why is the sky blue?', + 'format': {'title': 'ResponseFormat', 'type': 'object', 'properties': {'answer': {'title': 'Answer', 'type': 'string'}, 'confidence': {'title': 'Confidence', 'type': 'number'}}, 'required': ['answer', 'confidence']}, + 'stream': False, + }, + ).respond_with_json( + { + 'model': 'dummy', + 'response': '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}', + } + ) + + client = AsyncClient(httpserver.url_for('/')) + response = await client.generate('dummy', 'Why is the sky blue?', format=ResponseFormat.model_json_schema()) + assert response['model'] == 'dummy' + assert response['response'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}' + + def test_client_pull(httpserver: HTTPServer): httpserver.expect_ordered_request( '/api/pull',