Skip to content

Commit

Permalink
Merge pull request #71 from Armandpl/sanity-check-robot
Browse files Browse the repository at this point in the history
improve gym env / tune hps
  • Loading branch information
Armandpl authored Feb 1, 2024
2 parents 07d92d4 + bcda11d commit f12a1da
Show file tree
Hide file tree
Showing 23 changed files with 942 additions and 704 deletions.
21 changes: 16 additions & 5 deletions furuta/rl/algos.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
import sb3_contrib
import sbx
import stable_baselines3


# wrapper class for stable-baselines3.SAC
# TODO can we make one class for all algos?
# check if they all have the train freq param
# check if they have other tuple args
# check if it would be cleaner for sb3 to accept list instead of tuple?
class SAC(stable_baselines3.SAC):
# TODO is there a cleaner way to do this?
class BaseAlgoWrapper:
def __init__(self, **kwargs):
# sb3 expects tuple, omegaconf returns list
# so we need to convert kwarg train_freq from tuple to list
if "train_freq" in kwargs and isinstance(kwargs["train_freq"], list):
kwargs.update({"train_freq": tuple(kwargs["train_freq"])})

super().__init__(**kwargs)


class SAC(BaseAlgoWrapper, stable_baselines3.SAC):
pass


class TQC(BaseAlgoWrapper, sb3_contrib.TQC):
pass


class SBXTQC(BaseAlgoWrapper, sbx.TQC):
pass
55 changes: 43 additions & 12 deletions furuta/rl/envs/furuta_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,36 @@
from furuta.utils import ALPHA, ALPHA_DOT, THETA, THETA_DOT, Timing


def exp_alpha_theta_reward(state, exp=2):
al_rew = exp_alpha_reward(state, exp)
th_rew = theta_reward(state)
return al_rew * th_rew


def exp_alpha_reward(state, exp=2):
al = np.mod((state[ALPHA] + np.pi), 2 * np.pi) - np.pi # between -pi and pi
al_rew = np.abs(al) / np.pi # 0 at 0, 1 at pi
al_rew = (np.exp(al_rew * exp) - np.exp(0)) / np.exp(exp)
return al_rew


def alpha_theta_reward(state):
return alpha_reward(state) * theta_reward(state)


def alpha_reward(state):
return (1 + -np.cos(state[ALPHA])) / 2


def alpha_theta_reward(state):
return alpha_reward(state) + (1 + np.cos(state[THETA])) / 2
def theta_reward(state):
return (1 + np.cos(state[THETA])) / 2


REWARDS = {
"alpha": alpha_reward,
"alpha_theta": alpha_theta_reward,
"cos_alpha": alpha_theta_reward,
"exp_alpha_2": lambda x: exp_alpha_theta_reward(x, exp=2),
"exp_alpha_4": lambda x: exp_alpha_theta_reward(x, exp=4),
"exp_alpha_6": lambda x: exp_alpha_theta_reward(x, exp=6),
}


Expand Down Expand Up @@ -50,14 +69,14 @@ def __init__(

act_max = np.array([1.0], dtype=np.float32)

if angle_limits is None:
angle_limits = [np.inf, np.inf]
if speed_limits is None:
speed_limits = [np.inf, np.inf]
angle_limits = np.array(angle_limits, dtype=np.float32)
speed_limits = np.array(speed_limits, dtype=np.float32)

self.state_max = np.array(
[angle_limits[0], angle_limits[1], speed_limits[0], speed_limits[1]], dtype=np.float32
)
# replace none values with inf
angle_limits = np.where(np.isnan(angle_limits), np.inf, angle_limits) # noqa
speed_limits = np.where(np.isnan(speed_limits), np.inf, speed_limits) # noqa

self.state_max = np.concatenate([angle_limits, speed_limits])

# max obs based on max speeds measured on the robot
# in sim the speeds spike at 30 rad/s when trained
Expand All @@ -67,6 +86,12 @@ def __init__(
# obs is [cos(th), sin(th), cos(al), sin(al), th_d, al_d)]
obs_max = np.array([1.0, 1.0, 1.0, 1.0, 30, 30], dtype=np.float32)

# if limit on angles, add them to the obs
if not np.isinf(self.state_max[ALPHA]):
obs_max = np.concatenate([np.array([self.state_max[ALPHA]]), obs_max])
if not np.isinf(self.state_max[THETA]):
obs_max = np.concatenate([np.array([self.state_max[THETA]]), obs_max])

# Spaces
self.state_space = Box(
# ('theta', 'alpha', 'theta_dot', 'alpha_dot'),
Expand Down Expand Up @@ -107,7 +132,7 @@ def step(self, action):
return obs, rwd, terminated, truncated, {}

def get_obs(self):
return np.float32(
obs = np.float32(
[
np.cos(self._state[THETA]),
np.sin(self._state[THETA]),
Expand All @@ -117,6 +142,12 @@ def get_obs(self):
self._state[ALPHA_DOT],
]
)
if not np.isinf(self.state_max[ALPHA]):
obs = np.concatenate([np.array([self._state[ALPHA]]), obs])
if not np.isinf(self.state_max[THETA]):
obs = np.concatenate([np.array([self._state[THETA]]), obs])

return obs

def reset(
self,
Expand Down
63 changes: 40 additions & 23 deletions furuta/rl/envs/furuta_real.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,35 @@
from typing import Optional

import numpy as np
from simple_pid import PID

from furuta.rl.envs.furuta_base import FurutaBase
from furuta.robot import Robot
from furuta.utils import ALPHA, ALPHA_DOT, THETA, THETA_DOT, VelocityFilter

MAX_RESET_TIME = 7 # seconds
MAX_MOTOR_RESET_TIME = 0.2 # seconds
RESET_TIME = 0.5
ALPHA_THRESH = np.deg2rad(
2
ALPHA_THRESH = np.cos(
np.deg2rad(2)
) # alpha should stay between -2 and 2 deg for 0.5 sec for us to consider the env reset


class FurutaReal(FurutaBase):
def __init__(
self,
control_freq=100,
reward="alpha_theta",
reward="cos_alpha",
angle_limits=None,
speed_limits=None,
usb_device="/dev/ttyACM0",
motor_stop_pid=[0.04, 0.0, 0.001],
):
super().__init__(control_freq, reward, angle_limits, speed_limits)
self.motor_stop_pid = motor_stop_pid

self.robot = Robot(usb_device)

self._init_vel_filt()

self._update_state(0.0)
self._state = None

def _init_vel_filt(self):
self.vel_filt = VelocityFilter(2, dt=self.timing.dt)
Expand All @@ -52,30 +53,46 @@ def reset(
super().reset(seed=seed)
logging.info("Reset env...")

# wait for pendulum to fall back to start position
reset_time = 0
time_under_thresh = 0

while time_under_thresh < RESET_TIME and reset_time < MAX_RESET_TIME:
sleep(0.01)
if abs(self._state[ALPHA]) < ALPHA_THRESH:
time_under_thresh += 0.01
else:
time_under_thresh = 0
reset_time += 0.01
self._update_state(0.0)

if reset_time >= MAX_RESET_TIME:
logging.error("Reset timeout")
if self._state is not None: # if not first reset
logging.debug("Stopping motor")
motor_pid = PID(
self.motor_stop_pid[0],
self.motor_stop_pid[1],
self.motor_stop_pid[2],
setpoint=0.0,
output_limits=(-1, 1),
)

reset_time = 0
while abs(self._state[THETA_DOT]) > 0.5 and reset_time < MAX_MOTOR_RESET_TIME:
act = motor_pid(self._state[THETA_DOT])
self._update_state(act)
reset_time += self.timing.dt
sleep(self.timing.dt)

logging.debug("Waiting for pendulum to fall back down")
time_under_thresh = 0
reset_time = 0
while time_under_thresh < RESET_TIME and reset_time < MAX_RESET_TIME:
if np.cos(self._state[ALPHA]) > ALPHA_THRESH:
time_under_thresh += self.timing.dt
else:
time_under_thresh = 0
self._update_state(0.0)
reset_time += self.timing.dt
sleep(self.timing.dt)

if reset_time >= MAX_RESET_TIME:
logging.info(f"Reset timeout, alpha: {np.rad2deg(self._state[ALPHA])}")

# reset both encoder, motor back to pos=0
self.robot.reset_encoders()

logging.info("Reset done")
self._update_state(0.0)
# else the first computed velocity will take into account previous episode
# and it'll be huge and wrong and will terminate the episode
self._init_vel_filt()
self._update_state(0.0) # initial state
return self.get_obs(), {}

# TODO: override parent render function
Expand Down
6 changes: 3 additions & 3 deletions furuta/rl/envs/furuta_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ def __init__(
self,
dyn: QubeDynamics = QubeDynamics(),
control_freq=50,
reward="alpha",
angle_limits=None,
speed_limits=None,
reward="cos_alpha",
angle_limits=[None, None],
speed_limits=[None, None],
encoders_CPRs: Optional[List[float]] = None,
velocity_filter: int = None,
render_mode="rgb_array",
Expand Down
31 changes: 26 additions & 5 deletions furuta/rl/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,33 @@ def step(self, action):
self.unwrapped.robot.step(0.0)
return observation, reward, terminated, truncated, info

def reset(
self,
seed: Optional[int] = None,
options: Optional[dict] = None,

class DeadZone(gym.Wrapper):
"""On the real robot, if it isn't moving, a zero command won't move it.
When using gSDE, having actions that don't move the robot seem to cause issues.
Also add the option to limit the max action because the robot can't really handle full power.
"""

def __init__(
self, env: gym.Env, deadzone: float = 0.2, center: float = 0.01, max_act: float = 0.75
):
return self.env.reset()
super().__init__(env)
self.deadzone = deadzone
self.center = center
self.max_act = max_act

def step(self, action):
# TODO this only works if the action is between -1 and 1
if abs(action) > self.center:
action = np.sign(action) * (
np.abs(action) * (self.max_act - self.deadzone) + self.deadzone
)
else:
action = np.zeros_like(action)
observation, reward, terminated, truncated, info = self.env.step(action)
return observation, reward, terminated, truncated, info


class MCAPLogger(gym.Wrapper):
Expand Down
95 changes: 95 additions & 0 deletions notebooks/2023_01_29_debug_gsde.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sbx import SAC as SAC_SBX\n",
"from stable_baselines3 import SAC as SAC_SB3\n",
"import matplotlib.pyplot as plt\n",
"import gymnasium as gym\n",
"from furuta.rl.envs.furuta_sim import FurutaSim\n",
"from gymnasium.wrappers import TimeLimit\n",
"\n",
"class ActionLogger(gym.Wrapper):\n",
" def __init__(self, env):\n",
" super().__init__(env)\n",
" self.actions = []\n",
" def step(self, action):\n",
" self.actions.append(action)\n",
" return self.env.step(action)\n",
" def plot_act(self):\n",
" plt.plot(self.actions[-100:])\n",
" plt.show()\n",
"\n",
"env = TimeLimit(ActionLogger(FurutaSim(speed_limits=[400, 400])), max_episode_steps=100)\n",
"\n",
"model = SAC_SB3(\"MlpPolicy\", env, verbose=1, use_sde=True, use_sde_at_warmup=True, learning_starts=500)\n",
"model.learn(total_timesteps=1000, log_interval=4)\n",
"\n",
"env.plot_act()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = SAC_SB3(\"MlpPolicy\", env, verbose=1, use_sde=True, use_sde_at_warmup=True, learning_starts=500, train_freq=(1, \"episode\"))\n",
"model.learn(total_timesteps=1000, log_interval=4)\n",
"\n",
"env.plot_act()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = SAC_SBX(\"MlpPolicy\", env, verbose=1, use_sde=True, use_sde_at_warmup=True, learning_starts=500)\n",
"model.learn(total_timesteps=1000, log_interval=4)\n",
"\n",
"env.plot_act()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sbx import TQC\n",
"\n",
"model = TQC(\"MlpPolicy\", env, verbose=1, use_sde=True, use_sde_at_warmup=True, learning_starts=500, train_freq=(1, \"episode\"))\n",
"model.learn(total_timesteps=1000, log_interval=4)\n",
"\n",
"env.plot_act()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit f12a1da

Please sign in to comment.