Skip to content

Commit

Permalink
Merge pull request #936 from fatualux/develop
Browse files Browse the repository at this point in the history
WebSocket Fix Proposal
  • Loading branch information
pieroit authored Oct 7, 2024
2 parents 6b251cb + 1cb3341 commit f51a40e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
7 changes: 3 additions & 4 deletions core/cat/auth/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,10 @@ async def get_user_stray(self, user: AuthUserInfo, connection: WebSocket) -> Str

if user.id in strays.keys():
stray = strays[user.id]
# Close previus ws connection
if stray._StrayCat__ws:
await stray._StrayCat__ws.close()
await stray.close_connection()

# Set new ws connection
stray._StrayCat__ws = connection
stray.reset_connection(connection)
log.info(
f"New websocket connection for user '{user.id}', the old one has been closed."
)
Expand Down
23 changes: 22 additions & 1 deletion core/cat/looking_glass/stray_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from cat.convo.messages import CatMessage, UserMessage, MessageWhy, Role, EmbedderModelInteraction
from cat.agents import AgentOutput
from cat import utils
from websockets.exceptions import ConnectionClosedOK

MSG_TYPES = Literal["notification", "chat", "error", "chat_token"]

Expand Down Expand Up @@ -484,7 +485,13 @@ def run(self, user_message_json, return_message=False):
if return_message:
return {"error": str(e)}
else:
self.send_error(e)
try:
self.send_error(e)
except ConnectionClosedOK as ex:
log.warning(ex)
if self.__ws:
del self.__ws
self.__ws = None

def classify(
self, sentence: str, labels: List[str] | Dict[str, List[str]]
Expand Down Expand Up @@ -599,6 +606,20 @@ def langchainfy_chat_history(self, latest_n: int = 5) -> List[BaseMessage]:

return langchain_chat_history

async def close_connection(self):
if self.__ws:
try:
await self.__ws.close()
except RuntimeError as ex:
log.warning(ex)
if self.__ws:
del self.__ws
self.__ws = None

def reset_connection(self, connection):
"""Reset the connection to the API service."""
self.__ws = connection

@property
def user_id(self):
return self.__user_id
Expand Down

0 comments on commit f51a40e

Please sign in to comment.