Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: llm engine driver #411

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions constraint/go.mod
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
module github.com/open-policy-agent/frameworks/constraint

go 1.18
go 1.21

toolchain go1.22.1

require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc
github.com/golang/glog v1.2.0
github.com/google/go-cmp v0.6.0
github.com/onsi/gomega v1.31.1
github.com/open-policy-agent/opa v0.62.1
github.com/sashabaranov/go-openai v1.20.4
github.com/sethvargo/go-retry v0.2.4
github.com/spf13/cobra v1.8.0
github.com/spf13/pflag v1.0.5
github.com/walles/env v0.0.4
golang.org/x/net v0.22.0
k8s.io/api v0.29.3
k8s.io/apiextensions-apiserver v0.29.3
Expand All @@ -25,12 +30,13 @@ require (
github.com/OneOfOne/xxhash v1.2.8 // indirect
github.com/agnivade/levenshtein v1.1.1 // indirect
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df // indirect
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a // indirect
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/blang/semver/v4 v4.0.0 // indirect
github.com/cenkalti/backoff/v4 v4.2.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/emicklei/go-restful/v3 v3.11.0 // indirect
github.com/evanphx/json-patch v5.6.0+incompatible // indirect
github.com/evanphx/json-patch/v5 v5.8.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
Expand All @@ -46,6 +52,7 @@ require (
github.com/google/cel-go v0.17.7 // indirect
github.com/google/gnostic-models v0.6.8 // indirect
github.com/google/gofuzz v1.2.0 // indirect
github.com/google/pprof v0.0.0-20230817174616-7a8ec2ada47b // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/gorilla/mux v1.8.1 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 // indirect
Expand All @@ -54,6 +61,7 @@ require (
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/miekg/dns v1.1.58 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
Expand All @@ -63,20 +71,23 @@ require (
github.com/prometheus/client_model v0.5.0 // indirect
github.com/prometheus/common v0.48.0 // indirect
github.com/prometheus/procfs v0.12.0 // indirect
github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 // indirect
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/stoewer/go-strcase v1.2.0 // indirect
github.com/tchap/go-patricia/v2 v2.3.1 // indirect
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
github.com/yashtewari/glob-intersection v0.2.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 // indirect
go.opentelemetry.io/otel v1.21.0 // indirect
go.etcd.io/etcd/api/v3 v3.5.12 // indirect
go.etcd.io/etcd/client/pkg/v3 v3.5.11 // indirect
go.etcd.io/etcd/client/v3 v3.5.11 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.47.0 // indirect
go.opentelemetry.io/otel v1.22.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.21.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.21.0 // indirect
go.opentelemetry.io/otel/metric v1.21.0 // indirect
go.opentelemetry.io/otel/metric v1.22.0 // indirect
go.opentelemetry.io/otel/sdk v1.21.0 // indirect
go.opentelemetry.io/otel/trace v1.21.0 // indirect
go.opentelemetry.io/otel/trace v1.22.0 // indirect
go.opentelemetry.io/proto/otlp v1.0.0 // indirect
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
golang.org/x/oauth2 v0.16.0 // indirect
Expand Down
81 changes: 62 additions & 19 deletions constraint/go.sum

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions constraint/pkg/client/drivers/llm/args.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package llm

type Arg func(*Driver) error
266 changes: 266 additions & 0 deletions constraint/pkg/client/drivers/llm/driver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
package llm

import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"

openai "github.com/sashabaranov/go-openai"

apiconstraints "github.com/open-policy-agent/frameworks/constraint/pkg/apis/constraints"
"github.com/open-policy-agent/frameworks/constraint/pkg/client/drivers"
llmSchema "github.com/open-policy-agent/frameworks/constraint/pkg/client/drivers/llm/schema"
"github.com/open-policy-agent/frameworks/constraint/pkg/core/templates"
"github.com/open-policy-agent/frameworks/constraint/pkg/types"
"github.com/open-policy-agent/opa/storage"
"github.com/sethvargo/go-retry"
flag "github.com/spf13/pflag"
"github.com/walles/env"
admissionv1 "k8s.io/api/admission/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
)

const (
maxRetries = 10
// need minimum of 2023-12-01-preview for JSON mode.
azureOpenAIAPIVersion = "2024-03-01-preview"
azureOpenAIURL = "openai.azure.com"
systemPrompt = "You are a policy engine for Kubernetes designed to output JSON. Input will be a policy definition, Kubernetes AdmissionRequest object, and parameters to apply to the policy if applicable. Output JSON should only have a 'decision' field with a boolean value and a 'reason' field with a string value explaining the decision, only if decision is false. Only output valid JSON."
)

var (
openAIAPIURLv1 = "https://api.openai.com/v1"

openAIDeploymentName = flag.String("openai-deployment-name", env.GetOr("OPENAI_DEPLOYMENT_NAME", env.String, "gpt-3.5-turbo-0301"), "The deployment name used for the model in OpenAI service.")
openAIAPIKey = flag.String("openai-api-key", env.GetOr("OPENAI_API_KEY", env.String, ""), "The API key for the OpenAI service. This is required.")
openAIEndpoint = flag.String("openai-endpoint", env.GetOr("OPENAI_ENDPOINT", env.String, openAIAPIURLv1), "The endpoint for OpenAI service. Defaults to"+openAIAPIURLv1+". Set this to Azure OpenAI Service or OpenAI compatible API endpoint, if needed.")
)

type Driver struct {
prompts map[string]string
}

var _ drivers.Driver = &Driver{}

type Decision struct {
Name string
Constraint *unstructured.Unstructured
Decision bool
Reason string
}

type ARGetter interface {
GetAdmissionRequest() *admissionv1.AdmissionRequest
}

// Name returns the name of the driver.
func (d *Driver) Name() string {
return llmSchema.Name
}

func (d *Driver) AddTemplate(_ context.Context, ct *templates.ConstraintTemplate) error {
source, err := llmSchema.GetSourceFromTemplate(ct)
if err != nil {
return err
}

prompt, err := source.GetPrompt()
if err != nil {
return err
}
if prompt == "" {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

other than empty string, are there other prompts that we may want to prevent? could any prompts be used to exploit or manipulate the system?

return fmt.Errorf("prompt is empty for template: %q", ct.Name)
}

d.prompts[ct.Name] = prompt
return nil
}

func (d *Driver) RemoveTemplate(_ context.Context, ct *templates.ConstraintTemplate) error {
delete(d.prompts, ct.Name)

return nil
}

func (d *Driver) AddConstraint(_ context.Context, constraint *unstructured.Unstructured) error {
promptName := strings.ToLower(constraint.GetKind())

_, found := d.prompts[promptName]
if !found {
return fmt.Errorf("no promptName with name: %q", promptName)
}
return nil
}

func (d *Driver) RemoveConstraint(_ context.Context, _ *unstructured.Unstructured) error {
return nil
}

func (d *Driver) AddData(_ context.Context, _ string, _ storage.Path, _ interface{}) error {
return nil
}

func (d *Driver) RemoveData(_ context.Context, _ string, _ storage.Path) error {
return nil
}

func (d *Driver) Query(ctx context.Context, _ string, constraints []*unstructured.Unstructured, review interface{}, _ ...drivers.QueryOpt) (*drivers.QueryResponse, error) {
llmc, err := newLLMClients()
if err != nil {
return nil, err
}

arGetter, ok := review.(ARGetter)
if !ok {
return nil, errors.New("cannot convert review to ARGetter")
}
aRequest := arGetter.GetAdmissionRequest()

var allDecisions []*Decision
for _, constraint := range constraints {
promptName := strings.ToLower(constraint.GetKind())
prompt, found := d.prompts[promptName]
if !found {
continue
}

paramsStruct, _, err := unstructured.NestedFieldNoCopy(constraint.Object, "spec", "parameters")
if err != nil {
return nil, err
}

params, err := json.Marshal(paramsStruct)
if err != nil {
return nil, err
}

llmPrompt := fmt.Sprintf("policy: %s\nadmission request: %s\nparameters: %s", prompt, string(aRequest.Object.Raw), string(params))

var resp string
r := retry.WithMaxRetries(maxRetries, retry.NewExponential(1*time.Second))
if err := retry.Do(ctx, r, func(ctx context.Context) error {
resp, err = llmc.openaiGptChatCompletion(ctx, llmPrompt)
requestErr := &openai.APIError{}
if errors.As(err, &requestErr) {
switch requestErr.HTTPStatusCode {
case http.StatusTooManyRequests, http.StatusRequestTimeout, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
return retry.RetryableError(err)
}
}
return nil
}); err != nil {
return nil, err
}

var decision Decision
err = json.Unmarshal([]byte(resp), &decision)
if err != nil {
return nil, err
}

if !decision.Decision {
llmDecision := &Decision{
Decision: decision.Decision,
Name: constraint.GetName(),
Constraint: constraint,
Reason: decision.Reason,
}
allDecisions = append(allDecisions, llmDecision)
}
}
if len(allDecisions) == 0 {
return nil, nil
}

results := make([]*types.Result, len(allDecisions))
for i, llmDecision := range allDecisions {
enforcementAction, found, err := unstructured.NestedString(llmDecision.Constraint.Object, "spec", "enforcementAction")
if err != nil {
return nil, err
}
if !found {
enforcementAction = apiconstraints.EnforcementActionDeny
}

results[i] = &types.Result{
Metadata: map[string]interface{}{
"name": llmDecision.Name,
},
Constraint: llmDecision.Constraint,
Msg: llmDecision.Reason,
EnforcementAction: enforcementAction,
}
}
return &drivers.QueryResponse{Results: results}, nil
}

func (d *Driver) Dump(_ context.Context) (string, error) {
panic("implement me")
}

func (d *Driver) GetDescriptionForStat(_ string) (string, error) {
panic("implement me")
}

type llmClients struct {
openAIClient openai.Client
}

func newLLMClients() (llmClients, error) {
var config openai.ClientConfig
// default to OpenAI API
config = openai.DefaultConfig(*openAIAPIKey)

if openAIEndpoint != &openAIAPIURLv1 {
// Azure OpenAI
if strings.Contains(*openAIEndpoint, azureOpenAIURL) {
config = openai.DefaultAzureConfig(*openAIAPIKey, *openAIEndpoint)
} else {
// OpenAI API compatible endpoint or proxy
config.BaseURL = *openAIEndpoint
}
config.APIVersion = azureOpenAIAPIVersion
}

clients := llmClients{
openAIClient: *openai.NewClientWithConfig(config),
}
return clients, nil
}

func (c *llmClients) openaiGptChatCompletion(ctx context.Context, prompt string) (string, error) {
req := openai.ChatCompletionRequest{
Model: *openAIDeploymentName,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: systemPrompt,
},
{
Role: openai.ChatMessageRoleUser,
Content: prompt,
},
},
N: 1, // Number of completions to generate
Temperature: 0, // 0 is more deterministic
ResponseFormat: &openai.ChatCompletionResponseFormat{
Type: openai.ChatCompletionResponseFormatTypeJSONObject,
},
}

resp, err := c.openAIClient.CreateChatCompletion(ctx, req)
if err != nil {
return "", err
}

if len(resp.Choices) != 1 {
return "", fmt.Errorf("expected choices to be 1 but received: %d", len(resp.Choices))
}

result := resp.Choices[0].Message.Content
return result, nil
}
13 changes: 13 additions & 0 deletions constraint/pkg/client/drivers/llm/new.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package llm

func New(args ...Arg) (*Driver, error) {
driver := &Driver{
prompts: make(map[string]string),
}
for _, arg := range args {
if err := arg(driver); err != nil {
return nil, err
}
}
return driver, nil
}
Loading
Loading