-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
source: iteration of the ai-brain mainly
- Loading branch information
Showing
7 changed files
with
507 additions
and
256 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,98 +1,23 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
_ "embed" | ||
"fmt" | ||
"sync" | ||
|
||
"github.com/kubeshop/botkube-cloud-plugins/internal/auth" | ||
|
||
"github.com/hashicorp/go-plugin" | ||
|
||
aibrain "github.com/kubeshop/botkube-cloud-plugins/internal/source/ai-brain" | ||
"github.com/kubeshop/botkube/pkg/api" | ||
"github.com/kubeshop/botkube/pkg/api/source" | ||
"github.com/kubeshop/botkube/pkg/loggerx" | ||
"github.com/sirupsen/logrus" | ||
) | ||
|
||
// version is set via ldflags by GoReleaser. | ||
var version = "dev" | ||
|
||
const ( | ||
pluginName = "ai-brain" | ||
description = "Calls AI engine with incoming webhook prompts and streams the response." | ||
) | ||
|
||
// AI implements Botkube source plugin. | ||
type AI struct { | ||
incomingPrompts sync.Map | ||
} | ||
|
||
// Metadata returns details about plugin. | ||
func (*AI) Metadata(context.Context) (api.MetadataOutput, error) { | ||
return api.MetadataOutput{ | ||
Version: version, | ||
Description: description, | ||
Recommended: true, | ||
JSONSchema: api.JSONSchema{ | ||
Value: aibrain.ConfigJSONSchema, | ||
}, | ||
ExternalRequest: api.ExternalRequestMetadata{ | ||
Payload: api.ExternalRequestPayload{ | ||
JSONSchema: api.JSONSchema{ | ||
Value: aibrain.IncomingWebhookJSONSchema, | ||
}, | ||
}, | ||
}, | ||
}, nil | ||
} | ||
|
||
// Stream implements Botkube source plugin. | ||
func (a *AI) Stream(_ context.Context, in source.StreamInput) (source.StreamOutput, error) { | ||
cfg, err := aibrain.MergeConfigs(in.Configs) | ||
if err != nil { | ||
return source.StreamOutput{}, fmt.Errorf("while merging configuration: %w", err) | ||
} | ||
|
||
log := loggerx.New(cfg.Log) | ||
out := source.StreamOutput{ | ||
Event: make(chan source.Event), | ||
} | ||
go a.processPrompts(in.Context.SourceName, out.Event, log) | ||
|
||
log.Infof("Setup successful for source configuration %q", in.Context.SourceName) | ||
return out, nil | ||
} | ||
|
||
func (a *AI) processPrompts(sourceName string, event chan<- source.Event, log logrus.FieldLogger) { | ||
a.incomingPrompts.Store(sourceName, aibrain.NewProcessor(log, event)) | ||
} | ||
|
||
// HandleExternalRequest handles incoming payload and returns an event based on it. | ||
func (a *AI) HandleExternalRequest(_ context.Context, in source.ExternalRequestInput) (source.ExternalRequestOutput, error) { | ||
brain, ok := a.incomingPrompts.Load(in.Context.SourceName) | ||
if !ok { | ||
return source.ExternalRequestOutput{}, fmt.Errorf("source %q not found", in.Context.SourceName) | ||
} | ||
quickResponse, err := brain.(*aibrain.Processor).Process(in.Payload) | ||
if err != nil { | ||
return source.ExternalRequestOutput{}, fmt.Errorf("while processing payload: %w", err) | ||
} | ||
|
||
return source.ExternalRequestOutput{ | ||
Event: source.Event{ | ||
Message: quickResponse, | ||
}, | ||
}, nil | ||
} | ||
|
||
func main() { | ||
source.Serve(map[string]plugin.Plugin{ | ||
pluginName: &source.Plugin{ | ||
Source: auth.NewProtectedSource(&AI{ | ||
incomingPrompts: sync.Map{}, | ||
}), | ||
aibrain.PluginName: &source.Plugin{ | ||
Source: auth.NewProtectedSource(aibrain.NewSource(version)), | ||
}, | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,262 @@ | ||
package aibrain | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"fmt" | ||
"strings" | ||
"time" | ||
|
||
"github.com/kubeshop/botkube/pkg/api" | ||
"github.com/kubeshop/botkube/pkg/api/source" | ||
"github.com/sashabaranov/go-openai" | ||
"github.com/sirupsen/logrus" | ||
) | ||
|
||
// Payload represents incoming webhook payload. | ||
type Payload struct { | ||
Prompt string `json:"prompt"` | ||
MessageID string `json:"messageId"` | ||
} | ||
|
||
// handle is simplified - don't do that this way! | ||
func (i *sourceInstance) handle(in source.ExternalRequestInput) (api.Message, error) { | ||
p := new(Payload) | ||
err := json.Unmarshal(in.Payload, p) | ||
if err != nil { | ||
return api.Message{}, fmt.Errorf("while unmarshalling payload: %w", err) | ||
} | ||
|
||
// TODO: why is the Prompt prefixed with `ai-face`? | ||
if p.Prompt == "ai-face" { | ||
return api.NewPlaintextMessage("Please clarify your question.", false), nil | ||
} | ||
|
||
// Cleanup the prompt. | ||
p.Prompt = strings.TrimPrefix(p.Prompt, "ai-face") | ||
|
||
// TODO: needs better goroutine management with persistent thread mapping. | ||
go func() { | ||
_ = i.handleThread(context.Background(), p) | ||
}() | ||
|
||
return api.Message{ | ||
ParentActivityID: p.MessageID, | ||
Sections: []api.Section{ | ||
{ | ||
// TODO: remove? | ||
Base: api.Base{ | ||
Body: api.Body{Plaintext: "Let me figure this out.."}, | ||
}, | ||
}, | ||
}, | ||
}, nil | ||
} | ||
|
||
// handleThread creates a new OpenAI assistant thread and handles the conversation. | ||
func (i *sourceInstance) handleThread(ctx context.Context, p *Payload) error { | ||
// Start a new thread and run it. | ||
run, err := i.openaiClient.CreateThreadAndRun(ctx, openai.CreateThreadAndRunRequest{ | ||
RunRequest: openai.RunRequest{ | ||
AssistantID: i.assistID, | ||
}, | ||
Thread: openai.ThreadRequest{ | ||
Metadata: map[string]any{ | ||
"messageId": p.MessageID, | ||
}, | ||
Messages: []openai.ThreadMessage{ | ||
{ | ||
Role: openai.ThreadMessageRoleUser, | ||
Content: p.Prompt, | ||
}, | ||
}, | ||
}, | ||
}) | ||
if err != nil { | ||
return fmt.Errorf("while creating thread and run: %w", err) | ||
} | ||
|
||
for { | ||
// Wait a little bit before polling. OpenAI assistant api does not support streaming yet. | ||
time.Sleep(2 * time.Second) | ||
|
||
// Get the run. | ||
run, err = i.openaiClient.RetrieveRun(ctx, run.ThreadID, run.ID) | ||
if err != nil { | ||
i.log.WithError(err).Error("while retrieving assistant thread run") | ||
continue | ||
} | ||
|
||
i.log.WithFields(logrus.Fields{ | ||
"messageId": p.MessageID, | ||
"runStatus": run.Status, | ||
}).Debug("retrieved assistant thread run") | ||
|
||
switch run.Status { | ||
case openai.RunStatusCancelling, openai.RunStatusFailed, openai.RunStatusExpired: | ||
// TODO tell the user that the assistant has stopped processing the request. | ||
continue | ||
|
||
// We have to wait. Here we could tell the user that we are waiting. | ||
case openai.RunStatusQueued, openai.RunStatusInProgress: | ||
continue | ||
|
||
// Fetch and return the response. | ||
case openai.RunStatusCompleted: | ||
if err = i.handleStatusCompleted(ctx, run, p); err != nil { | ||
i.log.WithError(err).Error("while handling completed case") | ||
continue | ||
} | ||
return nil | ||
|
||
// The assistant is attempting to call a function. | ||
case openai.RunStatusRequiresAction: | ||
if err = i.handleStatusRequiresAction(ctx, run); err != nil { | ||
return fmt.Errorf("while handling requires action: %w", err) | ||
} | ||
} | ||
} | ||
} | ||
|
||
func (i *sourceInstance) handleStatusCompleted(ctx context.Context, run openai.Run, p *Payload) error { | ||
msgList, err := i.openaiClient.ListMessage(ctx, run.ThreadID, nil, nil, nil, nil) | ||
if err != nil { | ||
return fmt.Errorf("while getting assistant messages response") | ||
} | ||
|
||
if len(msgList.Messages) == 0 { | ||
i.log.Debug("no response messages were found") | ||
i.out <- source.Event{ | ||
Message: api.Message{ | ||
ParentActivityID: p.MessageID, | ||
Sections: []api.Section{ | ||
{ | ||
Base: api.Base{ | ||
Body: api.Body{Plaintext: "I am sorry, but I don't have a good answer."}, | ||
}, | ||
}, | ||
}, | ||
}, | ||
} | ||
|
||
return nil | ||
} | ||
|
||
i.out <- source.Event{ | ||
Message: api.Message{ | ||
ParentActivityID: p.MessageID, | ||
Sections: []api.Section{ | ||
{ | ||
Base: api.Base{ | ||
Body: api.Body{Plaintext: msgList.Messages[0].Content[0].Text.Value}, | ||
}, | ||
}, | ||
}, | ||
}, | ||
} | ||
return nil | ||
} | ||
|
||
func (i *sourceInstance) handleStatusRequiresAction(ctx context.Context, run openai.Run) error { | ||
for _, t := range run.RequiredAction.SubmitToolOutputs.ToolCalls { | ||
if t.Type != openai.ToolTypeFunction { | ||
continue | ||
} | ||
|
||
switch t.Function.Name { | ||
case "kubectlGetPods": | ||
args := &kubectlGetPodsArgs{} | ||
if err := json.Unmarshal([]byte(t.Function.Arguments), args); err != nil { | ||
return err | ||
} | ||
|
||
out, err := kubectlGetPods(args) | ||
if err != nil { | ||
return err | ||
} | ||
// Submit tool output. | ||
_, err = i.openaiClient.SubmitToolOutputs(ctx, run.ThreadID, run.ID, openai.SubmitToolOutputsRequest{ | ||
ToolOutputs: []openai.ToolOutput{ | ||
{ | ||
ToolCallID: t.ID, | ||
Output: string(out), | ||
}, | ||
}, | ||
}) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
case "kubectlGetSecrets": | ||
args := &kubectlGetSecretsArgs{} | ||
if err := json.Unmarshal([]byte(t.Function.Arguments), args); err != nil { | ||
return err | ||
} | ||
|
||
out, err := kubectlGetSecrets(args) | ||
if err != nil { | ||
return err | ||
} | ||
// Submit tool output. | ||
_, err = i.openaiClient.SubmitToolOutputs(ctx, run.ThreadID, run.ID, openai.SubmitToolOutputsRequest{ | ||
ToolOutputs: []openai.ToolOutput{ | ||
{ | ||
ToolCallID: t.ID, | ||
Output: string(out), | ||
}, | ||
}, | ||
}) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
case "kubectlDescribePod": | ||
args := &kubectlDescribePodArgs{} | ||
if err := json.Unmarshal([]byte(t.Function.Arguments), args); err != nil { | ||
return err | ||
} | ||
|
||
out, err := kubectlDescribePod(args) | ||
if err != nil { | ||
return err | ||
} | ||
// Submit tool output. | ||
_, err = i.openaiClient.SubmitToolOutputs(ctx, run.ThreadID, run.ID, openai.SubmitToolOutputsRequest{ | ||
ToolOutputs: []openai.ToolOutput{ | ||
{ | ||
ToolCallID: t.ID, | ||
Output: string(out), | ||
}, | ||
}, | ||
}) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
case "kubectlLogs": | ||
args := &kubectlLogsArgs{} | ||
if err := json.Unmarshal([]byte(t.Function.Arguments), args); err != nil { | ||
return err | ||
} | ||
|
||
out, err := kubectlLogs(args) | ||
if err != nil { | ||
return err | ||
} | ||
// Submit tool output. | ||
_, err = i.openaiClient.SubmitToolOutputs(ctx, run.ThreadID, run.ID, openai.SubmitToolOutputsRequest{ | ||
ToolOutputs: []openai.ToolOutput{ | ||
{ | ||
ToolCallID: t.ID, | ||
Output: string(out), | ||
}, | ||
}, | ||
}) | ||
if err != nil { | ||
return err | ||
} | ||
} | ||
} | ||
|
||
return nil | ||
} |
Oops, something went wrong.