diff --git a/diambra/arena/sheeprl/__init__.py b/diambra/arena/sheeprl/__init__.py new file mode 100644 index 0000000..c319f81 --- /dev/null +++ b/diambra/arena/sheeprl/__init__.py @@ -0,0 +1,9 @@ +# Diambra Agents + +import importlib_resources +import sheeprl.utils.env + +from diambra.arena.sheeprl.make_sheeprl_env import make_sheeprl_env + +sheeprl.utils.env.make_env = make_sheeprl_env +CONFIGS_PATH = str(importlib_resources.files("sheeprl.configs")) diff --git a/diambra/arena/sheeprl/make_sheeprl_env.py b/diambra/arena/sheeprl/make_sheeprl_env.py new file mode 100644 index 0000000..e7da196 --- /dev/null +++ b/diambra/arena/sheeprl/make_sheeprl_env.py @@ -0,0 +1,164 @@ +# Diambra Arena + +from __future__ import annotations + +import os +import warnings +from typing import Any, Callable, Dict + +import cv2 +import gymnasium as gym +import hydra +import numpy as np +from sheeprl.envs.wrappers import ( + FrameStack, + GrayscaleRenderWrapper, + RewardAsObservationWrapper, +) + + +def make_sheeprl_env( + cfg: Dict[str, Any], + seed: int, + rank: int, + run_name: str | None = None, + prefix: str = "", + vector_env_idx: int = 0, +) -> Callable[[], gym.Env]: + """ + Create the callable function to create environment and + force the environment to return an observation space of type + gymnasium.spaces.Dict. + + Args: + cfg (Dict[str, Any]): the configs of the environment to initialize. + seed (int): the seed to use. + rank (int): the rank of the process. + run_name (str, optional): the name of the run. + Default to None. + prefix (str): the prefix to add to the video folder. + Default to "". + vector_env_idx (int): the index of the environment. + + Returns: + The callable function that initializes the environment. + """ + + def thunk() -> gym.Env: + if "diambra" in cfg.env.wrapper._target_ and not cfg.env.sync_env: + if cfg.env.wrapper.diambra_settings.pop("splash_screen", True): + warnings.warn( + "You must set the `splash_screen` setting to `False` when using the `AsyncVectorEnv` " + "in `DIAMBRA` environments. The specified `splash_screen` setting is ignored and set " + "to `False`." + ) + cfg.env.wrapper.diambra_settings.splash_screen = False + + instantiate_kwargs = {} + if "seed" in cfg.env.wrapper: + instantiate_kwargs["seed"] = seed + if "rank" in cfg.env.wrapper: + instantiate_kwargs["rank"] = rank + vector_env_idx + env = hydra.utils.instantiate(cfg.env.wrapper, **instantiate_kwargs) + + env_cnn_keys = set( + [ + k + for k in env.observation_space.spaces.keys() + if len(env.observation_space[k].shape) in {2, 3} + ] + ) + if cfg.cnn_keys.encoder is None: + user_cnn_keys = set() + else: + user_cnn_keys = set(cfg.cnn_keys.encoder) + cnn_keys = env_cnn_keys.intersection(user_cnn_keys) + + def transform_obs(obs: Dict[str, Any]): + for k in cnn_keys: + current_obs = obs[k] + shape = current_obs.shape + is_3d = len(shape) == 3 + is_grayscale = not is_3d or shape[0] == 1 or shape[-1] == 1 + channel_first = not is_3d or shape[0] in (1, 3) + + # to 3D image + if not is_3d: + current_obs = np.expand_dims(current_obs, axis=0) + + # channel last (opencv needs it) + if channel_first: + current_obs = np.transpose(current_obs, (1, 2, 0)) + + # resize + if current_obs.shape[:-1] != (cfg.env.screen_size, cfg.env.screen_size): + current_obs = cv2.resize( + current_obs, + (cfg.env.screen_size, cfg.env.screen_size), + interpolation=cv2.INTER_AREA, + ) + + # to grayscale + if cfg.env.grayscale and not is_grayscale: + current_obs = cv2.cvtColor(current_obs, cv2.COLOR_RGB2GRAY) + + # back to 3D + if len(current_obs.shape) == 2: + current_obs = np.expand_dims(current_obs, axis=-1) + if not cfg.env.grayscale: + current_obs = np.repeat(current_obs, 3, axis=-1) + + # channel first (PyTorch default) + obs[k] = current_obs.transpose(2, 0, 1) + + return obs + + env = gym.wrappers.TransformObservation(env, transform_obs) + for k in cnn_keys: + env.observation_space[k] = gym.spaces.Box( + 0, + 255, + ( + 1 if cfg.env.grayscale else 3, + cfg.env.screen_size, + cfg.env.screen_size, + ), + np.uint8, + ) + + if cnn_keys is not None and len(cnn_keys) > 0 and cfg.env.frame_stack > 1: + if cfg.env.frame_stack_dilation <= 0: + raise ValueError( + f"The frame stack dilation argument must be greater than zero, got: {cfg.env.frame_stack_dilation}" + ) + env = FrameStack( + env, cfg.env.frame_stack, cnn_keys, cfg.env.frame_stack_dilation + ) + + if cfg.env.reward_as_observation: + env = RewardAsObservationWrapper(env) + + env.action_space.seed(seed) + env.observation_space.seed(seed) + if cfg.env.max_episode_steps and cfg.env.max_episode_steps > 0: + env = gym.wrappers.TimeLimit( + env, max_episode_steps=cfg.env.max_episode_steps + ) + env = gym.wrappers.RecordEpisodeStatistics(env) + if ( + cfg.env.capture_video + and rank == 0 + and vector_env_idx == 0 + and run_name is not None + ): + if cfg.env.grayscale: + env = GrayscaleRenderWrapper(env) + env = gym.experimental.wrappers.RecordVideoV0( + env, + os.path.join(run_name, prefix + "_videos" if prefix else "videos"), + disable_logger=True, + ) + env.metadata["render_fps"] = env.frames_per_sec + return env + + return thunk diff --git a/setup.py b/setup.py index de633c8..2ea732c 100644 --- a/setup.py +++ b/setup.py @@ -1,59 +1,75 @@ -import setuptools, os +import os from pathlib import Path +import setuptools + try: from pip import main as pipmain except ImportError: from pip._internal import main as pipmain -pipmain(['install', 'setuptools']) -pipmain(['install', 'distro']) +pipmain(["install", "setuptools"]) +pipmain(["install", "distro"]) -extras= { - 'core': [], - 'tests': ['pytest', 'pytest-mock', 'testresources'], - 'stable-baselines': ['stable-baselines~=2.10.2', 'gym<=0.21.0', "protobuf==3.20.1", "pyyaml"], - 'stable-baselines3': ['stable-baselines3[extra]~=2.1.0', "pyyaml"], - 'ray-rllib': ['ray[rllib]~=2.7.0', 'tensorflow', 'torch', "pyyaml"], +extras = { + "core": [], + "tests": ["pytest", "pytest-mock", "testresources"], + "stable-baselines": [ + "stable-baselines~=2.10.2", + "gym<=0.21.0", + "protobuf==3.20.1", + "pyyaml", + ], + "stable-baselines3": ["stable-baselines3[extra]~=2.1.0", "pyyaml"], + "ray-rllib": ["ray[rllib]~=2.7.0", "tensorflow", "torch", "pyyaml"], + "sheeprl": [ + "sheeprl @ git+https://github.com/Eclectic-Sheep/sheeprl.git", + "importlib-resources==6.1.0", + ], } # NOTE Package data is inside MANIFEST.In setuptools.setup( - name='diambra-arena', - url='https://github.com/diambra/arena', - version=os.environ.get('VERSION', '0.0.0'), + name="diambra-arena", + url="https://github.com/diambra/arena", + version=os.environ.get("VERSION", "0.0.0"), author="DIAMBRA Team", author_email="info@diambra.ai", description="DIAMBRA™ Arena. Built with OpenAI Gym Python interface, easy to use, transforms popular video games into Reinforcement Learning environments", - long_description = (Path(__file__).parent / "README.md").read_text(), + long_description=(Path(__file__).parent / "README.md").read_text(), long_description_content_type="text/markdown", - license='Custom', + license="Custom", install_requires=[ - 'pip>=21', - 'importlib-metadata<=4.12.0; python_version <= "3.7"', # problem with gym for importlib-metadata==5.0.0 and python <=3.7 - 'setuptools', - 'distro>=1', - 'gymnasium>=0.26.3', - 'inputs', - 'screeninfo', - 'tk', - 'opencv-python>=4.4.0.42', - 'grpcio', - 'diambra-engine~=2.2.0', - 'dacite'], - packages=[package for package in setuptools.find_packages() if package.startswith("diambra")], + "pip>=21", + 'importlib-metadata<=4.12.0; python_version <= "3.7"', # problem with gym for importlib-metadata==5.0.0 and python <=3.7 + "setuptools", + "distro>=1", + "gymnasium>=0.26.3", + "inputs", + "screeninfo", + "tk", + "opencv-python>=4.4.0.42", + "grpcio", + "diambra-engine~=2.2.0", + "dacite", + ], + packages=[ + package + for package in setuptools.find_packages() + if package.startswith("diambra") + ], include_package_data=True, extras_require=extras, classifiers=[ - 'Development Status :: 3 - Alpha', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: Scientific/Engineering :: Artificial Life', - 'Topic :: Games/Entertainment', - 'Topic :: Games/Entertainment :: Arcade', - 'Topic :: Education', - ] + "Development Status :: 3 - Alpha", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Artificial Life", + "Topic :: Games/Entertainment", + "Topic :: Games/Entertainment :: Arcade", + "Topic :: Education", + ], )