diff --git a/contracts/contracts/ChatOracle.sol b/contracts/contracts/ChatOracle.sol index 2f4ec49..5aecca5 100644 --- a/contracts/contracts/ChatOracle.sol +++ b/contracts/contracts/ChatOracle.sol @@ -266,6 +266,7 @@ contract ChatOracle { string memory errorMessage ) public onlyWhitelisted { require(!isPromptProcessed[promptId], "Prompt already processed"); + isPromptProcessed[promptId] = true; IChatGpt(callbackAddresses[promptId]).onOracleLlmResponse( promptCallBackId, response, @@ -318,6 +319,7 @@ contract ChatOracle { string memory errorMessage ) public onlyWhitelisted { require(!isFunctionProcessed[functionId], "Function already processed"); + isFunctionProcessed[functionId] = true; IChatGpt(functionCallbackAddresses[functionId]).onOracleFunctionResponse( functionCallBackId, response, @@ -351,6 +353,7 @@ contract ChatOracle { string memory errorMessage ) public onlyWhitelisted { require(!isPromptProcessed[promptId], "Prompt already processed"); + isPromptProcessed[promptId] = true; IChatGpt(callbackAddresses[promptId]).onOracleOpenAiLlmResponse( promptCallBackId, response, @@ -384,6 +387,7 @@ contract ChatOracle { string memory errorMessage ) public onlyWhitelisted { require(!isPromptProcessed[promptId], "Prompt already processed"); + isPromptProcessed[promptId] = true; IChatGpt(callbackAddresses[promptId]).onOracleGroqLlmResponse( promptCallBackId, response, diff --git a/oracles/src/repositories/oracle_repository.py b/oracles/src/repositories/oracle_repository.py index 9caf47a..0af0350 100644 --- a/oracles/src/repositories/oracle_repository.py +++ b/oracles/src/repositories/oracle_repository.py @@ -23,6 +23,7 @@ from src.entities import PromptType from src.entities import KnowledgeBaseIndexingRequest from src.entities import KnowledgeBaseQuery +from web3.types import TxReceipt class OracleRepository: @@ -48,10 +49,8 @@ async def _index_new_chats(self): chats_count = await self.oracle_contract.functions.promptsCount().call() config = None if chats_count > self.last_chats_count: - print( - f"Indexing new prompts from {self.last_chats_count} to {chats_count}", - flush=True, - ) + print(f"Indexing new prompts from {self.last_chats_count} to {chats_count}", + flush=True) for i in range(self.last_chats_count, chats_count): callback_id = await self.oracle_contract.functions.promptCallbackIds( i @@ -103,18 +102,40 @@ async def send_chat_response(self, chat: Chat) -> bool: except Exception as e: chat.is_processed = True chat.transaction_receipt = {"error": str(e)} + await self.mark_as_done(chat) return False - signed_tx = self.web3_client.eth.account.sign_transaction( - tx, private_key=self.account.key - ) - tx_hash = await self.web3_client.eth.send_raw_transaction( - signed_tx.rawTransaction - ) - tx_receipt = await self.web3_client.eth.wait_for_transaction_receipt(tx_hash) + tx_receipt = await self._sign_and_send_tx(tx) chat.transaction_receipt = tx_receipt chat.is_processed = bool(tx_receipt.get("status")) return bool(tx_receipt.get("status")) + async def mark_as_done(self, chat: Chat): + nonce = await self.web3_client.eth.get_transaction_count(self.account.address) + tx_data = { + "from": self.account.address, + "nonce": nonce, + # TODO: pick gas amount in a better way + # "gas": 1000000, + "maxFeePerGas": self.web3_client.to_wei("2", "gwei"), + "maxPriorityFeePerGas": self.web3_client.to_wei("1", "gwei"), + } + if chain_id := settings.CHAIN_ID: + tx_data["chainId"] = int(chain_id) + + if chat.prompt_type == PromptType.OPENAI: + tx = await self.oracle_contract.functions.markOpenAiPromptAsProcessed( + chat.id, + ).build_transaction(tx_data) + elif chat.prompt_type == PromptType.GROQ: + tx = await self.oracle_contract.functions.markGroqPromptAsProcessed( + chat.id, + ).build_transaction(tx_data) + else: + tx = await self.oracle_contract.functions.markPromptAsProcessed( + chat.id, + ).build_transaction(tx_data) + return await self._sign_and_send_tx(tx) + async def _build_response_tx(self, chat: Chat): nonce = await self.web3_client.eth.get_transaction_count(self.account.address) tx_data = { @@ -217,18 +238,31 @@ async def send_function_call_response( except Exception as e: function_call.is_processed = True function_call.transaction_receipt = {"error": str(e)} + await self.mark_function_call_as_done(function_call) return False - signed_tx = self.web3_client.eth.account.sign_transaction( - tx, private_key=self.account.key - ) - tx_hash = await self.web3_client.eth.send_raw_transaction( - signed_tx.rawTransaction - ) - tx_receipt = await self.web3_client.eth.wait_for_transaction_receipt(tx_hash) + tx_receipt = await self._sign_and_send_tx(tx) function_call.transaction_receipt = tx_receipt function_call.is_processed = bool(tx_receipt.get("status")) return bool(tx_receipt.get("status")) + async def mark_function_call_as_done(self, function_call: FunctionCall): + nonce = await self.web3_client.eth.get_transaction_count(self.account.address) + tx_data = { + "from": self.account.address, + "nonce": nonce, + # TODO: pick gas amount in a better way + # "gas": 1000000, + "maxFeePerGas": self.web3_client.to_wei("2", "gwei"), + "maxPriorityFeePerGas": self.web3_client.to_wei("1", "gwei"), + } + if chain_id := settings.CHAIN_ID: + tx_data["chainId"] = int(chain_id) + + tx = await self.oracle_contract.functions.markFunctionAsProcessed( + function_call.id, + ).build_transaction(tx_data) + return await self._sign_and_send_tx(tx) + async def _index_new_kb_index_requests(self): kb_index_request_count = ( await self.oracle_contract.functions.kbIndexingRequestCount().call() @@ -394,11 +428,8 @@ async def _get_openai_config(self, i: int) -> Optional[OpenAiConfig]: temperature=_parse_float_from_int(config[8], 0, 20), top_p=_parse_float_from_int(config[9], 0, 100, decimals=2), tools=_parse_tools(config[10]), - tool_choice=( - config[11] - if (config[11] and config[11] in get_args(OpenaiToolChoiceType)) - else None - ), + tool_choice=config[11] if (config[11] and config[11] in get_args( + OpenaiToolChoiceType)) else None, user=_value_or_none(config[12]), ) except: @@ -428,8 +459,7 @@ async def _get_groq_config(self, i: int) -> Optional[GroqConfig]: async def _get_prompt_type(self, i) -> PromptType: prompt_type: Optional[str] = await self.oracle_contract.functions.promptType( - i - ).call() + i).call() if not prompt_type: return PromptType.DEFAULT try: @@ -437,6 +467,15 @@ async def _get_prompt_type(self, i) -> PromptType: except: return PromptType.DEFAULT + async def _sign_and_send_tx(self, tx) -> TxReceipt: + signed_tx = self.web3_client.eth.account.sign_transaction( + tx, private_key=self.account.key + ) + tx_hash = await self.web3_client.eth.send_raw_transaction( + signed_tx.rawTransaction + ) + await self.web3_client.eth.wait_for_transaction_receipt(tx_hash) + async def _get_openai_config(self, i: int) -> Optional[OpenAiConfig]: config = await self.oracle_contract.functions.openAiConfigurations(i).call() if not config or not config[0] or not config[0] in get_args(OpenAiModelType): @@ -471,11 +510,13 @@ def _value_or_none(value: Any) -> Optional[Any]: def _parse_float_from_int( - value: Optional[float], min_value: int, max_value: int, decimals: int = 1 + value: Optional[float], + min_value: int, + max_value: int, + decimals: int = 1 ) -> Optional[int]: - return ( - round(value / (10**decimals), 1) if (min_value <= value <= max_value) else None - ) + return round(value / (10 ** decimals), 1) if ( + min_value <= value <= max_value) else None def _parse_json_string(value: Optional[str]) -> Optional[Dict]: