diff --git a/core/cat/factory/custom_llm.py b/core/cat/factory/custom_llm.py index 2a2d37b51..672db4d76 100644 --- a/core/cat/factory/custom_llm.py +++ b/core/cat/factory/custom_llm.py @@ -3,6 +3,7 @@ import requests from langchain.llms.base import LLM from langchain.llms.openai import OpenAI +from langchain.llms.ollama import Ollama class LLMDefault(LLM): @@ -86,4 +87,19 @@ def __init__(self, **kwargs): self.url = kwargs['url'] self.openai_api_base = os.path.join(self.url, "v1") - \ No newline at end of file + +class CustomOllama(Ollama): + + + def __init__(self, **kwargs): + + super().__init__( + base_url=kwargs["base_url"], + model=kwargs["model"], + num_ctx=kwargs["num_ctx"], + repeat_last_n=kwargs["repeat_last_n"], + repeat_penalty=kwargs["repeat_penalty"], + temperature=kwargs["temperature"], + ) + + \ No newline at end of file diff --git a/core/cat/factory/llm.py b/core/cat/factory/llm.py index ffbf66940..36c8aca47 100644 --- a/core/cat/factory/llm.py +++ b/core/cat/factory/llm.py @@ -6,7 +6,7 @@ import json from pydantic import BaseModel, ConfigDict -from cat.factory.custom_llm import LLMDefault, LLMCustom, CustomOpenAI +from cat.factory.custom_llm import CustomOllama, LLMDefault, LLMCustom, CustomOpenAI # Base class to manage LLM configuration. @@ -272,6 +272,24 @@ class LLMGooglePalmConfig(LLMSettings): } ) +class LLMCustomOllama(LLMSettings): + base_url: str + model: str = "llama2" + num_ctx: int = 2048 + repeat_last_n: int = 64 + repeat_penalty: float = 1.1 + temperature: float = 0.8 + + _pyclass: Type = CustomOllama + + model_config = ConfigDict( + json_schema_extra = { + "humanReadableName": "Ollama", + "description": "Configuration for Ollama", + "link": "https://ollama.ai/library" + } + ) + SUPPORTED_LANGUAGE_MODELS = [ LLMDefaultConfig, @@ -286,7 +304,8 @@ class LLMGooglePalmConfig(LLMSettings): LLMAzureOpenAIConfig, LLMAzureChatOpenAIConfig, LLMAnthropicConfig, - LLMGooglePalmConfig + LLMGooglePalmConfig, + LLMCustomOllama ] # LLM_SCHEMAS contains metadata to let any client know