diff --git a/sopel/bot.py b/sopel/bot.py index abe2f506e..d52e1707c 100644 --- a/sopel/bot.py +++ b/sopel/bot.py @@ -602,30 +602,26 @@ def rate_limit_info( if trigger.admin or rule.is_unblockable(): return False, None + nick = trigger.nick is_channel = trigger.sender and not trigger.sender.is_nick() channel = trigger.sender if is_channel else None at_time = trigger.time - - user_metrics = rule.get_user_metrics(trigger.nick) - channel_metrics = rule.get_channel_metrics(channel) - global_metrics = rule.get_global_metrics() - - if user_metrics.is_limited(at_time - rule.user_rate_limit): + if rule.is_user_rate_limited(nick, at_time): template = rule.user_rate_template rate_limit_type = "user" rate_limit = rule.user_rate_limit - metrics = user_metrics - elif is_channel and channel_metrics.is_limited(at_time - rule.channel_rate_limit): + metrics = rule.get_user_metrics(nick) + elif channel and rule.is_channel_rate_limited(channel, at_time): template = rule.channel_rate_template rate_limit_type = "channel" rate_limit = rule.channel_rate_limit - metrics = channel_metrics - elif global_metrics.is_limited(at_time - rule.global_rate_limit): + metrics = rule.get_channel_metrics(channel) + elif rule.is_global_rate_limited(at_time): template = rule.global_rate_template rate_limit_type = "global" rate_limit = rule.global_rate_limit - metrics = global_metrics + metrics = rule.get_global_metrics() else: return False, None diff --git a/sopel/plugins/rules.py b/sopel/plugins/rules.py index 36aee06eb..6276f27da 100644 --- a/sopel/plugins/rules.py +++ b/sopel/plugins/rules.py @@ -765,40 +765,49 @@ def global_rate_limit(self) -> datetime.timedelta: def is_user_rate_limited( self, nick: Identifier, - at_time: Optional[datetime.datetime] = None, + at_time: datetime.datetime, ) -> bool: """Tell when the rule reached the ``nick``'s rate limit. :param nick: the nick associated with this check - :param at_time: optional aware datetime for the rate limit check; - if not given, ``utcnow`` will be used + :param at_time: aware datetime for the rate limit check :return: ``True`` when the rule reached the limit, ``False`` otherwise. + + .. versionchanged:: 8.0.1 + + Parameter ``at_time`` is now required. + """ @abc.abstractmethod def is_channel_rate_limited( self, channel: Identifier, - at_time: Optional[datetime.datetime] = None, + at_time: datetime.datetime, ) -> bool: """Tell when the rule reached the ``channel``'s rate limit. :param channel: the channel associated with this check - :param at_time: optional aware datetime for the rate limit check; - if not given, ``utcnow`` will be used + :param at_time: aware datetime for the rate limit check :return: ``True`` when the rule reached the limit, ``False`` otherwise. + + .. versionchanged:: 8.0.1 + + Parameter ``at_time`` is now required. + """ @abc.abstractmethod - def is_global_rate_limited( - self, - at_time: Optional[datetime.datetime] = None, - ) -> bool: + def is_global_rate_limited(self, at_time: datetime.datetime) -> bool: """Tell when the rule reached the global rate limit. - :param at_time: optional aware datetime for the rate limit check; - if not given, ``utcnow`` will be used + :param at_time: aware datetime for the rate limit check :return: ``True`` when the rule reached the limit, ``False`` otherwise. + + .. versionchanged:: 8.0.1 + + Parameter ``at_time`` is now required. + """ @property @@ -1209,29 +1218,29 @@ def global_rate_limit(self) -> datetime.timedelta: def is_user_rate_limited( self, nick: Identifier, - at_time: Optional[datetime.datetime] = None, + at_time: datetime.datetime, ) -> bool: - if at_time is None: - at_time = datetime.datetime.now(datetime.timezone.utc) + if self._user_rate_limit <= 0: + return False + metrics = self.get_user_metrics(nick) return metrics.is_limited(at_time - self.user_rate_limit) def is_channel_rate_limited( self, channel: Identifier, - at_time: Optional[datetime.datetime] = None, + at_time: datetime.datetime, ) -> bool: - if at_time is None: - at_time = datetime.datetime.now(datetime.timezone.utc) + if self._channel_rate_limit <= 0: + return False + metrics = self.get_channel_metrics(channel) return metrics.is_limited(at_time - self.channel_rate_limit) - def is_global_rate_limited( - self, - at_time: Optional[datetime.datetime] = None, - ) -> bool: - if at_time is None: - at_time = datetime.datetime.now(datetime.timezone.utc) + def is_global_rate_limited(self, at_time: datetime.datetime) -> bool: + if self._global_rate_limit <= 0: + return False + metrics = self.get_global_metrics() return metrics.is_limited(at_time - self.global_rate_limit) diff --git a/test/plugins/test_plugins_rules.py b/test/plugins/test_plugins_rules.py index d252439ca..1ff854dbb 100644 --- a/test/plugins/test_plugins_rules.py +++ b/test/plugins/test_plugins_rules.py @@ -1566,14 +1566,16 @@ def handler(bot, trigger): global_rate_limit=20, channel_rate_limit=20, ) - assert rule.is_user_rate_limited(mocktrigger.nick) is False - assert rule.is_channel_rate_limited(mocktrigger.sender) is False - assert rule.is_global_rate_limited() is False + at_time = datetime.datetime.now(datetime.timezone.utc) + assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is False + assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is False + assert rule.is_global_rate_limited(at_time) is False rule.execute(mockbot, mocktrigger) - assert rule.is_user_rate_limited(mocktrigger.nick) is True - assert rule.is_channel_rate_limited(mocktrigger.sender) is True - assert rule.is_global_rate_limited() is True + at_time = datetime.datetime.now(datetime.timezone.utc) + assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is True + assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is True + assert rule.is_global_rate_limited(at_time) is True def test_rule_rate_limit_no_limit(mockbot, triggerfactory): @@ -1592,14 +1594,16 @@ def handler(bot, trigger): global_rate_limit=0, channel_rate_limit=0, ) - assert rule.is_user_rate_limited(mocktrigger.nick) is False - assert rule.is_channel_rate_limited(mocktrigger.sender) is False - assert rule.is_global_rate_limited() is False + at_time = datetime.datetime.now(datetime.timezone.utc) + assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is False + assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is False + assert rule.is_global_rate_limited(at_time) is False rule.execute(mockbot, mocktrigger) - assert rule.is_user_rate_limited(mocktrigger.nick) is False - assert rule.is_channel_rate_limited(mocktrigger.sender) is False - assert rule.is_global_rate_limited() is False + at_time = datetime.datetime.now(datetime.timezone.utc) + assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is False + assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is False + assert rule.is_global_rate_limited(at_time) is False def test_rule_rate_limit_ignore_rate_limit(mockbot, triggerfactory): @@ -1619,14 +1623,16 @@ def handler(bot, trigger): channel_rate_limit=20, threaded=False, # make sure there is no race-condition here ) - assert rule.is_user_rate_limited(mocktrigger.nick) is False - assert rule.is_channel_rate_limited(mocktrigger.sender) is False - assert rule.is_global_rate_limited() is False + at_time = datetime.datetime.now(datetime.timezone.utc) + assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is False + assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is False + assert rule.is_global_rate_limited(at_time) is False rule.execute(mockbot, mocktrigger) - assert rule.is_user_rate_limited(mocktrigger.nick) is False - assert rule.is_channel_rate_limited(mocktrigger.sender) is False - assert rule.is_global_rate_limited() is False + at_time = datetime.datetime.now(datetime.timezone.utc) + assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is False + assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is False + assert rule.is_global_rate_limited(at_time) is False def test_rule_rate_limit_messages(mockbot, triggerfactory): diff --git a/test/test_bot.py b/test/test_bot.py index 964878bee..73904e810 100644 --- a/test/test_bot.py +++ b/test/test_bot.py @@ -15,7 +15,9 @@ if typing.TYPE_CHECKING: from sopel.config import Config - from sopel.tests.factories import BotFactory, IRCFactory, UserFactory + from sopel.tests.factories import ( + BotFactory, ConfigFactory, IRCFactory, TriggerFactory, UserFactory, + ) from sopel.tests.mocks import MockIRCServer @@ -81,17 +83,17 @@ def ignored(): @pytest.fixture -def tmpconfig(configfactory): +def tmpconfig(configfactory: ConfigFactory) -> Config: return configfactory('test.cfg', TMP_CONFIG) @pytest.fixture -def mockbot(tmpconfig, botfactory): +def mockbot(tmpconfig: Config, botfactory: BotFactory) -> bot.Sopel: return botfactory(tmpconfig) @pytest.fixture -def mockplugin(tmpdir): +def mockplugin(tmpdir) -> plugins.handlers.PyFilePlugin: root = tmpdir.mkdir('loader_mods') mod_file = root.join('mockplugin.py') mod_file.write(MOCK_MODULE_CONTENT) @@ -676,7 +678,7 @@ def url_callback_http(bot, trigger, match): # call_rule @pytest.fixture -def match_hello_rule(mockbot, triggerfactory): +def match_hello_rule(mockbot: bot.Sopel, triggerfactory: TriggerFactory): """Helper for generating matches to each `Rule` in the following tests""" def _factory(rule_hello): # trigger @@ -694,7 +696,25 @@ def _factory(rule_hello): return _factory -def test_call_rule(mockbot, match_hello_rule): +@pytest.fixture +def multimatch_hello_rule(mockbot: bot.Sopel, triggerfactory: TriggerFactory): + def _factory(rule_hello): + # trigger + line = ':Test!test@example.com PRIVMSG #channel :hello hello hello' + + trigger = triggerfactory(mockbot, line) + pretrigger = trigger._pretrigger + + for match in rule_hello.match(mockbot, pretrigger): + wrapper = bot.SopelWrapper(mockbot, trigger) + yield match, trigger, wrapper + return _factory + + +def test_call_rule( + mockbot: bot.Sopel, + match_hello_rule: typing.Callable, +) -> None: # setup items = [] @@ -721,9 +741,10 @@ def testrule(bot, trigger): assert items == [1] # assert the rule is not rate limited - assert not rule_hello.is_user_rate_limited(Identifier('Test')) - assert not rule_hello.is_channel_rate_limited('#channel') - assert not rule_hello.is_global_rate_limited() + at_time = datetime.now(timezone.utc) + assert not rule_hello.is_user_rate_limited(Identifier('Test'), at_time) + assert not rule_hello.is_channel_rate_limited('#channel', at_time) + assert not rule_hello.is_global_rate_limited(at_time) match, rule_trigger, wrapper = match_hello_rule(rule_hello) @@ -738,6 +759,36 @@ def testrule(bot, trigger): assert items == [1, 1] +def test_call_rule_multiple_matches( + mockbot: bot.Sopel, + multimatch_hello_rule: typing.Callable, +) -> None: + # setup + items = [] + + def testrule(bot, trigger): + bot.say('hi') + items.append(1) + return "Return Value" + + find_hello = rules.FindRule( + [re.compile(r'(hi|hello|hey|sup)')], + plugin='testplugin', + label='testrule', + handler=testrule) + + for match, rule_trigger, wrapper in multimatch_hello_rule(find_hello): + mockbot.call_rule(find_hello, wrapper, rule_trigger) + + # assert the rule has been executed three times now + assert mockbot.backend.message_sent == rawlist( + 'PRIVMSG #channel :hi', + 'PRIVMSG #channel :hi', + 'PRIVMSG #channel :hi', + ) + assert items == [1, 1, 1] + + def test_call_rule_rate_limited_user(mockbot, match_hello_rule): items = [] @@ -767,9 +818,10 @@ def testrule(bot, trigger): assert items == [1] # assert the rule is now rate limited - assert rule_hello.is_user_rate_limited(Identifier('Test')) - assert not rule_hello.is_channel_rate_limited('#channel') - assert not rule_hello.is_global_rate_limited() + at_time = datetime.now(timezone.utc) + assert rule_hello.is_user_rate_limited(Identifier('Test'), at_time) + assert not rule_hello.is_channel_rate_limited('#channel', at_time) + assert not rule_hello.is_global_rate_limited(at_time) match, rule_trigger, wrapper = match_hello_rule(rule_hello) @@ -852,9 +904,10 @@ def testrule(bot, trigger): assert items == [1] # assert the rule is now rate limited - assert not rule_hello.is_user_rate_limited(Identifier('Test')) - assert rule_hello.is_channel_rate_limited('#channel') - assert not rule_hello.is_global_rate_limited() + at_time = datetime.now(timezone.utc) + assert not rule_hello.is_user_rate_limited(Identifier('Test'), at_time) + assert rule_hello.is_channel_rate_limited('#channel', at_time) + assert not rule_hello.is_global_rate_limited(at_time) match, rule_trigger, wrapper = match_hello_rule(rule_hello) @@ -897,9 +950,10 @@ def testrule(bot, trigger): assert items == [1] # assert the rule is now rate limited - assert not rule_hello.is_user_rate_limited(Identifier('Test')) - assert rule_hello.is_channel_rate_limited('#channel') - assert not rule_hello.is_global_rate_limited() + at_time = datetime.now(timezone.utc) + assert not rule_hello.is_user_rate_limited(Identifier('Test'), at_time) + assert rule_hello.is_channel_rate_limited('#channel', at_time) + assert not rule_hello.is_global_rate_limited(at_time) match, rule_trigger, wrapper = match_hello_rule(rule_hello) @@ -942,9 +996,10 @@ def testrule(bot, trigger): assert items == [1] # assert the rule is now rate limited - assert not rule_hello.is_user_rate_limited(Identifier('Test')) - assert not rule_hello.is_channel_rate_limited('#channel') - assert rule_hello.is_global_rate_limited() + at_time = datetime.now(timezone.utc) + assert not rule_hello.is_user_rate_limited(Identifier('Test'), at_time) + assert not rule_hello.is_channel_rate_limited('#channel', at_time) + assert rule_hello.is_global_rate_limited(at_time) match, rule_trigger, wrapper = match_hello_rule(rule_hello) @@ -987,9 +1042,10 @@ def testrule(bot, trigger): assert items == [1] # assert the rule is now rate limited - assert not rule_hello.is_user_rate_limited(Identifier('Test')) - assert not rule_hello.is_channel_rate_limited('#channel') - assert rule_hello.is_global_rate_limited() + at_time = datetime.now(timezone.utc) + assert not rule_hello.is_user_rate_limited(Identifier('Test'), at_time) + assert not rule_hello.is_channel_rate_limited('#channel', at_time) + assert rule_hello.is_global_rate_limited(at_time) match, rule_trigger, wrapper = match_hello_rule(rule_hello)