diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d40a3c9..7fce9ea 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -20,7 +20,7 @@ jobs: - name: Set up Python 3.11 uses: actions/setup-python@v3 with: - python-version: "3.11.8" + python-version: "3.11.4" - name: Cache dependencies id: cache diff --git a/.idea/misc.xml b/.idea/misc.xml index 3a82b7d..109df5b 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -3,5 +3,5 @@ - + \ No newline at end of file diff --git a/.idea/workspace.xml b/.idea/workspace.xml index 7cbe1da..8b85c7d 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -4,15 +4,28 @@ - - - + + + + + + + + + + + + @@ -25,6 +38,19 @@ + + + + + + + + + + + + + + + + @@ -50,23 +83,131 @@ - { - "keyToString": { - "ASKED_ADD_EXTERNAL_FILES": "true", - "ASKED_SHARE_PROJECT_CONFIGURATION_FILES": "true", - "RunOnceActivity.OpenProjectViewOnStart": "true", - "RunOnceActivity.ShowReadmeOnStart": "true", - "git-widget-placeholder": "development", - "ignore.virus.scanning.warn.message": "true", - "node.js.detected.package.eslint": "true", - "node.js.detected.package.tslint": "true", - "node.js.selected.package.eslint": "(autodetect)", - "node.js.selected.package.tslint": "(autodetect)", - "nodejs_package_manager_path": "npm", - "settings.editor.selected.configurable": "com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable", - "vue.rearranger.settings.migration": "true" + +}]]> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -88,6 +229,12 @@ + + + + + + - @@ -177,7 +348,13 @@ + - + + \ No newline at end of file diff --git a/app.py b/app.py index 535d8e5..502b03b 100644 --- a/app.py +++ b/app.py @@ -1,16 +1,9 @@ """ Flask PostgreSQL Process Handler """ - from src.app.flask_postgresql.configs import Config -from src.app.flask_postgresql.create_flask_postgresql_app import ( - create_flask_postgresql_app, -) -from src.infrastructure.loggers.logger_default import LoggerDefault - -logger = LoggerDefault() - +from src.app.flask_postgresql.create_flask_postgresql_app import create_flask_postgresql_app if __name__ == "__main__": - app = create_flask_postgresql_app(Config, logger) + app = create_flask_postgresql_app(Config) app.run(host="0.0.0.0", port=Config.PORT, debug=True) diff --git a/poetry.lock b/poetry.lock index 065f040..d50c685 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2303,4 +2303,4 @@ requests = ">=2.31" [metadata] lock-version = "2.0" python-versions = "3.11.4" -content-hash = "5756e8462df2bb105e395d1e952430f6c2972e4f3d97576a9c7a1a5d18da1671" +content-hash = "429381f71eba65b35eb1570849cd6163566b2eaaf9bec64a268a894aa954fc51" diff --git a/src/app/flask_postgresql/api/blueprints/__init__.py b/src/app/flask_postgresql/api/blueprints/__init__.py index b5cdd73..8a73e2d 100644 --- a/src/app/flask_postgresql/api/blueprints/__init__.py +++ b/src/app/flask_postgresql/api/blueprints/__init__.py @@ -4,7 +4,7 @@ from .main import blueprint as main_blueprint -def setup_blueprints(app: Flask) -> None: +def setup_blueprints(app: Flask) -> Flask: """ Register the necessary blueprints for the Flask app. @@ -12,7 +12,7 @@ def setup_blueprints(app: Flask) -> None: app (Flask): The Flask app instance. Returns: - None: This function does not return anything. + None: The Flask app instance with the blueprints registered. """ app.register_blueprint(callback_blueprint) app.register_blueprint(main_blueprint) diff --git a/src/app/flask_postgresql/api/blueprints/callback.py b/src/app/flask_postgresql/api/blueprints/callback.py index 4b4d04c..b4918bc 100644 --- a/src/app/flask_postgresql/api/blueprints/callback.py +++ b/src/app/flask_postgresql/api/blueprints/callback.py @@ -13,7 +13,7 @@ @blueprint.route("/callback", methods=["POST"]) def callback_blueprint(): """Function printing python version.""" - logger: LoggerInterface = current_app.config["logger"] + logger: LoggerInterface = current_app.config["container"].resolve("logger") signature = request.headers["X-Line-Signature"] body = request.get_data(as_text=True) diff --git a/src/app/flask_postgresql/api/container.py b/src/app/flask_postgresql/api/container.py new file mode 100644 index 0000000..3015309 --- /dev/null +++ b/src/app/flask_postgresql/api/container.py @@ -0,0 +1,20 @@ +from flask import Flask, g +from src.infrastructure.container.container import Container +from src.infrastructure.repositories.agent_chain.agent_chain_in_memory_repository import ( + AgentExecutorInMemoryRepository, +) +from src.infrastructure.repositories.window import WindowPostgresqlRepository +from src.infrastructure.loggers.logger_default import LoggerDefault + + +def setup_container(app: Flask) -> Flask: + """Setup Container for the app""" + container = Container() + + container.register("agent_repository", AgentExecutorInMemoryRepository()) + container.register("window_repository", WindowPostgresqlRepository()) + container.register("logger", LoggerDefault()) + + app.config["container"] = container + + return app diff --git a/src/app/flask_postgresql/api/error_handler.py b/src/app/flask_postgresql/api/error_handler.py index 353f8f7..f41692a 100644 --- a/src/app/flask_postgresql/api/error_handler.py +++ b/src/app/flask_postgresql/api/error_handler.py @@ -31,12 +31,12 @@ def format_marshmallow_validation_error(errors: Dict): return errors_message -def setup_error_handler(app: Flask) -> None: +def setup_error_handler(app: Flask) -> Flask: """ Function that will register all the specified error handlers for the app """ - logger: LoggerInterface = app.config["logger"] + logger: LoggerInterface = app.config["container"].resolve("logger") def error_handler(error): logger.log_exception("exception of type {} occurred".format(type(error))) diff --git a/src/app/flask_postgresql/api/event_handlers/__init__.py b/src/app/flask_postgresql/api/event_handlers/__init__.py index 151bf5b..62aa280 100644 --- a/src/app/flask_postgresql/api/event_handlers/__init__.py +++ b/src/app/flask_postgresql/api/event_handlers/__init__.py @@ -3,43 +3,38 @@ from flask import current_app from linebot.v3 import WebhookHandler from linebot.v3.messaging import Configuration -from linebot.v3.webhooks import FileMessageContent, MessageEvent, TextMessageContent +from linebot.v3.webhooks import MessageEvent, TextMessageContent from src.app.flask_postgresql.api.event_handlers.file_event_handler import FileEventHandler from src.app.flask_postgresql.api.event_handlers.text_event_handler import TextEventHandler from src.app.flask_postgresql.api.response import create_response from src.app.flask_postgresql.configs import Config -from src.infrastructure.repositories.agent_chain.agent_chain_in_memory_repository import ( - AgentExecutorInMemoryRepository, -) +from src.infrastructure.container.container import Container handler = WebhookHandler(Config.CHANNEL_SECRET) configuration = Configuration(access_token=Config.CHANNEL_ACCESS_TOKEN) -agent_repository = AgentExecutorInMemoryRepository() @handler.add(MessageEvent, message=TextMessageContent) def handle_text_message(event: MessageEvent): - handler = TextEventHandler( - logger=current_app.config["logger"], - agent_repository=agent_repository, - ) - handler.get_event_info(event) - result = handler.execute() + container: Container = current_app.config["container"] + text_handler = TextEventHandler(container) + text_handler.get_event_info(event) + result = text_handler.execute() return create_response(configuration, event.reply_token, result) -@handler.add(MessageEvent, message=FileMessageContent) -def handle_file_message(event: MessageEvent): - handler = FileEventHandler( - logger=current_app.config["logger"], - agent_repository=agent_repository, - configuration=configuration, - ) - result = handler.execute(event) - - return create_response(configuration, event.reply_token, result) +# @handler.add(MessageEvent, message=FileMessageContent) +# def handle_file_message(event: MessageEvent): +# file_handler = FileEventHandler( +# logger=current_app.config["logger"], +# agent_repository=agent_repository, +# configuration=configuration, +# ) +# result = file_handler.execute(event) +# +# return create_response(configuration, event.reply_token, result) __all__ = ["handler"] diff --git a/src/app/flask_postgresql/api/event_handlers/event_handler.py b/src/app/flask_postgresql/api/event_handlers/event_handler.py index 65c4647..3754d8d 100644 --- a/src/app/flask_postgresql/api/event_handlers/event_handler.py +++ b/src/app/flask_postgresql/api/event_handlers/event_handler.py @@ -1,38 +1,44 @@ """ This module implements the event handler for text message events. """ from abc import abstractmethod -from typing import Dict +from typing import Dict, cast from linebot.v3.webhooks import MessageEvent from src.app.flask_postgresql.configs import Config from src.app.flask_postgresql.interfaces.event_handler_interface import EventHandlerInterface from src.app.flask_postgresql.presenters.window_presenter import WindowPresenter -from src.infrastructure.repositories.agent_chain.agent_chain_in_memory_repository import ( - AgentExecutorInMemoryRepository, -) -from src.infrastructure.repositories.window.window_postgresql_repository import ( - WindowPostgresqlRepository, -) +from src.infrastructure.container.container import Container from src.interactor.dtos.event_dto import EventInputDto from src.interactor.dtos.window_dtos import CreateWindowInputDto, GetWindowInputDto from src.interactor.interfaces.logger.logger import LoggerInterface +from src.interactor.interfaces.repositories.agent_executor_repository import AgentExecutorRepositoryInterface +from src.interactor.interfaces.repositories.window_repository import WindowRepositoryInterface from src.interactor.use_cases.window.create_window import CreateWindowUseCase from src.interactor.use_cases.window.get_window import GetWindowUseCase class EventHandler(EventHandlerInterface): - def __init__(self, logger: LoggerInterface, agent_repository: AgentExecutorInMemoryRepository): - self.logger = logger - self.agent_repository = agent_repository - self.input_dto: EventInputDto + def __init__(self, container: Container): + self.container = container + self.logger = cast(LoggerInterface, container.resolve("logger")) + self.agent_repository = cast(AgentExecutorRepositoryInterface, container.resolve("agent_repository")) + self.window_repository = cast(WindowRepositoryInterface, container.resolve("window_repository")) + self.input_dto = None + self.event = None def get_event_info(self, event: MessageEvent): - if event.source.type == "user": + """ + Retrieves the information of an event. + """ + self.event = event + source_type = event.source.type + + if source_type == "user": window_id = event.source.user_id - elif event.source.type == "group": + elif source_type == "group": window_id = event.source.group_id - elif event.source.type == "room": + elif source_type == "room": window_id = event.source.room_id else: raise ValueError("Invalid event source type") @@ -45,6 +51,7 @@ def get_event_info(self, event: MessageEvent): self.input_dto = EventInputDto( window=self.get_window_info(window_id=window_id), user_input=user_input, + source_type=source_type, ) def get_window_info(self, window_id: str) -> Dict: @@ -66,9 +73,8 @@ def _get_window_info(self, window_id: str): :return: The result of executing the use case. """ - repository = WindowPostgresqlRepository() presenter = WindowPresenter() - use_case = GetWindowUseCase(presenter=presenter, repository=repository, logger=self.logger) + use_case = GetWindowUseCase(presenter=presenter, repository=self.window_repository, logger=self.logger) get_window_input_dto = GetWindowInputDto(window_id) result = use_case.execute(get_window_input_dto) return result @@ -80,7 +86,6 @@ def _create_window_info(self, window_id: str): Returns: The result of the create window use case execution. """ - repository = WindowPostgresqlRepository() presenter = WindowPresenter() create_window_input_dto = CreateWindowInputDto( window_id=window_id, @@ -90,7 +95,7 @@ def _create_window_info(self, window_id: str): temperature=0, ) use_case = CreateWindowUseCase( - presenter=presenter, repository=repository, logger=self.logger + presenter=presenter, repository=self.window_repository, logger=self.logger ) result = use_case.execute(create_window_input_dto) return result diff --git a/src/app/flask_postgresql/api/event_handlers/text_event_handler.py b/src/app/flask_postgresql/api/event_handlers/text_event_handler.py index 00329e8..51c8c53 100644 --- a/src/app/flask_postgresql/api/event_handlers/text_event_handler.py +++ b/src/app/flask_postgresql/api/event_handlers/text_event_handler.py @@ -3,22 +3,18 @@ from src.app.flask_postgresql.api.event_handlers.event_handler import EventHandler from src.app.flask_postgresql.presenters.message_reply_presenter import EventPresenter +from src.infrastructure.container.container import Container from src.interactor.dtos.event_dto import EventInputDto from src.interactor.use_cases.message.create_message_reply import CreateMessageReplyUseCase class TextEventHandler(EventHandler): - def __init__(self, logger, agent_repository): - self.logger = logger - self.agent_repository = agent_repository + def __init__(self, container: Container): + super().__init__(container) self.input_dto: EventInputDto def execute(self): presenter = EventPresenter() - use_case = CreateMessageReplyUseCase( - presenter=presenter, - repository=self.agent_repository, - logger=self.logger, - ) + use_case = CreateMessageReplyUseCase(presenter=presenter, container=self.container) result = use_case.execute(self.input_dto) return result diff --git a/src/app/flask_postgresql/api/response.py b/src/app/flask_postgresql/api/response.py index 838d64d..7d1732c 100644 --- a/src/app/flask_postgresql/api/response.py +++ b/src/app/flask_postgresql/api/response.py @@ -10,9 +10,9 @@ from linebot.v3.messaging.models.message import Message -def create_response( - configuration: Configuration, reply_token: str, messages: List[Message] -) -> ApiResponse: +def create_response(configuration: Configuration, reply_token: str, messages: List[Message]) -> ApiResponse: + if not messages: + return with ApiClient(configuration) as api_client: line_bot_api = MessagingApi(api_client) diff --git a/src/app/flask_postgresql/create_flask_postgresql_app.py b/src/app/flask_postgresql/create_flask_postgresql_app.py index 03bff80..17ec0eb 100644 --- a/src/app/flask_postgresql/create_flask_postgresql_app.py +++ b/src/app/flask_postgresql/create_flask_postgresql_app.py @@ -3,17 +3,17 @@ from flask import Flask from src.app.flask_postgresql.api.blueprints import setup_blueprints +from src.app.flask_postgresql.api.container import setup_container from src.app.flask_postgresql.api.error_handler import setup_error_handler from src.app.flask_postgresql.api.request_context import setup_request_context from src.infrastructure.databases.sql_alchemy import setup_sqlalchemy -from src.interactor.interfaces.logger.logger import LoggerInterface -def create_flask_postgresql_app(config, logger: LoggerInterface) -> Flask: +def create_flask_postgresql_app(config) -> Flask: """Create Main Flask PostgreSQL app""" app = Flask(__name__) app.config.from_object(config) - app.config["logger"] = logger + app = setup_container(app) app = setup_blueprints(app) app = setup_sqlalchemy(app) app = setup_error_handler(app) diff --git a/src/app/flask_postgresql/create_flask_postgresql_app_test.py b/src/app/flask_postgresql/create_flask_postgresql_app_test.py index 2b89c18..a5189d2 100644 --- a/src/app/flask_postgresql/create_flask_postgresql_app_test.py +++ b/src/app/flask_postgresql/create_flask_postgresql_app_test.py @@ -5,21 +5,17 @@ from flask.testing import FlaskClient from src.app.flask_postgresql.configs import Config -from src.infrastructure.loggers.logger_default import LoggerDefault with mock.patch("sqlalchemy.create_engine") as mock_create_engine, mock.patch( - "langchain.utilities.SerpAPIWrapper" + "langchain.utilities.SerpAPIWrapper" ) as mock_sessionmaker: from .create_flask_postgresql_app import create_flask_postgresql_app -logger = LoggerDefault() - - @pytest.fixture(name="flask_postgresql_app") def fixture_flask_postgresql_app(): """Fixture for flask app with blueprint""" - app: Flask = create_flask_postgresql_app(Config, logger) + app: Flask = create_flask_postgresql_app(Config) app.config.update( { "TESTING": True, @@ -56,8 +52,8 @@ def test_request_window(mocker, client_flask_postgresql_app: FlaskClient): def test_request_window_wrong_url_error( - mocker, - client_flask_postgresql_app, + mocker, + client_flask_postgresql_app, ): """Test request example""" headers_data = {"X-Line-Signature": "test"} diff --git a/src/infrastructure/container/container.py b/src/infrastructure/container/container.py new file mode 100644 index 0000000..3f84c3d --- /dev/null +++ b/src/infrastructure/container/container.py @@ -0,0 +1,15 @@ +class Container: + def __init__(self): + self._dependencies = {} + + def register(self, key, dependency): + self._dependencies[key] = dependency + + def resolve(self, key): + return self._dependencies[key] + + def __getitem__(self, key): + return self.resolve(key) + + def __setitem__(self, key, dependency): + self.register(key, dependency) diff --git a/src/infrastructure/repositories/agent_chain/agent_chain_in_memory_repository.py b/src/infrastructure/repositories/agent_chain/agent_chain_in_memory_repository.py index 3f4284e..633902f 100644 --- a/src/infrastructure/repositories/agent_chain/agent_chain_in_memory_repository.py +++ b/src/infrastructure/repositories/agent_chain/agent_chain_in_memory_repository.py @@ -61,7 +61,7 @@ def _create_agent( agent_language: str, temperature: float, memory_key: str, - tools: list, + tools_list: list, ) -> OpenAIFunctionsAgent: """ Creates an instance of the OpenAIFunctionsAgent class. @@ -82,10 +82,10 @@ def _create_agent( system_message=system_message, extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key), system_language], ) - return OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt) + return OpenAIFunctionsAgent(llm=llm, tools=tools_list, prompt=prompt) def _create_agent_executor( - self, agent: OpenAIFunctionsAgent, memory: BaseChatMemory, tools: list + self, agent: OpenAIFunctionsAgent, memory: BaseChatMemory, tools_list: list ) -> AgentExecutor: """ Creates an agent executor using the provided agent, memory, and other optional arguments. @@ -100,13 +100,13 @@ def _create_agent_executor( """ return AgentExecutor( agent=agent, - tools=tools, + tools=tools_list, memory=memory, verbose=True, max_iterations=3, ) - def get(self, window_id: str) -> AgentExecutor: + def get(self, window_id: str) -> AgentExecutor | None: """Get AgentExecutor by id :param window_id: str @@ -150,9 +150,9 @@ def create( agent_language=agent_language, temperature=temperature, memory_key=memory_key, - tools=tools, + tools_list=tools, ) - agent_executor = self._create_agent_executor(agent=agent, memory=memory, tools=tools) + agent_executor = self._create_agent_executor(agent=agent, memory=memory, tools_list=tools) self._data[window_id] = agent_executor return self._data[window_id] diff --git a/src/interactor/dtos/event_dto.py b/src/interactor/dtos/event_dto.py index cee0ac6..ff9d5bd 100644 --- a/src/interactor/dtos/event_dto.py +++ b/src/interactor/dtos/event_dto.py @@ -1,7 +1,6 @@ """ Module for Events Dtos """ - from dataclasses import asdict, dataclass from typing import Dict, List @@ -14,6 +13,7 @@ class EventInputDto: window: Dict user_input: str + source_type: str def to_dict(self): """Convert data into dictionary""" diff --git a/src/interactor/use_cases/message/cor/__init__.py b/src/interactor/use_cases/message/cor/__init__.py index c82a5f4..eb2bf5d 100644 --- a/src/interactor/use_cases/message/cor/__init__.py +++ b/src/interactor/use_cases/message/cor/__init__.py @@ -1,21 +1,33 @@ -from typing import List - -from linebot.v3.messaging.models.message import Message - +from src.infrastructure.container.container import Container from src.interactor.dtos.event_dto import EventInputDto -from src.interactor.interfaces.repositories.agent_executor_repository import ( - AgentExecutorRepositoryInterface, -) from src.interactor.use_cases.message.cor.addition_handler import AdditionHandler from src.interactor.use_cases.message.cor.default_handler import DefaultHandler from src.interactor.use_cases.message.cor.muting_handler import MutingHandler +from src.interactor.use_cases.message.cor.viki_handler import VikiHandler +from src.interactor.use_cases.message.cor.window_mutable_handler import WindowMutableHandler class ReplyMessagesCOR: - def __init__(self): - self._chain = MutingHandler(AdditionHandler(DefaultHandler())) + def __init__(self, container: Container): + self.container = container + self._chain = None + self._initialize_chain() + + def _initialize_chain(self): + # create the chain of responsibility + viki_handler = VikiHandler(self.container) + window_mutable_handler = WindowMutableHandler(self.container) + muting_handler = MutingHandler(self.container) + # addition_handler = AdditionHandler(self.container) + default_handler = DefaultHandler(self.container) + + # set the successor of each handler + viki_handler.set_successor(window_mutable_handler) + window_mutable_handler.set_successor(muting_handler) + muting_handler.set_successor(default_handler) + + # set the chain + self._chain = viki_handler - def handle(self, input_dto: EventInputDto, repository: AgentExecutorRepositoryInterface): - response: List[Message] = [] - self._chain.handle(input_dto, repository, response) - return response + def handle(self, input_dto: EventInputDto): + return self._chain.handle(input_dto) diff --git a/src/interactor/use_cases/message/cor/addition_handler.py b/src/interactor/use_cases/message/cor/addition_handler.py index 4304092..8a0deee 100644 --- a/src/interactor/use_cases/message/cor/addition_handler.py +++ b/src/interactor/use_cases/message/cor/addition_handler.py @@ -11,14 +11,13 @@ class AdditionHandler(Handler): - def handle( - self, - input_dto: EventInputDto, - repository: AgentExecutorRepositoryInterface, - response: List[Message], - ): - response.append(TextMessage(text="test handler")) + def handle(self, input_dto: EventInputDto): + + messages: List[Message] = [] + messages.extend([TextMessage(text="test handler")]) if self._successor is not None: - return self._successor.handle(input_dto, repository, response) + messages.extend(self._successor.handle(input_dto)) else: - response.append(TextMessage(text="靜悄悄的,什麼都沒有發生。")) + messages.extend([TextMessage(text="Something went wrong! >_<")]) + + return messages diff --git a/src/interactor/use_cases/message/cor/default_handler.py b/src/interactor/use_cases/message/cor/default_handler.py index 017c08c..f4f8682 100644 --- a/src/interactor/use_cases/message/cor/default_handler.py +++ b/src/interactor/use_cases/message/cor/default_handler.py @@ -1,29 +1,41 @@ -from typing import List - from langchain.agents import AgentExecutor from linebot.v3.messaging.models import TextMessage -from linebot.v3.messaging.models.message import Message from src.interactor.dtos.event_dto import EventInputDto -from src.interactor.interfaces.repositories.agent_executor_repository import ( - AgentExecutorRepositoryInterface, -) from src.interactor.use_cases.message.cor.handler_base import Handler class DefaultHandler(Handler): - def handle( - self, - input_dto: EventInputDto, - repository: AgentExecutorRepositoryInterface, - response: List[Message], - ): + """ + The last handler in the chain of responsibility. + + This handler uses the agent executor to process the user input. + """ + + def _get_agent_executor(self, input_dto: EventInputDto) -> AgentExecutor: + """ + Retrieves the agent executor associated with the current window. + + :param None: This function does not take any parameters. + :return: None + """ + window_id = input_dto.window.get("window_id") + + agent_executor = self.agent_repository.get( + window_id=window_id, + ) + if agent_executor is None: + agent_executor = self.agent_repository.create(window_id=window_id) + return agent_executor + + def handle(self, input_dto: EventInputDto): + messages = [] try: - agent_executor = self._get_agent_executor(input_dto, repository) + agent_executor = self._get_agent_executor(input_dto) result = agent_executor.run(input=input_dto.user_input) - response.append(TextMessage(text=result)) + messages.extend([TextMessage(text=result)]) except Exception as e: - print(e) - response.append(TextMessage(text="出現錯誤啦!請稍後再試。")) + self.logger.error(e) + messages.extend([TextMessage(text="Something went wrong! >_<")]) finally: - return response + return messages diff --git a/src/interactor/use_cases/message/cor/handler_base.py b/src/interactor/use_cases/message/cor/handler_base.py index fd0dcc1..38ef059 100644 --- a/src/interactor/use_cases/message/cor/handler_base.py +++ b/src/interactor/use_cases/message/cor/handler_base.py @@ -1,47 +1,28 @@ from abc import ABC, abstractmethod -from typing import List, Type +from typing import List -from langchain.agents import AgentExecutor from linebot.v3.messaging.models.message import Message +from src.infrastructure.container.container import Container from src.interactor.dtos.event_dto import EventInputDto +from src.interactor.interfaces.logger.logger import LoggerInterface from src.interactor.interfaces.repositories.agent_executor_repository import ( AgentExecutorRepositoryInterface, ) +from src.interactor.interfaces.repositories.window_repository import WindowRepositoryInterface class Handler(ABC): - def __init__(self, successor: Type["Handler"] = None): + def __init__(self, container: Container = None): + self._successor = None + self.container: Container = container + self.logger: LoggerInterface = container.resolve("logger") + self.agent_repository: AgentExecutorRepositoryInterface = container.resolve("agent_repository") + self.window_repository: WindowRepositoryInterface = container.resolve("window_repository") + + def set_successor(self, successor: "Handler"): self._successor = successor - def _get_agent_executor( - self, - input_dto: EventInputDto, - repository: AgentExecutorRepositoryInterface, - ) -> AgentExecutor: - """ - Retrieves the agent executor associated with the current window. - - :param None: This function does not take any parameters. - :return: None - """ - - window_id = input_dto.window.get("window_id") - - agent_executor = repository.get( - window_id=window_id, - ) - if agent_executor is None: - agent_executor = repository.create( - window_id=window_id, - ) - return agent_executor - @abstractmethod - def handle( - self, - input_dto: EventInputDto, - repository: AgentExecutorRepositoryInterface, - response: List[Message], - ) -> List[Message]: + def handle(self, input_dto: EventInputDto, ) -> List[Message]: pass diff --git a/src/interactor/use_cases/message/cor/muting_handler.py b/src/interactor/use_cases/message/cor/muting_handler.py index f8fa12f..96c9b02 100644 --- a/src/interactor/use_cases/message/cor/muting_handler.py +++ b/src/interactor/use_cases/message/cor/muting_handler.py @@ -4,22 +4,23 @@ from linebot.v3.messaging.models.message import Message from src.interactor.dtos.event_dto import EventInputDto -from src.interactor.interfaces.repositories.agent_executor_repository import ( - AgentExecutorRepositoryInterface, -) from src.interactor.use_cases.message.cor.handler_base import Handler class MutingHandler(Handler): - def handle( - self, - input_dto: EventInputDto, - repository: AgentExecutorRepositoryInterface, - response: List[Message], - ): + + def handle(self, input_dto: EventInputDto): + """ + Check if the window is currently muting. If it is, return an empty list. + """ + + messages: List[Message] = [] + if input_dto.window.get("is_muting") is True: - return response + return messages elif self._successor is not None: - return self._successor.handle(input_dto, repository, response) + messages.extend(self._successor.handle(input_dto)) else: - response.append(TextMessage(text="靜悄悄的,什麼都沒有發生。")) + messages.extend([TextMessage(text="Something went wrong! >_<")]) + + return messages diff --git a/src/interactor/use_cases/message/cor/test/default_handler_test.py b/src/interactor/use_cases/message/cor/test/default_handler_test.py index 05a4220..c4435f0 100644 --- a/src/interactor/use_cases/message/cor/test/default_handler_test.py +++ b/src/interactor/use_cases/message/cor/test/default_handler_test.py @@ -7,47 +7,49 @@ def test_default_handler(mocker: mock, fixture_window: dict): - repository_mock = mocker.patch( - "src.interactor.use_cases.message.cor.default_handler.AgentExecutorRepositoryInterface" - ) - agent_executor_mock = mocker.Mock(run=mocker.Mock(return_value="mock agent executor result")) + container_mock = mocker.MagicMock() + agent_executor_mock = mocker.MagicMock() + agent_executor_mock.run.return_value = "mock agent executor result" get_agent_executor_mock = mocker.patch( "src.interactor.use_cases.message.cor.default_handler.DefaultHandler._get_agent_executor" ) get_agent_executor_mock.return_value = agent_executor_mock - response = [] - input_dto_mock = EventInputDto( window=fixture_window, user_input="mock user input", + source_type="mock source type", ) - default_handler = DefaultHandler() - default_handler.handle(input_dto_mock, repository_mock, response) + default_handler = DefaultHandler(container_mock) + messages = default_handler.handle(input_dto_mock) - get_agent_executor_mock.assert_called_once_with(input_dto_mock, repository_mock) + get_agent_executor_mock.assert_called_once_with(input_dto_mock) agent_executor_mock.run.assert_called_once_with(input=input_dto_mock.user_input) - assert len(response) == 1 - assert response[0] == TextMessage(text="mock agent executor result") + assert len(messages) == 1 + assert messages[0] == TextMessage(text="mock agent executor result") def test_get_agent_executor_in_default_handler(mocker: mock, fixture_window: dict): - repository_mock = mocker.patch( - "src.interactor.use_cases.message.cor.default_handler.AgentExecutorRepositoryInterface" - ) - repository_mock.get.return_value = None - agent_executor_mock = mocker.Mock(run=mocker.Mock(return_value="mock agent executor result")) - repository_mock.create.return_value = mocker.Mock(return_value=agent_executor_mock) + container_mock = mocker.MagicMock() + agent_executor_repository_mock = mocker.MagicMock() + container_mock.resolve.return_value = agent_executor_repository_mock + agent_executor_mock = mocker.MagicMock() + agent_executor_mock.run.return_value = "mock agent executor result" + + agent_executor_repository_mock.get.return_value = None + agent_executor_repository_mock.create.return_value = mocker.MagicMock() + agent_executor_repository_mock.create.return_value.run.return_value = "mock agent executor result" input_dto_mock = EventInputDto( window=fixture_window, user_input="mock user input", + source_type="mock source type", ) - default_handler = DefaultHandler() - default_handler._get_agent_executor(input_dto_mock, repository_mock) + default_handler = DefaultHandler(container_mock) + default_handler._get_agent_executor(input_dto_mock) - repository_mock.get.assert_called_once_with(window_id=fixture_window["window_id"]) - repository_mock.create.assert_called_once_with(window_id=fixture_window["window_id"]) + agent_executor_repository_mock.get.assert_called_once_with(window_id=fixture_window["window_id"]) + agent_executor_repository_mock.create.assert_called_once_with(window_id=fixture_window["window_id"]) diff --git a/src/interactor/use_cases/message/cor/test/muting_handler_test.py b/src/interactor/use_cases/message/cor/test/muting_handler_test.py index 1a20a17..409fa48 100644 --- a/src/interactor/use_cases/message/cor/test/muting_handler_test.py +++ b/src/interactor/use_cases/message/cor/test/muting_handler_test.py @@ -7,61 +7,55 @@ def test_muting_handler_with_fixture_window_with_muting( - mocker: mock, fixture_window_with_muting: dict + mocker: mock, fixture_window_with_muting: dict ): - repository_mock = mocker.patch( - "src.interactor.use_cases.message.cor.muting_handler.AgentExecutorRepositoryInterface" - ) - response = [] + container_mock = mocker.MagicMock() input_dto_mock = EventInputDto( window=fixture_window_with_muting, user_input="mock user input", + source_type="mock source type", ) - muting_handler = MutingHandler() - muting_handler.handle(input_dto_mock, repository_mock, response) + muting_handler = MutingHandler(container_mock) + messages = muting_handler.handle(input_dto_mock) # self._successor.handle is not called - assert len(response) == 0 + assert len(messages) == 0 assert muting_handler._successor is None - assert response == [] + assert messages == [] def test_muting_handler_with_no_successor(mocker: mock, fixture_window: dict): - repository_mock = mocker.patch( - "src.interactor.use_cases.message.cor.muting_handler.AgentExecutorRepositoryInterface" - ) - response = [] - input_dto_mock = EventInputDto( window=fixture_window, user_input="mock user input", + source_type="mock source type", ) - - muting_handler = MutingHandler() - muting_handler.handle(input_dto_mock, repository_mock, response) + container_mock = mocker.MagicMock() + muting_handler = MutingHandler(container_mock) + messages = muting_handler.handle(input_dto_mock) # self._successor.handle is not called - assert len(response) == 1 + assert len(messages) == 1 assert muting_handler._successor is None - assert response[0] == TextMessage(text="靜悄悄的,什麼都沒有發生。") + assert messages[0] == TextMessage(text="Something went wrong! >_<") def test_muting_handler_with_successor(mocker: mock, fixture_window: dict): - repository_mock = mocker.patch( - "src.interactor.use_cases.message.cor.muting_handler.AgentExecutorRepositoryInterface" - ) - response = [] + + container_mock = mocker.MagicMock() input_dto_mock = EventInputDto( window=fixture_window, user_input="mock user input", + source_type="mock source type", ) - muting_handler = MutingHandler(mocker.Mock()) - muting_handler.handle(input_dto_mock, repository_mock, response) + muting_handler = MutingHandler(container_mock) + muting_handler.set_successor(mocker.MagicMock()) + messages = muting_handler.handle(input_dto_mock) # self._successor.handle is called - assert len(response) == 0 + assert len(messages) == 0 assert muting_handler._successor.handle.call_count == 1 diff --git a/src/interactor/use_cases/message/cor/test/reply_messages_cor_test.py b/src/interactor/use_cases/message/cor/test/reply_messages_cor_test.py index 6390ca0..3ca8489 100644 --- a/src/interactor/use_cases/message/cor/test/reply_messages_cor_test.py +++ b/src/interactor/use_cases/message/cor/test/reply_messages_cor_test.py @@ -4,32 +4,36 @@ def test_reply_message_cor_initial(mocker: mock): + container_mock = mocker.MagicMock() muting_handler_mock = mocker.patch("src.interactor.use_cases.message.cor.MutingHandler") default_handler_mock = mocker.patch("src.interactor.use_cases.message.cor.DefaultHandler") - addition_handler_mock = mocker.patch("src.interactor.use_cases.message.cor.AdditionHandler") + # addition_handler_mock = mocker.patch("src.interactor.use_cases.message.cor.AdditionHandler") + window_mutable_handler_mock = mocker.patch("src.interactor.use_cases.message.cor.WindowMutableHandler") + viki_handler_mock = mocker.patch("src.interactor.use_cases.message.cor.VikiHandler") - reply_messages_cor = ReplyMessagesCOR() + muting_handler_mock_instance = muting_handler_mock.return_value + # addition_handler_mock_instance = addition_handler_mock.return_value + default_handler_mock_instance = default_handler_mock.return_value + window_mutable_handler_mock_instance = window_mutable_handler_mock.return_value + viki_handler_mock_instance = viki_handler_mock.return_value - muting_handler_mock.assert_called_once_with(addition_handler_mock.return_value) - addition_handler_mock.assert_called_once_with(default_handler_mock.return_value) - default_handler_mock.assert_called_once_with() + reply_messages_cor = ReplyMessagesCOR(container_mock) - assert reply_messages_cor._chain == muting_handler_mock.return_value + viki_handler_mock_instance.set_successor.assert_called_once_with(window_mutable_handler_mock_instance) + window_mutable_handler_mock_instance.set_successor.assert_called_once_with(muting_handler_mock_instance) + muting_handler_mock_instance.set_successor.assert_called_once_with(default_handler_mock_instance) + default_handler_mock_instance.set_successor.assert_not_called() + assert reply_messages_cor._chain == viki_handler_mock_instance -def test_reply_message_cor_handle(mocker: mock): - mocker.patch("src.interactor.use_cases.message.cor.MutingHandler") - mocker.patch("src.interactor.use_cases.message.cor.DefaultHandler") - mocker.patch("src.interactor.use_cases.message.cor.AdditionHandler") - reply_messages_cor = ReplyMessagesCOR() +def test_reply_message_cor_handle(mocker: mock): + container_mock = mocker.MagicMock() + input_dto = mocker.MagicMock() - input_dto = mocker.Mock() - repository = mocker.Mock() - response = [] + reply_messages_cor = ReplyMessagesCOR(container_mock) reply_messages_cor._chain.handle = mocker.Mock() - reply_messages_cor.handle(input_dto, repository) - - reply_messages_cor._chain.handle.assert_called_once_with(input_dto, repository, response) + reply_messages_cor.handle(input_dto) + reply_messages_cor._chain.handle.assert_called_once_with(input_dto) diff --git a/src/interactor/use_cases/message/cor/viki_handler.py b/src/interactor/use_cases/message/cor/viki_handler.py new file mode 100644 index 0000000..bd645b6 --- /dev/null +++ b/src/interactor/use_cases/message/cor/viki_handler.py @@ -0,0 +1,34 @@ +from linebot.v3.messaging.models import TextMessage + +from src.infrastructure.container.container import Container +from src.interactor.dtos.event_dto import EventInputDto +from src.interactor.use_cases.message.cor.handler_base import Handler + + +class VikiHandler(Handler): + """ + The first handler in the chain of responsibility. + + This handler checks if the user input contains "Viki" or "viki". + If it does, it will return an empty list. + """ + + def __init__(self, container: Container): + super().__init__(container) + + # convert all words to lowercase and add in _mute_word_set + self._trigger_words = {"Viki", "viki"} + + def handle(self, input_dto: EventInputDto): + # if input_dto.user_input not include "Viki" or "viki", return, unless the source_type is "user" + if input_dto.source_type != "user" and not any(word in input_dto.user_input for word in self._trigger_words): + return [] + + messages = [] + + if self._successor is not None: + messages.extend(self._successor.handle(input_dto)) + else: + messages.extend([TextMessage(text="Something went wrong! >_<")]) + + return messages diff --git a/src/interactor/use_cases/message/cor/window_mutable_handler.py b/src/interactor/use_cases/message/cor/window_mutable_handler.py new file mode 100644 index 0000000..fbb708f --- /dev/null +++ b/src/interactor/use_cases/message/cor/window_mutable_handler.py @@ -0,0 +1,67 @@ +from typing import List + +from linebot.v3.messaging.models import TextMessage +from linebot.v3.messaging.models.message import Message + +from src.infrastructure.container.container import Container +from src.interactor.dtos.event_dto import EventInputDto +from src.interactor.use_cases.message.cor.handler_base import Handler + +mute_word_set = { + "Viki mute", + "Viki shut up", +} + +unmute_word_set = { + "Viki unmute" +} + + +class WindowMutableHandler(Handler): + def __init__(self, container: Container): + super().__init__(container) + + # convert all words to lowercase and add in _mute_word_set + self._all_mute_words = mute_word_set.union( + {word.lower() for word in mute_word_set} + ) + + # convert all words to lowercase and add in _unmute_word_set + self._all_unmute_words = unmute_word_set.union( + {word.lower() for word in unmute_word_set} + ) + + def handle(self, input_dto: EventInputDto): + """ + Check if the user input contains any mute words. If it does, set the window's is_muting to True. + """ + messages: List[Message] = [] + + # if input_dto.user_input include mute word, set is_muting to True + if any(word in input_dto.user_input for word in self._all_mute_words): + window_id = input_dto.window.get("window_id") + window = self.window_repository.get(window_id) + if window is not None: + window.is_muting = True + self.window_repository.update(window) + messages.extend([TextMessage(text="Okay, bye (´•̥̥̥ω•̥̥̥`)")]) + else: + messages.extend([TextMessage(text="Something went wrong! >_<")]) + + # if input_dto.user_input include unmute word, set is_muting to False + elif any(word in input_dto.user_input for word in self._all_unmute_words): + window_id = input_dto.window.get("window_id") + window = self.window_repository.get(window_id) + if window is not None: + window.is_muting = False + self.window_repository.update(window) + messages.extend([TextMessage(text="Hello again! ヾ(*´▽‘*)ノ")]) + else: + messages.extend([TextMessage(text="Something went wrong! >_<")]) + + elif self._successor is not None: + messages.extend(self._successor.handle(input_dto)) + else: + messages.extend([TextMessage(text="Something went wrong! >_<")]) + + return messages diff --git a/src/interactor/use_cases/message/create_message_reply.py b/src/interactor/use_cases/message/create_message_reply.py index c05b917..e79dae3 100644 --- a/src/interactor/use_cases/message/create_message_reply.py +++ b/src/interactor/use_cases/message/create_message_reply.py @@ -1,11 +1,10 @@ """ This module is responsible for creating a new window. """ + +from src.infrastructure.container.container import Container from src.interactor.dtos.event_dto import EventInputDto, EventOutputDto from src.interactor.interfaces.logger.logger import LoggerInterface from src.interactor.interfaces.presenters.message_reply_presenter import EventPresenterInterface -from src.interactor.interfaces.repositories.agent_executor_repository import ( - AgentExecutorRepositoryInterface, -) from src.interactor.use_cases.message.cor import ReplyMessagesCOR from src.interactor.validations.event_input_validator import EventInputDtoValidator @@ -13,15 +12,10 @@ class CreateMessageReplyUseCase: """This class is responsible for creating a new window.""" - def __init__( - self, - presenter: EventPresenterInterface, - repository: AgentExecutorRepositoryInterface, - logger: LoggerInterface, - ): + def __init__(self, presenter: EventPresenterInterface, container: Container): self.presenter = presenter - self.repository = repository - self.logger = logger + self.container = container + self.logger: LoggerInterface = container.resolve("logger") def execute(self, input_dto: EventInputDto): """ @@ -39,9 +33,9 @@ def execute(self, input_dto: EventInputDto): validator = EventInputDtoValidator(input_dto.to_dict()) validator.validate() - reply_messages_cor = ReplyMessagesCOR() + reply_messages_cor = ReplyMessagesCOR(self.container) - response = reply_messages_cor.handle(input_dto, self.repository) + response = reply_messages_cor.handle(input_dto) output_dto = EventOutputDto( window=input_dto.window, diff --git a/src/interactor/use_cases/message/test/create_message_reply_test.py b/src/interactor/use_cases/message/test/create_message_reply_test.py index a9a2bc5..ead2f02 100644 --- a/src/interactor/use_cases/message/test/create_message_reply_test.py +++ b/src/interactor/use_cases/message/test/create_message_reply_test.py @@ -11,6 +11,9 @@ def test_create_message_reply(mocker: mock, fixture_window): presenter_mock = mocker.patch.object(EventPresenterInterface, "present") presenter_mock.present.return_value = "Test output" + logger_mock = mocker.patch.object(LoggerInterface, "log_info") + container_mock = mocker.MagicMock() + container_mock.resolve.return_value = logger_mock reply_messages_cor_mock = mocker.patch( "src.interactor.use_cases.message.create_message_reply.ReplyMessagesCOR" @@ -18,17 +21,12 @@ def test_create_message_reply(mocker: mock, fixture_window): reply_messages_cor_instance = reply_messages_cor_mock.return_value reply_messages_cor_instance.handle.return_value = [TextMessage(text="Test output")] - repository_mock = mocker.patch( - "src.interactor.use_cases.message.create_message_reply.AgentExecutorRepositoryInterface" - ) - input_dto_validator_mock = mocker.patch( "src.interactor.use_cases.message.create_message_reply.EventInputDtoValidator" ) - logger_mock = mocker.patch.object(LoggerInterface, "log_info") - use_case = CreateMessageReplyUseCase(presenter_mock, repository_mock, logger_mock) - input_dto = EventInputDto(window=fixture_window, user_input="Test input") + use_case = CreateMessageReplyUseCase(presenter_mock, container_mock) + input_dto = EventInputDto(window=fixture_window, user_input="Test input", source_type="Test source") result = use_case.execute(input_dto) input_dto_validator_mock.assert_called_once_with(input_dto.to_dict()) @@ -37,7 +35,7 @@ def test_create_message_reply(mocker: mock, fixture_window): logger_mock.log_info.assert_called_once_with("Create reply successfully") - reply_messages_cor_instance.handle.assert_called_once_with(input_dto, repository_mock) + reply_messages_cor_instance.handle.assert_called_once_with(input_dto) output_dto = EventOutputDto( window=fixture_window, user_input="Test input", response=[TextMessage(text="Test output")] diff --git a/src/interactor/validations/event_input_validator.py b/src/interactor/validations/event_input_validator.py index 8bd76cb..7d12043 100644 --- a/src/interactor/validations/event_input_validator.py +++ b/src/interactor/validations/event_input_validator.py @@ -1,7 +1,6 @@ """ Defines the validator for the create window input data. """ - from typing import Dict from src.interactor.validations.base_input_validator import BaseInputValidator @@ -28,6 +27,7 @@ def __init__(self, input_data: Dict) -> None: }, }, "user_input": {"type": "string", "required": True}, + "source_type": {"type": "string", "required": True}, } def validate(self) -> None: diff --git a/src/interactor/validations/test/event_input_validator_test.py b/src/interactor/validations/test/event_input_validator_test.py index ce09c25..471240e 100644 --- a/src/interactor/validations/test/event_input_validator_test.py +++ b/src/interactor/validations/test/event_input_validator_test.py @@ -3,7 +3,7 @@ from src.interactor.validations.event_input_validator import EventInputDtoValidator -def test_get_window_validator_valid_data(mocker, fixture_window): +def test_event_input_validator_valid_data(mocker, fixture_window): mocker.patch("src.interactor.validations.base_input_validator.BaseInputValidator.verify") input_data = {"window": fixture_window, "user_input": "test"} schema = { @@ -19,16 +19,17 @@ def test_get_window_validator_valid_data(mocker, fixture_window): }, }, "user_input": {"type": "string", "required": True}, + "source_type": {"type": "string", "required": True}, } validator = EventInputDtoValidator(input_data) validator.validate() validator.verify.assert_called_once_with(schema) -def test_get_window_validator_without_user_input(fixture_window): +def test_event_input_validator_without_user_input(fixture_window): # We are doing just a simple test as the complete test is done in # base_input_validator_test.py - input_data = {"window": fixture_window} + input_data = {"window": fixture_window, "source_type": "test"} validator = EventInputDtoValidator(input_data) with pytest.raises(ValueError) as exception_info: validator.validate()