diff --git a/ops/_private/harness.py b/ops/_private/harness.py index 2be1e457b..96a58d5c0 100644 --- a/ops/_private/harness.py +++ b/ops/_private/harness.py @@ -312,6 +312,12 @@ def __init__( self._meta, self._model, juju_debug_at=self._juju_context.debug_at, + # Harness tests will often have defer() usage without 'purging' the + # deferred handler with reemit(), but still expect the next emit() + # to result in a call, so we can't safely skip duplicate events. + # When behaviour matching production is required, Scenario tests + # should be used instead. + skip_duplicate_events=False, ) warnings.warn( diff --git a/ops/framework.py b/ops/framework.py index 0fd089510..f05a20aba 100644 --- a/ops/framework.py +++ b/ops/framework.py @@ -608,6 +608,7 @@ def __init__( model: 'Model', event_name: Optional[str] = None, juju_debug_at: Optional[Set[str]] = None, + skip_duplicate_events: bool = True, ): super().__init__(self, None) @@ -624,6 +625,7 @@ def __init__( self._event_name = event_name self.meta = meta self.model = model + self.skip_duplicate_events = skip_duplicate_events # [(observer_path, method_name, parent_path, event_key)] self._observers: _ObserverPath = [] # {observer_path: observing Object} @@ -719,15 +721,15 @@ def register_type( self._type_registry[(parent_path, kind_)] = cls self._type_known.add(cls) - def save_snapshot(self, value: Union['StoredStateData', 'EventBase']): - """Save a persistent snapshot of the provided value.""" + def _validate_snapshot_data( + self, value: Union['StoredStateData', 'EventBase'], data: Dict[str, Any] + ): if type(value) not in self._type_known: raise RuntimeError( f'cannot save {type(value).__name__} values before registering that type' ) - data = value.snapshot() - # Use marshal as a validator, enforcing the use of simple types, as we later the + # Use marshal as a validator, enforcing the use of simple types, as later the # information is really pickled, which is too error-prone for future evolution of the # stored data (e.g. if the developer stores a custom object and later changes its # class name; when unpickling the original class will not be there and event @@ -738,6 +740,10 @@ def save_snapshot(self, value: Union['StoredStateData', 'EventBase']): msg = 'unable to save the data for {}, it must contain only simple types: {!r}' raise ValueError(msg.format(value.__class__.__name__, data)) from None + def save_snapshot(self, value: Union['StoredStateData', 'EventBase']): + """Save a persistent snapshot of the provided value.""" + data = value.snapshot() + self._validate_snapshot_data(value, data) self._storage.save_snapshot(value.handle.path, data) def load_snapshot(self, handle: Handle) -> Serializable: @@ -831,6 +837,32 @@ def _next_event_key(self) -> str: self._stored['event_count'] += 1 return str(self._stored['event_count']) + def _event_is_in_storage( + self, observer_path: str, method_name: str, event_path: str, event_data: Dict[str, Any] + ) -> bool: + """Check if there is already a notice with the same snapshot in the storage.""" + # Check all the notices to see if there is one that is the same other + # than the event ID. + for ( + existing_event_path, + existing_observer_path, + existing_method_name, + ) in self._storage.notices(): + if ( + existing_observer_path != observer_path + or existing_method_name != method_name + or existing_event_path.split('[')[0] != event_path.split('[')[0] + ): + continue + # Check if the snapshot for this notice is the same. + try: + existing_event_data = self._storage.load_snapshot(existing_event_path) + except NoSnapshotError: + existing_event_data = {} + if event_data == existing_event_data: + return True + return False + def _emit(self, event: EventBase): """See BoundEvent.emit for the public way to call this.""" saved = False @@ -839,17 +871,33 @@ def _emit(self, event: EventBase): parent = event.handle.parent assert isinstance(parent, Handle), 'event handle must have a parent' parent_path = parent.path + this_event_data = event.snapshot() + self._validate_snapshot_data(event, this_event_data) # TODO Track observers by (parent_path, event_kind) rather than as a list of - # all observers. Avoiding linear search through all observers for every event + # all observers. Avoiding linear search through all observers for every event for observer_path, method_name, _parent_path, _event_kind in self._observers: if _parent_path != parent_path: continue if _event_kind and _event_kind != event_kind: continue + if self.skip_duplicate_events and self._event_is_in_storage( + observer_path, method_name, event_path, this_event_data + ): + logger.info( + 'Skipping notice (%s/%s/%s) - already in the queue.', + event_path, + observer_path, + method_name, + ) + # We don't need to save a new notice and snapshot, but we do + # want the event to run, because it has been saved previously + # and not completed. + saved = True + continue if not saved: # Save the event for all known observers before the first notification # takes place, so that either everyone interested sees it, or nobody does. - self.save_snapshot(event) + self._storage.save_snapshot(event.handle.path, this_event_data) saved = True # Again, only commit this after all notices are saved. self._storage.save_notice(event_path, observer_path, method_name) diff --git a/test/test_framework.py b/test/test_framework.py index ad92aef0f..184094c80 100644 --- a/test/test_framework.py +++ b/test/test_framework.py @@ -373,7 +373,15 @@ def test_defer_and_reemit(self, request: pytest.FixtureRequest): framework = create_framework(request) class MyEvent(ops.EventBase): - pass + def __init__(self, handle: ops.Handle, data: str): + super().__init__(handle) + self.data: str = data + + def restore(self, snapshot: typing.Dict[str, typing.Any]): + self.data = typing.cast(str, snapshot['data']) + + def snapshot(self) -> typing.Dict[str, typing.Any]: + return {'data': self.data} class MyNotifier1(ops.Object): a = ops.EventSource(MyEvent) @@ -404,18 +412,18 @@ def on_any(self, event: ops.EventBase): framework.observe(pub1.b, obs2.on_any) framework.observe(pub2.c, obs2.on_any) - pub1.a.emit() - pub1.b.emit() - pub2.c.emit() + pub1.a.emit('a') + pub1.b.emit('b') + pub2.c.emit('c') - # Events remain stored because they were deferred. + # Events remain stored because they were deferred (and distinct). ev_a_handle = ops.Handle(pub1, 'a', '1') framework.load_snapshot(ev_a_handle) ev_b_handle = ops.Handle(pub1, 'b', '2') framework.load_snapshot(ev_b_handle) ev_c_handle = ops.Handle(pub2, 'c', '3') framework.load_snapshot(ev_c_handle) - # make sure the objects are gone before we reemit them + # Make sure the objects are gone before we reemit them. gc.collect() framework.reemit() @@ -439,6 +447,113 @@ def on_any(self, event: ops.EventBase): pytest.raises(NoSnapshotError, framework.load_snapshot, ev_b_handle) pytest.raises(NoSnapshotError, framework.load_snapshot, ev_c_handle) + def test_repeated_defer(self, request: pytest.FixtureRequest): + framework = create_framework(request) + + class MyEvent(ops.EventBase): + data: typing.Optional[str] = None + + class MyDataEvent(MyEvent): + def __init__(self, handle: ops.Handle, data: str): + super().__init__(handle) + self.data: typing.Optional[str] = data + + def restore(self, snapshot: typing.Dict[str, typing.Any]): + self.data = typing.cast(typing.Optional[str], snapshot['data']) + + def snapshot(self) -> typing.Dict[str, typing.Any]: + return {'data': self.data} + + class ReleaseEvent(ops.EventBase): + pass + + class MyNotifier(ops.Object): + n = ops.EventSource(MyEvent) + d = ops.EventSource(MyDataEvent) + r = ops.EventSource(ReleaseEvent) + + class MyObserver(ops.Object): + def __init__(self, parent: ops.Object, key: str): + super().__init__(parent, key) + self.defer_all = True + + def stop_deferring(self, _: MyEvent): + self.defer_all = False + + def on_any(self, event: MyEvent): + if self.defer_all: + event.defer() + + pub = MyNotifier(framework, 'n') + obs1 = MyObserver(framework, '1') + obs2 = MyObserver(framework, '2') + + framework.observe(pub.n, obs1.on_any) + framework.observe(pub.n, obs2.on_any) + framework.observe(pub.d, obs1.on_any) + framework.observe(pub.d, obs2.on_any) + framework.observe(pub.r, obs1.stop_deferring) + + # Emit an event, which will be deferred. + pub.d.emit('foo') + notices = tuple(framework._storage.notices()) + assert len(notices) == 2 # One per observer. + assert framework._storage.load_snapshot(notices[0][0]) == {'data': 'foo'} + + # Emit the same event, and we'll still just have the single notice. + pub.d.emit('foo') + assert len(tuple(framework._storage.notices())) == 2 + + # Emit the same event kind but with a different snapshot, and we'll get a new notice. + pub.d.emit('bar') + notices = tuple(framework._storage.notices()) + assert len(notices) == 4 + assert framework._storage.load_snapshot(notices[2][0]) == {'data': 'bar'} + + # Emit a totally different event, and we'll get a new notice. + pub.n.emit() + notices = tuple(framework._storage.notices()) + assert len(notices) == 6 + assert framework._storage.load_snapshot(notices[2][0]) == {'data': 'bar'} + assert framework._storage.load_snapshot(notices[4][0]) == {} + + # Even though these events are far back in the queue, since they're + # duplicates, they will get skipped. + pub.d.emit('foo') + pub.d.emit('bar') + pub.n.emit() + assert len(tuple(framework._storage.notices())) == 6 + + def notices_for_observer(n: int): + return [ + notice for notice in framework._storage.notices() if notice[1].endswith(f'[{n}]') + ] + + # Stop deferring on the first observer, and all those events will be + # completed and the notices removed, while the second observer will + # still have them queued. + pub.r.emit() + assert len(tuple(framework._storage.notices())) == 6 + pub.n.emit() + framework.reemit() + assert len(notices_for_observer(1)) == 0 + assert len(notices_for_observer(2)) == 3 + + # Without the defer active, the first observer always ends up with an + # empty queue, while the second observer's queue continues to skip + # duplicates and add new events. + pub.d.emit('foo') + pub.d.emit('foo') + pub.d.emit('bar') + pub.n.emit() + pub.d.emit('foo') + pub.d.emit('bar') + pub.n.emit() + pub.d.emit('baz') + framework.reemit() + assert len(notices_for_observer(1)) == 0 + assert len(notices_for_observer(2)) == 4 + def test_custom_event_data(self, request: pytest.FixtureRequest): framework = create_framework(request)