diff --git a/packages/cdk/lambda/utils/bedrockAgentApi.ts b/packages/cdk/lambda/utils/bedrockAgentApi.ts index ba54a38e..60f04ada 100644 --- a/packages/cdk/lambda/utils/bedrockAgentApi.ts +++ b/packages/cdk/lambda/utils/bedrockAgentApi.ts @@ -1,7 +1,7 @@ import { BedrockAgentClient, GetAgentAliasCommand, - ListAgentActionGroupsCommand + ListAgentActionGroupsCommand, } from '@aws-sdk/client-bedrock-agent'; import { BedrockAgentRuntimeClient, @@ -35,7 +35,7 @@ const s3Client = new S3Client({}); const agentMap: AgentMap = JSON.parse(process.env.AGENT_MAP || '{}'); type AgentInfo = { codeInterpreterEnabled: boolean; -} +}; const agentInfoMap: { [aliasId: string]: AgentInfo } = {}; // s3:/// から https://.s3.amazonaws.com/ に変換する @@ -51,6 +51,36 @@ const convertS3UriToUrl = (s3Uri: string): string => { return ''; }; +const getAgentInfo = async (agentId: string, agentAliasId: string) => { + // Get Agent Info if not cached + if (!agentInfoMap[agentAliasId]) { + // Get Agent Version + const agentAliasInfoRes = await agentClient.send( + new GetAgentAliasCommand({ + agentId: agentId, + agentAliasId: agentAliasId, + }) + ); + const agentVersion = + agentAliasInfoRes.agentAlias?.routingConfiguration?.pop()?.agentVersion ?? + '1'; + // List Action Group + const actionGroups = await agentClient.send( + new ListAgentActionGroupsCommand({ + agentId: agentId, + agentVersion: agentVersion, + }) + ); + // Cache Agent Info + agentInfoMap[agentAliasId] = { + codeInterpreterEnabled: !!actionGroups.actionGroupSummaries?.find( + (actionGroup) => actionGroup.actionGroupName === 'CodeInterpreterAction' + ), + }; + } + return agentInfoMap[agentAliasId]; +}; + const bedrockAgentApi: Pick = { invokeStream: async function* (model: Model, messages: UnrecordedMessage[]) { try { @@ -58,28 +88,9 @@ const bedrockAgentApi: Pick = { if (!agentMap[model.modelId]) { throw new Error('Agent not found'); } - const agentId = agentMap[model.modelId].agentId + const agentId = agentMap[model.modelId].agentId; const agentAliasId = agentMap[model.modelId].aliasId; - - // Get Agent Info if not cached - if (!agentInfoMap[agentAliasId]) { - // Get Agent Version - const agentAliasInfoRes = await agentClient.send(new GetAgentAliasCommand({ - agentId: agentId, - agentAliasId: agentAliasId, - })); - const agentVersion = agentAliasInfoRes.agentAlias?.routingConfiguration?.pop()?.agentVersion ?? "1"; - // List Action Group - const actionGroups = await agentClient.send(new ListAgentActionGroupsCommand({ - agentId: agentId, - agentVersion: agentVersion, - })); - // Cache Agent Info - agentInfoMap[agentAliasId] = { - codeInterpreterEnabled: !!actionGroups.actionGroupSummaries?.find(actionGroup => actionGroup.actionGroupName === "CodeInterpreterAction"), - }; - } - const agentInfo = agentInfoMap[agentAliasId] + const agentInfo = await getAgentInfo(agentId, agentAliasId); // Invoke Agent const command = new InvokeAgentCommand({ @@ -94,7 +105,9 @@ const bedrockAgentApi: Pick = { data: Buffer.from(file.source.data, 'base64'), }, }, - useCase: agentInfo.codeInterpreterEnabled ? 'CODE_INTERPRETER' : 'CHAT', + useCase: agentInfo.codeInterpreterEnabled + ? 'CODE_INTERPRETER' + : 'CHAT', })) || [], }, agentId: agentId,