From e203139b40dcd3f24e6be2910f27e66a910a8e24 Mon Sep 17 00:00:00 2001 From: Ke Chen Date: Fri, 9 Feb 2024 22:05:54 +0800 Subject: [PATCH] feat: add system prompt support --- apis/record/infer.go | 21 +++++++++++++-------- models/config.go | 1 + 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/apis/record/infer.go b/apis/record/infer.go index 17cd463..b052f1f 100644 --- a/apis/record/infer.go +++ b/apis/record/infer.go @@ -69,15 +69,19 @@ func InferOpenAI( openaiConfig.BaseURL = model.Url client := openai.NewClientWithConfig(openaiConfig) + var messages = make([]openai.ChatCompletionMessage, 0, len(postRecord)+2) + messages = append(messages, openai.ChatCompletionMessage{ + Role: "system", + Content: model.OpenAISystemPrompt, + }) + messages = append(messages, postRecord.ToOpenAIMessages()...) + messages = append(messages, openai.ChatCompletionMessage{ + Role: "user", + Content: record.Request, + }) request := openai.ChatCompletionRequest{ - Model: model.OpenAIModelName, - Messages: append( - postRecord.ToOpenAIMessages(), - openai.ChatCompletionMessage{ - Role: "user", - Content: record.Request, - }, - ), + Model: model.OpenAIModelName, + Messages: messages, } if ctx == nil { @@ -137,6 +141,7 @@ func InferOpenAI( resultBuilder.WriteString(response.Choices[0].Delta.Content) nowOutput = resultBuilder.String() + before, _, found := CutLastAny(nowOutput, ",.?!\n,。?!") if !found || before == detectedOutput { continue diff --git a/models/config.go b/models/config.go index e49b528..bbb770a 100644 --- a/models/config.go +++ b/models/config.go @@ -24,6 +24,7 @@ type ModelConfig struct { CallbackUrl string `json:"callback_url"` APIType APIType `json:"api_type"` OpenAIModelName string `json:"openai_model_name"` + OpenAISystemPrompt string `json:"openai_system_prompt"` } type ModelConfigs = []*ModelConfig