From 089bcaff860ebee16a475a8382dcc018b1622292 Mon Sep 17 00:00:00 2001 From: JarbasAI <33701864+JarbasAl@users.noreply.github.com> Date: Wed, 13 Nov 2024 15:37:32 +0000 Subject: [PATCH] chat history (#9) * feat:intents * feat:intents * fix:improve active persona handling * feat: chat history * feat: chat history --- ovos_persona/__init__.py | 98 +++++++++++++++++++++++++++++++--------- ovos_persona/solvers.py | 41 ++++++++++++++++- 2 files changed, 116 insertions(+), 23 deletions(-) diff --git a/ovos_persona/__init__.py b/ovos_persona/__init__.py index 78280fd..0b0ef4f 100644 --- a/ovos_persona/__init__.py +++ b/ovos_persona/__init__.py @@ -4,7 +4,8 @@ from typing import Optional, Dict, List, Union from ovos_bus_client.client import MessageBusClient -from ovos_bus_client.message import Message +from ovos_bus_client.message import Message, dig_for_message +from ovos_bus_client.session import SessionManager from ovos_config.config import Configuration from ovos_config.locations import get_xdg_config_save_path from ovos_plugin_manager.persona import find_persona_plugins @@ -17,6 +18,11 @@ from padacioso import IntentContainer from ovos_persona.solvers import QuestionSolversService +try: + from ovos_plugin_manager.solvers import find_chat_solver_plugins +except ImportError: + def find_chat_solver_plugins(): + return {} class Persona: @@ -31,26 +37,21 @@ def __init__(self, name, config, blacklist=None): plugs[plug_name] = {"enabled": False} else: plugs[plug_name] = config.get(plug_name) or {"enabled": True} + for plug_name, plug in find_chat_solver_plugins().items(): + if plug_name not in persona or plug_name in blacklist: + plugs[plug_name] = {"enabled": False} + else: + plugs[plug_name] = config.get(plug_name) or {"enabled": True} self.solvers = QuestionSolversService(config=plugs) def __repr__(self): return f"Persona({self.name}:{list(self.solvers.loaded_modules.keys())})" def chat(self, messages: list = None, lang: str = None) -> str: - # TODO - message history solver - # messages = [ - # {"role": "system", "content": "You are a helpful assistant."}, - # {"role": "user", "content": "Knock knock."}, - # {"role": "assistant", "content": "Who's there?"}, - # {"role": "user", "content": "Orange."}, - # ] - prompt = messages[-1]["content"] - return self.solvers.spoken_answer(prompt, lang) + return self.solvers.chat_completion(messages, lang) class PersonaService(PipelineStageConfidenceMatcher, OVOSAbstractApplication): - intents = ["ask.intent", "summon.intent"] - intent_matchers = {} def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None, config: Optional[Dict] = None): @@ -59,13 +60,17 @@ def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None, self, bus=bus or FakeBus(), skill_id="persona.openvoiceos", resources_dir=f"{dirname(__file__)}") PipelineStageConfidenceMatcher.__init__(self, bus, config) + self.sessions = {} self.personas = {} + self.intent_matchers = {} self.blacklist = self.config.get("persona_blacklist") or [] self.load_personas(self.config.get("personas_path")) self.active_persona = None self.add_event('persona:answer', self.handle_persona_answer) self.add_event('persona:summon', self.handle_persona_summon) self.add_event('persona:release', self.handle_persona_release) + self.add_event("speak", self.handle_speak) + self.add_event("recognizer_loop:utterance", self.handle_utterance) @classmethod def load_resource_files(cls): @@ -78,7 +83,7 @@ def load_resource_files(cls): if locale_folder is not None: for f in os.listdir(locale_folder): path = join(locale_folder, f) - if f in cls.intents: + if f in ["ask.intent", "summon.intent"]: with open(path) as intent: samples = intent.read().split("\n") for idx, s in enumerate(samples): @@ -86,18 +91,16 @@ def load_resource_files(cls): intents[lang][f] = samples return intents - @classmethod - def load_intent_files(cls): - intent_files = cls.load_resource_files() - + def load_intent_files(self): + intent_files = self.load_resource_files() for lang, intent_data in intent_files.items(): lang = standardize_lang_tag(lang) - cls.intent_matchers[lang] = IntentContainer() - for intent_name in cls.intents: + self.intent_matchers[lang] = IntentContainer() + for intent_name in ["ask.intent", "summon.intent"]: samples = intent_data.get(intent_name) if samples: LOG.debug(f"registering OCP intent: {intent_name}") - cls.intent_matchers[lang].add_intent( + self.intent_matchers[lang].add_intent( intent_name.replace(".intent", ""), samples) @property @@ -136,14 +139,51 @@ def deregister_persona(self, name): self.personas.pop(name) # Chatbot API - def chatbox_ask(self, prompt: str, persona: Optional[str] = None, lang: Optional[str] = None) -> Optional[str]: + def chatbox_ask(self, prompt: str, + persona: Optional[str] = None, + lang: Optional[str] = None, + message: Message = None) -> Optional[str]: persona = persona or self.active_persona or self.default_persona if persona not in self.personas: LOG.error(f"unknown persona, choose one of {self.personas.keys()}") return None - messages = [{"role": "user", "content": prompt}] + messages = [] + message = message or dig_for_message() + if message: + for q, a in self._build_msg_history(message): + messages.append({"role": "user", "content": q}) + messages.append({"role": "assistant", "content": a}) + messages.append({"role": "user", "content": prompt}) + return self.personas[persona].chat(messages, lang) + def _build_msg_history(self, message: Message): + sess = SessionManager.get(message) + if sess.session_id not in self.sessions: + return [] + messages = [] # tuple of question, answer + + q = None + ans = None + for m in self.sessions[sess.session_id]: + if m[0] == "user": + if ans is not None and q is not None: + # save previous q/a pair + messages.append((q, ans)) + q = None + ans = None + q = m[1] # track question + elif m[0] == "ai": + if ans is None: + ans = m[1] # track answer + else: # merge multi speak answers + ans = f"{ans}. {m[1]}" + + # save last q/a pair + if ans is not None and q is not None: + messages.append((q, ans)) + return messages + # Abstract methods def match_high(self, utterances: List[str], lang: Optional[str] = None, message: Optional[Message] = None) -> Optional[IntentHandlerMatch]: @@ -222,6 +262,20 @@ def match_low(self, utterances: List[str], lang: Optional[str] = None, skill_id="persona.openvoiceos", utterance=utterances[0]) + # bus events + def handle_utterance(self, message): + utt = message.data.get("utterances")[0] + sess = SessionManager.get(message) + if sess.session_id not in self.sessions: + self.sessions[sess.session_id] = [] + self.sessions[sess.session_id].append(("user", utt)) + + def handle_speak(self, message): + utt = message.data.get("utterance") + sess = SessionManager.get(message) + if sess.session_id in self.sessions: + self.sessions[sess.session_id].append(("ai", utt)) + def handle_persona_answer(self, message): utt = message.data["answer"] self.speak(utt) diff --git a/ovos_persona/solvers.py b/ovos_persona/solvers.py index 36c1559..d3d68f4 100644 --- a/ovos_persona/solvers.py +++ b/ovos_persona/solvers.py @@ -1,9 +1,19 @@ -from typing import Optional +from typing import Optional, List, Dict from ovos_config import Configuration from ovos_plugin_manager.solvers import find_question_solver_plugins from ovos_utils.log import LOG from ovos_utils.messagebus import FakeBus +try: + from ovos_plugin_manager.solvers import find_chat_solver_plugins + from ovos_plugin_manager.templates.solvers import ChatMessageSolver +except ImportError: + # using outdated ovos-plugin-manager + class ChatMessageSolver: + pass + + def find_chat_solver_plugins(): + return {} class QuestionSolversService: @@ -26,6 +36,17 @@ def load_plugins(self): except Exception as e: LOG.exception(f"Failed to load question solver plugin: {plug_name}") + for plug_name, plug in find_chat_solver_plugins().items(): + config = self.config.get(plug_name) or {} + if not config.get("enabled", True): + continue + try: + LOG.debug(f"loading chat plugin with cfg: {config}") + self.loaded_modules[plug_name] = plug(config=config) + LOG.info(f"loaded chat solver plugin: {plug_name}") + except Exception as e: + LOG.exception(f"Failed to load chat solver plugin: {plug_name}") + @property def modules(self): return sorted(self.loaded_modules.values(), @@ -38,6 +59,24 @@ def shutdown(self): except: pass + def chat_completion(self, messages: List[Dict[str, str]], + lang: Optional[str] = None, + units: Optional[str] = None) -> Optional[str]: + for module in self.modules: + try: + if isinstance(module, ChatMessageSolver): + ans = module.get_chat_completion(messages=messages, + lang=lang) + else: + LOG.debug(f"{module} does not supported chat history!") + query = messages[-1]["content"] + ans = module.spoken_answer(query, lang=lang) + if ans: + return ans + except Exception as e: + LOG.error(e) + pass + def spoken_answer(self, query: str, lang: Optional[str] = None, units: Optional[str] = None) -> Optional[str]: