Skip to content

Commit

Permalink
chat history (#9)
Browse files Browse the repository at this point in the history
* feat:intents

* feat:intents

* fix:improve active persona handling

* feat: chat history

* feat: chat history
  • Loading branch information
JarbasAl authored Nov 13, 2024
1 parent 6b29c2e commit 089bcaf
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 23 deletions.
98 changes: 76 additions & 22 deletions ovos_persona/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing import Optional, Dict, List, Union

from ovos_bus_client.client import MessageBusClient
from ovos_bus_client.message import Message
from ovos_bus_client.message import Message, dig_for_message
from ovos_bus_client.session import SessionManager
from ovos_config.config import Configuration
from ovos_config.locations import get_xdg_config_save_path
from ovos_plugin_manager.persona import find_persona_plugins
Expand All @@ -17,6 +18,11 @@
from padacioso import IntentContainer

from ovos_persona.solvers import QuestionSolversService
try:
from ovos_plugin_manager.solvers import find_chat_solver_plugins
except ImportError:
def find_chat_solver_plugins():
return {}


class Persona:
Expand All @@ -31,26 +37,21 @@ def __init__(self, name, config, blacklist=None):
plugs[plug_name] = {"enabled": False}
else:
plugs[plug_name] = config.get(plug_name) or {"enabled": True}
for plug_name, plug in find_chat_solver_plugins().items():
if plug_name not in persona or plug_name in blacklist:
plugs[plug_name] = {"enabled": False}
else:
plugs[plug_name] = config.get(plug_name) or {"enabled": True}
self.solvers = QuestionSolversService(config=plugs)

def __repr__(self):
return f"Persona({self.name}:{list(self.solvers.loaded_modules.keys())})"

def chat(self, messages: list = None, lang: str = None) -> str:
# TODO - message history solver
# messages = [
# {"role": "system", "content": "You are a helpful assistant."},
# {"role": "user", "content": "Knock knock."},
# {"role": "assistant", "content": "Who's there?"},
# {"role": "user", "content": "Orange."},
# ]
prompt = messages[-1]["content"]
return self.solvers.spoken_answer(prompt, lang)
return self.solvers.chat_completion(messages, lang)


class PersonaService(PipelineStageConfidenceMatcher, OVOSAbstractApplication):
intents = ["ask.intent", "summon.intent"]
intent_matchers = {}

def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None,
config: Optional[Dict] = None):
Expand All @@ -59,13 +60,17 @@ def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None,
self, bus=bus or FakeBus(), skill_id="persona.openvoiceos",
resources_dir=f"{dirname(__file__)}")
PipelineStageConfidenceMatcher.__init__(self, bus, config)
self.sessions = {}
self.personas = {}
self.intent_matchers = {}
self.blacklist = self.config.get("persona_blacklist") or []
self.load_personas(self.config.get("personas_path"))
self.active_persona = None
self.add_event('persona:answer', self.handle_persona_answer)
self.add_event('persona:summon', self.handle_persona_summon)
self.add_event('persona:release', self.handle_persona_release)
self.add_event("speak", self.handle_speak)
self.add_event("recognizer_loop:utterance", self.handle_utterance)

@classmethod
def load_resource_files(cls):
Expand All @@ -78,26 +83,24 @@ def load_resource_files(cls):
if locale_folder is not None:
for f in os.listdir(locale_folder):
path = join(locale_folder, f)
if f in cls.intents:
if f in ["ask.intent", "summon.intent"]:
with open(path) as intent:
samples = intent.read().split("\n")
for idx, s in enumerate(samples):
samples[idx] = s.replace("{{", "{").replace("}}", "}")
intents[lang][f] = samples
return intents

@classmethod
def load_intent_files(cls):
intent_files = cls.load_resource_files()

def load_intent_files(self):
intent_files = self.load_resource_files()
for lang, intent_data in intent_files.items():
lang = standardize_lang_tag(lang)
cls.intent_matchers[lang] = IntentContainer()
for intent_name in cls.intents:
self.intent_matchers[lang] = IntentContainer()
for intent_name in ["ask.intent", "summon.intent"]:
samples = intent_data.get(intent_name)
if samples:
LOG.debug(f"registering OCP intent: {intent_name}")
cls.intent_matchers[lang].add_intent(
self.intent_matchers[lang].add_intent(
intent_name.replace(".intent", ""), samples)

@property
Expand Down Expand Up @@ -136,14 +139,51 @@ def deregister_persona(self, name):
self.personas.pop(name)

# Chatbot API
def chatbox_ask(self, prompt: str, persona: Optional[str] = None, lang: Optional[str] = None) -> Optional[str]:
def chatbox_ask(self, prompt: str,
persona: Optional[str] = None,
lang: Optional[str] = None,
message: Message = None) -> Optional[str]:
persona = persona or self.active_persona or self.default_persona
if persona not in self.personas:
LOG.error(f"unknown persona, choose one of {self.personas.keys()}")
return None
messages = [{"role": "user", "content": prompt}]
messages = []
message = message or dig_for_message()
if message:
for q, a in self._build_msg_history(message):
messages.append({"role": "user", "content": q})
messages.append({"role": "assistant", "content": a})
messages.append({"role": "user", "content": prompt})

return self.personas[persona].chat(messages, lang)

def _build_msg_history(self, message: Message):
sess = SessionManager.get(message)
if sess.session_id not in self.sessions:
return []
messages = [] # tuple of question, answer

q = None
ans = None
for m in self.sessions[sess.session_id]:
if m[0] == "user":
if ans is not None and q is not None:
# save previous q/a pair
messages.append((q, ans))
q = None
ans = None
q = m[1] # track question
elif m[0] == "ai":
if ans is None:
ans = m[1] # track answer
else: # merge multi speak answers
ans = f"{ans}. {m[1]}"

# save last q/a pair
if ans is not None and q is not None:
messages.append((q, ans))
return messages

# Abstract methods
def match_high(self, utterances: List[str], lang: Optional[str] = None,
message: Optional[Message] = None) -> Optional[IntentHandlerMatch]:
Expand Down Expand Up @@ -222,6 +262,20 @@ def match_low(self, utterances: List[str], lang: Optional[str] = None,
skill_id="persona.openvoiceos",
utterance=utterances[0])

# bus events
def handle_utterance(self, message):
utt = message.data.get("utterances")[0]
sess = SessionManager.get(message)
if sess.session_id not in self.sessions:
self.sessions[sess.session_id] = []
self.sessions[sess.session_id].append(("user", utt))

def handle_speak(self, message):
utt = message.data.get("utterance")
sess = SessionManager.get(message)
if sess.session_id in self.sessions:
self.sessions[sess.session_id].append(("ai", utt))

def handle_persona_answer(self, message):
utt = message.data["answer"]
self.speak(utt)
Expand Down
41 changes: 40 additions & 1 deletion ovos_persona/solvers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
from typing import Optional
from typing import Optional, List, Dict

from ovos_config import Configuration
from ovos_plugin_manager.solvers import find_question_solver_plugins
from ovos_utils.log import LOG
from ovos_utils.messagebus import FakeBus
try:
from ovos_plugin_manager.solvers import find_chat_solver_plugins
from ovos_plugin_manager.templates.solvers import ChatMessageSolver
except ImportError:
# using outdated ovos-plugin-manager
class ChatMessageSolver:
pass

def find_chat_solver_plugins():
return {}


class QuestionSolversService:
Expand All @@ -26,6 +36,17 @@ def load_plugins(self):
except Exception as e:
LOG.exception(f"Failed to load question solver plugin: {plug_name}")

for plug_name, plug in find_chat_solver_plugins().items():
config = self.config.get(plug_name) or {}
if not config.get("enabled", True):
continue
try:
LOG.debug(f"loading chat plugin with cfg: {config}")
self.loaded_modules[plug_name] = plug(config=config)
LOG.info(f"loaded chat solver plugin: {plug_name}")
except Exception as e:
LOG.exception(f"Failed to load chat solver plugin: {plug_name}")

@property
def modules(self):
return sorted(self.loaded_modules.values(),
Expand All @@ -38,6 +59,24 @@ def shutdown(self):
except:
pass

def chat_completion(self, messages: List[Dict[str, str]],
lang: Optional[str] = None,
units: Optional[str] = None) -> Optional[str]:
for module in self.modules:
try:
if isinstance(module, ChatMessageSolver):
ans = module.get_chat_completion(messages=messages,
lang=lang)
else:
LOG.debug(f"{module} does not supported chat history!")
query = messages[-1]["content"]
ans = module.spoken_answer(query, lang=lang)
if ans:
return ans
except Exception as e:
LOG.error(e)
pass

def spoken_answer(self, query: str,
lang: Optional[str] = None,
units: Optional[str] = None) -> Optional[str]:
Expand Down

0 comments on commit 089bcaf

Please sign in to comment.