Skip to content

Commit

Permalink
ai-brain: implement thread mapping with an internal cache
Browse files Browse the repository at this point in the history
  • Loading branch information
vaijab committed Feb 29, 2024
1 parent 68f1496 commit 36c53a4
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 29 deletions.
93 changes: 64 additions & 29 deletions internal/source/ai-brain/assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
)

const (
cacheTTL = 8 * time.Hour
openAIPollInterval = 2 * time.Second
maxToolExecutionRetries = 3
)
Expand All @@ -32,12 +33,12 @@ type Payload struct {
}

type assistant struct {
log logrus.FieldLogger
out chan<- source.Event
openaiClient *openai.Client
assistID string
tools map[string]tool
threadMapping map[string]string
log logrus.FieldLogger
out chan<- source.Event
openaiClient *openai.Client
assistID string
tools map[string]tool
cache *cache
}

func newAssistant(cfg *Config, log logrus.FieldLogger, out chan source.Event, kubeConfigPath string) *assistant {
Expand All @@ -50,11 +51,11 @@ func newAssistant(cfg *Config, log logrus.FieldLogger, out chan source.Event, ku
config.BaseURL = cfg.OpenAICloudServiceURL

return &assistant{
log: log,
out: out,
openaiClient: openai.NewClientWithConfig(config),
assistID: cfg.OpenAIAssistantID,
threadMapping: make(map[string]string),
log: log,
out: out,
openaiClient: openai.NewClientWithConfig(config),
assistID: cfg.OpenAIAssistantID,
cache: newCache(cacheTTL),
tools: map[string]tool{
"kubectlGetPods": kcRunner.GetPods,
"kubectlGetSecrets": kcRunner.GetSecrets,
Expand Down Expand Up @@ -91,25 +92,31 @@ func (i *assistant) handle(in source.ExternalRequestInput) (api.Message, error)

// handleThread creates a new OpenAI assistant thread and handles the conversation.
func (i *assistant) handleThread(ctx context.Context, p *Payload) error {
run, err := i.openaiClient.CreateThreadAndRun(ctx, openai.CreateThreadAndRunRequest{
RunRequest: openai.RunRequest{
AssistantID: i.assistID,
},
Thread: openai.ThreadRequest{
Metadata: map[string]any{
"messageId": p.MessageID,
"instanceId": os.Getenv(remote.ProviderIdentifierEnvKey),
},
Messages: []openai.ThreadMessage{
{
Role: openai.ThreadMessageRoleUser,
Content: p.Prompt,
},
},
},
var thread openai.Thread
var err error

// First we check if we have a cached thread ID for the given message ID.
threadID, ok := i.cache.Get(p.MessageID)
if ok {
err = i.createNewMessage(ctx, threadID, p)
if err != nil {
return fmt.Errorf("while creating a new message on a thread: %w", err)
}
} else {
thread, err = i.createNewThread(ctx, p)
if err != nil {
return fmt.Errorf("while creating a new thread: %w", err)
}
threadID = thread.ID
}

i.cache.Set(p.MessageID, threadID)

run, err := i.openaiClient.CreateRun(ctx, threadID, openai.RunRequest{
AssistantID: i.assistID,
})
if err != nil {
return fmt.Errorf("while creating thread and run: %w", err)
return fmt.Errorf("while creating a thread run: %w", err)
}

toolsRetries := 0
Expand All @@ -123,10 +130,15 @@ func (i *assistant) handleThread(ctx context.Context, p *Payload) error {
"messageId": p.MessageID,
"runStatus": run.Status,
}).Debug("retrieved assistant thread run")

switch run.Status {
case openai.RunStatusCancelling, openai.RunStatusFailed, openai.RunStatusExpired:
case openai.RunStatusCancelling, openai.RunStatusFailed:
return false, fmt.Errorf("got unexpected status: %s", run.Status)

case openai.RunStatusExpired:
i.cache.Delete(p.MessageID)
return true, nil

case openai.RunStatusQueued, openai.RunStatusInProgress:
return false, nil // continue

Expand All @@ -146,6 +158,29 @@ func (i *assistant) handleThread(ctx context.Context, p *Payload) error {
})
}

func (i *assistant) createNewThread(ctx context.Context, p *Payload) (openai.Thread, error) {
return i.openaiClient.CreateThread(ctx, openai.ThreadRequest{
Metadata: map[string]any{
"messageId": p.MessageID,
"instanceId": os.Getenv(remote.ProviderIdentifierEnvKey),
},
Messages: []openai.ThreadMessage{
{
Role: openai.ThreadMessageRoleUser,
Content: p.Prompt,
},
},
})
}

func (i *assistant) createNewMessage(ctx context.Context, threadID string, p *Payload) error {
_, err := i.openaiClient.CreateMessage(ctx, threadID, openai.MessageRequest{
Role: openai.ChatMessageRoleUser,
Content: p.Prompt,
})
return err
}

func (i *assistant) handleStatusCompleted(ctx context.Context, run openai.Run, p *Payload) error {
limit := 1
msgList, err := i.openaiClient.ListMessage(ctx, run.ThreadID, &limit, nil, nil, nil)
Expand Down
4 changes: 4 additions & 0 deletions internal/source/ai-brain/brain.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ func (s *Source) Stream(ctx context.Context, in source.StreamInput) (source.Stre
instance := newAssistant(cfg, s.log, streamOutput.Event, kubeConfigPath)
s.instances.Store(sourceName, instance)

// Start assistant thread mapping cache cleanup. Technically the cache won't
// grow that much because botkube agent will be eventually restarted anyway.
go instance.cache.Cleanup()

s.log.Infof("Setup successful for source configuration %q", sourceName)
return streamOutput, nil
}
Expand Down
64 changes: 64 additions & 0 deletions internal/source/ai-brain/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package aibrain

import (
"sync"
"time"
)

type cache struct {
m sync.Map
ttl time.Duration
}

type entry struct {
threadID string
expireTime time.Time
}

func newCache(ttl time.Duration) *cache {
return &cache{
m: sync.Map{},
ttl: ttl,
}
}

func (c *cache) Set(messageID, threadID string) {
c.m.Store(messageID, &entry{
threadID: threadID,
expireTime: time.Now().Add(c.ttl),
})
}

func (c *cache) Get(messageID string) (string, bool) {
item, ok := c.m.Load(messageID)
if !ok {
return "", false
}

e, ok := item.(*entry)
if !ok || time.Now().After(e.expireTime) {
c.m.Delete(messageID)
return "", false
}

return e.threadID, true
}

func (c *cache) Delete(messageID string) {
c.m.Delete(messageID)
}

func (c *cache) Cleanup() {
ticker := time.NewTicker(c.ttl / 16)
defer ticker.Stop()

for range ticker.C {
c.m.Range(func(k interface{}, v interface{}) bool {
e, ok := v.(*entry)
if !ok || time.Now().After(e.expireTime) {
c.m.Delete(k)
}
return true
})
}
}

0 comments on commit 36c53a4

Please sign in to comment.