Skip to content

Commit

Permalink
Merge pull request #240 from huangshiyu13/main
Browse files Browse the repository at this point in the history
fix arena bugs
  • Loading branch information
huangshiyu13 authored Oct 13, 2023
2 parents 87cd1cb + 8ddbdba commit 4024a72
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 8 deletions.
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,7 @@ conda-upload:
./scripts/conda_upload.sh

doc:
./scripts/gen_api_docs.sh
./scripts/gen_api_docs.sh

upload-codecov:
codecov --file coverage.xml -t $(CODECOV_TOKEN)
5 changes: 3 additions & 2 deletions examples/arena/run_arena.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ def run_arena(
seed=0,
total_games: int = 10,
max_game_onetime: int = 5,
use_tqdm: bool = True,
):
env_wrappers = [RecordWinner]
if render:
from examples.selfplay.tictactoe_utils.tictactoe_render import TictactoeRender

env_wrappers.append(TictactoeRender)

arena = make_arena("tictactoe_v3", env_wrappers=env_wrappers, use_tqdm=True)
arena = make_arena("tictactoe_v3", env_wrappers=env_wrappers, use_tqdm=use_tqdm)

agent1 = LocalAgent("../selfplay/opponent_templates/random_opponent")
agent2 = LocalAgent("../selfplay/opponent_templates/random_opponent")
Expand All @@ -52,4 +53,4 @@ def run_arena(

if __name__ == "__main__":
run_arena(render=False, parallel=True, seed=0, total_games=100, max_game_onetime=10)
# run_arena(render=True, parallel=True, seed=1, total_games=10, max_game_onetime=2)
# run_arena(render=False, parallel=False, seed=1, total_games=1, max_game_onetime=1,use_tqdm=False)
13 changes: 11 additions & 2 deletions examples/snake/jidi_random_vs_openrl_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def run_arena(
seed=0,
total_games: int = 10,
max_game_onetime: int = 5,
use_tqdm: bool = True,
):
env_wrappers = [RecordWinner]

Expand All @@ -36,7 +37,7 @@ def run_arena(
f"snakes_{player_num}v{player_num}",
env_wrappers=env_wrappers,
render=render,
use_tqdm=True,
use_tqdm=use_tqdm,
)

agent1 = JiDiAgent("./submissions/random_agent", player_num=player_num)
Expand All @@ -55,4 +56,12 @@ def run_arena(


if __name__ == "__main__":
run_arena(render=False, parallel=True, seed=0, total_games=100, max_game_onetime=5)
# run_arena(render=False, parallel=True, seed=0, total_games=100, max_game_onetime=5)
run_arena(
render=False,
parallel=False,
seed=0,
total_games=1,
max_game_onetime=1,
use_tqdm=False,
)
1 change: 1 addition & 0 deletions openrl/envs/PettingZoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def PettingZoo_make(id, render_mode, disable_env_checker, **kwargs):
from pettingzoo.classic import tictactoe_v3

env = tictactoe_v3.env(render_mode=render_mode)

else:
raise NotImplementedError
return env
Expand Down
3 changes: 2 additions & 1 deletion openrl/envs/snake/snake_pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def action_space(self, agent):
return deepcopy(self._action_spaces[agent])

def observe(self, agent):
return self.raw_obs[self.agent_name_to_slice[agent]]
obs = self.raw_obs[self.agent_name_to_slice[agent]]
return obs

def reset(
self,
Expand Down
4 changes: 4 additions & 0 deletions openrl/selfplay/opponents/random_opponent.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ def _sample_random_action(
):
action_space = self.env.action_space(player_name)
if isinstance(action_space, list):
if not isinstance(observation, list):
observation = [observation]

action = []

for obs, space in zip(observation, action_space):
mask = obs.get("action_mask", None)
action.append(space.sample(mask))
Expand Down
1 change: 1 addition & 0 deletions openrl/supports/opendata/utils/opendata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def data_server_wrapper(fp):

def load_dataset(data_path: str, split: str):
from datasets import load_from_disk

if Path(data_path).exists():
dataset = load_from_disk("{}/{}".format(data_path, split))
elif "data_server:" in data_path:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def get_extra_requires() -> dict:
"retro": ["gym-retro"],
"super_mario": ["gym-super-mario-bros"],
}
req["test"].extend(req["selfplay"])
return req


Expand Down
72 changes: 72 additions & 0 deletions tests/test_arena/test_reproducibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 The OpenRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""""""
import os
import sys

import pytest

from openrl.arena import make_arena
from openrl.arena.agents.local_agent import LocalAgent
from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner


def run_arena(
render: bool = False,
parallel: bool = True,
seed=0,
total_games: int = 10,
max_game_onetime: int = 5,
):
env_wrappers = [RecordWinner]
if render:
from examples.selfplay.tictactoe_utils.tictactoe_render import TictactoeRender

env_wrappers.append(TictactoeRender)

arena = make_arena("tictactoe_v3", env_wrappers=env_wrappers, use_tqdm=False)

agent1 = LocalAgent("./examples/selfplay/opponent_templates/random_opponent")
agent2 = LocalAgent("./examples/selfplay/opponent_templates/random_opponent")

arena.reset(
agents={"agent1": agent1, "agent2": agent2},
total_games=total_games,
max_game_onetime=max_game_onetime,
seed=seed,
)
result = arena.run(parallel=parallel)
arena.close()
print(result)
return result


@pytest.mark.unittest
def test_seed():
seed = 0
test_time = 3
pre_result = None
for parallel in [False, True]:
for i in range(test_time):
result = run_arena(seed=seed, parallel=parallel, total_games=20)
if pre_result is not None:
assert pre_result == result, f"parallel={parallel}, seed={seed}"
pre_result = result


if __name__ == "__main__":
sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
2 changes: 0 additions & 2 deletions tests/test_supports/test_opendata/test_opendata.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,5 @@ def test_data_abs_path():
assert data_abs_path(data_path) == data_path




if __name__ == "__main__":
sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))

0 comments on commit 4024a72

Please sign in to comment.