Skip to content

Commit

Permalink
Fixing engine terminate behaviour when resumed (#2678)
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 authored Aug 29, 2022
1 parent 942af82 commit 26f7cec
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
25 changes: 18 additions & 7 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def execute_something():
except ValueError:
_check_signature(handler, "handler", *(event_args + args), **kwargs)
self._event_handlers[event_name].append((handler, args, kwargs))
self.logger.debug(f"added handler for event {event_name}")
self.logger.debug(f"Added handler for event {event_name}")

return RemovableEventHandle(event_name, handler, self)

Expand Down Expand Up @@ -406,7 +406,7 @@ def _fire_event(self, event_name: Any, *event_args: Any, **event_kwargs: Any) ->
**event_kwargs: optional keyword args to be passed to all handlers.
"""
self.logger.debug(f"firing handlers for event {event_name}")
self.logger.debug(f"{self.state.epoch} | {self.state.iteration}, Firing handlers for event {event_name}")
self.last_event_name = event_name
for func, args, kwargs in self._event_handlers[event_name]:
kwargs.update(event_kwargs)
Expand Down Expand Up @@ -720,6 +720,11 @@ def switch_batch(engine):
if self.state.epoch_length is None and data is None:
raise ValueError("epoch_length should be provided if data is None")

if self.should_terminate:
# If engine was terminated and now is resuming from terminated state
# we need to initialize iter_counter as 0
self._init_iter.append(0)

self.state.dataloader = data
return self._internal_run()

Expand Down Expand Up @@ -750,12 +755,13 @@ def _setup_dataloader_iter(self) -> None:

def _setup_engine(self) -> None:
self._setup_dataloader_iter()
iteration = self.state.iteration

# Below we define initial counter value for _run_once_on_dataset to measure a single epoch
if self.state.epoch_length is not None:
iteration %= self.state.epoch_length
self._init_iter.append(iteration)
if len(self._init_iter) == 0:
iteration = self.state.iteration
# Below we define initial counter value for _run_once_on_dataset to measure a single epoch
if self.state.epoch_length is not None:
iteration %= self.state.epoch_length
self._init_iter.append(iteration)

def _internal_run(self) -> State:
self.should_terminate = self.should_terminate_single_epoch = False
Expand Down Expand Up @@ -826,6 +832,11 @@ def _run_once_on_dataset(self) -> float:
start_time = time.time()

# We need to setup iter_counter > 0 if we resume from an iteration
if len(self._init_iter) > 1:
raise RuntimeError(
"Internal error, len(self._init_iter) should 0 or 1, "
f"but got: {len(self._init_iter)}, {self._init_iter}"
)
iter_counter = self._init_iter.pop() if len(self._init_iter) > 0 else 0
should_exit = False
try:
Expand Down
2 changes: 1 addition & 1 deletion tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def check_iter_and_data():

assert state.epoch == max_epochs
assert not engine.should_terminate
assert state.iteration == real_epoch_length * (max_epochs - 1)
assert state.iteration == real_epoch_length * (max_epochs - 1) + (iteration_to_stop % real_epoch_length)


class RecordedEngine(Engine):
Expand Down

0 comments on commit 26f7cec

Please sign in to comment.