Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ai-brain: thread mapping cache and better formatting of messages #11

Merged
merged 2 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 64 additions & 29 deletions internal/source/ai-brain/assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
)

const (
cacheTTL = 8 * time.Hour
openAIPollInterval = 2 * time.Second
maxToolExecutionRetries = 3
)
Expand All @@ -32,12 +33,12 @@ type Payload struct {
}

type assistant struct {
log logrus.FieldLogger
out chan<- source.Event
openaiClient *openai.Client
assistID string
tools map[string]tool
threadMapping map[string]string
log logrus.FieldLogger
out chan<- source.Event
openaiClient *openai.Client
assistID string
tools map[string]tool
cache *cache
}

func newAssistant(cfg *Config, log logrus.FieldLogger, out chan source.Event, kubeConfigPath string) *assistant {
Expand All @@ -50,11 +51,11 @@ func newAssistant(cfg *Config, log logrus.FieldLogger, out chan source.Event, ku
config.BaseURL = cfg.OpenAICloudServiceURL

return &assistant{
log: log,
out: out,
openaiClient: openai.NewClientWithConfig(config),
assistID: cfg.OpenAIAssistantID,
threadMapping: make(map[string]string),
log: log,
out: out,
openaiClient: openai.NewClientWithConfig(config),
assistID: cfg.OpenAIAssistantID,
cache: newCache(cacheTTL),
tools: map[string]tool{
"kubectlGetPods": kcRunner.GetPods,
"kubectlGetSecrets": kcRunner.GetSecrets,
Expand Down Expand Up @@ -91,25 +92,31 @@ func (i *assistant) handle(in source.ExternalRequestInput) (api.Message, error)

// handleThread creates a new OpenAI assistant thread and handles the conversation.
func (i *assistant) handleThread(ctx context.Context, p *Payload) error {
run, err := i.openaiClient.CreateThreadAndRun(ctx, openai.CreateThreadAndRunRequest{
RunRequest: openai.RunRequest{
AssistantID: i.assistID,
},
Thread: openai.ThreadRequest{
Metadata: map[string]any{
"messageId": p.MessageID,
"instanceId": os.Getenv(remote.ProviderIdentifierEnvKey),
},
Messages: []openai.ThreadMessage{
{
Role: openai.ThreadMessageRoleUser,
Content: p.Prompt,
},
},
},
var thread openai.Thread
var err error

// First we check if we have a cached thread ID for the given message ID.
threadID, ok := i.cache.Get(p.MessageID)
if ok {
err = i.createNewMessage(ctx, threadID, p)
if err != nil {
return fmt.Errorf("while creating a new message on a thread: %w", err)
}
} else {
thread, err = i.createNewThread(ctx, p)
if err != nil {
return fmt.Errorf("while creating a new thread: %w", err)
}
threadID = thread.ID
}

i.cache.Set(p.MessageID, threadID)

run, err := i.openaiClient.CreateRun(ctx, threadID, openai.RunRequest{
AssistantID: i.assistID,
})
if err != nil {
return fmt.Errorf("while creating thread and run: %w", err)
return fmt.Errorf("while creating a thread run: %w", err)
}

toolsRetries := 0
Expand All @@ -123,10 +130,15 @@ func (i *assistant) handleThread(ctx context.Context, p *Payload) error {
"messageId": p.MessageID,
"runStatus": run.Status,
}).Debug("retrieved assistant thread run")

switch run.Status {
case openai.RunStatusCancelling, openai.RunStatusFailed, openai.RunStatusExpired:
case openai.RunStatusCancelling, openai.RunStatusFailed:
return false, fmt.Errorf("got unexpected status: %s", run.Status)

case openai.RunStatusExpired:
i.cache.Delete(p.MessageID)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be good to send some user facing message:

i.out <- source.Event{
	Message: msgExipredAIAnswer(p.MessageID),
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Um, I was thinking about that, but an expired thread is by definition old, like 10 or more minutes old, so it could mean that response will be out of the blue and confusing. How about this, I'll make sure we'll add run status as an attribute to honeycomb and then we can see how often it happens, if at all.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so side topic, should we even consider polling such runs? maybe we should change the input ctx?
and have sth like:

ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)

so polling will be canceled after 5min (we can then also call CancelRun) and we can respond to the user that we were not able to process this prompt. Otherwise, user will see only:
Screenshot 2024-03-01 at 00 22 15

return true, nil

case openai.RunStatusQueued, openai.RunStatusInProgress:
return false, nil // continue

Expand All @@ -146,6 +158,29 @@ func (i *assistant) handleThread(ctx context.Context, p *Payload) error {
})
}

func (i *assistant) createNewThread(ctx context.Context, p *Payload) (openai.Thread, error) {
return i.openaiClient.CreateThread(ctx, openai.ThreadRequest{
Metadata: map[string]any{
"messageId": p.MessageID,
"instanceId": os.Getenv(remote.ProviderIdentifierEnvKey),
},
Messages: []openai.ThreadMessage{
{
Role: openai.ThreadMessageRoleUser,
Content: p.Prompt,
},
},
})
}

func (i *assistant) createNewMessage(ctx context.Context, threadID string, p *Payload) error {
_, err := i.openaiClient.CreateMessage(ctx, threadID, openai.MessageRequest{
Role: openai.ChatMessageRoleUser,
Content: p.Prompt,
})
return err
}

func (i *assistant) handleStatusCompleted(ctx context.Context, run openai.Run, p *Payload) error {
limit := 1
msgList, err := i.openaiClient.ListMessage(ctx, run.ThreadID, &limit, nil, nil, nil)
Expand Down
4 changes: 4 additions & 0 deletions internal/source/ai-brain/brain.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ func (s *Source) Stream(ctx context.Context, in source.StreamInput) (source.Stre
instance := newAssistant(cfg, s.log, streamOutput.Event, kubeConfigPath)
s.instances.Store(sourceName, instance)

// Start assistant thread mapping cache cleanup. Technically the cache won't
// grow that much because botkube agent will be eventually restarted anyway.
go instance.cache.Cleanup()

s.log.Infof("Setup successful for source configuration %q", sourceName)
return streamOutput, nil
}
Expand Down
64 changes: 64 additions & 0 deletions internal/source/ai-brain/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package aibrain

import (
"sync"
"time"
)

type cache struct {
m sync.Map
ttl time.Duration
}

type entry struct {
threadID string
expireTime time.Time
}

func newCache(ttl time.Duration) *cache {
return &cache{
m: sync.Map{},
ttl: ttl,
}
}

func (c *cache) Set(messageID, threadID string) {
c.m.Store(messageID, &entry{
threadID: threadID,
expireTime: time.Now().Add(c.ttl),
})
}

func (c *cache) Get(messageID string) (string, bool) {
item, ok := c.m.Load(messageID)
if !ok {
return "", false
}

e, ok := item.(*entry)
if !ok || time.Now().After(e.expireTime) {
c.m.Delete(messageID)
return "", false
}

return e.threadID, true
}

func (c *cache) Delete(messageID string) {
c.m.Delete(messageID)
}

func (c *cache) Cleanup() {
ticker := time.NewTicker(c.ttl / 16)
defer ticker.Stop()

for range ticker.C {
c.m.Range(func(k interface{}, v interface{}) bool {
e, ok := v.(*entry)
if !ok || time.Now().After(e.expireTime) {
c.m.Delete(k)
}
return true
})
}
}
97 changes: 73 additions & 24 deletions internal/source/ai-brain/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,56 @@ package aibrain

import (
"math/rand"
"regexp"
"strings"
"time"

"github.com/kubeshop/botkube/pkg/api"
)

var quickResponses = []string{
"Just a moment, please...",
"Thinking about this one...",
"Let me check on that for you.",
"Processing your request...",
"Working on it!",
"This one needs some extra thought.",
"I'm carefully considering your request.",
"Consulting my super-smart brain...",
"Cogs are turning...",
"Accessing the knowledge archives...",
"Running calculations at lightning speed!",
"Hold on tight, I'm diving into the details.",
"I'm here to help!",
"Happy to look into this for you.",
"Always learning to do this better.",
"I want to get this right for you.",
"Let me see what I can find out.",
"My circuits are buzzing!",
"Let me consult with my owl advisor...",
"Consider it done (or at least, I'll try my best!)",
"I'll get back to you with the best possible answer.",
}
const (
teamsMessageIDSubstr = "thread.tacv2"
)

var (
quickResponses = []string{
"Just a moment, please...",
"Thinking about this one...",
"Let me check on that for you.",
"Processing your request...",
"Working on it!",
"This one needs some extra thought.",
"I'm carefully considering your request.",
"Consulting my super-smart brain...",
"Cogs are turning...",
"Accessing the knowledge archives...",
"Running calculations at lightning speed!",
"Hold on tight, I'm diving into the details.",
"I'm here to help!",
"Happy to look into this for you.",
"Always learning to do this better.",
"I want to get this right for you.",
"Let me see what I can find out.",
"My circuits are buzzing!",
"Let me consult with my owl advisor...",
"Consider it done (or at least, I'll try my best!)",
"I'll get back to you with the best possible answer.",
}

reSlackEscapeCodeBlocks = regexp.MustCompile("`(.*?)`")
reSlackBoldText = regexp.MustCompile(`\*\*(.*?)\*\*`)
reSlackItalicText = regexp.MustCompile(`_([^_]+)_`)
reSlackStrikethroughText = regexp.MustCompile(`~~([^~]+)~~`)
reSlackHeadings = regexp.MustCompile(`^#+ (.*)$`)
reSlackLinks = regexp.MustCompile(`\[([^\]]+)]\(([^)]+)\)`)
reSlackImages = regexp.MustCompile(`!\[([^\]]+)]\(([^)]+)\)`)

reTeamsBold = regexp.MustCompile(`\*\*(.*?)\*\*`)
reTeamsItalic = regexp.MustCompile(`_([^_]+)_`)
reTeamsStrikethrough = regexp.MustCompile(`~~([^~]+)~~`)
reTeamsHeading = regexp.MustCompile(`^#+ (.*)$`)
reTeamsImageLink = regexp.MustCompile(`!\[([^\]]+)]\(([^)]+)\)`)
)

func pickQuickResponse(messageID string) api.Message {
rand.New(rand.NewSource(time.Now().UnixNano())) // #nosec G404
Expand Down Expand Up @@ -74,12 +96,18 @@ func msgNoAIAnswer(messageID string) api.Message {
}

func msgAIAnswer(messageID, text string) api.Message {
convertedText := markdownToSlack(text)

if strings.Contains(messageID, teamsMessageIDSubstr) {
convertedText = markdownToTeams(text)
}

return api.Message{
ParentActivityID: messageID,
Sections: []api.Section{
{
Base: api.Base{
Body: api.Body{Plaintext: text},
Body: api.Body{Plaintext: convertedText},
},
Context: []api.ContextItem{
{Text: "AI-generated content may be incorrect."},
Expand All @@ -88,3 +116,24 @@ func msgAIAnswer(messageID, text string) api.Message {
},
}
}

func markdownToSlack(markdownText string) string {
mszostok marked this conversation as resolved.
Show resolved Hide resolved
text := reSlackEscapeCodeBlocks.ReplaceAllString(markdownText, "`$1`")
text = reSlackBoldText.ReplaceAllString(text, "*$1*")
text = reSlackItalicText.ReplaceAllString(text, "_$1_")
text = reSlackStrikethroughText.ReplaceAllString(text, "~$1~")
text = reSlackHeadings.ReplaceAllString(text, "*$1*")
text = reSlackLinks.ReplaceAllString(text, "<$2|$1>")
return reSlackImages.ReplaceAllString(text, "<$2|$1>")
}

func markdownToTeams(markdownText string) string {
text := reTeamsBold.ReplaceAllString(markdownText, "**$1**")
text = reTeamsItalic.ReplaceAllString(text, "_$1_")
text = reTeamsStrikethrough.ReplaceAllString(text, "<s>$1</s>")
text = reTeamsHeading.ReplaceAllString(text, "**$1**")

// Images in Teams require specific upload mechanisms;
// here, we'll just preserve the image link
return reTeamsImageLink.ReplaceAllString(text, "(Image: $2)")
}
Loading