Skip to content

Commit

Permalink
Upgrade ALE use for latest version of Gymnasium.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 691514007
  • Loading branch information
psc-g committed Oct 31, 2024
1 parent 02a3547 commit 5a9c4ee
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
6 changes: 3 additions & 3 deletions dopamine/discrete_domains/atari_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def reset(self):
environment.
"""
self.environment.reset()
self.lives = self.environment.ale.lives()
self.lives = self.environment.env.ale.lives()
self._fetch_grayscale_observation(self.screen_buffer[0])
self.screen_buffer[1].fill(0)
return self._pool_and_resize()
Expand Down Expand Up @@ -608,7 +608,7 @@ def step(self, action):
accumulated_reward += reward

if self.terminal_on_life_loss:
new_lives = self.environment.ale.lives()
new_lives = self.environment.env.ale.lives()
is_terminal = game_over or new_lives < self.lives
self.lives = new_lives
else:
Expand Down Expand Up @@ -639,7 +639,7 @@ def _fetch_grayscale_observation(self, output):
Returns:
observation: numpy array, the current observation in grayscale.
"""
self.environment.ale.getScreenGrayscale(output)
self.environment.env.ale.getScreenGrayscale(output)
return output

def _pool_and_resize(self):
Expand Down
15 changes: 11 additions & 4 deletions tests/dopamine/discrete_domains/atari_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,32 +67,39 @@ def getScreenGrayscale(self, screen): # pylint: disable=invalid-name
screen.fill(self.screen_value)


class MockEnvALEWrapper(object):
"""Mock ALE env wrapper."""

def __init__(self):
self.ale = MockALE()


class MockEnvironment(object):
"""Mock environment for testing."""

def __init__(self, screen_size=10, max_steps=10):
self.max_steps = max_steps
self.screen_size = screen_size
self.ale = MockALE()
self.env = MockEnvALEWrapper()
self.observation_space = np.empty((screen_size, screen_size))
self.game_over = False

def reset(self):
self.ale.screen_value = 10
self.env.ale.screen_value = 10
self.num_steps = 0
return self.get_observation()

def get_observation(self):
observation = np.empty((self.screen_size, self.screen_size))
return self.ale.getScreenGrayscale(observation)
return self.env.ale.getScreenGrayscale(observation)

def step(self, action):
reward = -1.0 if action > 0 else 1.0
self.num_steps += 1
is_terminal = self.num_steps >= self.max_steps

unused = 0
self.ale.screen_value = max(0, self.ale.screen_value - 2)
self.env.ale.screen_value = max(0, self.env.ale.screen_value - 2)
return (self.get_observation(), reward, is_terminal, False, unused)

def render(self, mode):
Expand Down

0 comments on commit 5a9c4ee

Please sign in to comment.