diff --git a/dopamine/discrete_domains/atari_lib.py b/dopamine/discrete_domains/atari_lib.py index 4781cc26..aa1e8473 100644 --- a/dopamine/discrete_domains/atari_lib.py +++ b/dopamine/discrete_domains/atari_lib.py @@ -553,7 +553,7 @@ def reset(self): environment. """ self.environment.reset() - self.lives = self.environment.ale.lives() + self.lives = self.environment.env.ale.lives() self._fetch_grayscale_observation(self.screen_buffer[0]) self.screen_buffer[1].fill(0) return self._pool_and_resize() @@ -608,7 +608,7 @@ def step(self, action): accumulated_reward += reward if self.terminal_on_life_loss: - new_lives = self.environment.ale.lives() + new_lives = self.environment.env.ale.lives() is_terminal = game_over or new_lives < self.lives self.lives = new_lives else: @@ -639,7 +639,7 @@ def _fetch_grayscale_observation(self, output): Returns: observation: numpy array, the current observation in grayscale. """ - self.environment.ale.getScreenGrayscale(output) + self.environment.env.ale.getScreenGrayscale(output) return output def _pool_and_resize(self): diff --git a/tests/dopamine/discrete_domains/atari_lib_test.py b/tests/dopamine/discrete_domains/atari_lib_test.py index c52a99a3..c3c5014d 100644 --- a/tests/dopamine/discrete_domains/atari_lib_test.py +++ b/tests/dopamine/discrete_domains/atari_lib_test.py @@ -67,24 +67,31 @@ def getScreenGrayscale(self, screen): # pylint: disable=invalid-name screen.fill(self.screen_value) +class MockEnvALEWrapper(object): + """Mock ALE env wrapper.""" + + def __init__(self): + self.ale = MockALE() + + class MockEnvironment(object): """Mock environment for testing.""" def __init__(self, screen_size=10, max_steps=10): self.max_steps = max_steps self.screen_size = screen_size - self.ale = MockALE() + self.env = MockEnvALEWrapper() self.observation_space = np.empty((screen_size, screen_size)) self.game_over = False def reset(self): - self.ale.screen_value = 10 + self.env.ale.screen_value = 10 self.num_steps = 0 return self.get_observation() def get_observation(self): observation = np.empty((self.screen_size, self.screen_size)) - return self.ale.getScreenGrayscale(observation) + return self.env.ale.getScreenGrayscale(observation) def step(self, action): reward = -1.0 if action > 0 else 1.0 @@ -92,7 +99,7 @@ def step(self, action): is_terminal = self.num_steps >= self.max_steps unused = 0 - self.ale.screen_value = max(0, self.ale.screen_value - 2) + self.env.ale.screen_value = max(0, self.env.ale.screen_value - 2) return (self.get_observation(), reward, is_terminal, False, unused) def render(self, mode):