From fd9941d4ac0a566ed11bd7eb533fb2c8fee25a58 Mon Sep 17 00:00:00 2001 From: helllllllder Date: Wed, 12 Jul 2023 14:58:01 -0300 Subject: [PATCH] feat: add chat_completion action on the room endpoint --- chats/apps/api/v1/msgs/serializers.py | 23 ++++++++++++++++++++++- chats/apps/api/v1/rooms/viewsets.py | 24 +++++++++++++++++++++++- chats/apps/rooms/models.py | 4 ++++ 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/chats/apps/api/v1/msgs/serializers.py b/chats/apps/api/v1/msgs/serializers.py index 81c76655..dc9d18bc 100644 --- a/chats/apps/api/v1/msgs/serializers.py +++ b/chats/apps/api/v1/msgs/serializers.py @@ -82,7 +82,6 @@ def create(self, validated_data): file_type.startswith("audio") or file_type.lower() in settings.UNPERMITTED_AUDIO_TYPES ): - export_conf = {"format": settings.AUDIO_TYPE_TO_CONVERT} if settings.AUDIO_CODEC_TO_CONVERT != "": export_conf["codec"] = settings.AUDIO_CODEC_TO_CONVERT @@ -210,3 +209,25 @@ class Meta: class MessageWSSerializer(MessageSerializer): pass + + +class ChatCompletionSerializer(serializers.ModelSerializer): + role = serializers.SerializerMethodField(read_only=True) + content = serializers.CharField(read_only=True, source="text") + + class Meta: + model = ChatMessage + fields = [ + "role", + "content", + ] + + extra_kwargs = { + "media_file": {"write_only": True}, + } + + def get_role(self, message: ChatMessage): + if message.contact: + return "user" + else: + return "assistant" diff --git a/chats/apps/api/v1/rooms/viewsets.py b/chats/apps/api/v1/rooms/viewsets.py index f5efab65..01389965 100644 --- a/chats/apps/api/v1/rooms/viewsets.py +++ b/chats/apps/api/v1/rooms/viewsets.py @@ -11,6 +11,8 @@ from rest_framework.viewsets import GenericViewSet from chats.apps.api.v1 import permissions as api_permissions +from chats.apps.api.v1.internal.rest_clients.openai_rest_client import OpenAIClient +from chats.apps.api.v1.msgs.serializers import ChatCompletionSerializer from chats.apps.api.v1.rooms import filters as room_filters from chats.apps.api.v1.rooms.serializers import ( RoomMessageStatusSerializer, @@ -237,4 +239,24 @@ def perform_destroy(self, instance): ) def chat_completion(self, request, *args, **kwargs): room = self.get_object() - pass + token = room.queue.sector.project.openai_token + if not token: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={"detail": "OpenAI token not found"}, + ) + messages = room.last_5_messages() + serialized_data = ChatCompletionSerializer(messages).data + sector = room.queue.sector + if sector.completion_context: + serialized_data.append( + {"role": "system", "content": sector.completion_context} + ) + + openai_client = OpenAIClient() + completion_response = openai_client.chat_completion( + token=token, messages=serialized_data + ) + return Response( + status=completion_response.status_code, data=completion_response.json() + ) diff --git a/chats/apps/rooms/models.py b/chats/apps/rooms/models.py index 5477b34c..9e82d72f 100644 --- a/chats/apps/rooms/models.py +++ b/chats/apps/rooms/models.py @@ -135,6 +135,10 @@ def serialized_ws_data(self): return RoomSerializer(self).data + @property + def last_5_messages(self): + return self.messages.exclude(text="").order_by("-created_on")[:5] + def close(self, tags: list = [], end_by: str = ""): self.is_active = False self.ended_at = timezone.now()