Skip to content

Commit

Permalink
refactor & lint
Browse files Browse the repository at this point in the history
  • Loading branch information
maekawataiki committed Oct 10, 2024
1 parent d126065 commit 19c0895
Showing 1 changed file with 37 additions and 24 deletions.
61 changes: 37 additions & 24 deletions packages/cdk/lambda/utils/bedrockAgentApi.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import {
BedrockAgentClient,
GetAgentAliasCommand,
ListAgentActionGroupsCommand
ListAgentActionGroupsCommand,
} from '@aws-sdk/client-bedrock-agent';
import {
BedrockAgentRuntimeClient,
Expand Down Expand Up @@ -35,7 +35,7 @@ const s3Client = new S3Client({});
const agentMap: AgentMap = JSON.parse(process.env.AGENT_MAP || '{}');
type AgentInfo = {
codeInterpreterEnabled: boolean;
}
};
const agentInfoMap: { [aliasId: string]: AgentInfo } = {};

// s3://<BUCKET>/<PREFIX> から https://<BUCKET>.s3.amazonaws.com/<PREFIX> に変換する
Expand All @@ -51,35 +51,46 @@ const convertS3UriToUrl = (s3Uri: string): string => {
return '';
};

const getAgentInfo = async (agentId: string, agentAliasId: string) => {
// Get Agent Info if not cached
if (!agentInfoMap[agentAliasId]) {
// Get Agent Version
const agentAliasInfoRes = await agentClient.send(
new GetAgentAliasCommand({
agentId: agentId,
agentAliasId: agentAliasId,
})
);
const agentVersion =
agentAliasInfoRes.agentAlias?.routingConfiguration?.pop()?.agentVersion ??
'1';
// List Action Group
const actionGroups = await agentClient.send(
new ListAgentActionGroupsCommand({
agentId: agentId,
agentVersion: agentVersion,
})
);
// Cache Agent Info
agentInfoMap[agentAliasId] = {
codeInterpreterEnabled: !!actionGroups.actionGroupSummaries?.find(
(actionGroup) => actionGroup.actionGroupName === 'CodeInterpreterAction'
),
};
}
return agentInfoMap[agentAliasId];
};

const bedrockAgentApi: Pick<ApiInterface, 'invokeStream'> = {
invokeStream: async function* (model: Model, messages: UnrecordedMessage[]) {
try {
// Get Agent
if (!agentMap[model.modelId]) {
throw new Error('Agent not found');
}
const agentId = agentMap[model.modelId].agentId
const agentId = agentMap[model.modelId].agentId;
const agentAliasId = agentMap[model.modelId].aliasId;

// Get Agent Info if not cached
if (!agentInfoMap[agentAliasId]) {
// Get Agent Version
const agentAliasInfoRes = await agentClient.send(new GetAgentAliasCommand({
agentId: agentId,
agentAliasId: agentAliasId,
}));
const agentVersion = agentAliasInfoRes.agentAlias?.routingConfiguration?.pop()?.agentVersion ?? "1";
// List Action Group
const actionGroups = await agentClient.send(new ListAgentActionGroupsCommand({
agentId: agentId,
agentVersion: agentVersion,
}));
// Cache Agent Info
agentInfoMap[agentAliasId] = {
codeInterpreterEnabled: !!actionGroups.actionGroupSummaries?.find(actionGroup => actionGroup.actionGroupName === "CodeInterpreterAction"),
};
}
const agentInfo = agentInfoMap[agentAliasId]
const agentInfo = await getAgentInfo(agentId, agentAliasId);

// Invoke Agent
const command = new InvokeAgentCommand({
Expand All @@ -94,7 +105,9 @@ const bedrockAgentApi: Pick<ApiInterface, 'invokeStream'> = {
data: Buffer.from(file.source.data, 'base64'),
},
},
useCase: agentInfo.codeInterpreterEnabled ? 'CODE_INTERPRETER' : 'CHAT',
useCase: agentInfo.codeInterpreterEnabled
? 'CODE_INTERPRETER'
: 'CHAT',
})) || [],
},
agentId: agentId,
Expand Down

0 comments on commit 19c0895

Please sign in to comment.