diff --git a/core/cat/auth/connection.py b/core/cat/auth/connection.py index d89661a2..702a15ae 100644 --- a/core/cat/auth/connection.py +++ b/core/cat/auth/connection.py @@ -140,11 +140,10 @@ async def get_user_stray(self, user: AuthUserInfo, connection: WebSocket) -> Str if user.id in strays.keys(): stray = strays[user.id] - # Close previus ws connection - if stray._StrayCat__ws: - await stray._StrayCat__ws.close() + await stray.close_connection() + # Set new ws connection - stray._StrayCat__ws = connection + stray.reset_connection(connection) log.info( f"New websocket connection for user '{user.id}', the old one has been closed." ) diff --git a/core/cat/looking_glass/stray_cat.py b/core/cat/looking_glass/stray_cat.py index 08a120d4..c06b9f3f 100644 --- a/core/cat/looking_glass/stray_cat.py +++ b/core/cat/looking_glass/stray_cat.py @@ -19,6 +19,7 @@ from cat.convo.messages import CatMessage, UserMessage, MessageWhy, Role, EmbedderModelInteraction from cat.agents import AgentOutput from cat import utils +from websockets.exceptions import ConnectionClosedOK MSG_TYPES = Literal["notification", "chat", "error", "chat_token"] @@ -484,7 +485,13 @@ def run(self, user_message_json, return_message=False): if return_message: return {"error": str(e)} else: - self.send_error(e) + try: + self.send_error(e) + except ConnectionClosedOK as ex: + log.warning(ex) + if self.__ws: + del self.__ws + self.__ws = None def classify( self, sentence: str, labels: List[str] | Dict[str, List[str]] @@ -599,6 +606,20 @@ def langchainfy_chat_history(self, latest_n: int = 5) -> List[BaseMessage]: return langchain_chat_history + async def close_connection(self): + if self.__ws: + try: + await self.__ws.close() + except RuntimeError as ex: + log.warning(ex) + if self.__ws: + del self.__ws + self.__ws = None + + def reset_connection(self, connection): + """Reset the connection to the API service.""" + self.__ws = connection + @property def user_id(self): return self.__user_id