Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt Flows Chat Use Case #666

Merged
merged 5 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions docs/DEPLOY_OPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,44 @@ Knowledge Base プロンプト例: キーワードで検索し情報を取得し
}
```

### PromptFlow チャットユースケースの有効化

PromptFlow チャットユースケースでは、作成済みの Prompt Flow を呼び出すことができます。

プロジェクトのルートディレクトリにある `cdk.json` ファイルを開き、`context` セクション内に `promptFlows` 配列を追加または編集します。

[Prompt Flows の AWS コンソール画面](https://us-east-1.console.aws.amazon.com/bedrock/home#/prompt-flows) から手動で Prompt Flows を作成します。その後、Alias を作成し、作成済みの Prompt Flow の `flowId` と `aliasId`, `flowName` を追加します。`description` にはユーザーの入力を促すための説明文章を記載します。この説明文章は Prompt Flow チャットのテキストボックスに記載されます。以下はその例です。
tbrand marked this conversation as resolved.
Show resolved Hide resolved

**[packages/cdk/cdk.json](/packages/cdk/cdk.json) を編集**
```json
{
"context": {
"promptFlows": [
{
"flowId": "XXXXXXXXXX",
"aliasId": "YYYYYYYYYY",
"flowName": "WhatIsItFlow",
"description": "任意のキーワードをウェブ検索して、説明を返すフローです。文字を入力してください"
},
{
"flowId": "ZZZZZZZZZZ",
"aliasId": "OOOOOOOOOO",
"flowName": "RecipeFlow",
"description": "与えられたJSONをもとに、レシピを作成します。\n{\"dish\": \"カレーライス\", \"people\": 3} のように入力してください。"
},
{
"flowId": "PPPPPPPPPP",
"aliasId": "QQQQQQQQQQQ",
"flowName": "TravelPlanFlow",
"description": "与えられた配列をもとに、旅行計画を作成します。\n[{\"place\": \"東京\", \"day\": 3}, {\"place\": \"大阪\", \"day\": 2}] のように入力してください。"
}
]
}
}
```



### 映像分析ユースケースの有効化

映像分析ユースケースでは、映像の画像フレームとテキストを入力して画像の内容を LLM に分析させます。
Expand Down
1 change: 1 addition & 0 deletions packages/cdk/cdk.json
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"searchAgentEnabled": false,
"searchApiKey": "",
"agents": [],
"promptFlows": [],
"allowedIpV4AddressRanges": null,
"allowedIpV6AddressRanges": null,
"allowedCountryCodes": null,
Expand Down
43 changes: 43 additions & 0 deletions packages/cdk/lambda/invokeFlow.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import { Handler, Context } from 'aws-lambda';
import { PromptFlowRequest } from 'generative-ai-use-cases-jp';
import bedrockFlowApi from './utils/bedrockFlowApi';

declare global {
namespace awslambda {
function streamifyResponse(
f: (
event: PromptFlowRequest,
responseStream: NodeJS.WritableStream,
context: Context
) => Promise<void>
): Handler;
}
}

export const handler = awslambda.streamifyResponse(
async (
event: PromptFlowRequest,
responseStream: NodeJS.WritableStream,
context: Context
) => {
try {
context.callbackWaitsForEmptyEventLoop = false;

for await (const token of bedrockFlowApi.invokeFlow({
flowIdentifier: event.flowIdentifier,
flowAliasIdentifier: event.flowAliasIdentifier,
document: event.document,
})) {
responseStream.write(token);
}

responseStream.end();
} catch (e) {
console.error('Error in handler:', e);
responseStream.write(
JSON.stringify({ error: 'An error occurred processing your request' })
);
responseStream.end();
}
}
);
3 changes: 2 additions & 1 deletion packages/cdk/lambda/utils/bedrockAgentApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
ServiceQuotaExceededException,
ThrottlingException,
} from '@aws-sdk/client-bedrock-agent-runtime';

import { PutObjectCommand, S3Client } from '@aws-sdk/client-s3';
import { v4 as uuidv4 } from 'uuid';
import {
Expand Down Expand Up @@ -36,7 +37,7 @@ const convertS3UriToUrl = (s3Uri: string): string => {
return '';
};

const bedrockAgentApi: Partial<ApiInterface> = {
const bedrockAgentApi: Pick<ApiInterface, 'invokeStream'> = {
invokeStream: async function* (model: Model, messages: UnrecordedMessage[]) {
try {
const command = new InvokeAgentCommand({
Expand Down
2 changes: 1 addition & 1 deletion packages/cdk/lambda/utils/bedrockApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ const extractOutputImage = (
return modelConfig.extractOutputImage(response);
};

const bedrockApi: ApiInterface = {
const bedrockApi: Omit<ApiInterface, 'invokeFlow'> = {
invoke: async (model, messages, id) => {
const client = await initBedrockClient();

Expand Down
73 changes: 73 additions & 0 deletions packages/cdk/lambda/utils/bedrockFlowApi.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import {
BedrockAgentRuntimeClient,
FlowInputContent,
InvokeFlowCommand,
InvokeFlowCommandInput,
ServiceQuotaExceededException,
ThrottlingException,
ValidationException,
} from '@aws-sdk/client-bedrock-agent-runtime';

const client = new BedrockAgentRuntimeClient({
region: process.env.MODEL_REGION,
});

type InvokeFlowGeneratorProps = {
flowIdentifier: string;
flowAliasIdentifier: string;
document: FlowInputContent.DocumentMember['document'];
};

const bedrockFlowApi = {
invokeFlow: async function* (props: InvokeFlowGeneratorProps) {
const input: InvokeFlowCommandInput = {
flowIdentifier: props.flowIdentifier,
flowAliasIdentifier: props.flowAliasIdentifier,
inputs: [
{
nodeName: 'FlowInputNode',
nodeOutputName: 'document',
content: {
document: props.document,
},
},
],
};

const command = new InvokeFlowCommand(input);

try {
const response = await client.send(command);

if (response.responseStream) {
for await (const event of response.responseStream) {
if (event.flowOutputEvent?.content?.document) {
const chunk =
event.flowOutputEvent.content.document.toString() + '\n';
yield chunk;
}

if (event.flowCompletionEvent?.completionReason === 'SUCCESS') {
break;
}
}
}

yield '\n';
} catch (e) {
if (
e instanceof ThrottlingException ||
e instanceof ServiceQuotaExceededException
) {
yield 'ただいまアクセスが集中しているため時間をおいて試してみてください。';
} else if (e instanceof ValidationException) {
yield `形式エラーです。\n ${e}`;
} else {
console.error(e);
yield 'エラーが発生しました。時間をおいて試してみてください。';
}
}
},
};

export default bedrockFlowApi;
26 changes: 26 additions & 0 deletions packages/cdk/lib/construct/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ export interface BackendApiProps {
export class Api extends Construct {
readonly api: RestApi;
readonly predictStreamFunction: NodejsFunction;
readonly invokePromptFlowFunction: NodejsFunction;
readonly modelRegion: string;
readonly modelIds: string[];
readonly multiModalModelIds: string[];
Expand Down Expand Up @@ -212,6 +213,27 @@ export class Api extends Construct {
fileBucket.grantWrite(predictStreamFunction);
predictStreamFunction.grantInvoke(idPool.authenticatedRole);

// Prompt Flow Lambda Function の追加
const invokePromptFlowFunction = new NodejsFunction(
this,
'InvokePromptFlow',
{
runtime: Runtime.NODEJS_18_X,
entry: './lambda/invokeFlow.ts',
timeout: Duration.minutes(15),
bundling: {
nodeModules: [
'@aws-sdk/client-bedrock-runtime',
'@aws-sdk/client-bedrock-agent-runtime',
],
},
environment: {
MODEL_REGION: modelRegion,
},
}
);
invokePromptFlowFunction.grantInvoke(idPool.authenticatedRole);

const predictTitleFunction = new NodejsFunction(this, 'PredictTitle', {
runtime: Runtime.NODEJS_18_X,
entry: './lambda/predictTitle.ts',
Expand Down Expand Up @@ -267,6 +289,7 @@ export class Api extends Construct {
predictStreamFunction.role?.addToPrincipalPolicy(sagemakerPolicy);
predictTitleFunction.role?.addToPrincipalPolicy(sagemakerPolicy);
generateImageFunction.role?.addToPrincipalPolicy(sagemakerPolicy);
invokePromptFlowFunction.role?.addToPrincipalPolicy(sagemakerPolicy);
}

// Bedrock は常に権限付与
Expand All @@ -284,6 +307,7 @@ export class Api extends Construct {
predictFunction.role?.addToPrincipalPolicy(bedrockPolicy);
predictTitleFunction.role?.addToPrincipalPolicy(bedrockPolicy);
generateImageFunction.role?.addToPrincipalPolicy(bedrockPolicy);
invokePromptFlowFunction.role?.addToPrincipalPolicy(bedrockPolicy);
} else {
// crossAccountBedrockRoleArn が指定されている場合のポリシー
const logsPolicy = new PolicyStatement({
Expand All @@ -304,6 +328,7 @@ export class Api extends Construct {
predictFunction.role?.addToPrincipalPolicy(assumeRolePolicy);
predictTitleFunction.role?.addToPrincipalPolicy(assumeRolePolicy);
generateImageFunction.role?.addToPrincipalPolicy(assumeRolePolicy);
invokePromptFlowFunction.role?.addToPrincipalPolicy(assumeRolePolicy);
}

const createChatFunction = new NodejsFunction(this, 'CreateChat', {
Expand Down Expand Up @@ -752,6 +777,7 @@ export class Api extends Construct {

this.api = api;
this.predictStreamFunction = predictStreamFunction;
this.invokePromptFlowFunction = invokePromptFlowFunction;
this.modelRegion = modelRegion;
this.modelIds = modelIds;
this.multiModalModelIds = multiModalModelIds;
Expand Down
6 changes: 6 additions & 0 deletions packages/cdk/lib/construct/web.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import * as s3 from 'aws-cdk-lib/aws-s3';
import { ARecord, HostedZone, RecordTarget } from 'aws-cdk-lib/aws-route53';
import { CloudFrontTarget } from 'aws-cdk-lib/aws-route53-targets';
import { ICertificate } from 'aws-cdk-lib/aws-certificatemanager';
import { PromptFlow } from 'generative-ai-use-cases-jp';

export interface WebProps {
apiEndpointUrl: string;
Expand All @@ -20,6 +21,8 @@ export interface WebProps {
ragEnabled: boolean;
ragKnowledgeBaseEnabled: boolean;
agentEnabled: boolean;
promptFlows?: PromptFlow[];
promptFlowStreamFunctionArn: string;
selfSignUpEnabled: boolean;
webAclId?: string;
modelRegion: string;
Expand Down Expand Up @@ -167,6 +170,9 @@ export class Web extends Construct {
VITE_APP_RAG_KNOWLEDGE_BASE_ENABLED:
props.ragKnowledgeBaseEnabled.toString(),
VITE_APP_AGENT_ENABLED: props.agentEnabled.toString(),
VITE_APP_PROMPT_FLOWS: JSON.stringify(props.promptFlows || []),
VITE_APP_PROMPT_FLOW_STREAM_FUNCTION_ARN:
props.promptFlowStreamFunctionArn,
VITE_APP_SELF_SIGN_UP_ENABLED: props.selfSignUpEnabled.toString(),
VITE_APP_MODEL_REGION: props.modelRegion,
VITE_APP_MODEL_IDS: JSON.stringify(props.modelIds),
Expand Down
14 changes: 13 additions & 1 deletion packages/cdk/lib/generative-ai-use-cases-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import {
import { CfnWebACLAssociation } from 'aws-cdk-lib/aws-wafv2';
import * as cognito from 'aws-cdk-lib/aws-cognito';
import { ICertificate } from 'aws-cdk-lib/aws-certificatemanager';
import { Agent } from 'generative-ai-use-cases-jp';
import { Agent, PromptFlow } from 'generative-ai-use-cases-jp';

const errorMessageForBooleanContext = (key: string) => {
return `${key} の設定でエラーになりました。原因として考えられるものは以下です。
Expand All @@ -34,6 +34,7 @@ interface GenerativeAiUseCasesStackProps extends StackProps {
domainName?: string;
hostedZoneId?: string;
agents?: Agent[];
promptFlows?: PromptFlow[];
knowledgeBaseId?: string;
knowledgeBaseDataSourceBucketName?: string;
guardrailIdentifier?: string;
Expand Down Expand Up @@ -69,6 +70,7 @@ export class GenerativeAiUseCasesStack extends Stack {
const samlCognitoFederatedIdentityProviderName: string =
this.node.tryGetContext('samlCognitoFederatedIdentityProviderName')!;
const agentEnabled = this.node.tryGetContext('agentEnabled') || false;
const promptFlows = this.node.tryGetContext('promptFlows') || [];
const recognizeFileEnabled: boolean = this.node.tryGetContext(
'recognizeFileEnabled'
)!;
Expand Down Expand Up @@ -149,6 +151,8 @@ export class GenerativeAiUseCasesStack extends Stack {
ragEnabled,
ragKnowledgeBaseEnabled,
agentEnabled,
promptFlows,
promptFlowStreamFunctionArn: api.invokePromptFlowFunction.functionArn,
selfSignUpEnabled,
webAclId: props.webAclId,
modelRegion: api.modelRegion,
Expand Down Expand Up @@ -238,6 +242,14 @@ export class GenerativeAiUseCasesStack extends Stack {
value: api.predictStreamFunction.functionArn,
});

new CfnOutput(this, 'InvokePromptFlowFunctionArn', {
value: api.invokePromptFlowFunction.functionArn,
});

new CfnOutput(this, 'PromptFlows', {
value: Buffer.from(JSON.stringify(promptFlows)).toString('base64'),
});

new CfnOutput(this, 'RagEnabled', {
value: ragEnabled.toString(),
});
Expand Down
7 changes: 7 additions & 0 deletions packages/types/src/message.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ export type Agent = {

export type AgentMap = Record<string, { agentId: string; aliasId: string }>;

export type PromptFlow = {
flowId: string;
aliasId: string;
flowName: string;
description: string;
};

export type MessageAttributes = {
messageId: string;
usecase: string;
Expand Down
12 changes: 10 additions & 2 deletions packages/types/src/protocol.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ import {
QueryCommandOutput,
RetrieveCommandOutput,
} from '@aws-sdk/client-kendra';
import { RetrieveCommandOutput as RetrieveCommandOutputKnowledgeBase } from '@aws-sdk/client-bedrock-agent-runtime';
import {
FlowInputContent,
RetrieveCommandOutput as RetrieveCommandOutputKnowledgeBase,
} from '@aws-sdk/client-bedrock-agent-runtime';
import { GenerateImageParams } from './image';
import { ShareId, UserIdAndChatId } from './share';
import { MediaFormat } from '@aws-sdk/client-transcribe';
tbrand marked this conversation as resolved.
Show resolved Hide resolved

export type StreamingChunk = {
text: string;
Expand Down Expand Up @@ -83,6 +85,12 @@ export type PredictRequest = {

export type PredictResponse = string;

export type PromptFlowRequest = {
flowIdentifier: string;
flowAliasIdentifier: string;
document: FlowInputContent.DocumentMember['document'];
};

export type PredictTitleRequest = {
model: Model;
chat: Chat;
Expand Down
Loading