Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New sessions table that allows for chat history and metadata greater than 400KB #584

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions lib/aws-genai-llm-chatbot-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack {
ragEngines,
messagesTopic: chatBotApi.messagesTopic,
sessionsTable: chatBotApi.sessionsTable,
byUserIdIndex: chatBotApi.byUserIdIndex,
});

// Route all incoming messages targeted to langchain to the langchain model interface queue
Expand Down Expand Up @@ -120,7 +119,6 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack {
config: props.config,
messagesTopic: chatBotApi.messagesTopic,
sessionsTable: chatBotApi.sessionsTable,
byUserIdIndex: chatBotApi.byUserIdIndex,
chatbotFilesBucket: chatBotApi.filesBucket,
createPrivateGateway: ideficsModels.length > 0,
});
Expand Down
12 changes: 5 additions & 7 deletions lib/chatbot-api/chatbot-dynamodb-tables/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@ export class ChatBotDynamoDBTables extends Construct {
constructor(scope: Construct, id: string, props: ChatBotDynamoDBTablesProps) {
super(scope, id);

// Create the sessions table with a partition key of USER#<UUID>
// and a sort key of SK of SESSION#<Unique Session ID>>
// No need to the global secondary index for this table
const sessionsTable = new dynamodb.Table(this, "SessionsTable", {
partitionKey: {
name: "SessionId",
name: "PK",
type: dynamodb.AttributeType.STRING,
},
sortKey: {
name: "UserId",
name: "SK",
type: dynamodb.AttributeType.STRING,
},
billingMode: dynamodb.BillingMode.PAY_PER_REQUEST,
Expand All @@ -36,11 +39,6 @@ export class ChatBotDynamoDBTables extends Construct {
pointInTimeRecovery: true,
});

sessionsTable.addGlobalSecondaryIndex({
indexName: this.byUserIdIndex,
partitionKey: { name: "UserId", type: dynamodb.AttributeType.STRING },
});

this.sessionsTable = sessionsTable;
}
}
6 changes: 2 additions & 4 deletions lib/chatbot-api/functions/api-handler/routes/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def get_sessions():
return [
{
"id": session.get("SessionId"),
"title": session.get("History", [{}])[0]
.get("data", {})
.get("content", "<no title>"),
"title": session.get("Title", "<no title>"),
"startTime": f'{session.get("StartTime")}Z',
}
for session in sessions
Expand Down Expand Up @@ -76,7 +74,7 @@ def get_session(id: str):
"type": item.get("type"),
"content": item.get("data", {}).get("content"),
"metadata": json.dumps(
item.get("data", {}).get("additional_kwargs"),
item.get("data", {}).get("additional_kwargs", {}),
cls=genai_core.utils.json.CustomEncoder,
),
}
Expand Down
3 changes: 0 additions & 3 deletions lib/chatbot-api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ export class ChatBotApi extends Construct {
public readonly messagesTopic: sns.Topic;
public readonly outBoundQueue: sqs.Queue;
public readonly sessionsTable: dynamodb.Table;
public readonly byUserIdIndex: string;
public readonly filesBucket: s3.Bucket;
public readonly userFeedbackBucket: s3.Bucket;
public readonly graphqlApi: appsync.GraphqlApi;
Expand Down Expand Up @@ -120,7 +119,6 @@ export class ChatBotApi extends Construct {
const apiResolvers = new ApiResolvers(this, "RestApi", {
...props,
sessionsTable: chatTables.sessionsTable,
byUserIdIndex: chatTables.byUserIdIndex,
api,
userFeedbackBucket: chatBuckets.userFeedbackBucket,
filesBucket: chatBuckets.filesBucket,
Expand Down Expand Up @@ -158,7 +156,6 @@ export class ChatBotApi extends Construct {
this.messagesTopic = realtimeBackend.messagesTopic;
this.outBoundQueue = realtimeBackend.queue;
this.sessionsTable = chatTables.sessionsTable;
this.byUserIdIndex = chatTables.byUserIdIndex;
this.userFeedbackBucket = chatBuckets.userFeedbackBucket;
this.filesBucket = chatBuckets.filesBucket;
this.graphqlApi = api;
Expand Down
2 changes: 0 additions & 2 deletions lib/chatbot-api/rest-api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ export interface ApiResolversProps {
readonly ragEngines?: RagEngines;
readonly userPool: cognito.UserPool;
readonly sessionsTable: dynamodb.Table;
readonly byUserIdIndex: string;
readonly filesBucket: s3.Bucket;
readonly userFeedbackBucket: s3.Bucket;
readonly modelsParameter: ssm.StringParameter;
Expand Down Expand Up @@ -70,7 +69,6 @@ export class ApiResolvers extends Construct {
props.shared.xOriginVerifySecret.secretArn,
API_KEYS_SECRETS_ARN: props.shared.apiKeysSecret.secretArn,
SESSIONS_TABLE_NAME: props.sessionsTable.tableName,
SESSIONS_BY_USER_ID_INDEX_NAME: props.byUserIdIndex,
USER_FEEDBACK_BUCKET_NAME: props.userFeedbackBucket?.bucketName ?? "",
UPLOAD_BUCKET_NAME: props.ragEngines?.uploadBucket?.bucketName ?? "",
CHATBOT_FILES_BUCKET_NAME: props.filesBucket.bucketName,
Expand Down
2 changes: 0 additions & 2 deletions lib/model-interfaces/idefics/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ interface IdeficsInterfaceProps {
readonly config: SystemConfig;
readonly messagesTopic: sns.Topic;
readonly sessionsTable: dynamodb.Table;
readonly byUserIdIndex: string;
readonly chatbotFilesBucket: s3.Bucket;
readonly createPrivateGateway: boolean;
}
Expand Down Expand Up @@ -68,7 +67,6 @@ export class IdeficsInterface extends Construct {
...props.shared.defaultEnvironmentVariables,
CONFIG_PARAMETER_NAME: props.shared.configParameter.parameterName,
SESSIONS_TABLE_NAME: props.sessionsTable.tableName,
SESSIONS_BY_USER_ID_INDEX_NAME: props.byUserIdIndex,
MESSAGES_TOPIC_ARN: props.messagesTopic.topicArn,
CHATBOT_FILES_BUCKET_NAME: props.chatbotFilesBucket.bucketName,
CHATBOT_FILES_PRIVATE_API: api?.url ?? "",
Expand Down
2 changes: 0 additions & 2 deletions lib/model-interfaces/langchain/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ interface LangChainInterfaceProps {
readonly ragEngines?: RagEngines;
readonly messagesTopic: sns.Topic;
readonly sessionsTable: dynamodb.Table;
readonly byUserIdIndex: string;
}

export class LangChainInterface extends Construct {
Expand Down Expand Up @@ -51,7 +50,6 @@ export class LangChainInterface extends Construct {
...props.shared.defaultEnvironmentVariables,
CONFIG_PARAMETER_NAME: props.shared.configParameter.parameterName,
SESSIONS_TABLE_NAME: props.sessionsTable.tableName,
SESSIONS_BY_USER_ID_INDEX_NAME: props.byUserIdIndex,
API_KEYS_SECRETS_ARN: props.shared.apiKeysSecret.secretArn,
MESSAGES_TOPIC_ARN: props.messagesTopic.topicArn,
WORKSPACES_TABLE_NAME:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@
from decimal import Decimal
from datetime import datetime
from botocore.exceptions import ClientError
from langchain_core.messages.ai import AIMessage, AIMessageChunk
from operator import itemgetter

from langchain.schema import BaseChatMessageHistory
from langchain.schema.messages import (
BaseMessage,
_message_to_dict,
messages_from_dict,
messages_to_dict,
_message_from_dict,
)
from langchain_core.messages.ai import AIMessage, AIMessageChunk
from genai_core.sessions import delete_session

client = boto3.resource("dynamodb")
logger = Logger()
logger = Logger(level="DEBUG")


class DynamoDBChatMessageHistory(BaseChatMessageHistory):
Expand All @@ -25,87 +26,154 @@ def __init__(
table_name: str,
session_id: str,
user_id: str,
max_messages: int = None, # Added max_messages parameter
):
self.table = client.Table(table_name)
self.session_id = session_id
self.user_id = user_id
self.max_messages = max_messages # Store max_messages

def _get_full_history(self) -> List[BaseMessage]:
"""Query all messages from DynamoDB for the current session"""
response = self.table.query(
KeyConditionExpression=(
"#pk = :user_id AND begins_with(#sk, :session_prefix)"
),
FilterExpression="#itemType = :itemType",
ExpressionAttributeNames={
"#pk": "PK",
"#sk": "SK",
"#itemType": "ItemType",
},
ScanIndexForward=True,
ExpressionAttributeValues={
":user_id": f"USER#{self.user_id}",
":session_prefix": f"SESSION#{self.session_id}",
":itemType": "message",
},
)
items = response.get("Items", [])

return items

@property
def messages(self) -> List[BaseMessage]:
"""Retrieve the messages from DynamoDB"""
response = None
try:
response = self.table.get_item(
Key={"SessionId": self.session_id, "UserId": self.user_id}
)
except ClientError as error:
if error.response["Error"]["Code"] == "ResourceNotFoundException":
logger.warning("No record found with session id: %s", self.session_id)
else:
logger.exception(error)
"""Get the last max_messages from the full history"""
full_history_items = self._get_full_history()

# Hande case where max_messages is None
if self.max_messages is None:
self.max.messages = len(full_history_items)

if response and "Item" in response:
items = response["Item"]["History"]
else:
items = []
# Slice before processing
relevant_items = full_history_items[-self.max_messages :]

messages = messages_from_dict(items)
return messages
# Use itemgetter and list comprehension
get_history_data = itemgetter("History")
return [
_message_from_dict(get_history_data(item) or "") for item in relevant_items
]

def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in DynamoDB"""
messages = messages_to_dict(self.messages)
if isinstance(message, AIMessageChunk):
# When streaming with RunnableWithMessageHistory,
# it would add a chunk to the history but it expects a text as content.
ai_message = ""
for c in message.content:
if "text" in c:
ai_message = ai_message + c.get("text")
_message = _message_to_dict(AIMessage(ai_message))
else:
_message = _message_to_dict(message)
messages.append(_message)

try:
current_time = datetime.now().isoformat()

# messages = messages_to_dict(self.messages)
if isinstance(message, AIMessageChunk):
# When streaming with RunnableWithMessageHistory,
# it would add a chunk to the history but it expects a text as content.
ai_message = ""
for c in message.content:
if "text" in c:
ai_message = ai_message + c.get("text")
_message = _message_to_dict(AIMessage(ai_message))
else:
_message = _message_to_dict(message)

try:
self.table.update_item(
Key={
"PK": f"USER#{self.user_id}",
"SK": f"SESSION#{self.session_id}",
},
UpdateExpression="SET LastUpdateTime = :time",
ConditionExpression="attribute_exists(PK)",
ExpressionAttributeValues={":time": current_time},
)
except ClientError as err:
if err.response["Error"]["Code"] == "ConditionalCheckFailedException":
# Session doesn't exist, so create a new one
self.table.put_item(
Item={
"PK": f"USER#{self.user_id}",
"SK": f"SESSION#{self.session_id}",
"Title": _message_to_dict(message)
.get("data", {})
.get("content", "<no title>"),
"StartTime": current_time,
"ItemType": "session",
"SessionId": self.session_id,
"LastUpdateTime": current_time,
}
)
else:
# If some other error occurs, re-raise the exception
raise

self.table.put_item(
Item={
"SessionId": self.session_id,
"UserId": self.user_id,
"StartTime": datetime.now().isoformat(),
"History": messages,
"PK": f"USER#{self.user_id}",
"SK": f"SESSION#{self.session_id}#{current_time}",
"StartTime": current_time,
"History": _message, # Store full history in DynamoDB
"ItemType": "message",
"Role": _message.get("type"),
}
)
except ClientError as err:
logger.exception(err)

def add_metadata(self, metadata: dict) -> None:
"""Add additional metadata to the last message"""
messages = messages_to_dict(self.messages)
if not messages:
full_history_items = self._get_full_history()
if not full_history_items:
return

metadata = json.loads(json.dumps(metadata), parse_float=Decimal)
messages[-1]["data"]["additional_kwargs"] = metadata

most_recent_history = full_history_items[-1]

most_recent_history["History"]["data"]["additional_kwargs"] = metadata

try:
self.table.put_item(
Item={
"SessionId": self.session_id,
"UserId": self.user_id,
"StartTime": datetime.now().isoformat(),
"History": messages,
}

# Perform the update operation
self.table.update_item(
Key={
"PK": f"USER#{self.user_id}",
"SK": (
f"SESSION#{self.session_id}"
f"#{most_recent_history['StartTime']}"
),
},
UpdateExpression="SET #data = :data",
ExpressionAttributeNames={
"#data": "History"
},
ExpressionAttributeValues={
":data": most_recent_history["History"]
},
)

except Exception as err:
logger.exception(err)
logger.exception(f"Failed to update metadata: {err}")

def clear(self) -> None:
"""Clear session memory from DynamoDB"""
try:
self.table.delete_item(
Key={"SessionId": self.session_id, "UserId": self.user_id}
)
delete_session(self.session_id, self.user_id)

except ClientError as err:
logger.exception(err)
Loading
Loading