diff --git a/src/aiogram_dialog/api/entities/__init__.py b/src/aiogram_dialog/api/entities/__init__.py index 97c8b650..724bd35a 100644 --- a/src/aiogram_dialog/api/entities/__init__.py +++ b/src/aiogram_dialog/api/entities/__init__.py @@ -1,13 +1,24 @@ __all__ = [ - "Context", "Data", + "Context", + "Data", "ChatEvent", "LaunchMode", - "MediaAttachment", "MediaId", - "ShowMode", "StartMode", - "MarkupVariant", "NewMessage", "OldMessage", "UnknownText", - "DEFAULT_STACK_ID", "Stack", - "DIALOG_EVENT_NAME", "DialogAction", "DialogUpdateEvent", - "DialogStartEvent", "DialogSwitchEvent", "DialogUpdate", + "MediaAttachment", + "MediaId", + "ShowMode", + "StartMode", + "MarkupVariant", + "NewMessage", + "OldMessage", + "UnknownText", + "DEFAULT_STACK_ID", + "Stack", + "DIALOG_EVENT_NAME", + "DialogAction", + "DialogUpdateEvent", + "DialogStartEvent", + "DialogSwitchEvent", + "DialogUpdate", ] from .context import Context, Data @@ -18,6 +29,10 @@ from .new_message import MarkupVariant, NewMessage, OldMessage, UnknownText from .stack import DEFAULT_STACK_ID, Stack from .update_event import ( - DIALOG_EVENT_NAME, DialogAction, DialogStartEvent, DialogSwitchEvent, - DialogUpdate, DialogUpdateEvent, + DIALOG_EVENT_NAME, + DialogAction, + DialogStartEvent, + DialogSwitchEvent, + DialogUpdate, + DialogUpdateEvent, ) diff --git a/src/aiogram_dialog/api/entities/context.py b/src/aiogram_dialog/api/entities/context.py index 60c4555a..44b16ce9 100644 --- a/src/aiogram_dialog/api/entities/context.py +++ b/src/aiogram_dialog/api/entities/context.py @@ -3,7 +3,8 @@ from aiogram.fsm.state import State -Data = Union[Dict, List, int, str, float, None] +SerializationData = Union[int, str, float, None] +Data = Union[Dict[SerializationData, "Data"], List["Data"], SerializationData] DataDict = Dict[str, Data] diff --git a/src/aiogram_dialog/api/entities/events.py b/src/aiogram_dialog/api/entities/events.py index 83f89ca3..816faa55 100644 --- a/src/aiogram_dialog/api/entities/events.py +++ b/src/aiogram_dialog/api/entities/events.py @@ -1,12 +1,13 @@ from typing import Union -from aiogram.types import ( - CallbackQuery, ChatJoinRequest, ChatMemberUpdated, Message, -) +from aiogram.types import CallbackQuery, ChatJoinRequest, ChatMemberUpdated, Message from .update_event import DialogUpdateEvent ChatEvent = Union[ - CallbackQuery, Message, DialogUpdateEvent, - ChatMemberUpdated, ChatJoinRequest, + CallbackQuery, + Message, + DialogUpdateEvent, + ChatMemberUpdated, + ChatJoinRequest, ] diff --git a/src/aiogram_dialog/api/entities/media.py b/src/aiogram_dialog/api/entities/media.py index 1b03373a..202a0bb4 100644 --- a/src/aiogram_dialog/api/entities/media.py +++ b/src/aiogram_dialog/api/entities/media.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from pathlib import Path -from typing import Optional, Union +from typing import Any, Optional, Union from aiogram.types import ContentType @@ -10,8 +10,8 @@ class MediaId: file_id: str file_unique_id: Optional[str] = None - def __eq__(self, other): - if type(other) is not MediaId: + def __eq__(self, other: object) -> bool: + if not isinstance(other, MediaId): return False if self.file_unique_id is None or other.file_unique_id is None: return self.file_id == other.file_id @@ -20,13 +20,13 @@ def __eq__(self, other): class MediaAttachment: def __init__( - self, - type: ContentType, - url: Optional[str] = None, - path: Union[str, Path, None] = None, - file_id: Optional[MediaId] = None, - use_pipe: bool = False, - **kwargs, + self, + type: ContentType, + url: Optional[str] = None, + path: Union[str, Path, None] = None, + file_id: Optional[MediaId] = None, + use_pipe: bool = False, + **kwargs: Any, ): if not (url or path or file_id): raise ValueError("Neither url nor path not file_id are provided") @@ -37,14 +37,14 @@ def __init__( self.use_pipe = use_pipe self.kwargs = kwargs - def __eq__(self, other): - if type(other) is not type(self): + def __eq__(self, other: object) -> bool: + if not isinstance(other, MediaAttachment): return False return ( - self.type == other.type and - self.url == other.url and - self.path == other.path and - self.file_id == other.file_id and - self.use_pipe == other.use_pipe and - self.kwargs == other.kwargs + self.type == other.type + and self.url == other.url + and self.path == other.path + and self.file_id == other.file_id + and self.use_pipe == other.use_pipe + and self.kwargs == other.kwargs ) diff --git a/src/aiogram_dialog/api/entities/new_message.py b/src/aiogram_dialog/api/entities/new_message.py index 8fd4ffa2..c819bc78 100644 --- a/src/aiogram_dialog/api/entities/new_message.py +++ b/src/aiogram_dialog/api/entities/new_message.py @@ -3,14 +3,20 @@ from typing import Optional, Union from aiogram.types import ( - Chat, ForceReply, InlineKeyboardMarkup, ReplyKeyboardMarkup, + Chat, + ForceReply, + InlineKeyboardMarkup, + ReplyKeyboardMarkup, ReplyKeyboardRemove, ) from aiogram_dialog.api.entities import MediaAttachment, ShowMode MarkupVariant = Union[ - ForceReply, InlineKeyboardMarkup, ReplyKeyboardMarkup, ReplyKeyboardRemove, + ForceReply, + InlineKeyboardMarkup, + ReplyKeyboardMarkup, + ReplyKeyboardRemove, ] diff --git a/src/aiogram_dialog/api/entities/stack.py b/src/aiogram_dialog/api/entities/stack.py index 21440771..4b52aa9c 100644 --- a/src/aiogram_dialog/api/entities/stack.py +++ b/src/aiogram_dialog/api/entities/stack.py @@ -7,6 +7,7 @@ from aiogram.fsm.state import State from aiogram_dialog.api.exceptions import DialogStackOverflow + from .context import Context, Data DEFAULT_STACK_ID = "" @@ -29,7 +30,7 @@ def id_to_str(int_id: int) -> str: return res -def new_id(): +def new_id() -> str: return id_to_str(new_int_id()) @@ -42,11 +43,12 @@ class Stack: last_media_id: Optional[str] = field(compare=False, default=None) last_media_unique_id: Optional[str] = field(compare=False, default=None) last_income_media_group_id: Optional[str] = field( - compare=False, default=None, + compare=False, + default=None, ) @property - def id(self): + def id(self) -> str: return self._id def push(self, state: State, data: Data) -> Context: @@ -64,14 +66,14 @@ def push(self, state: State, data: Data) -> Context: self.intents.append(context.id) return context - def pop(self): + def pop(self) -> str: return self.intents.pop() - def last_intent_id(self): + def last_intent_id(self) -> str: return self.intents[-1] - def empty(self): + def empty(self) -> bool: return not self.intents - def default(self): + def default(self) -> bool: return self.id == DEFAULT_STACK_ID diff --git a/src/aiogram_dialog/api/entities/update_event.py b/src/aiogram_dialog/api/entities/update_event.py index 3c189858..aa174efb 100644 --- a/src/aiogram_dialog/api/entities/update_event.py +++ b/src/aiogram_dialog/api/entities/update_event.py @@ -2,18 +2,10 @@ from typing import Any, Optional from aiogram.fsm.state import State -from aiogram.types import ( - Chat, - TelegramObject, - Update, - User, -) +from aiogram.types import Chat, TelegramObject, Update, User from pydantic import ConfigDict -from .modes import ( - ShowMode, - StartMode, -) +from .modes import ShowMode, StartMode DIALOG_EVENT_NAME = "aiogd_update" diff --git a/src/aiogram_dialog/api/exceptions.py b/src/aiogram_dialog/api/exceptions.py index 89682d61..b76eba9c 100644 --- a/src/aiogram_dialog/api/exceptions.py +++ b/src/aiogram_dialog/api/exceptions.py @@ -12,7 +12,7 @@ class UnknownIntent(DialogsError): class OutdatedIntent(DialogsError): - def __init__(self, stack_id, text): + def __init__(self, stack_id: str, text: str) -> None: super().__init__(text) self.stack_id = stack_id diff --git a/src/aiogram_dialog/api/internal/__init__.py b/src/aiogram_dialog/api/internal/__init__.py index c3e86d15..1e1449e4 100644 --- a/src/aiogram_dialog/api/internal/__init__.py +++ b/src/aiogram_dialog/api/internal/__init__.py @@ -1,22 +1,41 @@ __all__ = [ - "FakeChat", "FakeUser", "ReplyCallbackQuery", + "FakeChat", + "FakeUser", + "ReplyCallbackQuery", "DialogManagerFactory", - "CALLBACK_DATA_KEY", "CONTEXT_KEY", "EVENT_SIMULATED", - "STACK_KEY", "STORAGE_KEY", - "ButtonVariant", "DataGetter", "InputWidget", "KeyboardWidget", - "MediaWidget", "RawKeyboard", "TextWidget", "Widget", + "CALLBACK_DATA_KEY", + "CONTEXT_KEY", + "EVENT_SIMULATED", + "STACK_KEY", + "STORAGE_KEY", + "ButtonVariant", + "DataGetter", + "InputWidget", + "KeyboardWidget", + "MediaWidget", + "RawKeyboard", + "TextWidget", + "Widget", "WindowProtocol", ] from .fake_data import FakeChat, FakeUser, ReplyCallbackQuery -from .manager import ( - DialogManagerFactory, -) +from .manager import DialogManagerFactory from .middleware import ( - CALLBACK_DATA_KEY, CONTEXT_KEY, EVENT_SIMULATED, STACK_KEY, STORAGE_KEY, + CALLBACK_DATA_KEY, + CONTEXT_KEY, + EVENT_SIMULATED, + STACK_KEY, + STORAGE_KEY, ) from .widgets import ( - ButtonVariant, DataGetter, InputWidget, KeyboardWidget, - MediaWidget, RawKeyboard, TextWidget, Widget, + ButtonVariant, + DataGetter, + InputWidget, + KeyboardWidget, + MediaWidget, + RawKeyboard, + TextWidget, + Widget, ) from .window import WindowProtocol diff --git a/src/aiogram_dialog/api/internal/fake_data.py b/src/aiogram_dialog/api/internal/fake_data.py index 331d7300..0353ab73 100644 --- a/src/aiogram_dialog/api/internal/fake_data.py +++ b/src/aiogram_dialog/api/internal/fake_data.py @@ -1,12 +1,7 @@ from typing import Any, Literal from aiogram.methods import AnswerCallbackQuery -from aiogram.types import ( - CallbackQuery, - Chat, - Message, - User, -) +from aiogram.types import CallbackQuery, Chat, Message, User class ReplyCallbackQuery(CallbackQuery): diff --git a/src/aiogram_dialog/api/internal/manager.py b/src/aiogram_dialog/api/internal/manager.py index 73e399c5..001021b6 100644 --- a/src/aiogram_dialog/api/internal/manager.py +++ b/src/aiogram_dialog/api/internal/manager.py @@ -1,19 +1,19 @@ from abc import abstractmethod -from typing import Dict, Protocol +from typing import Any, Dict, Protocol from aiogram import Router from aiogram_dialog.api.entities import ChatEvent -from aiogram_dialog.api.protocols import ( - DialogManager, DialogRegistryProtocol, -) +from aiogram_dialog.api.protocols import DialogManager, DialogRegistryProtocol class DialogManagerFactory(Protocol): @abstractmethod def __call__( - self, event: ChatEvent, data: Dict, - registry: DialogRegistryProtocol, - router: Router, + self, + event: ChatEvent, + data: Dict[str, Any], + registry: DialogRegistryProtocol, + router: Router, ) -> DialogManager: raise NotImplementedError diff --git a/src/aiogram_dialog/api/internal/widgets.py b/src/aiogram_dialog/api/internal/widgets.py index df306af8..62334155 100644 --- a/src/aiogram_dialog/api/internal/widgets.py +++ b/src/aiogram_dialog/api/internal/widgets.py @@ -1,22 +1,28 @@ from abc import abstractmethod from typing import ( - Any, Awaitable, Callable, Dict, List, Optional, Protocol, - runtime_checkable, Union, + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Protocol, + Union, + runtime_checkable, ) -from aiogram.types import ( - CallbackQuery, InlineKeyboardButton, KeyboardButton, Message, -) +from aiogram.types import CallbackQuery, InlineKeyboardButton, KeyboardButton, Message -from aiogram_dialog import DialogManager +from aiogram_dialog import ChatEvent, DialogManager from aiogram_dialog.api.entities import MarkupVariant, MediaAttachment +from aiogram_dialog.api.entities.context import DataDict from aiogram_dialog.api.protocols import DialogProtocol @runtime_checkable class Widget(Protocol): @abstractmethod - def managed(self, manager: DialogManager) -> Any: + def managed(self, manager: DialogManager) -> "Widget": raise NotImplementedError @abstractmethod @@ -28,7 +34,9 @@ def find(self, widget_id: str) -> Optional["Widget"]: class TextWidget(Widget, Protocol): @abstractmethod async def render_text( - self, data: Dict, manager: DialogManager, + self, + data: Dict[str, Union[DataDict, Dict[str, Any], ChatEvent]], + manager: DialogManager, ) -> str: """Create text.""" raise NotImplementedError @@ -42,15 +50,19 @@ async def render_text( class KeyboardWidget(Widget, Protocol): @abstractmethod async def render_keyboard( - self, data: Dict, manager: DialogManager, + self, + data: Dict[str, Union[DataDict, Dict[str, Any], ChatEvent]], + manager: DialogManager, ) -> RawKeyboard: """Create Inline keyboard contents.""" raise NotImplementedError @abstractmethod async def process_callback( - self, callback: CallbackQuery, dialog: DialogProtocol, - manager: DialogManager, + self, + callback: CallbackQuery, + dialog: DialogProtocol, + manager: DialogManager, ) -> bool: """ Handle user click on some inline button. @@ -66,7 +78,9 @@ async def process_callback( class MediaWidget(Widget, Protocol): @abstractmethod async def render_media( - self, data: dict, manager: DialogManager, + self, + data: Dict[str, Union[DataDict, Dict[str, Any], ChatEvent]], + manager: DialogManager, ) -> Optional[MediaAttachment]: """Create media attachment.""" raise NotImplementedError @@ -76,8 +90,10 @@ async def render_media( class InputWidget(Widget, Protocol): @abstractmethod async def process_message( - self, message: Message, dialog: DialogProtocol, - manager: DialogManager, + self, + message: Message, + dialog: DialogProtocol, + manager: DialogManager, ) -> bool: """ Handle incoming message from user. @@ -89,14 +105,17 @@ async def process_message( raise NotImplementedError -DataGetter = Callable[..., Awaitable[Dict]] +DataGetter = Callable[..., Awaitable[Dict[str, Any]]] @runtime_checkable class MarkupFactory(Protocol): @abstractmethod async def render_markup( - self, data: dict, manager: DialogManager, keyboard: RawKeyboard, + self, + data: Dict[str, Union[DataDict, Dict[str, Any], ChatEvent]], + manager: DialogManager, + keyboard: RawKeyboard, ) -> MarkupVariant: """Render reply_markup using prepared keyboard.""" raise NotImplementedError diff --git a/src/aiogram_dialog/api/internal/window.py b/src/aiogram_dialog/api/internal/window.py index bdafd22a..4f13dd5e 100644 --- a/src/aiogram_dialog/api/internal/window.py +++ b/src/aiogram_dialog/api/internal/window.py @@ -1,69 +1,72 @@ from abc import abstractmethod -from typing import ( - Any, - Dict, - Protocol, -) +from typing import Any, Dict, Protocol, Union from aiogram.fsm.state import State from aiogram.types import CallbackQuery, Message -from aiogram_dialog.api.entities import Data, NewMessage +from aiogram_dialog import ChatEvent, DialogManager +from aiogram_dialog.api.entities import Data, MarkupVariant, NewMessage +from aiogram_dialog.api.entities.context import DataDict from aiogram_dialog.api.protocols import DialogProtocol -from .manager import DialogManager -from .widgets import MarkupVariant class WindowProtocol(Protocol): @abstractmethod - async def render_text(self, data: Dict, - manager: DialogManager) -> str: + async def render_text( + self, + data: Dict[str, Union[DataDict, Dict[str, Any], ChatEvent]], + manager: DialogManager, + ) -> str: raise NotImplementedError @abstractmethod async def render_kbd( - self, data: Dict, manager: DialogManager, + self, + data: Dict[str, Union[DataDict, Dict[str, Any], ChatEvent]], + manager: DialogManager, ) -> MarkupVariant: raise NotImplementedError @abstractmethod async def load_data( - self, - dialog: "DialogProtocol", - manager: DialogManager, - ) -> Dict: + self, + dialog: "DialogProtocol", + manager: DialogManager, + ) -> Dict[str, Union[DataDict, Dict[str, Any], ChatEvent]]: raise NotImplementedError @abstractmethod async def process_message( - self, - message: Message, - dialog: "DialogProtocol", - manager: DialogManager, + self, + message: Message, + dialog: "DialogProtocol", + manager: DialogManager, ) -> None: raise NotImplementedError @abstractmethod async def process_callback( - self, - callback: CallbackQuery, - dialog: "DialogProtocol", - manager: DialogManager, + self, + callback: CallbackQuery, + dialog: "DialogProtocol", + manager: DialogManager, ) -> None: raise NotImplementedError @abstractmethod async def process_result( - self, start_data: Data, result: Any, - manager: "DialogManager", + self, + start_data: Data, + result: Any, + manager: "DialogManager", ) -> None: raise NotImplementedError @abstractmethod async def render( - self, - dialog: "DialogProtocol", - manager: DialogManager, + self, + dialog: "DialogProtocol", + manager: DialogManager, ) -> NewMessage: raise NotImplementedError @@ -72,5 +75,5 @@ def get_state(self) -> State: raise NotImplementedError @abstractmethod - def find(self, widget_id) -> Any: + def find(self, widget_id: str) -> Any: raise NotImplementedError diff --git a/src/aiogram_dialog/api/protocols/__init__.py b/src/aiogram_dialog/api/protocols/__init__.py index 77571aea..44af9102 100644 --- a/src/aiogram_dialog/api/protocols/__init__.py +++ b/src/aiogram_dialog/api/protocols/__init__.py @@ -1,9 +1,13 @@ __all__ = [ "DialogProtocol", - "BaseDialogManager", "BgManagerFactory", "DialogManager", + "BaseDialogManager", + "BgManagerFactory", + "DialogManager", "MediaIdStorageProtocol", - "MessageManagerProtocol", "MessageNotModified", - "DialogProtocol", "DialogRegistryProtocol", + "MessageManagerProtocol", + "MessageNotModified", + "DialogProtocol", + "DialogRegistryProtocol", ] from .dialog import DialogProtocol diff --git a/src/aiogram_dialog/api/protocols/dialog.py b/src/aiogram_dialog/api/protocols/dialog.py index 40fac5cc..73b818d9 100644 --- a/src/aiogram_dialog/api/protocols/dialog.py +++ b/src/aiogram_dialog/api/protocols/dialog.py @@ -1,11 +1,13 @@ from abc import abstractmethod -from typing import Any, Dict, List, Optional, Protocol, runtime_checkable, Type +from typing import Any, Dict, List, Optional, Protocol, Type, Union, runtime_checkable from aiogram.fsm.state import State, StatesGroup -from aiogram_dialog.api.entities import ( - Data, LaunchMode, NewMessage, -) +from aiogram_dialog.api.entities import Data, LaunchMode, NewMessage +from ..internal import Widget + +from ... import ChatEvent +from ..entities.context import DataDict from .manager import DialogManager @@ -29,34 +31,39 @@ def states_group(self) -> Type[StatesGroup]: @abstractmethod async def process_close( - self, result: Any, manager: DialogManager, + self, + result: Any, + manager: DialogManager, ) -> None: raise NotImplementedError @abstractmethod async def process_start( - self, - manager: "DialogManager", - start_data: Data, - state: Optional[State] = None, + self, + manager: "DialogManager", + start_data: Data, + state: Optional[State] = None, ) -> None: raise NotImplementedError @abstractmethod async def process_result( - self, start_data: Data, result: Any, - manager: "DialogManager", + self, + start_data: Data, + result: Any, + manager: "DialogManager", ) -> None: raise NotImplementedError @abstractmethod - def find(self, widget_id) -> Any: + def find(self, widget_id: str) -> Optional[Widget]: raise NotImplementedError @abstractmethod async def load_data( - self, manager: DialogManager, - ) -> Dict: + self, + manager: DialogManager, + ) -> Dict[str, Union[DataDict, Dict[str, Any], ChatEvent]]: raise NotImplementedError @abstractmethod diff --git a/src/aiogram_dialog/api/protocols/manager.py b/src/aiogram_dialog/api/protocols/manager.py index beb202f6..8d3188dc 100644 --- a/src/aiogram_dialog/api/protocols/manager.py +++ b/src/aiogram_dialog/api/protocols/manager.py @@ -1,56 +1,63 @@ from abc import abstractmethod -from typing import Any, Dict, Optional, Protocol +from typing import Any, Dict, Optional, Protocol, Union from aiogram import Bot from aiogram.fsm.state import State from aiogram_dialog.api.entities import ( - ChatEvent, Context, Data, ShowMode, Stack, StartMode, + ChatEvent, + Context, + Data, + ShowMode, + Stack, + StartMode, ) +from aiogram_dialog.api.entities.context import DataDict +from aiogram_dialog.api.internal import Widget class BaseDialogManager(Protocol): @abstractmethod async def done( - self, - result: Any = None, - show_mode: Optional[ShowMode] = None, + self, + result: Any = None, + show_mode: Optional[ShowMode] = None, ) -> None: raise NotImplementedError @abstractmethod async def start( - self, - state: State, - data: Data = None, - mode: StartMode = StartMode.NORMAL, - show_mode: Optional[ShowMode] = None, + self, + state: State, + data: Data = None, + mode: StartMode = StartMode.NORMAL, + show_mode: Optional[ShowMode] = None, ) -> None: raise NotImplementedError @abstractmethod async def switch_to( - self, - state: State, - show_mode: Optional[ShowMode] = None, + self, + state: State, + show_mode: Optional[ShowMode] = None, ) -> None: raise NotImplementedError @abstractmethod async def update( - self, - data: Dict, - show_mode: Optional[ShowMode] = None, + self, + data: DataDict, + show_mode: Optional[ShowMode] = None, ) -> None: raise NotImplementedError @abstractmethod def bg( - self, - user_id: Optional[int] = None, - chat_id: Optional[int] = None, - stack_id: Optional[str] = None, - load: bool = False, # load chat and user + self, + user_id: Optional[int] = None, + chat_id: Optional[int] = None, + stack_id: Optional[str] = None, + load: bool = False, # load chat and user ) -> "BaseDialogManager": raise NotImplementedError @@ -58,12 +65,12 @@ def bg( class BgManagerFactory(Protocol): @abstractmethod def bg( - self, - bot: Bot, - user_id: int, - chat_id: int, - stack_id: Optional[str] = None, - load: bool = False, # load chat and user + self, + bot: Bot, + user_id: int, + chat_id: int, + stack_id: Optional[str] = None, + load: bool = False, # load chat and user ) -> "BaseDialogManager": raise NotImplementedError @@ -80,13 +87,13 @@ async def mark_closed(self) -> None: @property @abstractmethod - def middleware_data(self) -> Dict: + def middleware_data(self) -> Dict[str, Any]: """Middleware data.""" raise NotImplementedError @property @abstractmethod - def dialog_data(self) -> Dict: + def dialog_data(self) -> DataDict: """Dialog data for current context.""" raise NotImplementedError @@ -153,7 +160,7 @@ async def back(self, show_mode: Optional[ShowMode] = None) -> None: raise NotImplementedError @abstractmethod - def find(self, widget_id) -> Optional[Any]: + def find(self, widget_id: str) -> Optional[Widget]: """ Find a widget in current dialog by its id. @@ -172,7 +179,9 @@ async def reset_stack(self, remove_keyboard: bool = True) -> None: raise NotImplementedError @abstractmethod - async def load_data(self) -> Dict: + async def load_data( + self, + ) -> Dict[str, Union[Data, DataDict, Dict[str, Any], ChatEvent]]: """Load data for current state.""" raise NotImplementedError diff --git a/src/aiogram_dialog/api/protocols/media.py b/src/aiogram_dialog/api/protocols/media.py index 134a76e5..e1dea017 100644 --- a/src/aiogram_dialog/api/protocols/media.py +++ b/src/aiogram_dialog/api/protocols/media.py @@ -1,5 +1,6 @@ from abc import abstractmethod -from typing import Optional, Protocol +from pathlib import Path +from typing import Optional, Protocol, Union from aiogram.types import ContentType @@ -9,19 +10,19 @@ class MediaIdStorageProtocol(Protocol): @abstractmethod async def get_media_id( - self, - path: Optional[str], - url: Optional[str], - type: ContentType, + self, + path: Optional[Union[str, Path]], + url: Optional[str], + type: ContentType, ) -> Optional[MediaId]: raise NotImplementedError @abstractmethod async def save_media_id( - self, - path: Optional[str], - url: Optional[str], - type: ContentType, - media_id: MediaId, + self, + path: Optional[Union[str, Path]], + url: Optional[str], + type: ContentType, + media_id: MediaId, ) -> None: raise NotImplementedError diff --git a/src/aiogram_dialog/api/protocols/message_manager.py b/src/aiogram_dialog/api/protocols/message_manager.py index e459bc47..dc230e83 100644 --- a/src/aiogram_dialog/api/protocols/message_manager.py +++ b/src/aiogram_dialog/api/protocols/message_manager.py @@ -16,22 +16,26 @@ class MessageNotModified(DialogsError): class MessageManagerProtocol(Protocol): @abstractmethod async def remove_kbd( - self, - bot: Bot, - show_mode: ShowMode, - old_message: Optional[OldMessage], + self, + bot: Bot, + show_mode: ShowMode, + old_message: Optional[OldMessage], ) -> Optional[Message]: raise NotImplementedError @abstractmethod async def show_message( - self, bot: Bot, new_message: NewMessage, - old_message: Optional[OldMessage], + self, + bot: Bot, + new_message: NewMessage, + old_message: Optional[OldMessage], ) -> OldMessage: raise NotImplementedError @abstractmethod async def answer_callback( - self, bot: Bot, callback_query: CallbackQuery, + self, + bot: Bot, + callback_query: CallbackQuery, ) -> None: raise NotImplementedError diff --git a/src/aiogram_dialog/context/intent_filter.py b/src/aiogram_dialog/context/intent_filter.py index 29309fdf..b5bfcb51 100644 --- a/src/aiogram_dialog/context/intent_filter.py +++ b/src/aiogram_dialog/context/intent_filter.py @@ -1,4 +1,4 @@ -from typing import Optional, Type +from typing import Any, Optional, Type from aiogram.filters import BaseFilter from aiogram.fsm.state import StatesGroup @@ -12,12 +12,12 @@ class IntentFilter(BaseFilter): def __init__(self, aiogd_intent_state_group: Optional[Type[StatesGroup]]): self.aiogd_intent_state_group = aiogd_intent_state_group - async def __call__(self, obj: TelegramObject, **kwargs) -> bool: + async def __call__(self, obj: TelegramObject, **kwargs: Any) -> bool: del obj # unused if self.aiogd_intent_state_group is None: return True - context: Context = kwargs.get(CONTEXT_KEY) + context: Optional[Context] = kwargs.get(CONTEXT_KEY) if not context: return False return context.state.group == self.aiogd_intent_state_group diff --git a/src/aiogram_dialog/context/intent_middleware.py b/src/aiogram_dialog/context/intent_middleware.py index 6e9193ba..7a690ded 100644 --- a/src/aiogram_dialog/context/intent_middleware.py +++ b/src/aiogram_dialog/context/intent_middleware.py @@ -1,23 +1,35 @@ from logging import getLogger -from typing import Any, Awaitable, Callable, Dict, Optional +from typing import Any, Awaitable, Callable, Dict, Optional, cast from aiogram import Router from aiogram.dispatcher.middlewares.base import BaseMiddleware -from aiogram.types import CallbackQuery, Chat, Message, User +from aiogram.types import CallbackQuery, Chat, Message, TelegramObject, User from aiogram.types.error_event import ErrorEvent from aiogram_dialog.api.entities import ( - ChatEvent, Context, DEFAULT_STACK_ID, DialogUpdateEvent, Stack, + DEFAULT_STACK_ID, + ChatEvent, + Context, + DialogUpdateEvent, + Stack, ) from aiogram_dialog.api.exceptions import ( - InvalidStackIdError, OutdatedIntent, UnknownIntent, UnknownState, + InvalidStackIdError, + OutdatedIntent, + UnknownIntent, + UnknownState, ) from aiogram_dialog.api.internal import ( - CALLBACK_DATA_KEY, CONTEXT_KEY, EVENT_SIMULATED, - ReplyCallbackQuery, STACK_KEY, STORAGE_KEY, + CALLBACK_DATA_KEY, + CONTEXT_KEY, + EVENT_SIMULATED, + STACK_KEY, + STORAGE_KEY, + ReplyCallbackQuery, ) from aiogram_dialog.api.protocols import DialogRegistryProtocol from aiogram_dialog.utils import remove_intent_id, split_reply_callback + from .storage import StorageProxy logger = getLogger(__name__) @@ -25,51 +37,47 @@ class IntentMiddlewareFactory: def __init__( - self, - registry: DialogRegistryProtocol, + self, + registry: DialogRegistryProtocol, ): super().__init__() self.registry = registry - def storage_proxy(self, data: dict): - proxy = StorageProxy( + def storage_proxy(self, data: Dict[str, Any]) -> StorageProxy: + return StorageProxy( bot=data["bot"], storage=data["fsm_storage"], user_id=data["event_from_user"].id, chat_id=data["event_chat"].id, state_groups=self.registry.states_groups(), ) - return proxy - def _check_outdated(self, intent_id: str, stack: Stack): + def _check_outdated(self, intent_id: str, stack: Stack) -> None: """Check if intent id is outdated for stack.""" if stack.empty(): raise OutdatedIntent( stack.id, - f"Outdated intent id ({intent_id}) " - f"for stack ({stack.id})", + f"Outdated intent id ({intent_id}) " f"for stack ({stack.id})", ) elif intent_id != stack.last_intent_id(): raise OutdatedIntent( stack.id, - f"Outdated intent id ({intent_id}) " - f"for stack ({stack.id})", + f"Outdated intent id ({intent_id}) " f"for stack ({stack.id})", ) async def _load_context( - self, - event: ChatEvent, - intent_id: Optional[str], - stack_id: Optional[str], - data: dict, + self, + event: ChatEvent, + intent_id: Optional[str], + stack_id: Optional[str], + data: Dict[str, Any], ) -> None: proxy = self.storage_proxy(data) logger.debug( - "Loading context for intent: `%s`, " - "stack: `%s`, user: `%s`, chat: `%s`", + "Loading context for intent: `%s`, " "stack: `%s`, user: `%s`, chat: `%s`", intent_id, stack_id, - event.from_user.id, + cast(User, event.from_user).id, proxy.chat_id, ) if intent_id is not None: @@ -91,13 +99,15 @@ async def _load_context( data[CONTEXT_KEY] = context def _intent_id_from_reply( - self, event: Message, data: dict, + self, + event: Message, + data: Dict[str, Any], ) -> Optional[str]: if not ( - event.reply_to_message and - event.reply_to_message.from_user.id == data["bot"].id and - event.reply_to_message.reply_markup and - event.reply_to_message.reply_markup.inline_keyboard + event.reply_to_message + and cast(User, event.reply_to_message.from_user).id == data["bot"].id + and event.reply_to_message.reply_markup + and event.reply_to_message.reply_markup.inline_keyboard ): return None for row in event.reply_to_message.reply_markup.inline_keyboard: @@ -108,11 +118,11 @@ def _intent_id_from_reply( return None async def process_message( - self, - handler: Callable, - event: Message, - data: dict, - ): + self, + handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]], + event: Message, + data: Dict[str, Any], + ) -> Any: text, callback_data = split_reply_callback(event.text) if callback_data: query = ReplyCallbackQuery( @@ -120,7 +130,7 @@ async def process_message( message=None, original_message=event, data=callback_data, - from_user=event.from_user, + from_user=event.from_user, # type: ignore[call-arg] # we cannot know real chat instance chat_instance=str(event.chat.id), ).as_(data["bot"]) @@ -139,38 +149,38 @@ async def process_message( return await handler(event, data) async def process_my_chat_member( - self, - handler: Callable, - event: Message, - data: dict, - ) -> None: + self, + handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]], + event: Message, + data: Dict[str, Any], + ) -> Any: await self._load_context(event, None, DEFAULT_STACK_ID, data) return await handler(event, data) async def process_chat_join_request( - self, - handler: Callable, - event: Message, - data: dict, - ) -> None: + self, + handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]], + event: Message, + data: Dict[str, Any], + ) -> Any: await self._load_context(event, None, DEFAULT_STACK_ID, data) return await handler(event, data) async def process_aiogd_update( - self, - handler: Callable, - event: DialogUpdateEvent, - data: dict, - ): + self, + handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]], + event: DialogUpdateEvent, + data: Dict[str, Any], + ) -> Any: await self._load_context(event, event.intent_id, event.stack_id, data) return await handler(event, data) async def process_callback_query( - self, - handler: Callable, - event: CallbackQuery, - data: dict, - ): + self, + handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]], + event: CallbackQuery, + data: Dict[str, Any], + ) -> Any: if "event_chat" not in data: return await handler(event, data) proxy = self.storage_proxy(data) @@ -194,7 +204,14 @@ async def process_callback_query( } -async def context_saver_middleware(handler, event, data): +async def context_saver_middleware( + handler: Callable[ + [TelegramObject, Dict[str, Any]], + Awaitable[Any], + ], + event: TelegramObject, + data: Dict[str, Any], +) -> Any: result = await handler(event, data) proxy: StorageProxy = data.pop(STORAGE_KEY, None) if proxy: @@ -205,14 +222,16 @@ async def context_saver_middleware(handler, event, data): class IntentErrorMiddleware(BaseMiddleware): def __init__( - self, - registry: DialogRegistryProtocol, + self, + registry: DialogRegistryProtocol, ): super().__init__() self.registry = registry def _is_error_supported( - self, event: ErrorEvent, data: Dict[str, Any], + self, + event: ErrorEvent, + data: Dict[str, Any], ) -> bool: if isinstance(event, InvalidStackIdError): return False @@ -225,33 +244,40 @@ def _is_error_supported( return True async def _fix_broken_stack( - self, storage: StorageProxy, stack: Stack, + self, + storage: StorageProxy, + stack: Stack, ) -> None: while not stack.empty(): await storage.remove_context(stack.pop()) await storage.save_stack(stack) async def _load_last_context( - self, storage: StorageProxy, stack: Stack, - chat: Chat, user: User, + self, + storage: StorageProxy, + stack: Stack, + chat: Chat, + user: User, ) -> Optional[Context]: try: return await storage.load_context(stack.last_intent_id()) except (UnknownIntent, OutdatedIntent): logger.warning( "Stack is broken for user %s, chat %s, resetting", - user.id, chat.id, + user.id, + chat.id, ) await self._fix_broken_stack(storage, stack) return None async def __call__( - self, - handler: Callable[ - [ErrorEvent, Dict[str, Any]], Awaitable[Any], - ], - event: ErrorEvent, - data: Dict[str, Any], + self, + handler: Callable[ + [ErrorEvent, Dict[str, Any]], + Awaitable[Any], + ], + event: ErrorEvent, # type: ignore[override] + data: Dict[str, Any], ) -> Any: error = event.exception if not self._is_error_supported(event, data): @@ -277,15 +303,18 @@ async def __call__( context = None else: context = await self._load_last_context( - storage=proxy, stack=stack, chat=chat, user=user, + storage=proxy, + stack=stack, + chat=chat, + user=user, ) data[STACK_KEY] = stack data[CONTEXT_KEY] = context return await handler(event, data) finally: - proxy: StorageProxy = data.pop(STORAGE_KEY, None) - if proxy: + storage_proxy: StorageProxy = data.pop(STORAGE_KEY, None) + if storage_proxy: context = data.pop(CONTEXT_KEY) if context is not None: - await proxy.save_context(context) - await proxy.save_stack(data.pop(STACK_KEY)) + await storage_proxy.save_context(context) + await storage_proxy.save_stack(data.pop(STACK_KEY)) diff --git a/src/aiogram_dialog/context/media_storage.py b/src/aiogram_dialog/context/media_storage.py index 9a972d2f..acd1beed 100644 --- a/src/aiogram_dialog/context/media_storage.py +++ b/src/aiogram_dialog/context/media_storage.py @@ -1,4 +1,5 @@ -from typing import Optional +from pathlib import Path +from typing import Any, Optional, Union from aiogram.types import ContentType from cachetools import LRUCache @@ -8,25 +9,27 @@ class MediaIdStorage(MediaIdStorageProtocol): - def __init__(self, maxsize=10240): + cache: LRUCache[Any, Any] + + def __init__(self, maxsize: int = 10240) -> None: self.cache = LRUCache(maxsize=maxsize) async def get_media_id( - self, - path: Optional[str], - url: Optional[str], - type: ContentType, + self, + path: Optional[Union[str, Path]], + url: Optional[str], + type: ContentType, ) -> Optional[MediaId]: if not path and not url: return None return self.cache.get((path, url, type)) async def save_media_id( - self, - path: Optional[str], - url: Optional[str], - type: ContentType, - media_id: MediaId, + self, + path: Optional[Union[str, Path]], + url: Optional[str], + type: ContentType, + media_id: MediaId, ) -> None: if not path and not url: return None diff --git a/src/aiogram_dialog/context/storage.py b/src/aiogram_dialog/context/storage.py index 47d30d7b..d79c28f5 100644 --- a/src/aiogram_dialog/context/storage.py +++ b/src/aiogram_dialog/context/storage.py @@ -5,20 +5,18 @@ from aiogram.fsm.state import State, StatesGroup from aiogram.fsm.storage.base import BaseStorage, StorageKey -from aiogram_dialog.api.entities import ( - Context, DEFAULT_STACK_ID, Stack, -) +from aiogram_dialog.api.entities import DEFAULT_STACK_ID, Context, Stack from aiogram_dialog.api.exceptions import UnknownIntent, UnknownState class StorageProxy: def __init__( - self, - storage: BaseStorage, - user_id: int, - chat_id: int, - bot: Bot, - state_groups: Dict[str, Type[StatesGroup]], + self, + storage: BaseStorage, + user_id: int, + chat_id: int, + bot: Bot, + state_groups: Dict[str, Type[StatesGroup]], ): self.storage = storage self.state_groups = state_groups @@ -55,13 +53,13 @@ async def save_context(self, context: Optional[Context]) -> None: data=data, ) - async def remove_context(self, intent_id: str): + async def remove_context(self, intent_id: str) -> None: await self.storage.set_data( key=self._context_key(intent_id), data={}, ) - async def remove_stack(self, stack_id: str): + async def remove_stack(self, stack_id: str) -> None: await self.storage.set_data( key=self._stack_key(stack_id), data={}, diff --git a/src/aiogram_dialog/manager/bg_manager.py b/src/aiogram_dialog/manager/bg_manager.py index 20156ec0..0571a90d 100644 --- a/src/aiogram_dialog/manager/bg_manager.py +++ b/src/aiogram_dialog/manager/bg_manager.py @@ -1,13 +1,13 @@ from logging import getLogger -from typing import Any, Dict, Optional +from typing import Any, Optional, TypedDict from aiogram import Bot, Router from aiogram.fsm.state import State from aiogram.types import Chat, User from aiogram_dialog.api.entities import ( - Data, DEFAULT_STACK_ID, + Data, DialogAction, DialogStartEvent, DialogSwitchEvent, @@ -16,9 +16,8 @@ ShowMode, StartMode, ) -from aiogram_dialog.api.internal import ( - FakeChat, FakeUser, -) +from aiogram_dialog.api.entities.context import DataDict +from aiogram_dialog.api.internal import FakeChat, FakeUser from aiogram_dialog.api.protocols import BaseDialogManager, BgManagerFactory from aiogram_dialog.manager.updater import Updater from aiogram_dialog.utils import is_chat_loaded, is_user_loaded @@ -26,16 +25,23 @@ logger = getLogger(__name__) +class BaseEventParams(TypedDict): + from_user: User + chat: Chat + intent_id: Optional[str] + stack_id: Optional[str] + + class BgManager(BaseDialogManager): def __init__( - self, - user: User, - chat: Chat, - bot: Bot, - router: Router, - intent_id: Optional[str], - stack_id: Optional[str], - load: bool = False, + self, + user: User, + chat: Chat, + bot: Bot, + router: Router, + intent_id: Optional[str], + stack_id: Optional[str], + load: bool = False, ): self.user = user self.chat = chat @@ -47,18 +53,18 @@ def __init__( self.load = load def bg( - self, - user_id: Optional[int] = None, - chat_id: Optional[int] = None, - stack_id: Optional[str] = None, - load: bool = False, + self, + user_id: Optional[int] = None, + chat_id: Optional[int] = None, + stack_id: Optional[str] = None, + load: bool = False, ) -> "BaseDialogManager": - if chat_id in (None, self.chat.id): + if chat_id is None or chat_id == self.chat.id: chat = self.chat else: chat = FakeChat(id=chat_id, type="") - if user_id in (None, self.user.id): + if user_id is None or user_id == self.user.id: user = self.user else: user = FakeUser(id=user_id, is_bot=False, first_name="") @@ -84,7 +90,7 @@ def bg( load=load, ) - def _base_event_params(self): + def _base_event_params(self) -> BaseEventParams: return { "from_user": self.user, "chat": self.chat, @@ -92,44 +98,50 @@ def _base_event_params(self): "stack_id": self.stack_id, } - async def _notify(self, event: DialogUpdateEvent): + async def _notify(self, event: DialogUpdateEvent) -> None: await self._updater.notify( - bot=self.bot, update=DialogUpdate(aiogd_update=event), + bot=self.bot, + update=DialogUpdate(aiogd_update=event), ) - async def _load(self): + async def _load(self) -> None: if self.load: if not is_chat_loaded(self.chat): logger.debug("load chat: %s", self.chat.id) self.chat = await self.bot.get_chat(self.chat.id) if not is_user_loaded(self.user): logger.debug( - "load user %s from chat %s", self.chat.id, self.user.id, + "load user %s from chat %s", + self.chat.id, + self.user.id, ) chat_member = await self.bot.get_chat_member( - self.chat.id, self.user.id, + self.chat.id, + self.user.id, ) self.user = chat_member.user async def done( - self, - result: Any = None, - show_mode: Optional[ShowMode] = None, + self, + result: Any = None, + show_mode: Optional[ShowMode] = None, ) -> None: await self._load() await self._notify( DialogUpdateEvent( - action=DialogAction.DONE, data=result, show_mode=show_mode, + action=DialogAction.DONE, + data=result, + show_mode=show_mode, **self._base_event_params(), ), ) async def start( - self, - state: State, - data: Data = None, - mode: StartMode = StartMode.NORMAL, - show_mode: Optional[ShowMode] = None, + self, + state: State, + data: Data = None, + mode: StartMode = StartMode.NORMAL, + show_mode: Optional[ShowMode] = None, ) -> None: await self._load() await self._notify( @@ -144,9 +156,9 @@ async def start( ) async def switch_to( - self, - state: State, - show_mode: Optional[ShowMode] = None, + self, + state: State, + show_mode: Optional[ShowMode] = None, ) -> None: await self._load() await self._notify( @@ -160,14 +172,16 @@ async def switch_to( ) async def update( - self, - data: Dict, - show_mode: Optional[ShowMode] = None, + self, + data: DataDict, + show_mode: Optional[ShowMode] = None, ) -> None: await self._load() await self._notify( DialogUpdateEvent( - action=DialogAction.UPDATE, data=data, show_mode=show_mode, + action=DialogAction.UPDATE, + data=data, + show_mode=show_mode, **self._base_event_params(), ), ) @@ -178,12 +192,12 @@ def __init__(self, router: Router): self._router = router def bg( - self, - bot: Bot, - user_id: int, - chat_id: int, - stack_id: Optional[str] = None, - load: bool = False, + self, + bot: Bot, + user_id: int, + chat_id: int, + stack_id: Optional[str] = None, + load: bool = False, ) -> "BaseDialogManager": chat = FakeChat(id=chat_id, type="") user = FakeUser(id=user_id, is_bot=False, first_name="") diff --git a/src/aiogram_dialog/manager/manager.py b/src/aiogram_dialog/manager/manager.py index 87d98858..920190cf 100644 --- a/src/aiogram_dialog/manager/manager.py +++ b/src/aiogram_dialog/manager/manager.py @@ -1,44 +1,72 @@ from logging import getLogger -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union, cast from aiogram import Router from aiogram.fsm.state import State from aiogram.types import ( - CallbackQuery, Chat, ErrorEvent, Message, ReplyKeyboardMarkup, User, + CallbackQuery, + Chat, + ErrorEvent, + Message, + ReplyKeyboardMarkup, + User, ) from aiogram_dialog.api.entities import ( - ChatEvent, Context, Data, DEFAULT_STACK_ID, LaunchMode, MediaId, - NewMessage, ShowMode, Stack, StartMode, + DEFAULT_STACK_ID, + ChatEvent, + Context, + Data, + LaunchMode, + MediaId, + NewMessage, + OldMessage, + ShowMode, + Stack, + StartMode, + UnknownText, ) -from aiogram_dialog.api.entities import OldMessage, UnknownText from aiogram_dialog.api.exceptions import ( - IncorrectBackgroundError, InvalidKeyboardType, NoContextError, + IncorrectBackgroundError, + InvalidKeyboardType, + NoContextError, ) from aiogram_dialog.api.internal import ( - CONTEXT_KEY, EVENT_SIMULATED, FakeChat, FakeUser, - STACK_KEY, STORAGE_KEY, + CONTEXT_KEY, + EVENT_SIMULATED, + STACK_KEY, + STORAGE_KEY, + FakeChat, + FakeUser, + Widget, ) from aiogram_dialog.api.protocols import ( - BaseDialogManager, DialogManager, DialogProtocol, DialogRegistryProtocol, - MediaIdStorageProtocol, MessageManagerProtocol, MessageNotModified, + BaseDialogManager, + DialogManager, + DialogProtocol, + DialogRegistryProtocol, + MediaIdStorageProtocol, + MessageManagerProtocol, + MessageNotModified, ) from aiogram_dialog.context.storage import StorageProxy from aiogram_dialog.utils import get_media_id + +from ..api.entities.context import DataDict from .bg_manager import BgManager logger = getLogger(__name__) class ManagerImpl(DialogManager): - def __init__( - self, event: ChatEvent, - message_manager: MessageManagerProtocol, - media_id_storage: MediaIdStorageProtocol, - registry: DialogRegistryProtocol, - router: Router, - data: Dict, + self, + event: ChatEvent, + message_manager: MessageManagerProtocol, + media_id_storage: MediaIdStorageProtocol, + registry: DialogRegistryProtocol, + router: Router, + data: Dict[str, Any], ): self.disabled = False self.message_manager = message_manager @@ -64,12 +92,12 @@ def event(self) -> ChatEvent: return self._event @property - def middleware_data(self) -> Dict: + def middleware_data(self) -> Dict[str, Any]: """Middleware data.""" return self._data @property - def dialog_data(self) -> Dict: + def dialog_data(self) -> DataDict: """Dialog data for current context.""" return self.current_context().dialog_data @@ -78,7 +106,7 @@ def start_data(self) -> Data: """Start data for current context.""" return self.current_context().start_data - def check_disabled(self): + def check_disabled(self) -> None: if self.disabled: raise IncorrectBackgroundError( "Detected background access to dialog manager. " @@ -86,7 +114,9 @@ def check_disabled(self): "method to access methods from background tasks", ) - async def load_data(self) -> Dict: + async def load_data( + self, + ) -> Dict[str, Union[Data, DataDict, Dict[str, Any], ChatEvent]]: context = self.current_context() return { "dialog_data": context.dialog_data, @@ -116,7 +146,7 @@ def current_context(self) -> Context: return context def _current_context_unsafe(self) -> Optional[Context]: - return self._data[CONTEXT_KEY] + return cast(Context, self._data[CONTEXT_KEY]) def has_context(self) -> bool: self.check_disabled() @@ -124,10 +154,10 @@ def has_context(self) -> bool: def current_stack(self) -> Stack: self.check_disabled() - return self._data[STACK_KEY] + return cast(Stack, self._data[STACK_KEY]) def storage(self) -> StorageProxy: - return self._data[STORAGE_KEY] + return cast(StorageProxy, self._data[STORAGE_KEY]) async def _remove_kbd(self) -> None: if self.current_stack().last_message_id is None: @@ -140,17 +170,17 @@ async def _remove_kbd(self) -> None: self.current_stack().last_message_id = None async def _process_last_dialog_result( - self, - start_data: Data, - result: Any, + self, + start_data: Data, + result: Any, ) -> None: """Process closing last dialog in stack.""" await self._remove_kbd() async def done( - self, - result: Any = None, - show_mode: Optional[ShowMode] = None, + self, + result: Any = None, + show_mode: Optional[ShowMode] = None, ) -> None: self.check_disabled() await self.dialog().process_close(result, self) @@ -192,11 +222,11 @@ async def mark_closed(self) -> None: await storage.save_stack(stack) async def start( - self, - state: State, - data: Data = None, - mode: StartMode = StartMode.NORMAL, - show_mode: Optional[ShowMode] = None, + self, + state: State, + data: Data = None, + mode: StartMode = StartMode.NORMAL, + show_mode: Optional[ShowMode] = None, ) -> None: self.check_disabled() self.show_mode = show_mode or self.show_mode @@ -224,7 +254,10 @@ async def reset_stack(self, remove_keyboard: bool = True) -> None: async def _start_new_stack(self, state: State, data: Data = None) -> None: stack = Stack() await self.bg(stack_id=stack.id).start( - state, data, StartMode.NORMAL, self.show_mode, + state, + data, + StartMode.NORMAL, + self.show_mode, ) async def _start_normal(self, state: State, data: Data = None) -> None: @@ -234,8 +267,7 @@ async def _start_normal(self, state: State, data: Data = None) -> None: old_dialog = self.dialog() if old_dialog.launch_mode is LaunchMode.EXCLUSIVE: raise ValueError( - "Cannot start dialog on top " - "of one with launch_mode==SINGLE", + "Cannot start dialog on top " "of one with launch_mode==SINGLE", ) new_dialog = self._registry.find_dialog(state) @@ -271,9 +303,9 @@ async def back(self, show_mode: Optional[ShowMode] = None) -> None: await self.switch_to(new_state, show_mode) async def switch_to( - self, - state: State, - show_mode: Optional[ShowMode] = None, + self, + state: State, + show_mode: Optional[ShowMode] = None, ) -> None: self.check_disabled() context = self.current_context() @@ -286,7 +318,9 @@ async def switch_to( context.state = state def _ensure_stack_compatible( - self, stack: Stack, new_message: NewMessage, + self, + stack: Stack, + new_message: NewMessage, ) -> None: if stack.id == DEFAULT_STACK_ID: return # no limitations for default stack @@ -313,7 +347,9 @@ async def show(self, show_mode: Optional[ShowMode] = None) -> None: try: sent_message = await self.message_manager.show_message( - bot, new_message, old_message, + bot, + new_message, + old_message, ) except MessageNotModified: # nothing changed so nothing to save @@ -327,14 +363,14 @@ async def show(self, show_mode: Optional[ShowMode] = None) -> None: url=new_message.media.url, type=new_message.media.type, media_id=MediaId( - sent_message.media_id, + cast(str, sent_message.media_id), sent_message.media_uniq_id, ), ) if isinstance(self.event, Message): stack.last_income_media_group_id = self.event.media_group_id - async def _fix_cached_media_id(self, new_message: NewMessage): + async def _fix_cached_media_id(self, new_message: NewMessage) -> None: if not new_message.media or new_message.media.file_id: return new_message.media.file_id = await self.media_id_storage.get_media_id( @@ -343,13 +379,14 @@ async def _fix_cached_media_id(self, new_message: NewMessage): type=new_message.media.type, ) - def is_event_simulated(self): + def is_event_simulated(self) -> bool: return bool(self.middleware_data.get(EVENT_SIMULATED)) def _get_message_from_callback( - self, event: CallbackQuery, + self, + event: CallbackQuery, ) -> Optional[OldMessage]: - current_message = event.message + current_message = cast(Message, event.message) stack = self.current_stack() chat = self.middleware_data["event_chat"] if current_message: @@ -395,7 +432,7 @@ def _get_last_message(self) -> Optional[OldMessage]: message_id=stack.last_message_id, ) - def _save_last_message(self, message: OldMessage): + def _save_last_message(self, message: OldMessage) -> None: stack = self.current_stack() stack.last_message_id = message.message_id stack.last_media_id = message.media_id @@ -412,22 +449,24 @@ def _calc_show_mode(self) -> ShowMode: if isinstance(self.event, Message): if self.event.media_group_id is None: return ShowMode.SEND - elif self.event.media_group_id == \ - self.current_stack().last_income_media_group_id: + elif ( + self.event.media_group_id + == self.current_stack().last_income_media_group_id + ): return ShowMode.EDIT else: return ShowMode.SEND return ShowMode.EDIT async def update( - self, - data: Dict, - show_mode: Optional[ShowMode] = None, + self, + data: DataDict, + show_mode: Optional[ShowMode] = None, ) -> None: self.current_context().dialog_data.update(data) await self.show(show_mode) - def find(self, widget_id) -> Optional[Any]: + def find(self, widget_id: str) -> Optional[Widget]: widget = self.dialog().find(widget_id) if not widget: return None @@ -438,35 +477,34 @@ def is_same_chat(self, user: User, chat: Chat) -> bool: return False current_chat = self._data["event_chat"] - current_user = self.event.from_user + current_user = cast(User, self.event.from_user) return user.id == current_user.id and chat.id == current_chat.id def _get_fake_user(self, user_id: Optional[int] = None) -> User: """Get User if we have info about him or FakeUser instead.""" - current_user = self.event.from_user - if user_id in (None, current_user.id): + current_user = cast(User, self.event.from_user) + if user_id is None or user_id == current_user.id: return current_user return FakeUser(id=user_id, is_bot=False, first_name="") def _get_fake_chat(self, chat_id: Optional[int] = None) -> Chat: """Get Chat if we have info about him or FakeChat instead.""" if "event_chat" in self._data: - current_chat = self._data["event_chat"] - if chat_id in (None, current_chat.id): + current_chat = cast(Chat, self._data["event_chat"]) + if chat_id is None or chat_id == current_chat.id: return current_chat elif chat_id is None: raise ValueError( - "Explicit `chat_id` is required " - "for events without current chat", + "Explicit `chat_id` is required " "for events without current chat", ) return FakeChat(id=chat_id, type="") def bg( - self, - user_id: Optional[int] = None, - chat_id: Optional[int] = None, - stack_id: Optional[str] = None, - load: bool = False, + self, + user_id: Optional[int] = None, + chat_id: Optional[int] = None, + stack_id: Optional[str] = None, + load: bool = False, ) -> BaseDialogManager: user = self._get_fake_user(user_id) chat = self._get_fake_chat(chat_id) diff --git a/src/aiogram_dialog/manager/manager_factory.py b/src/aiogram_dialog/manager/manager_factory.py index a42f157c..c44be88e 100644 --- a/src/aiogram_dialog/manager/manager_factory.py +++ b/src/aiogram_dialog/manager/manager_factory.py @@ -1,29 +1,34 @@ -from typing import Dict +from typing import Any, Dict from aiogram import Router from aiogram_dialog.api.entities import ChatEvent from aiogram_dialog.api.internal import DialogManagerFactory from aiogram_dialog.api.protocols import ( - DialogManager, DialogRegistryProtocol, - MediaIdStorageProtocol, MessageManagerProtocol, + DialogManager, + DialogRegistryProtocol, + MediaIdStorageProtocol, + MessageManagerProtocol, ) + from .manager import ManagerImpl class DefaultManagerFactory(DialogManagerFactory): def __init__( - self, - message_manager: MessageManagerProtocol, - media_id_storage: MediaIdStorageProtocol, + self, + message_manager: MessageManagerProtocol, + media_id_storage: MediaIdStorageProtocol, ) -> None: self.message_manager = message_manager self.media_id_storage = media_id_storage def __call__( - self, event: ChatEvent, data: Dict, - registry: DialogRegistryProtocol, - router: Router, + self, + event: ChatEvent, + data: Dict[str, Any], + registry: DialogRegistryProtocol, + router: Router, ) -> DialogManager: return ManagerImpl( event=event, diff --git a/src/aiogram_dialog/manager/manager_middleware.py b/src/aiogram_dialog/manager/manager_middleware.py index 4d9d4829..021d3c29 100644 --- a/src/aiogram_dialog/manager/manager_middleware.py +++ b/src/aiogram_dialog/manager/manager_middleware.py @@ -1,13 +1,15 @@ -from typing import Any, Awaitable, Callable, Dict, Union +from typing import Any, Awaitable, Callable, Dict, Optional, Union from aiogram import Router from aiogram.dispatcher.middlewares.base import BaseMiddleware -from aiogram.types import TelegramObject, Update +from aiogram.types import TelegramObject from aiogram_dialog.api.entities import ChatEvent, DialogUpdateEvent -from aiogram_dialog.api.internal import DialogManagerFactory, STORAGE_KEY +from aiogram_dialog.api.internal import STORAGE_KEY, DialogManagerFactory from aiogram_dialog.api.protocols import ( - BgManagerFactory, DialogManager, DialogRegistryProtocol, + BgManagerFactory, + DialogManager, + DialogRegistryProtocol, ) MANAGER_KEY = "dialog_manager" @@ -16,10 +18,10 @@ class ManagerMiddleware(BaseMiddleware): def __init__( - self, - dialog_manager_factory: DialogManagerFactory, - registry: DialogRegistryProtocol, - router: Router, + self, + dialog_manager_factory: DialogManagerFactory, + registry: DialogRegistryProtocol, + router: Router, ) -> None: super().__init__() self.dialog_manager_factory = dialog_manager_factory @@ -27,18 +29,20 @@ def __init__( self.router = router def _is_event_supported( - self, event: TelegramObject, data: Dict[str, Any], + self, + event: TelegramObject, + data: Dict[str, Any], ) -> bool: return STORAGE_KEY in data async def __call__( - self, - handler: Callable[ - [Union[Update, DialogUpdateEvent], Dict[str, Any]], - Awaitable[Any], - ], - event: ChatEvent, - data: Dict[str, Any], + self, + handler: Callable[ + [Union[ChatEvent, DialogUpdateEvent], Dict[str, Any]], + Awaitable[Any], + ], + event: ChatEvent, # type: ignore[override] + data: Dict[str, Any], ) -> Any: if self._is_event_supported(event, data): data[MANAGER_KEY] = self.dialog_manager_factory( @@ -51,27 +55,27 @@ async def __call__( try: return await handler(event, data) finally: - manager: DialogManager = data.pop(MANAGER_KEY, None) + manager: Optional[DialogManager] = data.pop(MANAGER_KEY, None) if manager: await manager.close_manager() class BgFactoryMiddleware(BaseMiddleware): def __init__( - self, - bg_manager_factory: BgManagerFactory, + self, + bg_manager_factory: BgManagerFactory, ) -> None: super().__init__() self.bg_manager_factory = bg_manager_factory async def __call__( - self, - handler: Callable[ - [Union[TelegramObject, DialogUpdateEvent], Dict[str, Any]], - Awaitable[TelegramObject], - ], - event: TelegramObject, - data: Dict[str, Any], + self, + handler: Callable[ + [Union[TelegramObject, DialogUpdateEvent], Dict[str, Any]], + Awaitable[TelegramObject], + ], + event: TelegramObject, + data: Dict[str, Any], ) -> Any: data[BG_FACTORY_KEY] = self.bg_manager_factory return await handler(event, data) diff --git a/src/aiogram_dialog/manager/message_manager.py b/src/aiogram_dialog/manager/message_manager.py index c8a12bb3..b8568281 100644 --- a/src/aiogram_dialog/manager/message_manager.py +++ b/src/aiogram_dialog/manager/message_manager.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Optional, Union +from typing import Optional, Union, cast from aiogram import Bot from aiogram.exceptions import TelegramAPIError, TelegramBadRequest @@ -7,6 +7,7 @@ CallbackQuery, ContentType, FSInputFile, + InlineKeyboardMarkup, InputFile, InputMediaAnimation, InputMediaAudio, @@ -20,11 +21,13 @@ ) from aiogram_dialog.api.entities import ( - MediaAttachment, MediaId, NewMessage, OldMessage, ShowMode, -) -from aiogram_dialog.api.protocols import ( - MessageManagerProtocol, MessageNotModified, + MediaAttachment, + MediaId, + NewMessage, + OldMessage, + ShowMode, ) +from aiogram_dialog.api.protocols import MessageManagerProtocol, MessageNotModified from aiogram_dialog.utils import get_media_id logger = getLogger(__name__) @@ -60,7 +63,8 @@ def _combine(sent_message: NewMessage, message_result: Message) -> OldMessage: message_id=message_result.message_id, chat=message_result.chat, has_reply_keyboard=isinstance( - sent_message.reply_markup, ReplyKeyboardMarkup, + sent_message.reply_markup, + ReplyKeyboardMarkup, ), text=message_result.text, media_uniq_id=(media_id.file_unique_id if media_id else None), @@ -70,7 +74,9 @@ def _combine(sent_message: NewMessage, message_result: Message) -> OldMessage: class MessageManager(MessageManagerProtocol): async def answer_callback( - self, bot: Bot, callback_query: CallbackQuery, + self, + bot: Bot, + callback_query: CallbackQuery, ) -> None: try: await bot.answer_callback_query( @@ -83,7 +89,9 @@ async def answer_callback( raise async def get_media_source( - self, media: MediaAttachment, bot: Bot, + self, + media: MediaAttachment, + bot: Bot, ) -> Union[InputFile, str]: if media.file_id: return media.file_id.file_id @@ -92,7 +100,7 @@ async def get_media_source( return URLInputFile(media.url, bot=bot) return media.url else: - return FSInputFile(media.path) + return FSInputFile(str(media.path)) def had_media(self, old_message: OldMessage) -> bool: return old_message.media_id is not None @@ -111,7 +119,9 @@ def need_reply_keyboard(self, new_message: Optional[NewMessage]) -> bool: return isinstance(new_message.reply_markup, ReplyKeyboardMarkup) def _message_changed( - self, new_message: NewMessage, old_message: OldMessage, + self, + new_message: NewMessage, + old_message: OldMessage, ) -> bool: if new_message.text != old_message.text: return True @@ -123,24 +133,27 @@ def _message_changed( return True if not self.need_media(new_message): return False - old_media_id = MediaId(old_message.media_id, old_message.media_uniq_id) - if new_message.media.file_id != old_media_id: + old_media_id = MediaId( + cast(str, old_message.media_id), old_message.media_uniq_id + ) + if cast(MediaAttachment, new_message.media).file_id != old_media_id: return True return False - def _can_edit(self, new_message: NewMessage, - old_message: OldMessage) -> bool: + def _can_edit(self, new_message: NewMessage, old_message: OldMessage) -> bool: # we cannot edit message if media appeared or removed return ( - self.had_media(old_message) == self.need_media(new_message) and - not self.had_reply_keyboard(old_message) and - not self.need_reply_keyboard(new_message) + self.had_media(old_message) == self.need_media(new_message) + and not self.had_reply_keyboard(old_message) + and not self.need_reply_keyboard(new_message) ) async def show_message( - self, bot: Bot, new_message: NewMessage, - old_message: Optional[OldMessage], + self, + bot: Bot, + new_message: NewMessage, + old_message: Optional[OldMessage], ) -> OldMessage: if new_message.show_mode is ShowMode.NO_UPDATE: logger.debug("ShowMode is NO_UPDATE, skipping show") @@ -188,39 +201,45 @@ async def show_message( # Clear async def remove_kbd( - self, - bot: Bot, - show_mode: ShowMode, - old_message: Optional[OldMessage], + self, + bot: Bot, + show_mode: ShowMode, + old_message: Optional[OldMessage], ) -> Optional[Message]: if show_mode is ShowMode.NO_UPDATE: - return + return None if show_mode is ShowMode.DELETE_AND_SEND and old_message: - return await self.remove_message_safe(bot, old_message, None) + await self.remove_message_safe(bot, old_message, None) return await self._remove_kbd(bot, old_message, None) async def _remove_kbd( - self, - bot: Bot, - old_message: Optional[OldMessage], - new_message: Optional[NewMessage], + self, + bot: Bot, + old_message: Optional[OldMessage], + new_message: Optional[NewMessage], ) -> Optional[Message]: if self.had_reply_keyboard(old_message): if not self.need_reply_keyboard(new_message): return await self.remove_reply_kbd(bot, old_message) + return None else: return await self.remove_inline_kbd(bot, old_message) async def remove_inline_kbd( - self, bot: Bot, old_message: Optional[OldMessage], + self, + bot: Bot, + old_message: Optional[OldMessage], ) -> Optional[Message]: if not old_message: - return + return None logger.debug("remove_inline_kbd in %s", old_message.chat) try: - return await bot.edit_message_reply_markup( - message_id=old_message.message_id, - chat_id=old_message.chat.id, + return cast( + Message, + await bot.edit_message_reply_markup( + message_id=old_message.message_id, + chat_id=old_message.chat.id, + ), ) except TelegramBadRequest as err: if "message is not modified" in err.message: @@ -231,12 +250,15 @@ async def remove_inline_kbd( pass else: raise err + return None async def remove_reply_kbd( - self, bot: Bot, old_message: Optional[OldMessage], + self, + bot: Bot, + old_message: Optional[OldMessage], ) -> Optional[Message]: if not old_message: - return + return None logger.debug("remove_reply_kbd in %s", old_message.chat) return await self.send_text( bot=bot, @@ -248,10 +270,10 @@ async def remove_reply_kbd( ) async def remove_message_safe( - self, - bot: Bot, - old_message: OldMessage, - new_message: Optional[NewMessage], + self, + bot: Bot, + old_message: OldMessage, + new_message: Optional[NewMessage], ) -> None: try: await bot.delete_message( @@ -266,9 +288,11 @@ async def remove_message_safe( else: raise - # Edit async def edit_message_safe( - self, bot: Bot, new_message: NewMessage, old_message: OldMessage, + self, + bot: Bot, + new_message: NewMessage, + old_message: OldMessage, ) -> Message: try: return await self.edit_message(bot, new_message, old_message) @@ -276,15 +300,18 @@ async def edit_message_safe( if "message is not modified" in err.message: raise MessageNotModified from err if ( - "message can't be edited" in err.message or - "message to edit not found" in err.message + "message can't be edited" in err.message + or "message to edit not found" in err.message ): return await self.send_message(bot, new_message) else: raise async def edit_message( - self, bot: Bot, new_message: NewMessage, old_message: OldMessage, + self, + bot: Bot, + new_message: NewMessage, + old_message: OldMessage, ) -> Message: if new_message.media: if new_message.media.file_id == old_message.media_id: @@ -294,33 +321,53 @@ async def edit_message( return await self.edit_text(bot, new_message, old_message) async def edit_caption( - self, bot: Bot, new_message: NewMessage, old_message: OldMessage, + self, + bot: Bot, + new_message: NewMessage, + old_message: OldMessage, ) -> Message: logger.debug("edit_caption to %s", new_message.chat) - return await bot.edit_message_caption( - message_id=old_message.message_id, - chat_id=old_message.chat.id, - caption=new_message.text, - reply_markup=new_message.reply_markup, - parse_mode=new_message.parse_mode, + return cast( + Message, + await bot.edit_message_caption( + message_id=old_message.message_id, + chat_id=old_message.chat.id, + caption=new_message.text, + reply_markup=cast( + InlineKeyboardMarkup | None, new_message.reply_markup + ), + parse_mode=new_message.parse_mode, + ), ) async def edit_text( - self, bot: Bot, new_message: NewMessage, old_message: OldMessage, + self, + bot: Bot, + new_message: NewMessage, + old_message: OldMessage, ) -> Message: logger.debug("edit_text to %s", new_message.chat) - return await bot.edit_message_text( - message_id=old_message.message_id, - chat_id=old_message.chat.id, - text=new_message.text, - reply_markup=new_message.reply_markup, - parse_mode=new_message.parse_mode, - disable_web_page_preview=new_message.disable_web_page_preview, + return cast( + Message, + await bot.edit_message_text( + message_id=old_message.message_id, + chat_id=old_message.chat.id, + text=cast(str, new_message.text), + reply_markup=cast( + InlineKeyboardMarkup | None, new_message.reply_markup + ), + parse_mode=new_message.parse_mode, + disable_web_page_preview=new_message.disable_web_page_preview, + ), ) async def edit_media( - self, bot: Bot, new_message: NewMessage, old_message: OldMessage, + self, + bot: Bot, + new_message: NewMessage, + old_message: OldMessage, ) -> Message: + new_message.media = cast(MediaAttachment, new_message.media) logger.debug( "edit_media to %s, media_id: %s", new_message.chat, @@ -334,11 +381,16 @@ async def edit_media( media=await self.get_media_source(new_message.media, bot), **new_message.media.kwargs, ) - return await bot.edit_message_media( - message_id=old_message.message_id, - chat_id=old_message.chat.id, - media=media, - reply_markup=new_message.reply_markup, + return cast( + Message, + await bot.edit_message_media( + message_id=old_message.message_id, + chat_id=old_message.chat.id, + media=media, + reply_markup=cast( + InlineKeyboardMarkup | None, new_message.reply_markup + ), + ), ) # Send @@ -352,13 +404,14 @@ async def send_text(self, bot: Bot, new_message: NewMessage) -> Message: logger.debug("send_text to %s", new_message.chat) return await bot.send_message( new_message.chat.id, - text=new_message.text, + text=cast(str, new_message.text), disable_web_page_preview=new_message.disable_web_page_preview, reply_markup=new_message.reply_markup, parse_mode=new_message.parse_mode, ) async def send_media(self, bot: Bot, new_message: NewMessage) -> Message: + new_message.media = cast(MediaAttachment, new_message.media) logger.debug( "send_media to %s, media_id: %s", new_message.chat, @@ -369,11 +422,14 @@ async def send_media(self, bot: Bot, new_message: NewMessage) -> Message: raise ValueError( f"ContentType {new_message.media.type} is not supported", ) - return await method( - new_message.chat.id, - await self.get_media_source(new_message.media, bot), - caption=new_message.text, - reply_markup=new_message.reply_markup, - parse_mode=new_message.parse_mode, - **new_message.media.kwargs, + return cast( + Message, + await method( + new_message.chat.id, + await self.get_media_source(new_message.media, bot), + caption=new_message.text, + reply_markup=new_message.reply_markup, + parse_mode=new_message.parse_mode, + **new_message.media.kwargs, + ), ) diff --git a/src/aiogram_dialog/manager/sub_manager.py b/src/aiogram_dialog/manager/sub_manager.py index 02056e3f..e77537e2 100644 --- a/src/aiogram_dialog/manager/sub_manager.py +++ b/src/aiogram_dialog/manager/sub_manager.py @@ -1,24 +1,28 @@ import dataclasses -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union, cast from aiogram.fsm.state import State -from aiogram.types import Message from aiogram_dialog.api.entities import ( - ChatEvent, Data, ShowMode, StartMode, + ChatEvent, + Context, + Data, + ShowMode, + Stack, + StartMode, ) -from aiogram_dialog.api.entities import Context, Stack +from aiogram_dialog.api.entities.context import DataDict from aiogram_dialog.api.internal import Widget from aiogram_dialog.api.protocols import BaseDialogManager, DialogManager class SubManager(DialogManager): def __init__( - self, - widget: Widget, - manager: DialogManager, - widget_id: str, - item_id: str, + self, + widget: Widget, + manager: DialogManager, + widget_id: str, + item_id: str, ): self.widget = widget self.manager = manager @@ -30,12 +34,12 @@ def event(self) -> ChatEvent: return self.manager.event @property - def middleware_data(self) -> Dict: + def middleware_data(self) -> Dict[str, Any]: """Middleware data.""" return self.manager.middleware_data @property - def dialog_data(self) -> Dict: + def dialog_data(self) -> DataDict: """Dialog data for current context.""" return self.current_context().dialog_data @@ -46,8 +50,8 @@ def start_data(self) -> Data: def current_context(self) -> Context: context = self.manager.current_context() - data = context.widget_data.setdefault(self.widget_id, {}) - row_data = data.setdefault(self.item_id, {}) + data = cast(Dict[str, Data], context.widget_data.setdefault(self.widget_id, {})) + row_data = cast(Dict[str, Data], data.setdefault(self.item_id, {})) return dataclasses.replace(context, widget_data=row_data) def has_context(self) -> bool: @@ -62,7 +66,7 @@ def current_stack(self) -> Stack: async def close_manager(self) -> None: return await self.manager.close_manager() - async def show(self, show_mode: Optional[ShowMode] = None) -> Message: + async def show(self, show_mode: Optional[ShowMode] = None) -> None: return await self.manager.show(show_mode) async def answer_callback(self) -> None: @@ -71,16 +75,18 @@ async def answer_callback(self) -> None: async def reset_stack(self, remove_keyboard: bool = True) -> None: return await self.manager.reset_stack(remove_keyboard) - async def load_data(self) -> Dict: + async def load_data( + self, + ) -> Dict[str, Union[Data, DataDict, Dict[str, Any], ChatEvent]]: return await self.manager.load_data() - def find(self, widget_id) -> Optional[Any]: + def find(self, widget_id: str) -> Optional[Any]: widget = self.widget.find(widget_id) if not widget: return None return widget.managed(self) - def find_in_parent(self, widget_id) -> Optional[Any]: + def find_in_parent(self, widget_id: str) -> Optional[Any]: return self.manager.find(widget_id) @property @@ -98,40 +104,54 @@ async def back(self, show_mode: Optional[ShowMode] = None) -> None: await self.manager.back(show_mode) async def done( - self, - result: Any = None, - show_mode: Optional[ShowMode] = None, + self, + result: Any = None, + show_mode: Optional[ShowMode] = None, ) -> None: await self.manager.done(result, show_mode) async def mark_closed(self) -> None: await self.manager.mark_closed() - async def start(self, state: State, data: Data = None, - mode: StartMode = StartMode.NORMAL, - show_mode: Optional[ShowMode] = None) -> None: + async def start( + self, + state: State, + data: Data = None, + mode: StartMode = StartMode.NORMAL, + show_mode: Optional[ShowMode] = None, + ) -> None: await self.manager.start( - state=state, data=data, mode=mode, show_mode=show_mode, + state=state, + data=data, + mode=mode, + show_mode=show_mode, ) async def switch_to( - self, - state: State, - show_mode: Optional[ShowMode] = None, + self, + state: State, + show_mode: Optional[ShowMode] = None, ) -> None: await self.manager.switch_to(state, show_mode) async def update( - self, - data: Dict, - show_mode: Optional[ShowMode] = None, + self, + data: DataDict, + show_mode: Optional[ShowMode] = None, ) -> None: self.current_context().dialog_data.update(data) await self.show(show_mode) - def bg(self, user_id: Optional[int] = None, chat_id: Optional[int] = None, - stack_id: Optional[str] = None, - load: bool = False) -> BaseDialogManager: + def bg( + self, + user_id: Optional[int] = None, + chat_id: Optional[int] = None, + stack_id: Optional[str] = None, + load: bool = False, + ) -> BaseDialogManager: return self.manager.bg( - user_id=user_id, chat_id=chat_id, stack_id=stack_id, load=load, + user_id=user_id, + chat_id=chat_id, + stack_id=stack_id, + load=load, ) diff --git a/src/aiogram_dialog/manager/update_handler.py b/src/aiogram_dialog/manager/update_handler.py index 19ecd117..fd7890b1 100644 --- a/src/aiogram_dialog/manager/update_handler.py +++ b/src/aiogram_dialog/manager/update_handler.py @@ -6,13 +6,14 @@ DialogSwitchEvent, DialogUpdateEvent, ) -from .manager import ManagerImpl + from .. import ShowMode +from .manager import ManagerImpl logger = getLogger(__name__) -async def handle_update(event: DialogUpdateEvent, dialog_manager: ManagerImpl): +async def handle_update(event: DialogUpdateEvent, dialog_manager: ManagerImpl) -> None: dialog_manager.show_mode = event.show_mode or ShowMode.AUTO if isinstance(event, DialogStartEvent): await dialog_manager.start( diff --git a/src/aiogram_dialog/manager/updater.py b/src/aiogram_dialog/manager/updater.py index b85a340b..145aa0cd 100644 --- a/src/aiogram_dialog/manager/updater.py +++ b/src/aiogram_dialog/manager/updater.py @@ -15,7 +15,7 @@ def __init__(self, dp: Router): self.dp = dp async def notify(self, bot: Bot, update: DialogUpdate) -> None: - def callback(): + def callback() -> None: asyncio.create_task( self._process_update(bot, update), )