Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve gym env / tune hps #71

Merged
merged 33 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
9086aed
use pid to speed up real env reset
Armandpl Jan 22, 2024
01451f4
setup deadzone wrapper
Armandpl Jan 23, 2024
9e7d5fa
setup tqdc + deadzone sim exp
Armandpl Jan 23, 2024
347befd
add angles to the obs when limiting them; add time feat wrapper
Armandpl Jan 23, 2024
cd9cf1c
check for nan instead of none bc none gets converted to nan by np.arr…
Armandpl Jan 23, 2024
5c4b6c4
setup real exp
Armandpl Jan 24, 2024
e2d6360
add deadzone wrapper
Armandpl Jan 24, 2024
99a80dd
use cos for env reset bc pendulum can do multiple turns
Armandpl Jan 25, 2024
df6cf31
add a center to the deadzone wrapper to allow for a zero action
Armandpl Jan 25, 2024
b5bdca6
add exp reward
Armandpl Jan 25, 2024
3e00fae
setup tqc + exp reward sim experiment
Armandpl Jan 25, 2024
06116f7
add sbx and jax-metal
Armandpl Jan 25, 2024
7c9f844
fix base wrapper; make sim exp closer to real w/ n_envs=1
Armandpl Jan 25, 2024
22c5a43
add timeout to motor stopping, hack to avoid pid oscillations instead…
Armandpl Jan 25, 2024
8a62cdc
setup rew sweep
Armandpl Jan 25, 2024
1b925e6
investigate tqc action looking like gaussian noise
Armandpl Jan 29, 2024
d650c03
fix default args
Armandpl Jan 29, 2024
0f6c7ad
add max act to deadzone wrapper
Armandpl Jan 29, 2024
5503027
setup hp tune sweep for tqc
Armandpl Jan 29, 2024
5b7e11d
actually use early stopping
Armandpl Jan 29, 2024
a46febd
use sb3_contrib master to train
Armandpl Jan 30, 2024
e3d3761
use progressbar
Armandpl Jan 30, 2024
2186f79
add option to use either sbx or sbx tqc
Armandpl Jan 31, 2024
df26804
add new reward taking theta into account
Armandpl Jan 31, 2024
f679592
remove unused code
Armandpl Jan 31, 2024
722456e
make progressbar a param
Armandpl Jan 31, 2024
e751650
setup new rew sweep
Armandpl Jan 31, 2024
2d76294
fix error in reward
Armandpl Jan 31, 2024
5c933d0
setup real training
Armandpl Jan 31, 2024
9759dc8
fix sweep setup
Armandpl Jan 31, 2024
5986850
add theta reward
Armandpl Feb 1, 2024
9070298
clear gsde notebook output
Armandpl Feb 1, 2024
bcda11d
fix default reward key
Armandpl Feb 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions furuta/rl/algos.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
import sb3_contrib
import sbx
import stable_baselines3


# wrapper class for stable-baselines3.SAC
# TODO can we make one class for all algos?
# check if they all have the train freq param
# check if they have other tuple args
# check if it would be cleaner for sb3 to accept list instead of tuple?
class SAC(stable_baselines3.SAC):
# TODO is there a cleaner way to do this?
class BaseAlgoWrapper:
def __init__(self, **kwargs):
# sb3 expects tuple, omegaconf returns list
# so we need to convert kwarg train_freq from tuple to list
if "train_freq" in kwargs and isinstance(kwargs["train_freq"], list):
kwargs.update({"train_freq": tuple(kwargs["train_freq"])})

super().__init__(**kwargs)


class SAC(BaseAlgoWrapper, stable_baselines3.SAC):
pass


class TQC(BaseAlgoWrapper, sb3_contrib.TQC):
pass


class SBXTQC(BaseAlgoWrapper, sbx.TQC):
pass
55 changes: 43 additions & 12 deletions furuta/rl/envs/furuta_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,36 @@
from furuta.utils import ALPHA, ALPHA_DOT, THETA, THETA_DOT, Timing


def exp_alpha_theta_reward(state, exp=2):
al_rew = exp_alpha_reward(state, exp)
th_rew = theta_reward(state)
return al_rew * th_rew


def exp_alpha_reward(state, exp=2):
al = np.mod((state[ALPHA] + np.pi), 2 * np.pi) - np.pi # between -pi and pi
al_rew = np.abs(al) / np.pi # 0 at 0, 1 at pi
al_rew = (np.exp(al_rew * exp) - np.exp(0)) / np.exp(exp)
return al_rew


def alpha_theta_reward(state):
return alpha_reward(state) * theta_reward(state)


def alpha_reward(state):
return (1 + -np.cos(state[ALPHA])) / 2


def alpha_theta_reward(state):
return alpha_reward(state) + (1 + np.cos(state[THETA])) / 2
def theta_reward(state):
return (1 + np.cos(state[THETA])) / 2


REWARDS = {
"alpha": alpha_reward,
"alpha_theta": alpha_theta_reward,
"cos_alpha": alpha_theta_reward,
"exp_alpha_2": lambda x: exp_alpha_theta_reward(x, exp=2),
"exp_alpha_4": lambda x: exp_alpha_theta_reward(x, exp=4),
"exp_alpha_6": lambda x: exp_alpha_theta_reward(x, exp=6),
}


Expand Down Expand Up @@ -50,14 +69,14 @@ def __init__(

act_max = np.array([1.0], dtype=np.float32)

if angle_limits is None:
angle_limits = [np.inf, np.inf]
if speed_limits is None:
speed_limits = [np.inf, np.inf]
angle_limits = np.array(angle_limits, dtype=np.float32)
speed_limits = np.array(speed_limits, dtype=np.float32)

self.state_max = np.array(
[angle_limits[0], angle_limits[1], speed_limits[0], speed_limits[1]], dtype=np.float32
)
# replace none values with inf
angle_limits = np.where(np.isnan(angle_limits), np.inf, angle_limits) # noqa
speed_limits = np.where(np.isnan(speed_limits), np.inf, speed_limits) # noqa

self.state_max = np.concatenate([angle_limits, speed_limits])

# max obs based on max speeds measured on the robot
# in sim the speeds spike at 30 rad/s when trained
Expand All @@ -67,6 +86,12 @@ def __init__(
# obs is [cos(th), sin(th), cos(al), sin(al), th_d, al_d)]
obs_max = np.array([1.0, 1.0, 1.0, 1.0, 30, 30], dtype=np.float32)

# if limit on angles, add them to the obs
if not np.isinf(self.state_max[ALPHA]):
obs_max = np.concatenate([np.array([self.state_max[ALPHA]]), obs_max])
if not np.isinf(self.state_max[THETA]):
obs_max = np.concatenate([np.array([self.state_max[THETA]]), obs_max])

# Spaces
self.state_space = Box(
# ('theta', 'alpha', 'theta_dot', 'alpha_dot'),
Expand Down Expand Up @@ -107,7 +132,7 @@ def step(self, action):
return obs, rwd, terminated, truncated, {}

def get_obs(self):
return np.float32(
obs = np.float32(
[
np.cos(self._state[THETA]),
np.sin(self._state[THETA]),
Expand All @@ -117,6 +142,12 @@ def get_obs(self):
self._state[ALPHA_DOT],
]
)
if not np.isinf(self.state_max[ALPHA]):
obs = np.concatenate([np.array([self._state[ALPHA]]), obs])
if not np.isinf(self.state_max[THETA]):
obs = np.concatenate([np.array([self._state[THETA]]), obs])

return obs

def reset(
self,
Expand Down
63 changes: 40 additions & 23 deletions furuta/rl/envs/furuta_real.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,35 @@
from typing import Optional

import numpy as np
from simple_pid import PID

from furuta.rl.envs.furuta_base import FurutaBase
from furuta.robot import Robot
from furuta.utils import ALPHA, ALPHA_DOT, THETA, THETA_DOT, VelocityFilter

MAX_RESET_TIME = 7 # seconds
MAX_MOTOR_RESET_TIME = 0.2 # seconds
RESET_TIME = 0.5
ALPHA_THRESH = np.deg2rad(
2
ALPHA_THRESH = np.cos(
np.deg2rad(2)
) # alpha should stay between -2 and 2 deg for 0.5 sec for us to consider the env reset


class FurutaReal(FurutaBase):
def __init__(
self,
control_freq=100,
reward="alpha_theta",
reward="cos_alpha",
angle_limits=None,
speed_limits=None,
usb_device="/dev/ttyACM0",
motor_stop_pid=[0.04, 0.0, 0.001],
):
super().__init__(control_freq, reward, angle_limits, speed_limits)
self.motor_stop_pid = motor_stop_pid

self.robot = Robot(usb_device)

self._init_vel_filt()

self._update_state(0.0)
self._state = None

def _init_vel_filt(self):
self.vel_filt = VelocityFilter(2, dt=self.timing.dt)
Expand All @@ -52,30 +53,46 @@ def reset(
super().reset(seed=seed)
logging.info("Reset env...")

# wait for pendulum to fall back to start position
reset_time = 0
time_under_thresh = 0

while time_under_thresh < RESET_TIME and reset_time < MAX_RESET_TIME:
sleep(0.01)
if abs(self._state[ALPHA]) < ALPHA_THRESH:
time_under_thresh += 0.01
else:
time_under_thresh = 0
reset_time += 0.01
self._update_state(0.0)

if reset_time >= MAX_RESET_TIME:
logging.error("Reset timeout")
if self._state is not None: # if not first reset
logging.debug("Stopping motor")
motor_pid = PID(
self.motor_stop_pid[0],
self.motor_stop_pid[1],
self.motor_stop_pid[2],
setpoint=0.0,
output_limits=(-1, 1),
)

reset_time = 0
while abs(self._state[THETA_DOT]) > 0.5 and reset_time < MAX_MOTOR_RESET_TIME:
act = motor_pid(self._state[THETA_DOT])
self._update_state(act)
reset_time += self.timing.dt
sleep(self.timing.dt)

logging.debug("Waiting for pendulum to fall back down")
time_under_thresh = 0
reset_time = 0
while time_under_thresh < RESET_TIME and reset_time < MAX_RESET_TIME:
if np.cos(self._state[ALPHA]) > ALPHA_THRESH:
time_under_thresh += self.timing.dt
else:
time_under_thresh = 0
self._update_state(0.0)
reset_time += self.timing.dt
sleep(self.timing.dt)

if reset_time >= MAX_RESET_TIME:
logging.info(f"Reset timeout, alpha: {np.rad2deg(self._state[ALPHA])}")

# reset both encoder, motor back to pos=0
self.robot.reset_encoders()

logging.info("Reset done")
self._update_state(0.0)
# else the first computed velocity will take into account previous episode
# and it'll be huge and wrong and will terminate the episode
self._init_vel_filt()
self._update_state(0.0) # initial state
return self.get_obs(), {}

# TODO: override parent render function
Expand Down
6 changes: 3 additions & 3 deletions furuta/rl/envs/furuta_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ def __init__(
self,
dyn: QubeDynamics = QubeDynamics(),
control_freq=50,
reward="alpha",
angle_limits=None,
speed_limits=None,
reward="cos_alpha",
angle_limits=[None, None],
speed_limits=[None, None],
encoders_CPRs: Optional[List[float]] = None,
velocity_filter: int = None,
render_mode="rgb_array",
Expand Down
31 changes: 26 additions & 5 deletions furuta/rl/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,33 @@ def step(self, action):
self.unwrapped.robot.step(0.0)
return observation, reward, terminated, truncated, info

def reset(
self,
seed: Optional[int] = None,
options: Optional[dict] = None,

class DeadZone(gym.Wrapper):
"""On the real robot, if it isn't moving, a zero command won't move it.

When using gSDE, having actions that don't move the robot seem to cause issues.

Also add the option to limit the max action because the robot can't really handle full power.
"""

def __init__(
self, env: gym.Env, deadzone: float = 0.2, center: float = 0.01, max_act: float = 0.75
):
return self.env.reset()
super().__init__(env)
self.deadzone = deadzone
self.center = center
self.max_act = max_act

def step(self, action):
# TODO this only works if the action is between -1 and 1
if abs(action) > self.center:
action = np.sign(action) * (
np.abs(action) * (self.max_act - self.deadzone) + self.deadzone
)
else:
action = np.zeros_like(action)
observation, reward, terminated, truncated, info = self.env.step(action)
return observation, reward, terminated, truncated, info


class MCAPLogger(gym.Wrapper):
Expand Down
95 changes: 95 additions & 0 deletions notebooks/2023_01_29_debug_gsde.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sbx import SAC as SAC_SBX\n",
"from stable_baselines3 import SAC as SAC_SB3\n",
"import matplotlib.pyplot as plt\n",
"import gymnasium as gym\n",
"from furuta.rl.envs.furuta_sim import FurutaSim\n",
"from gymnasium.wrappers import TimeLimit\n",
"\n",
"class ActionLogger(gym.Wrapper):\n",
" def __init__(self, env):\n",
" super().__init__(env)\n",
" self.actions = []\n",
" def step(self, action):\n",
" self.actions.append(action)\n",
" return self.env.step(action)\n",
" def plot_act(self):\n",
" plt.plot(self.actions[-100:])\n",
" plt.show()\n",
"\n",
"env = TimeLimit(ActionLogger(FurutaSim(speed_limits=[400, 400])), max_episode_steps=100)\n",
"\n",
"model = SAC_SB3(\"MlpPolicy\", env, verbose=1, use_sde=True, use_sde_at_warmup=True, learning_starts=500)\n",
"model.learn(total_timesteps=1000, log_interval=4)\n",
"\n",
"env.plot_act()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = SAC_SB3(\"MlpPolicy\", env, verbose=1, use_sde=True, use_sde_at_warmup=True, learning_starts=500, train_freq=(1, \"episode\"))\n",
"model.learn(total_timesteps=1000, log_interval=4)\n",
"\n",
"env.plot_act()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = SAC_SBX(\"MlpPolicy\", env, verbose=1, use_sde=True, use_sde_at_warmup=True, learning_starts=500)\n",
"model.learn(total_timesteps=1000, log_interval=4)\n",
"\n",
"env.plot_act()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sbx import TQC\n",
"\n",
"model = TQC(\"MlpPolicy\", env, verbose=1, use_sde=True, use_sde_at_warmup=True, learning_starts=500, train_freq=(1, \"episode\"))\n",
"model.learn(total_timesteps=1000, log_interval=4)\n",
"\n",
"env.plot_act()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading
Loading