-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b643cf1
commit a0745d6
Showing
3 changed files
with
226 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Diambra Agents | ||
|
||
import importlib_resources | ||
import sheeprl.utils.env | ||
|
||
from diambra.arena.sheeprl.make_sheeprl_env import make_sheeprl_env | ||
|
||
sheeprl.utils.env.make_env = make_sheeprl_env | ||
CONFIGS_PATH = str(importlib_resources.files("sheeprl.configs")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
# Diambra Arena | ||
|
||
from __future__ import annotations | ||
|
||
import os | ||
import warnings | ||
from typing import Any, Callable, Dict | ||
|
||
import cv2 | ||
import gymnasium as gym | ||
import hydra | ||
import numpy as np | ||
from sheeprl.envs.wrappers import ( | ||
FrameStack, | ||
GrayscaleRenderWrapper, | ||
RewardAsObservationWrapper, | ||
) | ||
|
||
|
||
def make_sheeprl_env( | ||
cfg: Dict[str, Any], | ||
seed: int, | ||
rank: int, | ||
run_name: str | None = None, | ||
prefix: str = "", | ||
vector_env_idx: int = 0, | ||
) -> Callable[[], gym.Env]: | ||
""" | ||
Create the callable function to create environment and | ||
force the environment to return an observation space of type | ||
gymnasium.spaces.Dict. | ||
Args: | ||
cfg (Dict[str, Any]): the configs of the environment to initialize. | ||
seed (int): the seed to use. | ||
rank (int): the rank of the process. | ||
run_name (str, optional): the name of the run. | ||
Default to None. | ||
prefix (str): the prefix to add to the video folder. | ||
Default to "". | ||
vector_env_idx (int): the index of the environment. | ||
Returns: | ||
The callable function that initializes the environment. | ||
""" | ||
|
||
def thunk() -> gym.Env: | ||
if "diambra" in cfg.env.wrapper._target_ and not cfg.env.sync_env: | ||
if cfg.env.wrapper.diambra_settings.pop("splash_screen", True): | ||
warnings.warn( | ||
"You must set the `splash_screen` setting to `False` when using the `AsyncVectorEnv` " | ||
"in `DIAMBRA` environments. The specified `splash_screen` setting is ignored and set " | ||
"to `False`." | ||
) | ||
cfg.env.wrapper.diambra_settings.splash_screen = False | ||
|
||
instantiate_kwargs = {} | ||
if "seed" in cfg.env.wrapper: | ||
instantiate_kwargs["seed"] = seed | ||
if "rank" in cfg.env.wrapper: | ||
instantiate_kwargs["rank"] = rank + vector_env_idx | ||
env = hydra.utils.instantiate(cfg.env.wrapper, **instantiate_kwargs) | ||
|
||
env_cnn_keys = set( | ||
[ | ||
k | ||
for k in env.observation_space.spaces.keys() | ||
if len(env.observation_space[k].shape) in {2, 3} | ||
] | ||
) | ||
if cfg.cnn_keys.encoder is None: | ||
user_cnn_keys = set() | ||
else: | ||
user_cnn_keys = set(cfg.cnn_keys.encoder) | ||
cnn_keys = env_cnn_keys.intersection(user_cnn_keys) | ||
|
||
def transform_obs(obs: Dict[str, Any]): | ||
for k in cnn_keys: | ||
current_obs = obs[k] | ||
shape = current_obs.shape | ||
is_3d = len(shape) == 3 | ||
is_grayscale = not is_3d or shape[0] == 1 or shape[-1] == 1 | ||
channel_first = not is_3d or shape[0] in (1, 3) | ||
|
||
# to 3D image | ||
if not is_3d: | ||
current_obs = np.expand_dims(current_obs, axis=0) | ||
|
||
# channel last (opencv needs it) | ||
if channel_first: | ||
current_obs = np.transpose(current_obs, (1, 2, 0)) | ||
|
||
# resize | ||
if current_obs.shape[:-1] != (cfg.env.screen_size, cfg.env.screen_size): | ||
current_obs = cv2.resize( | ||
current_obs, | ||
(cfg.env.screen_size, cfg.env.screen_size), | ||
interpolation=cv2.INTER_AREA, | ||
) | ||
|
||
# to grayscale | ||
if cfg.env.grayscale and not is_grayscale: | ||
current_obs = cv2.cvtColor(current_obs, cv2.COLOR_RGB2GRAY) | ||
|
||
# back to 3D | ||
if len(current_obs.shape) == 2: | ||
current_obs = np.expand_dims(current_obs, axis=-1) | ||
if not cfg.env.grayscale: | ||
current_obs = np.repeat(current_obs, 3, axis=-1) | ||
|
||
# channel first (PyTorch default) | ||
obs[k] = current_obs.transpose(2, 0, 1) | ||
|
||
return obs | ||
|
||
env = gym.wrappers.TransformObservation(env, transform_obs) | ||
for k in cnn_keys: | ||
env.observation_space[k] = gym.spaces.Box( | ||
0, | ||
255, | ||
( | ||
1 if cfg.env.grayscale else 3, | ||
cfg.env.screen_size, | ||
cfg.env.screen_size, | ||
), | ||
np.uint8, | ||
) | ||
|
||
if cnn_keys is not None and len(cnn_keys) > 0 and cfg.env.frame_stack > 1: | ||
if cfg.env.frame_stack_dilation <= 0: | ||
raise ValueError( | ||
f"The frame stack dilation argument must be greater than zero, got: {cfg.env.frame_stack_dilation}" | ||
) | ||
env = FrameStack( | ||
env, cfg.env.frame_stack, cnn_keys, cfg.env.frame_stack_dilation | ||
) | ||
|
||
if cfg.env.reward_as_observation: | ||
env = RewardAsObservationWrapper(env) | ||
|
||
env.action_space.seed(seed) | ||
env.observation_space.seed(seed) | ||
if cfg.env.max_episode_steps and cfg.env.max_episode_steps > 0: | ||
env = gym.wrappers.TimeLimit( | ||
env, max_episode_steps=cfg.env.max_episode_steps | ||
) | ||
env = gym.wrappers.RecordEpisodeStatistics(env) | ||
if ( | ||
cfg.env.capture_video | ||
and rank == 0 | ||
and vector_env_idx == 0 | ||
and run_name is not None | ||
): | ||
if cfg.env.grayscale: | ||
env = GrayscaleRenderWrapper(env) | ||
env = gym.experimental.wrappers.RecordVideoV0( | ||
env, | ||
os.path.join(run_name, prefix + "_videos" if prefix else "videos"), | ||
disable_logger=True, | ||
) | ||
env.metadata["render_fps"] = env.frames_per_sec | ||
return env | ||
|
||
return thunk |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,59 +1,75 @@ | ||
import setuptools, os | ||
import os | ||
from pathlib import Path | ||
|
||
import setuptools | ||
|
||
try: | ||
from pip import main as pipmain | ||
except ImportError: | ||
from pip._internal import main as pipmain | ||
|
||
pipmain(['install', 'setuptools']) | ||
pipmain(['install', 'distro']) | ||
pipmain(["install", "setuptools"]) | ||
pipmain(["install", "distro"]) | ||
|
||
extras= { | ||
'core': [], | ||
'tests': ['pytest', 'pytest-mock', 'testresources'], | ||
'stable-baselines': ['stable-baselines~=2.10.2', 'gym<=0.21.0', "protobuf==3.20.1", "pyyaml"], | ||
'stable-baselines3': ['stable-baselines3[extra]~=2.1.0', "pyyaml"], | ||
'ray-rllib': ['ray[rllib]~=2.7.0', 'tensorflow', 'torch', "pyyaml"], | ||
extras = { | ||
"core": [], | ||
"tests": ["pytest", "pytest-mock", "testresources"], | ||
"stable-baselines": [ | ||
"stable-baselines~=2.10.2", | ||
"gym<=0.21.0", | ||
"protobuf==3.20.1", | ||
"pyyaml", | ||
], | ||
"stable-baselines3": ["stable-baselines3[extra]~=2.1.0", "pyyaml"], | ||
"ray-rllib": ["ray[rllib]~=2.7.0", "tensorflow", "torch", "pyyaml"], | ||
"sheeprl": [ | ||
"sheeprl @ git+https://github.com/Eclectic-Sheep/sheeprl.git", | ||
"importlib-resources==6.1.0", | ||
], | ||
} | ||
|
||
# NOTE Package data is inside MANIFEST.In | ||
|
||
setuptools.setup( | ||
name='diambra-arena', | ||
url='https://github.com/diambra/arena', | ||
version=os.environ.get('VERSION', '0.0.0'), | ||
name="diambra-arena", | ||
url="https://github.com/diambra/arena", | ||
version=os.environ.get("VERSION", "0.0.0"), | ||
author="DIAMBRA Team", | ||
author_email="[email protected]", | ||
description="DIAMBRA™ Arena. Built with OpenAI Gym Python interface, easy to use, transforms popular video games into Reinforcement Learning environments", | ||
long_description = (Path(__file__).parent / "README.md").read_text(), | ||
long_description=(Path(__file__).parent / "README.md").read_text(), | ||
long_description_content_type="text/markdown", | ||
license='Custom', | ||
license="Custom", | ||
install_requires=[ | ||
'pip>=21', | ||
'importlib-metadata<=4.12.0; python_version <= "3.7"', # problem with gym for importlib-metadata==5.0.0 and python <=3.7 | ||
'setuptools', | ||
'distro>=1', | ||
'gymnasium>=0.26.3', | ||
'inputs', | ||
'screeninfo', | ||
'tk', | ||
'opencv-python>=4.4.0.42', | ||
'grpcio', | ||
'diambra-engine~=2.2.0', | ||
'dacite'], | ||
packages=[package for package in setuptools.find_packages() if package.startswith("diambra")], | ||
"pip>=21", | ||
'importlib-metadata<=4.12.0; python_version <= "3.7"', # problem with gym for importlib-metadata==5.0.0 and python <=3.7 | ||
"setuptools", | ||
"distro>=1", | ||
"gymnasium>=0.26.3", | ||
"inputs", | ||
"screeninfo", | ||
"tk", | ||
"opencv-python>=4.4.0.42", | ||
"grpcio", | ||
"diambra-engine~=2.2.0", | ||
"dacite", | ||
], | ||
packages=[ | ||
package | ||
for package in setuptools.find_packages() | ||
if package.startswith("diambra") | ||
], | ||
include_package_data=True, | ||
extras_require=extras, | ||
classifiers=[ | ||
'Development Status :: 3 - Alpha', | ||
'Operating System :: OS Independent', | ||
'Programming Language :: Python', | ||
'Programming Language :: Python :: 3', | ||
'Topic :: Scientific/Engineering :: Artificial Intelligence', | ||
'Topic :: Scientific/Engineering :: Artificial Life', | ||
'Topic :: Games/Entertainment', | ||
'Topic :: Games/Entertainment :: Arcade', | ||
'Topic :: Education', | ||
] | ||
"Development Status :: 3 - Alpha", | ||
"Operating System :: OS Independent", | ||
"Programming Language :: Python", | ||
"Programming Language :: Python :: 3", | ||
"Topic :: Scientific/Engineering :: Artificial Intelligence", | ||
"Topic :: Scientific/Engineering :: Artificial Life", | ||
"Topic :: Games/Entertainment", | ||
"Topic :: Games/Entertainment :: Arcade", | ||
"Topic :: Education", | ||
], | ||
) |