Skip to content

Commit

Permalink
WIP - Multiple changes, see below
Browse files Browse the repository at this point in the history
- Acquire Role as Enum modification
- Rework base observation space keeping P1 and P2 as in engine
- Use wrapper to add action to obs space
- Keep discrete and multi discrete action type in observation space
- TODO:
  - Check how normalization works for new actions
  - Add wrapper to make observations relative (own_ opp_) for both 1P and 2P (with agent_0 and agent_1 differentiation)
  • Loading branch information
alexpalms committed Sep 20, 2023
1 parent ef1f465 commit f09ae9f
Show file tree
Hide file tree
Showing 18 changed files with 265 additions and 165 deletions.
2 changes: 1 addition & 1 deletion diambra/arena/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from diambra.engine import SpaceType
from diambra.engine import SpaceTypes, Roles
from diambra.engine import model
from .make_env import make
from .utils.gym_utils import available_games, game_sha_256, check_game_sha_256, get_num_envs
34 changes: 16 additions & 18 deletions diambra/arena/arena_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from diambra.arena.engine.interface import DiambraEngine
from diambra.arena.env_settings import EnvironmentSettings1P, EnvironmentSettings2P
from typing import Union, Any, Dict, List
from diambra.engine import model, SpaceType
from diambra.engine import model, SpaceTypes

class DiambraGymBase(gym.Env):
"""Diambra Environment gymnasium base interface"""
Expand Down Expand Up @@ -54,35 +54,33 @@ def __init__(self, env_settings: Union[EnvironmentSettings1P, EnvironmentSetting
self.print_actions_dict = [move_dict, attack_dict]

# Maximum difference in players health
category_key_enum = model.RamStatesCategories.Value("P1")
for k in sorted(self.env_info.ram_states_categories[category_key_enum].ram_states.keys()):
for k in sorted(self.env_info.ram_states_categories[model.RamStatesCategories.P1].ram_states.keys()):
key_enum_name = model.RamStates.Name(k)
if "health" in key_enum_name:
self.max_delta_health = self.env_info.ram_states_categories[category_key_enum].ram_states[k].max - \
self.env_info.ram_states_categories[category_key_enum].ram_states[k].min
self.max_delta_health = self.env_info.ram_states_categories[model.RamStatesCategories.P1].ram_states[k].max - \
self.env_info.ram_states_categories[model.RamStatesCategories.P1].ram_states[k].min
break

# Observation space
# Dictionary
observation_space_dict = {}
observation_space_dict['frame'] = gym.spaces.Box(low=0, high=255, shape=(self.env_info.frame_shape.h,
self.env_info.frame_shape.w,
self.env_info.frame_shape.c),
self.env_info.frame_shape.w,
self.env_info.frame_shape.c),
dtype=np.uint8)

# Adding RAM states observations
for k, v in self.env_info.ram_states_categories.items():
print("Processing {}, {}".format(model.RamStatesCategories.Name(k), v))
if k == model.RamStatesCategories.common:
target_dict = observation_space_dict
else:
observation_space_dict[model.RamStatesCategories.Name(k)] = {}
target_dict = observation_space_dict[model.RamStatesCategories.Name(k)]

for k2, v2 in v.ram_states.items():
if v2.type == SpaceType.BINARY or v2.type == SpaceType.DISCRETE:
if v2.type == SpaceTypes.BINARY or v2.type == SpaceTypes.DISCRETE:
target_dict[model.RamStates.Name(k2)] = gym.spaces.Discrete(v2.max + 1)
elif v2.type == SpaceType.BOX:
elif v2.type == SpaceTypes.BOX:
target_dict[model.RamStates.Name(k2)] = gym.spaces.Box(low=v2.min, high=v2.max, shape=(1,), dtype=np.int16)
else:
raise RuntimeError("Only Discrete (Binary/Categorical) | Box Spaces allowed")
Expand Down Expand Up @@ -203,7 +201,7 @@ def _get_obs(self, response):

for k2, v2 in v.ram_states.items():
# Box spaces
if v2.type == SpaceType.BOX:
if v2.type == SpaceTypes.BOX:
target_dict[model.RamStates.Name(k2)] = np.array([category_ram_states.ram_states[k2]])
else: # Discrete spaces (binary / categorical)
target_dict[model.RamStates.Name(k2)] = category_ram_states.ram_states[k2]
Expand All @@ -222,11 +220,11 @@ def __init__(self, env_settings):
# Discrete actions:
# - Arrows U Buttons -> One discrete set
# NB: use the convention NOOP = 0
if env_settings.action_space == SpaceType.MULTI_DISCRETE:
if env_settings.action_space == SpaceTypes.MULTI_DISCRETE:
self.action_space = gym.spaces.MultiDiscrete(self.n_actions)
elif env_settings.action_space == SpaceType.DISCRETE:
elif env_settings.action_space == SpaceTypes.DISCRETE:
self.action_space = gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1)
self.logger.debug("Using {} action space".format(SpaceType.Name(env_settings.action_space)))
self.logger.debug("Using {} action space".format(SpaceTypes.Name(env_settings.action_space)))

# Return the no-op action
def get_no_op_action(self):
Expand All @@ -251,16 +249,16 @@ def __init__(self, env_settings):

# Action space
# Dictionary
action_spaces_values = {SpaceType.MULTI_DISCRETE: gym.spaces.MultiDiscrete(self.n_actions),
SpaceType.DISCRETE: gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1)}
action_spaces_values = {SpaceTypes.MULTI_DISCRETE: gym.spaces.MultiDiscrete(self.n_actions),
SpaceTypes.DISCRETE: gym.spaces.Discrete(self.n_actions[0] + self.n_actions[1] - 1)}
action_space_dict = self._map_action_spaces_to_agents(action_spaces_values)
self.logger.debug("Using the following action spaces: {}".format(action_space_dict))
self.action_space = gym.spaces.Dict(action_space_dict)

# Return the no-op action
def get_no_op_action(self):
no_op_values = {SpaceType.MULTI_DISCRETE: [0, 0],
SpaceType.DISCRETE: 0}
no_op_values = {SpaceTypes.MULTI_DISCRETE: [0, 0],
SpaceTypes.DISCRETE: 0}
return self._map_action_spaces_to_agents(no_op_values)

# Step the environment
Expand Down
46 changes: 28 additions & 18 deletions diambra/arena/env_settings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Union, List, Tuple, Any, Dict
from diambra.arena.utils.gym_utils import available_games
from diambra.arena import SpaceType
from diambra.arena import SpaceTypes, Roles
import numpy as np
import random
from diambra.engine import model
Expand All @@ -23,7 +23,11 @@ def check_val_in_list(key, value, valid_list):
assert (type(value)==type(valid_list[valid_list.index(value)])), error_message

def check_space_type(key, value, valid_list):
error_message = "ERROR: \"{}\" ({}) admissible values are {}".format(key, SpaceType.Name(value), [SpaceType.Name(elem) for elem in valid_list])
error_message = "ERROR: \"{}\" ({}) admissible values are {}".format(key, SpaceTypes.Name(value), [SpaceTypes.Name(elem) for elem in valid_list])
assert (value in valid_list), error_message

def check_roles(key, value, valid_list):
error_message = "ERROR: \"{}\" ({}) admissible values are {}".format(key, Roles.Name(value), [Roles.Name(elem) for elem in valid_list])
assert (value in valid_list), error_message

@dataclass
Expand Down Expand Up @@ -192,10 +196,10 @@ class EnvironmentSettings1P(EnvironmentSettings):
"""Single Agent Environment Settings Class"""

# Env settings
action_space: int = SpaceType.MULTI_DISCRETE
action_space: int = SpaceTypes.MULTI_DISCRETE

# Episode settings
role: Union[None, str] = None
role: Union[None, int] = None
characters: Union[None, str, Tuple[str], Tuple[str, str], Tuple[str, str, str]] = None
outfits: int = 1
super_art: Union[None, int] = None # SFIII Specific
Expand All @@ -207,10 +211,11 @@ def _sanity_check(self):

# Env settings
check_num_in_range("n_players", self.n_players, [1, 1])
check_space_type("action_space", self.action_space, [SpaceType.DISCRETE, SpaceType.MULTI_DISCRETE])
check_space_type("action_space", self.action_space, [SpaceTypes.DISCRETE, SpaceTypes.MULTI_DISCRETE])

# Episode settings
check_val_in_list("role", self.role, [None, "P1", "P2"])
if self.role is not None:
check_roles("role", self.role, [Roles.P1, Roles.P2])
if isinstance(self.characters, str) or self.characters is None:
self.characters = (self.characters, None, None)
else:
Expand Down Expand Up @@ -243,7 +248,7 @@ def _process_random_values(self):
self.characters = tuple(characters_tmp)

if self.role is None:
self.role = random.choice(["P1", "P2"])
self.role = random.choice([Roles.P1, Roles.P2])
if self.super_art is None:
self.super_art = random.choice(list(range(1, 4)))
if self.fighting_style is None:
Expand All @@ -267,10 +272,10 @@ def _get_player_specific_values(self):
class EnvironmentSettings2P(EnvironmentSettings):
"""Single Agent Environment Settings Class"""
# Env Settings
action_space: Tuple[int, int] = (SpaceType.MULTI_DISCRETE, SpaceType.MULTI_DISCRETE)
action_space: Tuple[int, int] = (SpaceTypes.MULTI_DISCRETE, SpaceTypes.MULTI_DISCRETE)

# Episode Settings
role: Union[Tuple[None, None], Tuple[str, str]] = (None, None)
role: Union[Tuple[None, None], Tuple[int, int]] = (None, None)
characters: Union[Tuple[None, None], Tuple[str, None], Tuple[None, str], Tuple[str, str],
Tuple[Tuple[str], Tuple[str]], Tuple[Tuple[str, str], Tuple[str, str]],
Tuple[Tuple[str, str, str], Tuple[str, str, str]]] = (None, None)
Expand All @@ -285,7 +290,7 @@ def _sanity_check(self):
# Env Settings
check_num_in_range("n_players", self.n_players, [2, 2])
for idx in range(2):
check_space_type("action_space[{}]".format(idx), self.action_space[idx], [SpaceType.DISCRETE, SpaceType.MULTI_DISCRETE])
check_space_type("action_space[{}]".format(idx), self.action_space[idx], [SpaceTypes.DISCRETE, SpaceTypes.MULTI_DISCRETE])

# Episode Settings
if isinstance(self.characters[0], str) or self.characters[0] is None:
Expand All @@ -299,7 +304,8 @@ def _sanity_check(self):
char_list = list(self.env_info.characters_info.char_list)
char_list.append(None)
for idx in range(2):
check_val_in_list("role[{}]".format(idx), self.role[idx], [None, "P1", "P2"])
if self.role[idx] is not None:
check_roles("role[{}]".format(idx), self.role[idx], [Roles.P1, Roles.P2])
for jdx in range(3):
check_val_in_list("characters[{}][{}]".format(idx, jdx), self.characters[idx][jdx], char_list)
check_num_in_range("outfits[{}]".format(idx), self.outfits[idx], self.games_dict[self.game_id]["outfits"])
Expand All @@ -326,13 +332,13 @@ def _process_random_values(self):

if self.role[0] is None:
if self.role[1] is None:
idx = random.choice([1, 2])
self.role = ("P{}".format(idx), "P{}".format((idx % 2) + 1))
coin = random.choice([True, False])
self.role = (Roles.P1, Roles.P2) if coin is True else (Roles.P2, Roles.P1)
else:
self.role = ("P1" if self.role[1] == "P2" else "P2", self.role[1])
self.role = (Roles.P1 if self.role[1] == Roles.P2 else Roles.P2, self.role[1])
else:
if self.role[1] is None:
self.role = (self.role[0], "P1" if self.role[0] == "P2" else "P2")
self.role = (self.role[0], Roles.P1 if self.role[0] == Roles.P2 else Roles.P2)

self.super_art = tuple([random.choice(list(range(1, 4))) if self.super_art[idx] is None else self.super_art[idx] for idx in range(2)])
self.fighting_style = tuple([random.choice(list(range(1, 4))) if self.fighting_style[idx] is None else self.fighting_style[idx] for idx in range(2)])
Expand Down Expand Up @@ -360,14 +366,15 @@ def _get_player_specific_values(self):

@dataclass
class WrappersSettings:
no_attack_buttons_combinations: bool = False
no_op_max: int = 0
sticky_actions: int = 1
clip_rewards: bool = False
reward_normalization: bool = False
reward_normalization_factor: float = 0.5
clip_rewards: bool = False
no_attack_buttons_combinations: bool = False
frame_stack: int = 1
dilation: int = 1
add_last_action_to_observation: bool = False
actions_stack: int = 1
scale: bool = False
exclude_image_scaling: bool = False
Expand All @@ -382,7 +389,10 @@ def sanity_check(self):
check_num_in_range("sticky_actions", self.sticky_actions, [1, 12])
check_num_in_range("frame_stack", self.frame_stack, [1, MAX_STACK_VALUE])
check_num_in_range("dilation", self.dilation, [1, MAX_STACK_VALUE])
check_num_in_range("actions_stack", self.actions_stack, [1, MAX_STACK_VALUE])
actions_stack_bounds = [1, 1]
if self.add_last_action_to_observation is True:
actions_stack_bounds = [1, MAX_STACK_VALUE]
check_num_in_range("actions_stack", self.actions_stack, actions_stack_bounds)
check_num_in_range("reward_normalization_factor", self.reward_normalization_factor, [0.0, 1000000])

check_val_in_list("frame_shape[2]", self.frame_shape[2], [0, 1, 3])
Expand Down
4 changes: 2 additions & 2 deletions diambra/arena/make_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ def make(game_id, env_settings:dict={}, wrappers_settings:dict={}, episode_recor

# Apply episode recorder wrapper
if len(episode_recording_settings) != 0:
episode_recording_settings = from_dict(RecordingSettings, episode_recording_settings)
episode_recording_settings = from_dict(RecordingSettings, episode_recording_settings, config=Config(strict=True))
env = EpisodeRecorder(env, episode_recording_settings)

# Apply environment wrappers
wrappers_settings = from_dict(WrappersSettings, wrappers_settings)
wrappers_settings = from_dict(WrappersSettings, wrappers_settings, config=Config(strict=True))
wrappers_settings.sanity_check()
env = env_wrapping(env, wrappers_settings)

Expand Down
Loading

0 comments on commit f09ae9f

Please sign in to comment.