Skip to content

Commit

Permalink
Merge branch 'main' into feature/lambda-function
Browse files Browse the repository at this point in the history
  • Loading branch information
saminegash authored Mar 1, 2024
2 parents 5e1ba44 + e20db87 commit 7f867f9
Show file tree
Hide file tree
Showing 37 changed files with 2,501 additions and 1,907 deletions.
3 changes: 3 additions & 0 deletions src/.env-sample
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ FLASK_DEBUG=True # True or False
# OpenAI. Mandatory. Enables language modeling.
OPENAI_API_KEY= # OpenAI API key

# Temperature configuration for OpenAI. Optional. Default is 0.
TEMPERATURE= # Only applies to the legacy task agent

# Serper.dev. Optional. Enables Google web search capability
# SERPER_API_KEY= # Serper.dev API key

Expand Down
1,603 changes: 826 additions & 777 deletions src/apps/slackapp/poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/apps/slackapp/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ flask-cors = "^4.0.0"
flask = "^2.3.3"
loguru = "^0.7.0"
sherpa-ai = {path = "../..", develop = true}
hydra-core = "^1.3.2"

[tool.poetry.scripts]
sherpa_slack = 'slackapp.bolt_app:main'
Expand Down
46 changes: 21 additions & 25 deletions src/apps/slackapp/slackapp/bolt_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
##############################################

import time
from typing import Dict, List
from typing import Dict, List, Optional

from flask import Flask, request
from langchain.schema import AIMessage, BaseMessage, HumanMessage
from loguru import logger
from omegaconf import OmegaConf
from slack_bolt import App
from slack_bolt.adapter.flask import SlackRequestHandler
from slackapp.routes.whitelist import whitelist_blueprint
from slackapp.utils import get_qa_agent_from_config_file

import sherpa_ai.config as cfg
from sherpa_ai.agents import QAAgent
Expand Down Expand Up @@ -92,7 +94,9 @@ def get_response(
previous_messages: List[BaseMessage],
verbose_logger: BaseVerboseLogger,
bot_info: Dict[str, str],
llm: SherpaChatOpenAI = None,
llm=None,
team_id: Optional[str] = None,
user_id: Optional[str] = None,
) -> str:
"""
Get response from the task agent for the question
Expand All @@ -103,6 +107,8 @@ def get_response(
verbose_logger (BaseVerboseLogger): verbose logger to be used
bot_info (Dict[str, str]): information of the Slack bot
llm (SherpaChatOpenAI, optional): LLM to be used. Defaults to None.
team_id (str, optional): team id of the Slack workspace. Defaults to "".
user_id (str, optional): user id of the Slack user. Defaults to "".
Returns:
str: response from the task agent
Expand All @@ -120,6 +126,13 @@ def get_response(
tools = get_tools(memory, agent_config)

if agent_config.use_task_agent:
llm = SherpaChatOpenAI(
openai_api_key=cfg.OPENAI_API_KEY,
user_id=user_id,
team_id=team_id,
temperature=cfg.TEMPERATURE,
)

verbose_logger.log("⚠️🤖 Use task agent (obsolete)...")
task_agent = TaskAgent.from_llm_and_tools(
ai_name=ai_name,
Expand All @@ -135,21 +148,11 @@ def get_response(

response = error_handler.run_with_error_handling(task_agent.run, task=question)
else:
memory = SharedMemory(objective="Answer the question")

agent = get_qa_agent_from_config_file("conf/config.yaml", team_id, user_id, llm)
for message in previous_messages:
memory.add(EventType.result, message.type, message.content)
memory.add(EventType.task, "human", question)

agent = QAAgent(
llm=llm,
name=ai_name,
num_runs=1,
shared_memory=memory,
agent_config=agent_config,
require_meta=True,
verbose_logger=verbose_logger,
)
agent.shared_memory.add(EventType.result, message.type, message.content)
agent.shared_memory.add(EventType.task, "human", question)
agent.verbose_logger = verbose_logger

error_handler = AgentErrorHandler()
response = error_handler.run_with_error_handling(agent.run)
Expand Down Expand Up @@ -242,20 +245,13 @@ def event_test(client, say, event):
)
question = reconstructor.reconstruct_prompt()

llm = SherpaChatOpenAI(
openai_api_key=cfg.OPENAI_API_KEY,
user_id=user_id,
team_id=team_id,
verbose_logger = slack_verbose_logger,
temperature=cfg.TEMPRATURE,
)

results = get_response(
question,
previous_messages,
verbose_logger=slack_verbose_logger,
bot_info=bot,
llm=llm,
team_id=team_id,
user_id=user_id,
)

say(results, thread_ts=thread_ts)
Expand Down
33 changes: 33 additions & 0 deletions src/apps/slackapp/slackapp/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Optional

from hydra.utils import instantiate
from langchain.base_language import BaseLanguageModel
from omegaconf import OmegaConf

from sherpa_ai.agents.qa_agent import QAAgent
from sherpa_ai.config.task_config import AgentConfig


def get_qa_agent_from_config_file(
config_path: str,
user_id: Optional[str] = None,
team_id: Optional[str] = None,
llm: Optional[BaseLanguageModel] = None,
) -> QAAgent:
config = OmegaConf.load(config_path)

agent_config: AgentConfig = instantiate(config.agent_config)
if user_id is not None:
config["user_id"] = user_id

if team_id is not None:
config["team_id"] = team_id

if llm is None:
qa_agent: QAAgent = instantiate(config.qa_agent, agent_config=agent_config)
else:
qa_agent: QAAgent = instantiate(
config.qa_agent, agent_config=agent_config, llm=llm
)

return qa_agent
51 changes: 51 additions & 0 deletions src/apps/slackapp/tests/data/test_get_agent.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
shared_memory:
_target_: sherpa_ai.memory.shared_memory.SharedMemory
objective: Answer the question

user_id: none
team_id: none

llm:
_target_: sherpa_ai.models.sherpa_base_chat_model.SherpaChatOpenAI
model_name: gpt-4
temperature: 0.7
user_id: ${user_id}
team_id: ${team_id}

agent_config:
_target_: sherpa_ai.config.task_config.AgentConfig

citation_validation:
_target_: sherpa_ai.output_parsers.citation_validation.CitationValidation
sequence_threshold: 0.8
jaccard_threshold: 0.7
token_overlap: 0.6

arxiv_search:
_target_: sherpa_ai.actions.arxiv_search.ArxivSearch
role_description: Act as a question answering agent
task: Question answering
llm: ${llm}
max_results: 3

google_search:
_target_: sherpa_ai.actions.GoogleSearch
role_description: Act as a question answering agent
task: Question answering
llm: ${llm}
include_metadata: true
config: ${agent_config}

qa_agent:
_target_: sherpa_ai.agents.qa_agent.QAAgent
llm: ${llm}
shared_memory: ${shared_memory}
name: QA Sherpa
description: Act as a question answering agent
agent_config: ${agent_config}
num_runs: 1
actions:
- ${arxiv_search}
- ${google_search}
validations:
- ${citation_validation}
20 changes: 20 additions & 0 deletions src/apps/slackapp/tests/test_get_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from slackapp.utils import get_qa_agent_from_config_file

from sherpa_ai.actions import ArxivSearch, GoogleSearch
from sherpa_ai.agents.qa_agent import QAAgent
from sherpa_ai.output_parsers.citation_validation import CitationValidation
from sherpa_ai.test_utils.data import get_test_data_file_path


def test_get_agent(get_test_data_file_path): # noqa: F811
config_filename = get_test_data_file_path(__file__, "test_get_agent.yaml")
agent = get_qa_agent_from_config_file(config_filename)

assert agent is not None
assert type(agent) is QAAgent

assert len(agent.actions) == 2
assert type(agent.actions[0]) is ArxivSearch
assert type(agent.actions[1]) is GoogleSearch
assert len(agent.validations) == 1
assert type(agent.validations[0]) is CitationValidation
55 changes: 55 additions & 0 deletions src/conf/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
shared_memory:
_target_: sherpa_ai.memory.shared_memory.SharedMemory
objective: Answer the question

user_id: none
team_id: none

llm:
_target_: sherpa_ai.models.sherpa_base_chat_model.SherpaChatOpenAI
model_name: gpt-3.5-turbo
temperature: 0
user_id: ${user_id}
team_id: ${team_id}

agent_config:
_target_: sherpa_ai.config.task_config.AgentConfig

citation_validation:
_target_: sherpa_ai.output_parsers.citation_validation.CitationValidation
sequence_threshold: 0.5
jaccard_threshold: 0.5
token_overlap: 0.5

number_validation:
_target_: sherpa_ai.output_parsers.number_validation.NumberValidation

arxiv_search:
_target_: sherpa_ai.actions.arxiv_search.ArxivSearch
role_description: Act as a question answering agent
task: Question answering
llm: ${llm}
max_results: 3

google_search:
_target_: sherpa_ai.actions.GoogleSearch
role_description: Act as a question answering agent
task: Question answering
llm: ${llm}
include_metadata: true
config: ${agent_config}

qa_agent:
_target_: sherpa_ai.agents.qa_agent.QAAgent
llm: ${llm}
shared_memory: ${shared_memory}
name: QA Sherpa
description: Act as a question answering agent
agent_config: ${agent_config}
num_runs: 1
validation_steps: 1
actions:
- ${google_search}
validations:
- ${number_validation}
- ${citation_validation}
Loading

0 comments on commit 7f867f9

Please sign in to comment.