Skip to content

Commit

Permalink
Change robusta-ai auth to be based on ui token (#149)
Browse files Browse the repository at this point in the history
  • Loading branch information
arikalon1 authored Oct 6, 2024
1 parent 046b065 commit 7dc06c3
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 8 deletions.
34 changes: 32 additions & 2 deletions holmes/core/supabase_dal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
import json
import logging
import os
import threading
from typing import Dict, Optional, List
from uuid import uuid4

import yaml
from postgrest.types import ReturnMethod
from supabase import create_client
from supabase.lib.client_options import ClientOptions
from pydantic import BaseModel
from cachetools import TTLCache

from holmes.common.env_vars import (ROBUSTA_CONFIG_PATH, ROBUSTA_ACCOUNT_ID, STORE_URL, STORE_API_KEY, STORE_EMAIL,
STORE_PASSWORD)
Expand All @@ -19,6 +23,7 @@
ISSUES_TABLE = "Issues"
EVIDENCE_TABLE = "Evidence"
RUNBOOKS_TABLE = "HolmesRunbooks"
SESSION_TOKENS_TABLE = "AuthTokens"

class RobustaConfig(BaseModel):
sinks_config: List[Dict[str, Dict]]
Expand All @@ -42,7 +47,10 @@ def __init__(self):
logging.info(f"Initializing robusta store for account {self.account_id}")
options = ClientOptions(postgrest_client_timeout=SUPABASE_TIMEOUT_SECONDS)
self.client = create_client(self.url, self.api_key, options)
self.sign_in()
self.user_id = self.sign_in()
ttl = int(os.environ.get("SAAS_SESSION_TOKEN_TTL_SEC", "82800")) # 23 hours
self.token_cache = TTLCache(maxsize=1, ttl=ttl)
self.lock = threading.Lock()

@staticmethod
def __load_robusta_config() -> Optional[RobustaToken]:
Expand Down Expand Up @@ -87,11 +95,12 @@ def __init_config(self) -> bool:
# valid only if all store parameters are provided
return all([self.account_id, self.url, self.api_key, self.email, self.password])

def sign_in(self):
def sign_in(self) -> str:
logging.info("Supabase DAL login")
res = self.client.auth.sign_in_with_password({"email": self.email, "password": self.password})
self.client.auth.set_session(res.session.access_token, res.session.refresh_token)
self.client.postgrest.auth(res.session.access_token)
return res.user.id

def get_issue_data(self, issue_id: str) -> Optional[Dict]:
# TODO this could be done in a single atomic SELECT, but there is no
Expand Down Expand Up @@ -147,6 +156,27 @@ def get_resource_instructions(self, type: str, name: str) -> List[str]:

return []

def create_session_token(self) -> str:
token = str(uuid4())
self.client.table(SESSION_TOKENS_TABLE).insert(
{
"account_id": self.account_id,
"user_id": self.user_id,
"token": token,
"type": "HOLMES",
}, returning=ReturnMethod.minimal # must use this, because the user cannot read this table
).execute()
return token

def get_ai_credentials(self) -> (str, str):
with self.lock:
session_token = self.token_cache.get("session_token")
if not session_token:
session_token = self.create_session_token()
self.token_cache["session_token"] = session_token

return self.account_id, session_token

def get_workload_issues(self, resource: dict, since_hours: float) -> List[str]:
if not self.enabled or not resource:
return []
Expand Down
2 changes: 1 addition & 1 deletion holmes/core/tool_calling_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def check_llm(self, model, api_key):
#if not litellm.supports_function_calling(model=model):
# raise Exception(f"model {model} does not support function calling. You must use HolmesGPT with a model that supports function calling.")
def get_context_window_size(self) -> int:
return litellm.model_cost[self.model]['max_input_tokens']
return litellm.model_cost[self.model]['max_input_tokens']

def count_tokens_for_message(self, messages: list[dict]) -> int:
return litellm.token_counter(model=self.model,
Expand Down
13 changes: 12 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ urllib3 = "^1.26.19"
boto3 = "^1.34.145"
setuptools = "^72.1.0"
aiohttp = "^3.10.2"
cachetools = "^5.5.0"

[build-system]
requires = ["poetry-core"]
Expand Down
18 changes: 14 additions & 4 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
print("added custom certificate")

# DO NOT ADD ANY IMPORTS OR CODE ABOVE THIS LINE
# IMPORTING ABOVE MIGHT INITIALIZE AN HTTPS CLIENT THAT DOESN'T TRUST THE CUSTOM CERTIFICATEE
# IMPORTING ABOVE MIGHT INITIALIZE AN HTTPS CLIENT THAT DOESN'T TRUST THE CUSTOM CERTIFICATE


import jinja2
import logging
import uvicorn
Expand All @@ -15,7 +17,7 @@
from typing import Dict, Callable
from litellm.exceptions import AuthenticationError
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from pydantic import SecretStr
from rich.console import Console

from holmes.common.env_vars import (
Expand All @@ -40,7 +42,6 @@
)
from holmes.plugins.prompts import load_and_render_prompt
from holmes.core.tool_calling_llm import ToolCallingLLM
import jinja2


def init_logging():
Expand Down Expand Up @@ -69,9 +70,16 @@ def init_logging():
config = Config.load_from_env()


def load_robusta_api_key():
if os.environ.get("ROBUSTA_AI"):
account_id, token = dal.get_ai_credentials()
config.api_key = SecretStr(f"{account_id} {token}")


@app.post("/api/investigate")
def investigate_issues(investigate_request: InvestigateRequest):
try:
load_robusta_api_key()
context = dal.get_issue_data(
investigate_request.context.get("robusta_issue_id")
)
Expand Down Expand Up @@ -112,7 +120,7 @@ def investigate_issues(investigate_request: InvestigateRequest):

@app.post("/api/workload_health_check")
def workload_health_check(request: WorkloadHealthRequest):

load_robusta_api_key()
try:
resource = request.resource
workload_alerts: list[str] = []
Expand Down Expand Up @@ -149,6 +157,7 @@ def workload_health_check(request: WorkloadHealthRequest):
def handle_issue_conversation(
conversation_request: ConversationRequest, ai: ToolCallingLLM
):
load_robusta_api_key()
context_window = ai.get_context_window_size()
number_of_tools = len(
conversation_request.context.investigation_result.tools
Expand Down Expand Up @@ -240,6 +249,7 @@ def handle_issue_conversation(
@app.post("/api/conversation")
def converstation(conversation_request: ConversationRequest):
try:
load_robusta_api_key()
ai = config.create_toolcalling_llm(console, allowed_toolsets=ALLOWED_TOOLSETS)

handler = conversation_type_handlers.get(conversation_request.conversation_type)
Expand Down

0 comments on commit 7dc06c3

Please sign in to comment.