diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index d40a3c9..7fce9ea 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -20,7 +20,7 @@ jobs:
- name: Set up Python 3.11
uses: actions/setup-python@v3
with:
- python-version: "3.11.8"
+ python-version: "3.11.4"
- name: Cache dependencies
id: cache
diff --git a/.idea/misc.xml b/.idea/misc.xml
index 3a82b7d..109df5b 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -3,5 +3,5 @@
-
+
\ No newline at end of file
diff --git a/.idea/workspace.xml b/.idea/workspace.xml
index 7cbe1da..8b85c7d 100644
--- a/.idea/workspace.xml
+++ b/.idea/workspace.xml
@@ -4,15 +4,28 @@
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
@@ -25,6 +38,19 @@
+
+
+
+
+
+
+
+
+
+
+
+
@@ -50,23 +83,131 @@
- {
- "keyToString": {
- "ASKED_ADD_EXTERNAL_FILES": "true",
- "ASKED_SHARE_PROJECT_CONFIGURATION_FILES": "true",
- "RunOnceActivity.OpenProjectViewOnStart": "true",
- "RunOnceActivity.ShowReadmeOnStart": "true",
- "git-widget-placeholder": "development",
- "ignore.virus.scanning.warn.message": "true",
- "node.js.detected.package.eslint": "true",
- "node.js.detected.package.tslint": "true",
- "node.js.selected.package.eslint": "(autodetect)",
- "node.js.selected.package.tslint": "(autodetect)",
- "nodejs_package_manager_path": "npm",
- "settings.editor.selected.configurable": "com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable",
- "vue.rearranger.settings.migration": "true"
+
+}]]>
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -88,6 +229,12 @@
+
+
+
+
+
+
@@ -153,7 +300,31 @@
1713211062073
-
+
+
+ 1714022158591
+
+
+
+ 1714022158591
+
+
+
+ 1714022396944
+
+
+
+ 1714022396944
+
+
+
+ 1714022444231
+
+
+
+ 1714022444231
+
+
@@ -177,7 +348,13 @@
+
-
+
+
+
+
+
+
\ No newline at end of file
diff --git a/app.py b/app.py
index 535d8e5..502b03b 100644
--- a/app.py
+++ b/app.py
@@ -1,16 +1,9 @@
""" Flask PostgreSQL Process Handler
"""
-
from src.app.flask_postgresql.configs import Config
-from src.app.flask_postgresql.create_flask_postgresql_app import (
- create_flask_postgresql_app,
-)
-from src.infrastructure.loggers.logger_default import LoggerDefault
-
-logger = LoggerDefault()
-
+from src.app.flask_postgresql.create_flask_postgresql_app import create_flask_postgresql_app
if __name__ == "__main__":
- app = create_flask_postgresql_app(Config, logger)
+ app = create_flask_postgresql_app(Config)
app.run(host="0.0.0.0", port=Config.PORT, debug=True)
diff --git a/poetry.lock b/poetry.lock
index 065f040..d50c685 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -2303,4 +2303,4 @@ requests = ">=2.31"
[metadata]
lock-version = "2.0"
python-versions = "3.11.4"
-content-hash = "5756e8462df2bb105e395d1e952430f6c2972e4f3d97576a9c7a1a5d18da1671"
+content-hash = "429381f71eba65b35eb1570849cd6163566b2eaaf9bec64a268a894aa954fc51"
diff --git a/src/app/flask_postgresql/api/blueprints/__init__.py b/src/app/flask_postgresql/api/blueprints/__init__.py
index b5cdd73..8a73e2d 100644
--- a/src/app/flask_postgresql/api/blueprints/__init__.py
+++ b/src/app/flask_postgresql/api/blueprints/__init__.py
@@ -4,7 +4,7 @@
from .main import blueprint as main_blueprint
-def setup_blueprints(app: Flask) -> None:
+def setup_blueprints(app: Flask) -> Flask:
"""
Register the necessary blueprints for the Flask app.
@@ -12,7 +12,7 @@ def setup_blueprints(app: Flask) -> None:
app (Flask): The Flask app instance.
Returns:
- None: This function does not return anything.
+ None: The Flask app instance with the blueprints registered.
"""
app.register_blueprint(callback_blueprint)
app.register_blueprint(main_blueprint)
diff --git a/src/app/flask_postgresql/api/blueprints/callback.py b/src/app/flask_postgresql/api/blueprints/callback.py
index 4b4d04c..b4918bc 100644
--- a/src/app/flask_postgresql/api/blueprints/callback.py
+++ b/src/app/flask_postgresql/api/blueprints/callback.py
@@ -13,7 +13,7 @@
@blueprint.route("/callback", methods=["POST"])
def callback_blueprint():
"""Function printing python version."""
- logger: LoggerInterface = current_app.config["logger"]
+ logger: LoggerInterface = current_app.config["container"].resolve("logger")
signature = request.headers["X-Line-Signature"]
body = request.get_data(as_text=True)
diff --git a/src/app/flask_postgresql/api/container.py b/src/app/flask_postgresql/api/container.py
new file mode 100644
index 0000000..3015309
--- /dev/null
+++ b/src/app/flask_postgresql/api/container.py
@@ -0,0 +1,20 @@
+from flask import Flask, g
+from src.infrastructure.container.container import Container
+from src.infrastructure.repositories.agent_chain.agent_chain_in_memory_repository import (
+ AgentExecutorInMemoryRepository,
+)
+from src.infrastructure.repositories.window import WindowPostgresqlRepository
+from src.infrastructure.loggers.logger_default import LoggerDefault
+
+
+def setup_container(app: Flask) -> Flask:
+ """Setup Container for the app"""
+ container = Container()
+
+ container.register("agent_repository", AgentExecutorInMemoryRepository())
+ container.register("window_repository", WindowPostgresqlRepository())
+ container.register("logger", LoggerDefault())
+
+ app.config["container"] = container
+
+ return app
diff --git a/src/app/flask_postgresql/api/error_handler.py b/src/app/flask_postgresql/api/error_handler.py
index 353f8f7..f41692a 100644
--- a/src/app/flask_postgresql/api/error_handler.py
+++ b/src/app/flask_postgresql/api/error_handler.py
@@ -31,12 +31,12 @@ def format_marshmallow_validation_error(errors: Dict):
return errors_message
-def setup_error_handler(app: Flask) -> None:
+def setup_error_handler(app: Flask) -> Flask:
"""
Function that will register all the specified error handlers for the app
"""
- logger: LoggerInterface = app.config["logger"]
+ logger: LoggerInterface = app.config["container"].resolve("logger")
def error_handler(error):
logger.log_exception("exception of type {} occurred".format(type(error)))
diff --git a/src/app/flask_postgresql/api/event_handlers/__init__.py b/src/app/flask_postgresql/api/event_handlers/__init__.py
index 151bf5b..62aa280 100644
--- a/src/app/flask_postgresql/api/event_handlers/__init__.py
+++ b/src/app/flask_postgresql/api/event_handlers/__init__.py
@@ -3,43 +3,38 @@
from flask import current_app
from linebot.v3 import WebhookHandler
from linebot.v3.messaging import Configuration
-from linebot.v3.webhooks import FileMessageContent, MessageEvent, TextMessageContent
+from linebot.v3.webhooks import MessageEvent, TextMessageContent
from src.app.flask_postgresql.api.event_handlers.file_event_handler import FileEventHandler
from src.app.flask_postgresql.api.event_handlers.text_event_handler import TextEventHandler
from src.app.flask_postgresql.api.response import create_response
from src.app.flask_postgresql.configs import Config
-from src.infrastructure.repositories.agent_chain.agent_chain_in_memory_repository import (
- AgentExecutorInMemoryRepository,
-)
+from src.infrastructure.container.container import Container
handler = WebhookHandler(Config.CHANNEL_SECRET)
configuration = Configuration(access_token=Config.CHANNEL_ACCESS_TOKEN)
-agent_repository = AgentExecutorInMemoryRepository()
@handler.add(MessageEvent, message=TextMessageContent)
def handle_text_message(event: MessageEvent):
- handler = TextEventHandler(
- logger=current_app.config["logger"],
- agent_repository=agent_repository,
- )
- handler.get_event_info(event)
- result = handler.execute()
+ container: Container = current_app.config["container"]
+ text_handler = TextEventHandler(container)
+ text_handler.get_event_info(event)
+ result = text_handler.execute()
return create_response(configuration, event.reply_token, result)
-@handler.add(MessageEvent, message=FileMessageContent)
-def handle_file_message(event: MessageEvent):
- handler = FileEventHandler(
- logger=current_app.config["logger"],
- agent_repository=agent_repository,
- configuration=configuration,
- )
- result = handler.execute(event)
-
- return create_response(configuration, event.reply_token, result)
+# @handler.add(MessageEvent, message=FileMessageContent)
+# def handle_file_message(event: MessageEvent):
+# file_handler = FileEventHandler(
+# logger=current_app.config["logger"],
+# agent_repository=agent_repository,
+# configuration=configuration,
+# )
+# result = file_handler.execute(event)
+#
+# return create_response(configuration, event.reply_token, result)
__all__ = ["handler"]
diff --git a/src/app/flask_postgresql/api/event_handlers/event_handler.py b/src/app/flask_postgresql/api/event_handlers/event_handler.py
index 65c4647..3754d8d 100644
--- a/src/app/flask_postgresql/api/event_handlers/event_handler.py
+++ b/src/app/flask_postgresql/api/event_handlers/event_handler.py
@@ -1,38 +1,44 @@
""" This module implements the event handler for text message events.
"""
from abc import abstractmethod
-from typing import Dict
+from typing import Dict, cast
from linebot.v3.webhooks import MessageEvent
from src.app.flask_postgresql.configs import Config
from src.app.flask_postgresql.interfaces.event_handler_interface import EventHandlerInterface
from src.app.flask_postgresql.presenters.window_presenter import WindowPresenter
-from src.infrastructure.repositories.agent_chain.agent_chain_in_memory_repository import (
- AgentExecutorInMemoryRepository,
-)
-from src.infrastructure.repositories.window.window_postgresql_repository import (
- WindowPostgresqlRepository,
-)
+from src.infrastructure.container.container import Container
from src.interactor.dtos.event_dto import EventInputDto
from src.interactor.dtos.window_dtos import CreateWindowInputDto, GetWindowInputDto
from src.interactor.interfaces.logger.logger import LoggerInterface
+from src.interactor.interfaces.repositories.agent_executor_repository import AgentExecutorRepositoryInterface
+from src.interactor.interfaces.repositories.window_repository import WindowRepositoryInterface
from src.interactor.use_cases.window.create_window import CreateWindowUseCase
from src.interactor.use_cases.window.get_window import GetWindowUseCase
class EventHandler(EventHandlerInterface):
- def __init__(self, logger: LoggerInterface, agent_repository: AgentExecutorInMemoryRepository):
- self.logger = logger
- self.agent_repository = agent_repository
- self.input_dto: EventInputDto
+ def __init__(self, container: Container):
+ self.container = container
+ self.logger = cast(LoggerInterface, container.resolve("logger"))
+ self.agent_repository = cast(AgentExecutorRepositoryInterface, container.resolve("agent_repository"))
+ self.window_repository = cast(WindowRepositoryInterface, container.resolve("window_repository"))
+ self.input_dto = None
+ self.event = None
def get_event_info(self, event: MessageEvent):
- if event.source.type == "user":
+ """
+ Retrieves the information of an event.
+ """
+ self.event = event
+ source_type = event.source.type
+
+ if source_type == "user":
window_id = event.source.user_id
- elif event.source.type == "group":
+ elif source_type == "group":
window_id = event.source.group_id
- elif event.source.type == "room":
+ elif source_type == "room":
window_id = event.source.room_id
else:
raise ValueError("Invalid event source type")
@@ -45,6 +51,7 @@ def get_event_info(self, event: MessageEvent):
self.input_dto = EventInputDto(
window=self.get_window_info(window_id=window_id),
user_input=user_input,
+ source_type=source_type,
)
def get_window_info(self, window_id: str) -> Dict:
@@ -66,9 +73,8 @@ def _get_window_info(self, window_id: str):
:return: The result of executing the use case.
"""
- repository = WindowPostgresqlRepository()
presenter = WindowPresenter()
- use_case = GetWindowUseCase(presenter=presenter, repository=repository, logger=self.logger)
+ use_case = GetWindowUseCase(presenter=presenter, repository=self.window_repository, logger=self.logger)
get_window_input_dto = GetWindowInputDto(window_id)
result = use_case.execute(get_window_input_dto)
return result
@@ -80,7 +86,6 @@ def _create_window_info(self, window_id: str):
Returns:
The result of the create window use case execution.
"""
- repository = WindowPostgresqlRepository()
presenter = WindowPresenter()
create_window_input_dto = CreateWindowInputDto(
window_id=window_id,
@@ -90,7 +95,7 @@ def _create_window_info(self, window_id: str):
temperature=0,
)
use_case = CreateWindowUseCase(
- presenter=presenter, repository=repository, logger=self.logger
+ presenter=presenter, repository=self.window_repository, logger=self.logger
)
result = use_case.execute(create_window_input_dto)
return result
diff --git a/src/app/flask_postgresql/api/event_handlers/text_event_handler.py b/src/app/flask_postgresql/api/event_handlers/text_event_handler.py
index 00329e8..51c8c53 100644
--- a/src/app/flask_postgresql/api/event_handlers/text_event_handler.py
+++ b/src/app/flask_postgresql/api/event_handlers/text_event_handler.py
@@ -3,22 +3,18 @@
from src.app.flask_postgresql.api.event_handlers.event_handler import EventHandler
from src.app.flask_postgresql.presenters.message_reply_presenter import EventPresenter
+from src.infrastructure.container.container import Container
from src.interactor.dtos.event_dto import EventInputDto
from src.interactor.use_cases.message.create_message_reply import CreateMessageReplyUseCase
class TextEventHandler(EventHandler):
- def __init__(self, logger, agent_repository):
- self.logger = logger
- self.agent_repository = agent_repository
+ def __init__(self, container: Container):
+ super().__init__(container)
self.input_dto: EventInputDto
def execute(self):
presenter = EventPresenter()
- use_case = CreateMessageReplyUseCase(
- presenter=presenter,
- repository=self.agent_repository,
- logger=self.logger,
- )
+ use_case = CreateMessageReplyUseCase(presenter=presenter, container=self.container)
result = use_case.execute(self.input_dto)
return result
diff --git a/src/app/flask_postgresql/api/response.py b/src/app/flask_postgresql/api/response.py
index 838d64d..7d1732c 100644
--- a/src/app/flask_postgresql/api/response.py
+++ b/src/app/flask_postgresql/api/response.py
@@ -10,9 +10,9 @@
from linebot.v3.messaging.models.message import Message
-def create_response(
- configuration: Configuration, reply_token: str, messages: List[Message]
-) -> ApiResponse:
+def create_response(configuration: Configuration, reply_token: str, messages: List[Message]) -> ApiResponse:
+ if not messages:
+ return
with ApiClient(configuration) as api_client:
line_bot_api = MessagingApi(api_client)
diff --git a/src/app/flask_postgresql/create_flask_postgresql_app.py b/src/app/flask_postgresql/create_flask_postgresql_app.py
index 03bff80..17ec0eb 100644
--- a/src/app/flask_postgresql/create_flask_postgresql_app.py
+++ b/src/app/flask_postgresql/create_flask_postgresql_app.py
@@ -3,17 +3,17 @@
from flask import Flask
from src.app.flask_postgresql.api.blueprints import setup_blueprints
+from src.app.flask_postgresql.api.container import setup_container
from src.app.flask_postgresql.api.error_handler import setup_error_handler
from src.app.flask_postgresql.api.request_context import setup_request_context
from src.infrastructure.databases.sql_alchemy import setup_sqlalchemy
-from src.interactor.interfaces.logger.logger import LoggerInterface
-def create_flask_postgresql_app(config, logger: LoggerInterface) -> Flask:
+def create_flask_postgresql_app(config) -> Flask:
"""Create Main Flask PostgreSQL app"""
app = Flask(__name__)
app.config.from_object(config)
- app.config["logger"] = logger
+ app = setup_container(app)
app = setup_blueprints(app)
app = setup_sqlalchemy(app)
app = setup_error_handler(app)
diff --git a/src/app/flask_postgresql/create_flask_postgresql_app_test.py b/src/app/flask_postgresql/create_flask_postgresql_app_test.py
index 2b89c18..a5189d2 100644
--- a/src/app/flask_postgresql/create_flask_postgresql_app_test.py
+++ b/src/app/flask_postgresql/create_flask_postgresql_app_test.py
@@ -5,21 +5,17 @@
from flask.testing import FlaskClient
from src.app.flask_postgresql.configs import Config
-from src.infrastructure.loggers.logger_default import LoggerDefault
with mock.patch("sqlalchemy.create_engine") as mock_create_engine, mock.patch(
- "langchain.utilities.SerpAPIWrapper"
+ "langchain.utilities.SerpAPIWrapper"
) as mock_sessionmaker:
from .create_flask_postgresql_app import create_flask_postgresql_app
-logger = LoggerDefault()
-
-
@pytest.fixture(name="flask_postgresql_app")
def fixture_flask_postgresql_app():
"""Fixture for flask app with blueprint"""
- app: Flask = create_flask_postgresql_app(Config, logger)
+ app: Flask = create_flask_postgresql_app(Config)
app.config.update(
{
"TESTING": True,
@@ -56,8 +52,8 @@ def test_request_window(mocker, client_flask_postgresql_app: FlaskClient):
def test_request_window_wrong_url_error(
- mocker,
- client_flask_postgresql_app,
+ mocker,
+ client_flask_postgresql_app,
):
"""Test request example"""
headers_data = {"X-Line-Signature": "test"}
diff --git a/src/infrastructure/container/container.py b/src/infrastructure/container/container.py
new file mode 100644
index 0000000..3f84c3d
--- /dev/null
+++ b/src/infrastructure/container/container.py
@@ -0,0 +1,15 @@
+class Container:
+ def __init__(self):
+ self._dependencies = {}
+
+ def register(self, key, dependency):
+ self._dependencies[key] = dependency
+
+ def resolve(self, key):
+ return self._dependencies[key]
+
+ def __getitem__(self, key):
+ return self.resolve(key)
+
+ def __setitem__(self, key, dependency):
+ self.register(key, dependency)
diff --git a/src/infrastructure/repositories/agent_chain/agent_chain_in_memory_repository.py b/src/infrastructure/repositories/agent_chain/agent_chain_in_memory_repository.py
index 3f4284e..633902f 100644
--- a/src/infrastructure/repositories/agent_chain/agent_chain_in_memory_repository.py
+++ b/src/infrastructure/repositories/agent_chain/agent_chain_in_memory_repository.py
@@ -61,7 +61,7 @@ def _create_agent(
agent_language: str,
temperature: float,
memory_key: str,
- tools: list,
+ tools_list: list,
) -> OpenAIFunctionsAgent:
"""
Creates an instance of the OpenAIFunctionsAgent class.
@@ -82,10 +82,10 @@ def _create_agent(
system_message=system_message,
extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key), system_language],
)
- return OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
+ return OpenAIFunctionsAgent(llm=llm, tools=tools_list, prompt=prompt)
def _create_agent_executor(
- self, agent: OpenAIFunctionsAgent, memory: BaseChatMemory, tools: list
+ self, agent: OpenAIFunctionsAgent, memory: BaseChatMemory, tools_list: list
) -> AgentExecutor:
"""
Creates an agent executor using the provided agent, memory, and other optional arguments.
@@ -100,13 +100,13 @@ def _create_agent_executor(
"""
return AgentExecutor(
agent=agent,
- tools=tools,
+ tools=tools_list,
memory=memory,
verbose=True,
max_iterations=3,
)
- def get(self, window_id: str) -> AgentExecutor:
+ def get(self, window_id: str) -> AgentExecutor | None:
"""Get AgentExecutor by id
:param window_id: str
@@ -150,9 +150,9 @@ def create(
agent_language=agent_language,
temperature=temperature,
memory_key=memory_key,
- tools=tools,
+ tools_list=tools,
)
- agent_executor = self._create_agent_executor(agent=agent, memory=memory, tools=tools)
+ agent_executor = self._create_agent_executor(agent=agent, memory=memory, tools_list=tools)
self._data[window_id] = agent_executor
return self._data[window_id]
diff --git a/src/interactor/dtos/event_dto.py b/src/interactor/dtos/event_dto.py
index cee0ac6..ff9d5bd 100644
--- a/src/interactor/dtos/event_dto.py
+++ b/src/interactor/dtos/event_dto.py
@@ -1,7 +1,6 @@
""" Module for Events Dtos
"""
-
from dataclasses import asdict, dataclass
from typing import Dict, List
@@ -14,6 +13,7 @@ class EventInputDto:
window: Dict
user_input: str
+ source_type: str
def to_dict(self):
"""Convert data into dictionary"""
diff --git a/src/interactor/use_cases/message/cor/__init__.py b/src/interactor/use_cases/message/cor/__init__.py
index c82a5f4..eb2bf5d 100644
--- a/src/interactor/use_cases/message/cor/__init__.py
+++ b/src/interactor/use_cases/message/cor/__init__.py
@@ -1,21 +1,33 @@
-from typing import List
-
-from linebot.v3.messaging.models.message import Message
-
+from src.infrastructure.container.container import Container
from src.interactor.dtos.event_dto import EventInputDto
-from src.interactor.interfaces.repositories.agent_executor_repository import (
- AgentExecutorRepositoryInterface,
-)
from src.interactor.use_cases.message.cor.addition_handler import AdditionHandler
from src.interactor.use_cases.message.cor.default_handler import DefaultHandler
from src.interactor.use_cases.message.cor.muting_handler import MutingHandler
+from src.interactor.use_cases.message.cor.viki_handler import VikiHandler
+from src.interactor.use_cases.message.cor.window_mutable_handler import WindowMutableHandler
class ReplyMessagesCOR:
- def __init__(self):
- self._chain = MutingHandler(AdditionHandler(DefaultHandler()))
+ def __init__(self, container: Container):
+ self.container = container
+ self._chain = None
+ self._initialize_chain()
+
+ def _initialize_chain(self):
+ # create the chain of responsibility
+ viki_handler = VikiHandler(self.container)
+ window_mutable_handler = WindowMutableHandler(self.container)
+ muting_handler = MutingHandler(self.container)
+ # addition_handler = AdditionHandler(self.container)
+ default_handler = DefaultHandler(self.container)
+
+ # set the successor of each handler
+ viki_handler.set_successor(window_mutable_handler)
+ window_mutable_handler.set_successor(muting_handler)
+ muting_handler.set_successor(default_handler)
+
+ # set the chain
+ self._chain = viki_handler
- def handle(self, input_dto: EventInputDto, repository: AgentExecutorRepositoryInterface):
- response: List[Message] = []
- self._chain.handle(input_dto, repository, response)
- return response
+ def handle(self, input_dto: EventInputDto):
+ return self._chain.handle(input_dto)
diff --git a/src/interactor/use_cases/message/cor/addition_handler.py b/src/interactor/use_cases/message/cor/addition_handler.py
index 4304092..8a0deee 100644
--- a/src/interactor/use_cases/message/cor/addition_handler.py
+++ b/src/interactor/use_cases/message/cor/addition_handler.py
@@ -11,14 +11,13 @@
class AdditionHandler(Handler):
- def handle(
- self,
- input_dto: EventInputDto,
- repository: AgentExecutorRepositoryInterface,
- response: List[Message],
- ):
- response.append(TextMessage(text="test handler"))
+ def handle(self, input_dto: EventInputDto):
+
+ messages: List[Message] = []
+ messages.extend([TextMessage(text="test handler")])
if self._successor is not None:
- return self._successor.handle(input_dto, repository, response)
+ messages.extend(self._successor.handle(input_dto))
else:
- response.append(TextMessage(text="靜悄悄的,什麼都沒有發生。"))
+ messages.extend([TextMessage(text="Something went wrong! >_<")])
+
+ return messages
diff --git a/src/interactor/use_cases/message/cor/default_handler.py b/src/interactor/use_cases/message/cor/default_handler.py
index 017c08c..f4f8682 100644
--- a/src/interactor/use_cases/message/cor/default_handler.py
+++ b/src/interactor/use_cases/message/cor/default_handler.py
@@ -1,29 +1,41 @@
-from typing import List
-
from langchain.agents import AgentExecutor
from linebot.v3.messaging.models import TextMessage
-from linebot.v3.messaging.models.message import Message
from src.interactor.dtos.event_dto import EventInputDto
-from src.interactor.interfaces.repositories.agent_executor_repository import (
- AgentExecutorRepositoryInterface,
-)
from src.interactor.use_cases.message.cor.handler_base import Handler
class DefaultHandler(Handler):
- def handle(
- self,
- input_dto: EventInputDto,
- repository: AgentExecutorRepositoryInterface,
- response: List[Message],
- ):
+ """
+ The last handler in the chain of responsibility.
+
+ This handler uses the agent executor to process the user input.
+ """
+
+ def _get_agent_executor(self, input_dto: EventInputDto) -> AgentExecutor:
+ """
+ Retrieves the agent executor associated with the current window.
+
+ :param None: This function does not take any parameters.
+ :return: None
+ """
+ window_id = input_dto.window.get("window_id")
+
+ agent_executor = self.agent_repository.get(
+ window_id=window_id,
+ )
+ if agent_executor is None:
+ agent_executor = self.agent_repository.create(window_id=window_id)
+ return agent_executor
+
+ def handle(self, input_dto: EventInputDto):
+ messages = []
try:
- agent_executor = self._get_agent_executor(input_dto, repository)
+ agent_executor = self._get_agent_executor(input_dto)
result = agent_executor.run(input=input_dto.user_input)
- response.append(TextMessage(text=result))
+ messages.extend([TextMessage(text=result)])
except Exception as e:
- print(e)
- response.append(TextMessage(text="出現錯誤啦!請稍後再試。"))
+ self.logger.error(e)
+ messages.extend([TextMessage(text="Something went wrong! >_<")])
finally:
- return response
+ return messages
diff --git a/src/interactor/use_cases/message/cor/handler_base.py b/src/interactor/use_cases/message/cor/handler_base.py
index fd0dcc1..38ef059 100644
--- a/src/interactor/use_cases/message/cor/handler_base.py
+++ b/src/interactor/use_cases/message/cor/handler_base.py
@@ -1,47 +1,28 @@
from abc import ABC, abstractmethod
-from typing import List, Type
+from typing import List
-from langchain.agents import AgentExecutor
from linebot.v3.messaging.models.message import Message
+from src.infrastructure.container.container import Container
from src.interactor.dtos.event_dto import EventInputDto
+from src.interactor.interfaces.logger.logger import LoggerInterface
from src.interactor.interfaces.repositories.agent_executor_repository import (
AgentExecutorRepositoryInterface,
)
+from src.interactor.interfaces.repositories.window_repository import WindowRepositoryInterface
class Handler(ABC):
- def __init__(self, successor: Type["Handler"] = None):
+ def __init__(self, container: Container = None):
+ self._successor = None
+ self.container: Container = container
+ self.logger: LoggerInterface = container.resolve("logger")
+ self.agent_repository: AgentExecutorRepositoryInterface = container.resolve("agent_repository")
+ self.window_repository: WindowRepositoryInterface = container.resolve("window_repository")
+
+ def set_successor(self, successor: "Handler"):
self._successor = successor
- def _get_agent_executor(
- self,
- input_dto: EventInputDto,
- repository: AgentExecutorRepositoryInterface,
- ) -> AgentExecutor:
- """
- Retrieves the agent executor associated with the current window.
-
- :param None: This function does not take any parameters.
- :return: None
- """
-
- window_id = input_dto.window.get("window_id")
-
- agent_executor = repository.get(
- window_id=window_id,
- )
- if agent_executor is None:
- agent_executor = repository.create(
- window_id=window_id,
- )
- return agent_executor
-
@abstractmethod
- def handle(
- self,
- input_dto: EventInputDto,
- repository: AgentExecutorRepositoryInterface,
- response: List[Message],
- ) -> List[Message]:
+ def handle(self, input_dto: EventInputDto, ) -> List[Message]:
pass
diff --git a/src/interactor/use_cases/message/cor/muting_handler.py b/src/interactor/use_cases/message/cor/muting_handler.py
index f8fa12f..96c9b02 100644
--- a/src/interactor/use_cases/message/cor/muting_handler.py
+++ b/src/interactor/use_cases/message/cor/muting_handler.py
@@ -4,22 +4,23 @@
from linebot.v3.messaging.models.message import Message
from src.interactor.dtos.event_dto import EventInputDto
-from src.interactor.interfaces.repositories.agent_executor_repository import (
- AgentExecutorRepositoryInterface,
-)
from src.interactor.use_cases.message.cor.handler_base import Handler
class MutingHandler(Handler):
- def handle(
- self,
- input_dto: EventInputDto,
- repository: AgentExecutorRepositoryInterface,
- response: List[Message],
- ):
+
+ def handle(self, input_dto: EventInputDto):
+ """
+ Check if the window is currently muting. If it is, return an empty list.
+ """
+
+ messages: List[Message] = []
+
if input_dto.window.get("is_muting") is True:
- return response
+ return messages
elif self._successor is not None:
- return self._successor.handle(input_dto, repository, response)
+ messages.extend(self._successor.handle(input_dto))
else:
- response.append(TextMessage(text="靜悄悄的,什麼都沒有發生。"))
+ messages.extend([TextMessage(text="Something went wrong! >_<")])
+
+ return messages
diff --git a/src/interactor/use_cases/message/cor/test/default_handler_test.py b/src/interactor/use_cases/message/cor/test/default_handler_test.py
index 05a4220..c4435f0 100644
--- a/src/interactor/use_cases/message/cor/test/default_handler_test.py
+++ b/src/interactor/use_cases/message/cor/test/default_handler_test.py
@@ -7,47 +7,49 @@
def test_default_handler(mocker: mock, fixture_window: dict):
- repository_mock = mocker.patch(
- "src.interactor.use_cases.message.cor.default_handler.AgentExecutorRepositoryInterface"
- )
- agent_executor_mock = mocker.Mock(run=mocker.Mock(return_value="mock agent executor result"))
+ container_mock = mocker.MagicMock()
+ agent_executor_mock = mocker.MagicMock()
+ agent_executor_mock.run.return_value = "mock agent executor result"
get_agent_executor_mock = mocker.patch(
"src.interactor.use_cases.message.cor.default_handler.DefaultHandler._get_agent_executor"
)
get_agent_executor_mock.return_value = agent_executor_mock
- response = []
-
input_dto_mock = EventInputDto(
window=fixture_window,
user_input="mock user input",
+ source_type="mock source type",
)
- default_handler = DefaultHandler()
- default_handler.handle(input_dto_mock, repository_mock, response)
+ default_handler = DefaultHandler(container_mock)
+ messages = default_handler.handle(input_dto_mock)
- get_agent_executor_mock.assert_called_once_with(input_dto_mock, repository_mock)
+ get_agent_executor_mock.assert_called_once_with(input_dto_mock)
agent_executor_mock.run.assert_called_once_with(input=input_dto_mock.user_input)
- assert len(response) == 1
- assert response[0] == TextMessage(text="mock agent executor result")
+ assert len(messages) == 1
+ assert messages[0] == TextMessage(text="mock agent executor result")
def test_get_agent_executor_in_default_handler(mocker: mock, fixture_window: dict):
- repository_mock = mocker.patch(
- "src.interactor.use_cases.message.cor.default_handler.AgentExecutorRepositoryInterface"
- )
- repository_mock.get.return_value = None
- agent_executor_mock = mocker.Mock(run=mocker.Mock(return_value="mock agent executor result"))
- repository_mock.create.return_value = mocker.Mock(return_value=agent_executor_mock)
+ container_mock = mocker.MagicMock()
+ agent_executor_repository_mock = mocker.MagicMock()
+ container_mock.resolve.return_value = agent_executor_repository_mock
+ agent_executor_mock = mocker.MagicMock()
+ agent_executor_mock.run.return_value = "mock agent executor result"
+
+ agent_executor_repository_mock.get.return_value = None
+ agent_executor_repository_mock.create.return_value = mocker.MagicMock()
+ agent_executor_repository_mock.create.return_value.run.return_value = "mock agent executor result"
input_dto_mock = EventInputDto(
window=fixture_window,
user_input="mock user input",
+ source_type="mock source type",
)
- default_handler = DefaultHandler()
- default_handler._get_agent_executor(input_dto_mock, repository_mock)
+ default_handler = DefaultHandler(container_mock)
+ default_handler._get_agent_executor(input_dto_mock)
- repository_mock.get.assert_called_once_with(window_id=fixture_window["window_id"])
- repository_mock.create.assert_called_once_with(window_id=fixture_window["window_id"])
+ agent_executor_repository_mock.get.assert_called_once_with(window_id=fixture_window["window_id"])
+ agent_executor_repository_mock.create.assert_called_once_with(window_id=fixture_window["window_id"])
diff --git a/src/interactor/use_cases/message/cor/test/muting_handler_test.py b/src/interactor/use_cases/message/cor/test/muting_handler_test.py
index 1a20a17..409fa48 100644
--- a/src/interactor/use_cases/message/cor/test/muting_handler_test.py
+++ b/src/interactor/use_cases/message/cor/test/muting_handler_test.py
@@ -7,61 +7,55 @@
def test_muting_handler_with_fixture_window_with_muting(
- mocker: mock, fixture_window_with_muting: dict
+ mocker: mock, fixture_window_with_muting: dict
):
- repository_mock = mocker.patch(
- "src.interactor.use_cases.message.cor.muting_handler.AgentExecutorRepositoryInterface"
- )
- response = []
+ container_mock = mocker.MagicMock()
input_dto_mock = EventInputDto(
window=fixture_window_with_muting,
user_input="mock user input",
+ source_type="mock source type",
)
- muting_handler = MutingHandler()
- muting_handler.handle(input_dto_mock, repository_mock, response)
+ muting_handler = MutingHandler(container_mock)
+ messages = muting_handler.handle(input_dto_mock)
# self._successor.handle is not called
- assert len(response) == 0
+ assert len(messages) == 0
assert muting_handler._successor is None
- assert response == []
+ assert messages == []
def test_muting_handler_with_no_successor(mocker: mock, fixture_window: dict):
- repository_mock = mocker.patch(
- "src.interactor.use_cases.message.cor.muting_handler.AgentExecutorRepositoryInterface"
- )
- response = []
-
input_dto_mock = EventInputDto(
window=fixture_window,
user_input="mock user input",
+ source_type="mock source type",
)
-
- muting_handler = MutingHandler()
- muting_handler.handle(input_dto_mock, repository_mock, response)
+ container_mock = mocker.MagicMock()
+ muting_handler = MutingHandler(container_mock)
+ messages = muting_handler.handle(input_dto_mock)
# self._successor.handle is not called
- assert len(response) == 1
+ assert len(messages) == 1
assert muting_handler._successor is None
- assert response[0] == TextMessage(text="靜悄悄的,什麼都沒有發生。")
+ assert messages[0] == TextMessage(text="Something went wrong! >_<")
def test_muting_handler_with_successor(mocker: mock, fixture_window: dict):
- repository_mock = mocker.patch(
- "src.interactor.use_cases.message.cor.muting_handler.AgentExecutorRepositoryInterface"
- )
- response = []
+
+ container_mock = mocker.MagicMock()
input_dto_mock = EventInputDto(
window=fixture_window,
user_input="mock user input",
+ source_type="mock source type",
)
- muting_handler = MutingHandler(mocker.Mock())
- muting_handler.handle(input_dto_mock, repository_mock, response)
+ muting_handler = MutingHandler(container_mock)
+ muting_handler.set_successor(mocker.MagicMock())
+ messages = muting_handler.handle(input_dto_mock)
# self._successor.handle is called
- assert len(response) == 0
+ assert len(messages) == 0
assert muting_handler._successor.handle.call_count == 1
diff --git a/src/interactor/use_cases/message/cor/test/reply_messages_cor_test.py b/src/interactor/use_cases/message/cor/test/reply_messages_cor_test.py
index 6390ca0..3ca8489 100644
--- a/src/interactor/use_cases/message/cor/test/reply_messages_cor_test.py
+++ b/src/interactor/use_cases/message/cor/test/reply_messages_cor_test.py
@@ -4,32 +4,36 @@
def test_reply_message_cor_initial(mocker: mock):
+ container_mock = mocker.MagicMock()
muting_handler_mock = mocker.patch("src.interactor.use_cases.message.cor.MutingHandler")
default_handler_mock = mocker.patch("src.interactor.use_cases.message.cor.DefaultHandler")
- addition_handler_mock = mocker.patch("src.interactor.use_cases.message.cor.AdditionHandler")
+ # addition_handler_mock = mocker.patch("src.interactor.use_cases.message.cor.AdditionHandler")
+ window_mutable_handler_mock = mocker.patch("src.interactor.use_cases.message.cor.WindowMutableHandler")
+ viki_handler_mock = mocker.patch("src.interactor.use_cases.message.cor.VikiHandler")
- reply_messages_cor = ReplyMessagesCOR()
+ muting_handler_mock_instance = muting_handler_mock.return_value
+ # addition_handler_mock_instance = addition_handler_mock.return_value
+ default_handler_mock_instance = default_handler_mock.return_value
+ window_mutable_handler_mock_instance = window_mutable_handler_mock.return_value
+ viki_handler_mock_instance = viki_handler_mock.return_value
- muting_handler_mock.assert_called_once_with(addition_handler_mock.return_value)
- addition_handler_mock.assert_called_once_with(default_handler_mock.return_value)
- default_handler_mock.assert_called_once_with()
+ reply_messages_cor = ReplyMessagesCOR(container_mock)
- assert reply_messages_cor._chain == muting_handler_mock.return_value
+ viki_handler_mock_instance.set_successor.assert_called_once_with(window_mutable_handler_mock_instance)
+ window_mutable_handler_mock_instance.set_successor.assert_called_once_with(muting_handler_mock_instance)
+ muting_handler_mock_instance.set_successor.assert_called_once_with(default_handler_mock_instance)
+ default_handler_mock_instance.set_successor.assert_not_called()
+ assert reply_messages_cor._chain == viki_handler_mock_instance
-def test_reply_message_cor_handle(mocker: mock):
- mocker.patch("src.interactor.use_cases.message.cor.MutingHandler")
- mocker.patch("src.interactor.use_cases.message.cor.DefaultHandler")
- mocker.patch("src.interactor.use_cases.message.cor.AdditionHandler")
- reply_messages_cor = ReplyMessagesCOR()
+def test_reply_message_cor_handle(mocker: mock):
+ container_mock = mocker.MagicMock()
+ input_dto = mocker.MagicMock()
- input_dto = mocker.Mock()
- repository = mocker.Mock()
- response = []
+ reply_messages_cor = ReplyMessagesCOR(container_mock)
reply_messages_cor._chain.handle = mocker.Mock()
- reply_messages_cor.handle(input_dto, repository)
-
- reply_messages_cor._chain.handle.assert_called_once_with(input_dto, repository, response)
+ reply_messages_cor.handle(input_dto)
+ reply_messages_cor._chain.handle.assert_called_once_with(input_dto)
diff --git a/src/interactor/use_cases/message/cor/viki_handler.py b/src/interactor/use_cases/message/cor/viki_handler.py
new file mode 100644
index 0000000..bd645b6
--- /dev/null
+++ b/src/interactor/use_cases/message/cor/viki_handler.py
@@ -0,0 +1,34 @@
+from linebot.v3.messaging.models import TextMessage
+
+from src.infrastructure.container.container import Container
+from src.interactor.dtos.event_dto import EventInputDto
+from src.interactor.use_cases.message.cor.handler_base import Handler
+
+
+class VikiHandler(Handler):
+ """
+ The first handler in the chain of responsibility.
+
+ This handler checks if the user input contains "Viki" or "viki".
+ If it does, it will return an empty list.
+ """
+
+ def __init__(self, container: Container):
+ super().__init__(container)
+
+ # convert all words to lowercase and add in _mute_word_set
+ self._trigger_words = {"Viki", "viki"}
+
+ def handle(self, input_dto: EventInputDto):
+ # if input_dto.user_input not include "Viki" or "viki", return, unless the source_type is "user"
+ if input_dto.source_type != "user" and not any(word in input_dto.user_input for word in self._trigger_words):
+ return []
+
+ messages = []
+
+ if self._successor is not None:
+ messages.extend(self._successor.handle(input_dto))
+ else:
+ messages.extend([TextMessage(text="Something went wrong! >_<")])
+
+ return messages
diff --git a/src/interactor/use_cases/message/cor/window_mutable_handler.py b/src/interactor/use_cases/message/cor/window_mutable_handler.py
new file mode 100644
index 0000000..fbb708f
--- /dev/null
+++ b/src/interactor/use_cases/message/cor/window_mutable_handler.py
@@ -0,0 +1,67 @@
+from typing import List
+
+from linebot.v3.messaging.models import TextMessage
+from linebot.v3.messaging.models.message import Message
+
+from src.infrastructure.container.container import Container
+from src.interactor.dtos.event_dto import EventInputDto
+from src.interactor.use_cases.message.cor.handler_base import Handler
+
+mute_word_set = {
+ "Viki mute",
+ "Viki shut up",
+}
+
+unmute_word_set = {
+ "Viki unmute"
+}
+
+
+class WindowMutableHandler(Handler):
+ def __init__(self, container: Container):
+ super().__init__(container)
+
+ # convert all words to lowercase and add in _mute_word_set
+ self._all_mute_words = mute_word_set.union(
+ {word.lower() for word in mute_word_set}
+ )
+
+ # convert all words to lowercase and add in _unmute_word_set
+ self._all_unmute_words = unmute_word_set.union(
+ {word.lower() for word in unmute_word_set}
+ )
+
+ def handle(self, input_dto: EventInputDto):
+ """
+ Check if the user input contains any mute words. If it does, set the window's is_muting to True.
+ """
+ messages: List[Message] = []
+
+ # if input_dto.user_input include mute word, set is_muting to True
+ if any(word in input_dto.user_input for word in self._all_mute_words):
+ window_id = input_dto.window.get("window_id")
+ window = self.window_repository.get(window_id)
+ if window is not None:
+ window.is_muting = True
+ self.window_repository.update(window)
+ messages.extend([TextMessage(text="Okay, bye (´•̥̥̥ω•̥̥̥`)")])
+ else:
+ messages.extend([TextMessage(text="Something went wrong! >_<")])
+
+ # if input_dto.user_input include unmute word, set is_muting to False
+ elif any(word in input_dto.user_input for word in self._all_unmute_words):
+ window_id = input_dto.window.get("window_id")
+ window = self.window_repository.get(window_id)
+ if window is not None:
+ window.is_muting = False
+ self.window_repository.update(window)
+ messages.extend([TextMessage(text="Hello again! ヾ(*´▽‘*)ノ")])
+ else:
+ messages.extend([TextMessage(text="Something went wrong! >_<")])
+
+ elif self._successor is not None:
+ messages.extend(self._successor.handle(input_dto))
+ else:
+ messages.extend([TextMessage(text="Something went wrong! >_<")])
+
+ return messages
diff --git a/src/interactor/use_cases/message/create_message_reply.py b/src/interactor/use_cases/message/create_message_reply.py
index c05b917..e79dae3 100644
--- a/src/interactor/use_cases/message/create_message_reply.py
+++ b/src/interactor/use_cases/message/create_message_reply.py
@@ -1,11 +1,10 @@
""" This module is responsible for creating a new window.
"""
+
+from src.infrastructure.container.container import Container
from src.interactor.dtos.event_dto import EventInputDto, EventOutputDto
from src.interactor.interfaces.logger.logger import LoggerInterface
from src.interactor.interfaces.presenters.message_reply_presenter import EventPresenterInterface
-from src.interactor.interfaces.repositories.agent_executor_repository import (
- AgentExecutorRepositoryInterface,
-)
from src.interactor.use_cases.message.cor import ReplyMessagesCOR
from src.interactor.validations.event_input_validator import EventInputDtoValidator
@@ -13,15 +12,10 @@
class CreateMessageReplyUseCase:
"""This class is responsible for creating a new window."""
- def __init__(
- self,
- presenter: EventPresenterInterface,
- repository: AgentExecutorRepositoryInterface,
- logger: LoggerInterface,
- ):
+ def __init__(self, presenter: EventPresenterInterface, container: Container):
self.presenter = presenter
- self.repository = repository
- self.logger = logger
+ self.container = container
+ self.logger: LoggerInterface = container.resolve("logger")
def execute(self, input_dto: EventInputDto):
"""
@@ -39,9 +33,9 @@ def execute(self, input_dto: EventInputDto):
validator = EventInputDtoValidator(input_dto.to_dict())
validator.validate()
- reply_messages_cor = ReplyMessagesCOR()
+ reply_messages_cor = ReplyMessagesCOR(self.container)
- response = reply_messages_cor.handle(input_dto, self.repository)
+ response = reply_messages_cor.handle(input_dto)
output_dto = EventOutputDto(
window=input_dto.window,
diff --git a/src/interactor/use_cases/message/test/create_message_reply_test.py b/src/interactor/use_cases/message/test/create_message_reply_test.py
index a9a2bc5..ead2f02 100644
--- a/src/interactor/use_cases/message/test/create_message_reply_test.py
+++ b/src/interactor/use_cases/message/test/create_message_reply_test.py
@@ -11,6 +11,9 @@
def test_create_message_reply(mocker: mock, fixture_window):
presenter_mock = mocker.patch.object(EventPresenterInterface, "present")
presenter_mock.present.return_value = "Test output"
+ logger_mock = mocker.patch.object(LoggerInterface, "log_info")
+ container_mock = mocker.MagicMock()
+ container_mock.resolve.return_value = logger_mock
reply_messages_cor_mock = mocker.patch(
"src.interactor.use_cases.message.create_message_reply.ReplyMessagesCOR"
@@ -18,17 +21,12 @@ def test_create_message_reply(mocker: mock, fixture_window):
reply_messages_cor_instance = reply_messages_cor_mock.return_value
reply_messages_cor_instance.handle.return_value = [TextMessage(text="Test output")]
- repository_mock = mocker.patch(
- "src.interactor.use_cases.message.create_message_reply.AgentExecutorRepositoryInterface"
- )
-
input_dto_validator_mock = mocker.patch(
"src.interactor.use_cases.message.create_message_reply.EventInputDtoValidator"
)
- logger_mock = mocker.patch.object(LoggerInterface, "log_info")
- use_case = CreateMessageReplyUseCase(presenter_mock, repository_mock, logger_mock)
- input_dto = EventInputDto(window=fixture_window, user_input="Test input")
+ use_case = CreateMessageReplyUseCase(presenter_mock, container_mock)
+ input_dto = EventInputDto(window=fixture_window, user_input="Test input", source_type="Test source")
result = use_case.execute(input_dto)
input_dto_validator_mock.assert_called_once_with(input_dto.to_dict())
@@ -37,7 +35,7 @@ def test_create_message_reply(mocker: mock, fixture_window):
logger_mock.log_info.assert_called_once_with("Create reply successfully")
- reply_messages_cor_instance.handle.assert_called_once_with(input_dto, repository_mock)
+ reply_messages_cor_instance.handle.assert_called_once_with(input_dto)
output_dto = EventOutputDto(
window=fixture_window, user_input="Test input", response=[TextMessage(text="Test output")]
diff --git a/src/interactor/validations/event_input_validator.py b/src/interactor/validations/event_input_validator.py
index 8bd76cb..7d12043 100644
--- a/src/interactor/validations/event_input_validator.py
+++ b/src/interactor/validations/event_input_validator.py
@@ -1,7 +1,6 @@
""" Defines the validator for the create window input data.
"""
-
from typing import Dict
from src.interactor.validations.base_input_validator import BaseInputValidator
@@ -28,6 +27,7 @@ def __init__(self, input_data: Dict) -> None:
},
},
"user_input": {"type": "string", "required": True},
+ "source_type": {"type": "string", "required": True},
}
def validate(self) -> None:
diff --git a/src/interactor/validations/test/event_input_validator_test.py b/src/interactor/validations/test/event_input_validator_test.py
index ce09c25..471240e 100644
--- a/src/interactor/validations/test/event_input_validator_test.py
+++ b/src/interactor/validations/test/event_input_validator_test.py
@@ -3,7 +3,7 @@
from src.interactor.validations.event_input_validator import EventInputDtoValidator
-def test_get_window_validator_valid_data(mocker, fixture_window):
+def test_event_input_validator_valid_data(mocker, fixture_window):
mocker.patch("src.interactor.validations.base_input_validator.BaseInputValidator.verify")
input_data = {"window": fixture_window, "user_input": "test"}
schema = {
@@ -19,16 +19,17 @@ def test_get_window_validator_valid_data(mocker, fixture_window):
},
},
"user_input": {"type": "string", "required": True},
+ "source_type": {"type": "string", "required": True},
}
validator = EventInputDtoValidator(input_data)
validator.validate()
validator.verify.assert_called_once_with(schema)
-def test_get_window_validator_without_user_input(fixture_window):
+def test_event_input_validator_without_user_input(fixture_window):
# We are doing just a simple test as the complete test is done in
# base_input_validator_test.py
- input_data = {"window": fixture_window}
+ input_data = {"window": fixture_window, "source_type": "test"}
validator = EventInputDtoValidator(input_data)
with pytest.raises(ValueError) as exception_info:
validator.validate()