Skip to content

Commit

Permalink
More LLM provider config refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
Robert Szefler committed Jun 11, 2024
1 parent ef24503 commit 3e10ef0
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 219 deletions.
41 changes: 24 additions & 17 deletions holmes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
import re
import warnings
from pathlib import Path
from typing import List, Optional, Pattern
from typing import List, Optional

import typer
from rich.console import Console
from rich.logging import RichHandler
from rich.markdown import Markdown
from rich.rule import Rule

from holmes.config import LLMConfig, LLMProviderType
from holmes.config import BaseLLMConfig, LLMProviderType
from holmes.core.provider import LLMProvider
from holmes.plugins.destinations import DestinationType
from holmes.plugins.prompts import load_prompt
from holmes import get_version
Expand All @@ -29,9 +30,10 @@


# Common cli options
llm_provider_names = ", ".join(str(tp) for tp in LLMProviderType)
opt_llm: Optional[LLMProviderType] = typer.Option(
LLMProviderType.OPENAI,
help="LLM provider ('openai' or 'azure')", # TODO list all
help="LLM provider (supported values: {llm_provider_names})"
)
opt_api_key: Optional[str] = typer.Option(
None,
Expand Down Expand Up @@ -136,7 +138,7 @@ def ask(
Ask any question and answer using available tools
"""
console = init_logging(verbose)
config = LLMConfig.load_from_file(
config = BaseLLMConfig.load_from_file(
config_file,
api_key=api_key,
llm=llm,
Expand All @@ -145,15 +147,17 @@ def ask(
max_steps=max_steps,
custom_toolsets=custom_toolsets,
)
provider = LLMProvider(config)
system_prompt = load_prompt(system_prompt)
ai = config.create_toolcalling_llm(console, allowed_toolsets)
ai = provider.create_toolcalling_llm(console, allowed_toolsets)
console.print("[bold yellow]User:[/bold yellow] " + prompt)
response = ai.call(system_prompt, prompt)
text_result = Markdown(response.result)
if show_tool_output and response.tool_calls:
for tool_call in response.tool_calls:
console.print(f"[bold magenta]Used Tool:[/bold magenta]", end="")
# we need to print this separately with markup=False because it contains arbitrary text and we don't want console.print to interpret it
# we need to print this separately with markup=False because it contains arbitrary text
# and we don't want console.print to interpret it
console.print(f"{tool_call.description}. Output=\n{tool_call.result}", markup=False)
console.print(f"[bold green]AI:[/bold green]", end=" ")
console.print(text_result, soft_wrap=True)
Expand Down Expand Up @@ -195,7 +199,7 @@ def alertmanager(
Investigate a Prometheus/Alertmanager alert
"""
console = init_logging(verbose)
config = LLMConfig.load_from_file(
config = BaseLLMConfig.load_from_file(
config_file,
api_key=api_key,
llm=llm,
Expand All @@ -210,17 +214,18 @@ def alertmanager(
custom_toolsets=custom_toolsets,
custom_runbooks=custom_runbooks
)

provider = LLMProvider(config)

if alertname:
alertname = re.compile(alertname)

system_prompt = load_prompt(system_prompt)
ai = config.create_issue_investigator(console, allowed_toolsets)
ai = provider.create_issue_investigator(console, allowed_toolsets)

source = config.create_alertmanager_source()
source = provider.create_alertmanager_source()

if destination == DestinationType.SLACK:
slack = config.create_slack_destination()
slack = provider.create_slack_destination()

try:
issues = source.fetch_issues(alertname)
Expand Down Expand Up @@ -291,7 +296,7 @@ def jira(
Investigate a Jira ticket
"""
console = init_logging(verbose)
config = LLMConfig.load_from_file(
config = BaseLLMConfig.load_from_file(
config_file,
api_key=api_key,
llm=llm,
Expand All @@ -305,10 +310,11 @@ def jira(
custom_toolsets=custom_toolsets,
custom_runbooks=custom_runbooks
)
provider = LLMProvider(config)

system_prompt = load_prompt(system_prompt)
ai = config.create_issue_investigator(console, allowed_toolsets)
source = config.create_jira_source()
ai = provider.create_issue_investigator(console, allowed_toolsets)
source = provider.create_jira_source()
try:
# TODO: allow passing issue ID
issues = source.fetch_issues()
Expand Down Expand Up @@ -380,7 +386,7 @@ def github(
Investigate a GitHub issue
"""
console = init_logging(verbose)
config = LLMConfig.load_from_file(
config = BaseLLMConfig.load_from_file(
config_file,
api_key=api_key,
llm=llm,
Expand All @@ -395,10 +401,11 @@ def github(
custom_toolsets=custom_toolsets,
custom_runbooks=custom_runbooks
)
provider = LLMProvider(config)

system_prompt = load_prompt(system_prompt)
ai = config.create_issue_investigator(console, allowed_toolsets)
source = config.create_github_source()
ai = provider.create_issue_investigator(console, allowed_toolsets)
source = provider.create_github_source()
try:
issues = source.fetch_issues()
except Exception as e:
Expand Down
5 changes: 0 additions & 5 deletions holmes/common/env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,3 @@
STORE_API_KEY = os.environ.get("STORE_API_KEY", "")
STORE_EMAIL = os.environ.get("STORE_EMAIL", "")
STORE_PASSWORD = os.environ.get("STORE_PASSWORD", "")

# Currently supports BUILTIN and ROBUSTA_AI
AI_AGENT = os.environ.get("AI_AGENT", "BUILTIN")

ROBUSTA_AI_URL = os.environ.get("ROBUSTA_AI_URL", "")
Loading

0 comments on commit 3e10ef0

Please sign in to comment.