Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ddpg #62

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions examples/ddpg_gym_tf_pendulum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from sandbox.rocky.tf.algos.ddpg.ddpg import DDPG
from sandbox.rocky.tf.algos.ddpg.noise import OrnsteinUhlenbeckActionNoise

import gym
import numpy as np
import tensorflow as tf

RANDOM_SEED = 1234

np.random.seed(RANDOM_SEED)
tf.set_random_seed(RANDOM_SEED)
Copy link
Collaborator

@eric-heiden eric-heiden May 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove these lines, call set_seed(seed) instead


env = gym.make('Pendulum-v0')
env.seed(RANDOM_SEED)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


action_dim = env.action_space.shape[-1]
action_noise = OrnsteinUhlenbeckActionNoise(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the noise should also be based on a seed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add another launcher script which runs your code via run_experiment_lite. It should work like this:

def run_task(*_):
    # initialize actor, critic, noise, env, etc.
    env = ...

    algo = DDPG(env=env, ...)
    algo.train()

run_experiment_lite(
    run_task,
    n_parallel=20,
    plot=False,
)

We need to make sure your DDPG implementation is serializable.

mu=np.zeros(action_dim), sigma=float(0.02) * np.ones(action_dim))

ddpg = DDPG(
env,
plot=False,
action_noise=action_noise,
check_point_dir='pendulum',
log_dir="pendulum_ou_noise")

ddpg.train()
233 changes: 233 additions & 0 deletions sandbox/rocky/tf/algos/ddpg/ddpg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
from sandbox.rocky.tf.algos.network.actor_critic_net import ActorNet, CriticNet
from sandbox.rocky.tf.algos.ddpg.replay_buffer import ReplayBuffer
from rllab.algos.base import RLAlgorithm

import tensorflow as tf
from copy import copy
import numpy as np
import os


class DDPG(RLAlgorithm):
def __init__(self,
env,
gamma=0.99,
tau=0.001,
observation_range=(-5, 5),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

observation_range isn't used anywhere (and shouldn't be needed)

action_range=(-1, 1),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assume that the algorithm always gets a normalized environment, where action_range is always in [-1, 1]. We need to implement normalization for gym Envs anyways as described in #64.

actor_lr=1e-4,
critic_lr=1e-3,
reward_scale=1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reward_scale should be called discount (in accordance with the other implementations). A value of 1 seems strange, are you sure it isn't <1, e.g. 0.99?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the paper and baseline, they all set this value to be 1.

batch_size=64,
critic_l2_weight_decay=0.01,
action_noise=None,
plot=False,
check_point_dir=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use logger.save_itr_params which takes care of checkpoint folders etc.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what this function is used for yet. The checkpoint folder is used for tensorflow, it seems it does not need to save other parameters but the graph.

log_dir=None):
"""
a DDPG model described in https://arxiv.org/pdf/1509.02971.pdf.
The hyperparameters used in the model:
:param env:
:param gamma: a discount factor
:param tau: soft update
:param observation_range: observation space range
:param action_range: action space range
:param actor_lr: learning rate for actor network
:param critic_lr: learning rate for critic network
:param reward_scale: reward discount factor
:param batch_size: batch size
:param critic_l2_weight_decay: L2 weight decay for the weights in critic network
:param action_noise: custom network for the output mean
:param plot: Is plot the train process?
:param checkpoint_dir: directory for saving model
:param log_dir: directory for saving tensorboard logs
:return:
"""

self._env = env
self._session = tf.Session()

observation_shape = env.observation_space.shape

action_shape = env.action_space.shape
actions_dim = env.action_space.shape[-1]
self._max_action = self._env.action_space.high

# Parameters.
self._observation_shape = observation_shape
self._action_shape = action_shape

self._tau = tau
self._action_noise = action_noise
self._action_range = action_range
self._reward_scale = reward_scale
self._batch_size = batch_size
self._plot = plot
self._check_point_dir = check_point_dir
self._saver = None

# Inputs.
self._state = tf.placeholder(
tf.float32, shape=(None, ) + observation_shape, name='state')
self._next_state = tf.placeholder(
tf.float32, shape=(None, ) + observation_shape, name='next_state')
self._terminals = tf.placeholder(
tf.float32, shape=(None, 1), name='terminals')
self._rewards = tf.placeholder(
tf.float32, shape=(None, 1), name='rewards')
self._actions = tf.placeholder(
tf.float32, shape=(None, actions_dim), name='actions')
self._critic_target = tf.placeholder(
tf.float32, shape=(None, 1), name='critic_target')

#actor
self._actor_net = ActorNet(self._session, actions_dim, lr=actor_lr)
self._target_actor = copy(self._actor_net)
self._target_actor.name = 'target_actor'

#critic
self._critic_net = CriticNet(
self._session,
gamma=gamma,
lr=critic_lr,
weight_decay=critic_l2_weight_decay)
self._target_critic = copy(self._critic_net)
self._target_critic.name = 'target_critic'

#replay buffer
self._replay_buffer = ReplayBuffer(1e6)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The size of the replay buffer should be user-definable


if log_dir:
self._summary_writer = tf.summary.FileWriter(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use rllab's logger

log_dir, self._session.graph)
else:
self._summary_writer = None
self._initialize()

def _initialize(self):
#build network
self._actor_net.build_net(self._state)
self._target_actor.build_net(self._next_state)

self._critic_net.build_net(self._state, self._actions, self._rewards,
self._terminals, self._critic_target,
self._actor_net.action)
self._target_critic.build_net(
self._next_state, self._actions, self._rewards, self._terminals,
self._critic_target, self._target_actor.action)

#set grad chain rule
self._actor_net.set_grad(self._critic_net.action_grads)

#setup network to target network params update op
self._actor_net.setup_target_net(self._target_actor, self._tau)
self._critic_net.setup_target_net(self._target_critic, self._tau)

self._global_step = tf.Variable(
initial_value=0, name='global_step', trainable=False)

self.load_session()

def _train_net(self):
# Get a batch.
[state, action, reward, terminal,
next_state] = self._replay_buffer.get_batch_data(self._batch_size)
reward = reward.reshape(-1, 1)
terminal = terminal.reshape(-1, 1)

target_action = self._target_actor.predict(next_state)
target_Q = self._target_critic.predict_target_Q(
next_state, target_action, reward, terminal)

self._critic_net.train(state, action, target_Q)
self._actor_net.train(state)

self._actor_net.update_target_net()
self._critic_net.update_target_net()
return

def _report_total_reward(self, reward, step):
summary = tf.Summary()
summary.value.add(tag='rollout/reward', simple_value=float(reward))
summary.value.add(
tag='train/episode_reward', simple_value=float(reward))
if self._summary_writer:
self._summary_writer.add_summary(summary, step)

def predict(self, state):
action = self._actor_net.predict(np.array(state).reshape(1, -1))[0]
if self._action_noise:
noise = self._action_noise.gen()
action = action + noise
action = np.clip(action, self._action_range[0], self._action_range[1])
return action

def train(self,
epochs=500,
epoch_cycles=20,
rollout_steps=100,
train_steps=50):
state = self._env.reset()
if self._action_noise:
self._action_noise.reset()
total_reward = 0.0
episode_step = self._session.run(self._global_step)

for epoch in range(epochs):
for step in range(epoch_cycles):
for rollout in range(rollout_steps):
if self._plot:
self._env.render()
action = self.predict(state)

next_state, reward, terminal, info = self._env.step(
action * self._max_action)
self._replay_buffer.add_data(state, action,
reward * self._reward_scale,
terminal, next_state)
state = next_state
total_reward += reward
if terminal:
self._report_total_reward(total_reward, episode_step)
print("epoch %d, total reward %lf\n" % (episode_step,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use rllab's logger for this (logger.record_tabular('key', value)) and log more performance measures than reward, e.g.:

  • actor loss
  • critic loss
  • max and average predicted Q value
  • etc. (maybe you can think of more things that could help us understand the performance)

total_reward))
episode_step = self._session.run(
self._global_step.assign_add(1))
total_reward = 0
state = self._env.reset()
if self._action_noise:
self._action_noise.reset()

for train in range(train_steps):
self._train_net()
self.save_session(episode_step)

def load_session(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this code if everything is serializable. Make sure you can run your implementation with run_experiment_lite.

if not self._check_point_dir:
return

if not self._saver:
self._saver = tf.train.Saver()
try:
print("Trying to restore last checkpoint ...:",
self._check_point_dir)
last_chk_path = tf.train.latest_checkpoint(
checkpoint_dir=self._check_point_dir)
self._saver.restore(self._session, save_path=last_chk_path)
print("restore last checkpoint %s done" % self._check_point_dir)
except Exception as e:
if not os.path.exists(self._check_point_dir):
os.mkdir(self._check_point_dir)
assert (os.path.exists(
self._check_point_dir
)), "%s check point file create fail" % self._check_point_dir
print(
"Failed to restore checkpoint. Initializing variables instead."
), e
self._session.run(tf.global_variables_initializer())

def save_session(self, step):
if not self._saver:
return
save_path = self._check_point_dir + "/event"
self._saver.save(self._session, save_path=save_path, global_step=step)
22 changes: 22 additions & 0 deletions sandbox/rocky/tf/algos/ddpg/noise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np


class OrnsteinUhlenbeckActionNoise:
def __init__(self, mu, sigma, theta=.15, dt=1e-2, x0=None):
self._theta = theta
self._mu = mu
self._sigma = sigma
self._dt = dt
self._x0 = x0
self.reset()

def gen(self):
x = self._x_prev + self._theta * (
self._mu - self._x_prev) * self._dt + self._sigma * np.sqrt(
self._dt) * np.random.normal(size=self._mu.shape)
self.x_prev = x
return x

def reset(self):
self._x_prev = self._x0 if self._x0 is not None else np.zeros_like(
self._mu)
25 changes: 25 additions & 0 deletions sandbox/rocky/tf/algos/ddpg/replay_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import numpy as np
import random


class ReplayBuffer:
def __init__(self, buffer_size):
self._buffer = []
self._buffer_size = buffer_size

def add_data(self, state, action, reward, terminal, next_state):
self._buffer.append((state, action, reward, terminal, next_state))
if (self.get_buffer_size() > self._buffer_size):
self._buffer = self._buffer[1:]

def get_batch_data(self, batch_size):
data = random.sample(self._buffer, batch_size)
states = np.array([d[0] for d in data])
actions = np.array([d[1] for d in data])
rewards = np.array([d[2] for d in data])
terminals = np.array([d[3] for d in data])
next_states = np.array([d[4] for d in data])
return [states, actions, rewards, terminals, next_states]

def get_buffer_size(self):
return len(self._buffer)
Loading