Skip to content

Commit

Permalink
Adds semaphores, file size limit, error reporting
Browse files Browse the repository at this point in the history
  • Loading branch information
kgrofelnik committed Mar 28, 2024
1 parent e70e308 commit 49d4bd5
Show file tree
Hide file tree
Showing 11 changed files with 183 additions and 50 deletions.
71 changes: 67 additions & 4 deletions contracts/contracts/ChatGpt.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ interface IOracle {
function createLlmCall(
uint promptId
) external returns (uint);

function createKnowledgeBaseQuery(
uint kbQueryCallbackId,
string memory cid,
string memory query,
uint32 num_documents
) external returns (uint i);
}

contract ChatGpt {
Expand All @@ -30,13 +37,14 @@ contract ChatGpt {

address private owner;
address public oracleAddress;
string public knowledgeBase;

event OracleAddressUpdated(address indexed newOracleAddress);

constructor(address initialOracleAddress) {
constructor(address initialOracleAddress, string memory knowledgeBaseCID) {
owner = msg.sender;
oracleAddress = initialOracleAddress;
chatRunsCount = 0;
knowledgeBase = knowledgeBaseCID;
}

modifier onlyOwner() {
Expand Down Expand Up @@ -67,7 +75,18 @@ contract ChatGpt {
uint currentId = chatRunsCount;
chatRunsCount = chatRunsCount + 1;

IOracle(oracleAddress).createLlmCall(currentId);
// If there is a knowledge base, create a knowledge base query
if (bytes(knowledgeBase).length > 0) {
IOracle(oracleAddress).createKnowledgeBaseQuery(
currentId,
knowledgeBase,
message,
3
);
} else {
// Otherwise, create an LLM call
IOracle(oracleAddress).createLlmCall(currentId);
}
emit ChatCreated(msg.sender, currentId);

return currentId;
Expand All @@ -91,6 +110,39 @@ contract ChatGpt {
run.messagesCount++;
}

function onOracleKnowledgeBaseQueryResponse(
uint runId,
string [] memory documents,
string memory errorMessage
) public onlyOracle {
ChatRun storage run = chatRuns[runId];
require(
keccak256(abi.encodePacked(run.messages[run.messagesCount - 1].role)) == keccak256(abi.encodePacked("user")),
"No message to add context to"
);
// Retrieve the last user message
Message storage lastMessage = run.messages[run.messagesCount - 1];

// Start with the original message content
string memory newContent = lastMessage.content;

// Append "Relevant context:\n" only if there are documents
if (documents.length > 0) {
newContent = string(abi.encodePacked(newContent, "\n\nRelevant context:\n"));
}

// Iterate through the documents and append each to the newContent
for (uint i = 0; i < documents.length; i++) {
newContent = string(abi.encodePacked(newContent, documents[i], "\n"));
}

// Finally, set the lastMessage content to the newly constructed string
lastMessage.content = newContent;

// Call LLM
IOracle(oracleAddress).createLlmCall(runId);
}

function addMessage(string memory message, uint runId) public {
ChatRun storage run = chatRuns[runId];
require(
Expand All @@ -106,7 +158,18 @@ contract ChatGpt {
newMessage.role = "user";
run.messages.push(newMessage);
run.messagesCount++;
IOracle(oracleAddress).createLlmCall(runId);
// If there is a knowledge base, create a knowledge base query
if (bytes(knowledgeBase).length > 0) {
IOracle(oracleAddress).createKnowledgeBaseQuery(
runId,
knowledgeBase,
message,
3
);
} else {
// Otherwise, create an LLM call
IOracle(oracleAddress).createLlmCall(runId);
}
}

function getMessageHistoryContents(uint chatId) public view returns (string[] memory) {
Expand Down
19 changes: 14 additions & 5 deletions contracts/contracts/ChatOracle.sol
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ interface IOracleTypes {
uint32 promptTokens;
uint32 totalTokens;
}

struct KnowledgeBaseQueryRequest {
string cid;
string query;
uint32 num_documents;
}
}

interface IChatGpt {
Expand Down Expand Up @@ -182,8 +188,7 @@ contract ChatOracle {
mapping(string => string) public kbIndexes;
uint public kbIndexingRequestCount;

mapping(uint => string) public kbQueryCids;
mapping(uint => string) public kbQueries;
mapping(uint => IOracleTypes.KnowledgeBaseQueryRequest) public kbQueries;
mapping(uint => address) public kbQueryCallbackAddresses;
mapping(uint => uint) public kbQueryCallbackIds;
mapping(uint => bool) public isKbQueryProcessed;
Expand Down Expand Up @@ -413,12 +418,16 @@ contract ChatOracle {
function createKnowledgeBaseQuery(
uint kbQueryCallbackId,
string memory cid,
string memory query
string memory query,
uint32 num_documents
) public returns (uint i) {
require(bytes(kbIndexes[cid]).length > 0, "Index not available for this CID");
require(bytes(query).length > 0, "Query cannot be empty");
require(num_documents > 0, "Number of documents should be greater than 0");
uint kbQueryId = kbQueryCount;
kbQueryCids[kbQueryId] = cid;
kbQueries[kbQueryId] = query;
kbQueries[kbQueryId].cid = cid;
kbQueries[kbQueryId].query = query;
kbQueries[kbQueryId].num_documents = num_documents;
kbQueryCallbackIds[kbQueryId] = kbQueryCallbackId;

kbQueryCallbackAddresses[kbQueryId] = msg.sender;
Expand Down
14 changes: 12 additions & 2 deletions contracts/scripts/deployAll.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ async function main() {
await deployVitailik(oracleAddress);
await deployAgent(oracleAddress);
console.log()

for (let contractName of ["ChatGpt", "OpenAiChatGpt", "GroqChatGpt"]) {
await deployChatGptWithKnowledgeBase("ChatGpt", oracleAddress, "");
for (let contractName of ["OpenAiChatGpt", "GroqChatGpt"]) {
await deployChatGpt(contractName, oracleAddress)
}
}
Expand Down Expand Up @@ -86,6 +86,16 @@ async function deployChatGpt(contractName: string, oracleAddress: string) {
);
}

async function deployChatGptWithKnowledgeBase(contractName: string, oracleAddress: string, knowledgeBaseCID: string) {
const agent = await ethers.deployContract(contractName, [oracleAddress, knowledgeBaseCID], {});

await agent.waitForDeployment();

console.log(
`${contractName} deployed to ${agent.target} with knowledge base "${knowledgeBaseCID}"`
);
}


// We recommend this pattern to be able to use async/await everywhere
// and properly handle errors.
Expand Down
67 changes: 41 additions & 26 deletions oracles/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
KB_QUERY_TASKS = {}
MAX_CONCURRENT_CHATS = 5
MAX_CONCURRENT_FUNCTION_CALLS = 5
MAX_CONCURRENT_INDEXING = 5
MAX_CONCURRENT_KB_QUERIES = 5


async def _answer_chat(chat: Chat, semaphore: Semaphore):
Expand All @@ -45,7 +47,7 @@ async def _answer_chat(chat: Chat, semaphore: Semaphore):
print(
f"Chat {chat.id} {'' if success else 'not '}"
f"replied, tx: {chat.transaction_receipt}",
flush=True
flush=True,
)
except Exception as ex:
print(f"Failed to answer chat {chat.id}, exc: {ex}", flush=True)
Expand Down Expand Up @@ -109,7 +111,7 @@ async def _call_function(function_call: FunctionCall, semaphore: Semaphore):
print(
f"Function {function_call.id} {'' if success else 'not '}"
f"called, tx: {function_call.transaction_receipt}",
flush=True
flush=True,
)
except Exception as ex:
print(f"Failed to call function {function_call.id}, exc: {ex}", flush=True)
Expand All @@ -131,7 +133,10 @@ async def _process_function_calls():
try:
await FUNCTION_TASKS[index]
except Exception as e:
print(f"Task for function {index} raised an exception: {e}", flush=True)
print(
f"Task for function {index} raised an exception: {e}",
flush=True,
)
del FUNCTION_TASKS[index]
except Exception as exc:
print(f"Function loop raised an exception: {exc}")
Expand All @@ -142,27 +147,29 @@ async def _index_knowledgebase_function(
request: KnowledgeBaseIndexingRequest,
ipfs_repository: IpfsRepository,
kb_repository: KnowledgeBaseRepository,
semaphore: Semaphore,
):
try:
indexing_result = await index_knowledge_base_use_case.execute(
request, ipfs_repository, kb_repository
)
success = await repository.send_kb_indexing_response(
request,
index_cid=indexing_result.index_cid,
error_message=indexing_result.error,
)
print(
f"Knowledge base indexing {request.id} {'' if success else 'not '} indexed, tx: {request.transaction_receipt}"
)
async with semaphore:
indexing_result = await index_knowledge_base_use_case.execute(
request, ipfs_repository, kb_repository
)
success = await repository.send_kb_indexing_response(
request,
index_cid=indexing_result.index_cid,
error_message=indexing_result.error,
)
print(
f"Knowledge base indexing {request.id} {'' if success else 'not '} indexed, tx: {request.transaction_receipt}"
)
except Exception as ex:
print(
f"Failed to index knowledge base {request.id}, cid {request.cid}, exc: {ex}"
)


async def _process_knowledge_base_indexing():

semaphore = asyncio.Semaphore(MAX_CONCURRENT_INDEXING)
while True:
try:
kb_indexing_requests = await repository.get_unindexed_knowledge_bases()
Expand All @@ -173,7 +180,10 @@ async def _process_knowledge_base_indexing():
)
task = asyncio.create_task(
_index_knowledgebase_function(
kb_indexing_request, ipfs_repository, kb_repository
kb_indexing_request,
ipfs_repository,
kb_repository,
semaphore,
)
)
KB_INDEXING_TASKS[kb_indexing_request.id] = task
Expand All @@ -197,17 +207,19 @@ async def _query_knowledge_base(
request: KnowledgeBaseQuery,
ipfs_repository: IpfsRepository,
kb_repository: KnowledgeBaseRepository,
semaphore: Semaphore,
):
try:
query_result = await query_knowledge_base_use_case.execute(
request, ipfs_repository, kb_repository
)
success = await repository.send_kb_query_response(
request, query_result.documents, error_message=query_result.error
)
print(
f"Knowledge base query {request.id} {'' if success else 'not '} answered, tx: {request.transaction_receipt}"
)
async with semaphore:
query_result = await query_knowledge_base_use_case.execute(
request, ipfs_repository, kb_repository
)
success = await repository.send_kb_query_response(
request, query_result.documents, error_message=query_result.error
)
print(
f"Knowledge base query {request.id} {'' if success else 'not '} answered, tx: {request.transaction_receipt}"
)
except Exception as ex:
print(
f"Failed to query knowledge base {request.id}, cid {request.index_cid}, exc: {ex}"
Expand All @@ -216,6 +228,7 @@ async def _query_knowledge_base(

async def _process_knowledge_base_queries():
ipfs_repository = IpfsRepository()
semaphore = asyncio.Semaphore(MAX_CONCURRENT_KB_QUERIES)
while True:
try:
kb_queries = await repository.get_unanswered_kb_queries()
Expand All @@ -225,7 +238,9 @@ async def _process_knowledge_base_queries():
f"Querying knowledge base {kb_query.id}, cid {kb_query.cid}, index_cid {kb_query.index_cid}"
)
task = asyncio.create_task(
_query_knowledge_base(kb_query, ipfs_repository, kb_repository)
_query_knowledge_base(
kb_query, ipfs_repository, kb_repository, semaphore
)
)
KB_QUERY_TASKS[kb_query.id] = task
completed_tasks = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async def execute(
error="",
)
except Exception as e:
print(e)
print(e, flush=True)
return KnowledgeBaseIndexingResult(
index_cid="",
error=str(e),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ async def execute(
await kb_repository.deserialize(
request.cid, documents=documents, data=index
)
documents = await kb_repository.query(request.cid, request.query)
return KnowledgeBaseQueryResult(documents=documents, error="")
documents = await kb_repository.query(
request.cid, request.query, request.num_documents
)
document_texts = [document.page_content for document in documents]
return KnowledgeBaseQueryResult(documents=document_texts, error="")
except Exception as e:
print(e)
print(e, flush=True)
return KnowledgeBaseQueryResult(documents=[], error=str(e))


Expand Down
3 changes: 1 addition & 2 deletions oracles/src/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,5 @@ class KnowledgeBaseQuery:
cid: str
index_cid: str
query: str
documents: List[str] = field(default_factory=list)
error_message: Optional[str] = None
num_documents: int
transaction_receipt: dict = None
4 changes: 3 additions & 1 deletion oracles/src/repositories/ipfs_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ async def read_file(self, cid: str, max_bytes: int = 0) -> bytes:
break
data += chunk
if max_bytes > 0 and len(data) > max_bytes:
raise Exception(f"File exceeded the maximum allowed size of {max_bytes} bytes.")
raise Exception(
f"File exceeded the maximum allowed size of {max_bytes} bytes."
)
return data

async def write_file(self, data: Union[str, bytes]) -> str:
Expand Down
7 changes: 4 additions & 3 deletions oracles/src/repositories/knowledge_base_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ async def _add_knowledge_base(
async def create(self, name: str, documents: List[Document]):
embeddings = []
for i in range(0, len(documents), BATCH_SIZE):
print(i)
batch = [
document.page_content for document in documents[i : i + BATCH_SIZE]
]
Expand All @@ -65,15 +64,17 @@ async def deserialize(self, name: str, documents: List[Document], data: bytes):
print(f"KB: Deserialized {name}", flush=True)
await self._add_knowledge_base(name, index, documents)

async def query(self, name: str, query: str, k: int = 1) -> List[str]:
async def query(self, name: str, query: str, k: int = 1) -> List[Document]:
async with self.lock:
self.indexes.move_to_end(name)
index, time = self.indexes[name]
doc_store = self.document_stores[name]
query_embedding = await self._create_embedding([query])
query_vector = np.array([query_embedding[0]]).astype("float32")
_, indexes = index.search(query_vector, k)
results = [doc_store[indexes[0][i]] for i in range(len(indexes[0]))]
results = [
doc_store[indexes[0][i]] for i in range(len(indexes[0]))
]
return results

async def exists(self, name: str) -> bool:
Expand Down
Loading

0 comments on commit 49d4bd5

Please sign in to comment.