From 07f194ea0cd316e474be99808e96faaa6b3a1d4e Mon Sep 17 00:00:00 2001 From: Piero Savastano Date: Tue, 17 Oct 2023 17:43:41 +0200 Subject: [PATCH] cat.llm can optionally stream tokens --- core/cat/looking_glass/cheshire_cat.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/core/cat/looking_glass/cheshire_cat.py b/core/cat/looking_glass/cheshire_cat.py index c9c1bfa5..c29975c8 100644 --- a/core/cat/looking_glass/cheshire_cat.py +++ b/core/cat/looking_glass/cheshire_cat.py @@ -2,24 +2,24 @@ from copy import deepcopy import traceback from typing import Literal, get_args -import langchain import os import asyncio +import langchain +from langchain.llms import Cohere, OpenAI, OpenAIChat, AzureOpenAI, HuggingFaceTextGenInference, HuggingFaceHub +from langchain.chat_models import ChatOpenAI, AzureChatOpenAI +from langchain.base_language import BaseLanguageModel + from cat.log import log +from cat.db import crud from cat.db.database import Database from cat.rabbit_hole import RabbitHole from cat.mad_hatter.mad_hatter import MadHatter from cat.memory.working_memory import WorkingMemoryList from cat.memory.long_term_memory import LongTermMemory from cat.looking_glass.agent_manager import AgentManager - -# TODO: natural language dependencies; move to another file +from cat.looking_glass.callbacks import NewTokenHandler import cat.factory.llm as llms import cat.factory.embedder as embedders -from cat.db import crud -from langchain.llms import Cohere, OpenAI, OpenAIChat, AzureOpenAI, HuggingFaceTextGenInference, HuggingFaceHub -from langchain.chat_models import ChatOpenAI, AzureChatOpenAI -from langchain.base_language import BaseLanguageModel from cat.factory.custom_llm import CustomOpenAI @@ -315,7 +315,7 @@ def recall_relevant_memories_to_working_memory(self): # hook to modify/enrich retrieved memories self.mad_hatter.execute_hook("after_cat_recalls_memories") - def llm(self, prompt: str) -> str: + def llm(self, prompt: str, chat: bool = False, stream: bool = False) -> str: """Generate a response using the LLM model. This method is useful for generating a response with both a chat and a completion model using the same syntax @@ -331,13 +331,19 @@ def llm(self, prompt: str) -> str: The generated response. """ + + # should we stream the tokens? + callbacks = [] + if stream: + callbacks.append( NewTokenHandler(self) ) + # Check if self._llm is a completion model and generate a response if isinstance(self._llm, langchain.llms.base.BaseLLM): - return self._llm(prompt) + return self._llm(prompt, callbacks=callbacks) # Check if self._llm is a chat model and call it as a completion model if isinstance(self._llm, langchain.chat_models.base.BaseChatModel): - return self._llm.call_as_llm(prompt) + return self._llm.call_as_llm(prompt, callbacks=callbacks) def send_ws_message(self, content: str, msg_type: MSG_TYPES = "notification"): """Send a message via websocket.