Skip to content

Commit

Permalink
More PEP8 formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
jonashen committed Jun 8, 2018
1 parent ee1f5ed commit 49b387f
Show file tree
Hide file tree
Showing 26 changed files with 584 additions and 464 deletions.
56 changes: 29 additions & 27 deletions rllab/tf/algos/batch_polopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,28 @@ class BatchPolopt(RLAlgorithm):
This includes various policy gradient methods like vpg, npg, ppo, trpo, etc.
"""

def __init__(
self,
env,
policy,
baseline,
scope=None,
n_itr=500,
start_itr=0,
batch_size=5000,
max_path_length=500,
discount=0.99,
gae_lambda=1,
plot=False,
pause_for_plot=False,
center_adv=True,
positive_adv=False,
store_paths=False,
whole_paths=True,
fixed_horizon=False,
sampler_cls=None,
sampler_args=None,
force_batch_sampler=False,
**kwargs
):
def __init__(self,
env,
policy,
baseline,
scope=None,
n_itr=500,
start_itr=0,
batch_size=5000,
max_path_length=500,
discount=0.99,
gae_lambda=1,
plot=False,
pause_for_plot=False,
center_adv=True,
positive_adv=False,
store_paths=False,
whole_paths=True,
fixed_horizon=False,
sampler_cls=None,
sampler_args=None,
force_batch_sampler=False,
**kwargs):
"""
:param env: Environment
:param policy: Policy
Expand Down Expand Up @@ -119,7 +117,8 @@ def train(self, sess=None):
logger.log("Optimizing policy...")
self.optimize_policy(itr, samples_data)
logger.log("Saving snapshot...")
params = self.get_itr_snapshot(itr, samples_data) # , **kwargs)
params = self.get_itr_snapshot(itr,
samples_data) # , **kwargs)
if self.store_paths:
params["paths"] = samples_data["paths"]
logger.save_itr_params(itr, params)
Expand All @@ -128,7 +127,11 @@ def train(self, sess=None):
logger.record_tabular('ItrTime', time.time() - itr_start_time)
logger.dump_tabular(with_prefix=False)
if self.plot:
rollout(self.env, self.policy, animated=True, max_path_length=self.max_path_length)
rollout(
self.env,
self.policy,
animated=True,
max_path_length=self.max_path_length)
if self.pause_for_plot:
input("Plotting evaluation run: Press Enter to "
"continue...")
Expand Down Expand Up @@ -157,4 +160,3 @@ def get_itr_snapshot(self, itr, samples_data):

def optimize_policy(self, itr, samples_data):
raise NotImplementedError

78 changes: 44 additions & 34 deletions rllab/tf/algos/npo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@



from rllab.misc import ext
from rllab.misc.overrides import overrides
import rllab.misc.logger as logger
Expand All @@ -16,13 +13,12 @@ class NPO(BatchPolopt):
Natural Policy Optimization.
"""

def __init__(
self,
optimizer=None,
optimizer_args=None,
step_size=0.01,
name="NPO",
**kwargs):
def __init__(self,
optimizer=None,
optimizer_args=None,
step_size=0.01,
name="NPO",
**kwargs):
if optimizer is None:
if optimizer_args is None:
optimizer_args = dict()
Expand Down Expand Up @@ -52,37 +48,52 @@ def init_opt(self):
dist = self.policy.distribution

old_dist_info_vars = {
k: tf.placeholder(tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name='old_%s' % k)
k: tf.placeholder(
tf.float32,
shape=[None] * (1 + is_recurrent) + list(shape),
name='old_%s' % k)
for k, shape in dist.dist_info_specs
}
old_dist_info_vars_list = [old_dist_info_vars[k] for k in dist.dist_info_keys]
}
old_dist_info_vars_list = [
old_dist_info_vars[k] for k in dist.dist_info_keys
]

state_info_vars = {
k: tf.placeholder(tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name=k)
k: tf.placeholder(
tf.float32,
shape=[None] * (1 + is_recurrent) + list(shape),
name=k)
for k, shape in self.policy.state_info_specs
}
state_info_vars_list = [state_info_vars[k] for k in self.policy.state_info_keys]
}
state_info_vars_list = [
state_info_vars[k] for k in self.policy.state_info_keys
]

if is_recurrent:
valid_var = tf.placeholder(tf.float32, shape=[None, None], name="valid")
valid_var = tf.placeholder(
tf.float32, shape=[None, None], name="valid")
else:
valid_var = None

dist_info_vars = self.policy.dist_info_sym(obs_var, state_info_vars)
dist_info_vars = self.policy.dist_info_sym(obs_var,
state_info_vars)
kl = dist.kl_sym(old_dist_info_vars, dist_info_vars)
lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars, dist_info_vars)
lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars,
dist_info_vars)
if is_recurrent:
mean_kl = tf.reduce_sum(kl * valid_var) / tf.reduce_sum(valid_var)
surr_loss = - tf.reduce_sum(lr * advantage_var * valid_var) / tf.reduce_sum(valid_var)
mean_kl = tf.reduce_sum(
kl * valid_var) / tf.reduce_sum(valid_var)
surr_loss = -tf.reduce_sum(
lr * advantage_var * valid_var) / tf.reduce_sum(valid_var)
else:
mean_kl = tf.reduce_mean(kl, name="reduce_mean_er")
surr_loss = - tf.reduce_mean(lr * advantage_var)
surr_loss = -tf.reduce_mean(lr * advantage_var)

input_list = [
obs_var,
action_var,
advantage_var,
] + state_info_vars_list + old_dist_info_vars_list
obs_var,
action_var,
advantage_var,
] + state_info_vars_list + old_dist_info_vars_list
if is_recurrent:
input_list.append(valid_var)

Expand All @@ -91,22 +102,21 @@ def init_opt(self):
target=self.policy,
leq_constraint=(mean_kl, self.step_size),
inputs=input_list,
constraint_name="mean_kl"
)
constraint_name="mean_kl")
return dict()

@overrides
def optimize_policy(self, itr, samples_data):
all_input_values = tuple(ext.extract(
samples_data,
"observations", "actions", "advantages"
))
all_input_values = tuple(
ext.extract(samples_data, "observations", "actions", "advantages"))
agent_infos = samples_data["agent_infos"]
state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]
dist_info_list = [
agent_infos[k] for k in self.policy.distribution.dist_info_keys
]
all_input_values += tuple(state_info_list) + tuple(dist_info_list)
if self.policy.recurrent:
all_input_values += (samples_data["valids"],)
all_input_values += (samples_data["valids"], )
logger.log("Computing loss before")
loss_before = self.optimizer.loss(all_input_values)
logger.log("Computing KL before")
Expand Down
83 changes: 48 additions & 35 deletions rllab/tf/algos/vpg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


from rllab.misc import logger
from rllab.misc import ext
from rllab.misc.overrides import overrides
Expand All @@ -16,15 +14,14 @@ class VPG(BatchPolopt, Serializable):
Vanilla Policy Gradient.
"""

def __init__(
self,
env,
policy,
baseline,
optimizer=None,
optimizer_args=None,
name="VPG",
**kwargs):
def __init__(self,
env,
policy,
baseline,
optimizer=None,
optimizer_args=None,
name="VPG",
**kwargs):
Serializable.quick_init(self, locals())
if optimizer is None:
default_args = dict(
Expand All @@ -39,7 +36,8 @@ def __init__(
self.optimizer = optimizer
self.opt_info = None
self.name = name
super(VPG, self).__init__(env=env, policy=policy, baseline=baseline, **kwargs)
super(VPG, self).__init__(
env=env, policy=policy, baseline=baseline, **kwargs)

@overrides
def init_opt(self):
Expand All @@ -62,71 +60,86 @@ def init_opt(self):
dist = self.policy.distribution

old_dist_info_vars = {
k: tf.placeholder(tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name='old_%s' % k)
k: tf.placeholder(
tf.float32,
shape=[None] * (1 + is_recurrent) + list(shape),
name='old_%s' % k)
for k, shape in dist.dist_info_specs
}
old_dist_info_vars_list = [old_dist_info_vars[k] for k in dist.dist_info_keys]
}
old_dist_info_vars_list = [
old_dist_info_vars[k] for k in dist.dist_info_keys
]

state_info_vars = {
k: tf.placeholder(tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name=k)
k: tf.placeholder(
tf.float32,
shape=[None] * (1 + is_recurrent) + list(shape),
name=k)
for k, shape in self.policy.state_info_specs
}
state_info_vars_list = [state_info_vars[k] for k in self.policy.state_info_keys]
}
state_info_vars_list = [
state_info_vars[k] for k in self.policy.state_info_keys
]

if is_recurrent:
valid_var = tf.placeholder(tf.float32, shape=[None, None], name="valid")
valid_var = tf.placeholder(
tf.float32, shape=[None, None], name="valid")
else:
valid_var = None

dist_info_vars = self.policy.dist_info_sym(obs_var, state_info_vars)
dist_info_vars = self.policy.dist_info_sym(obs_var,
state_info_vars)
logli = dist.log_likelihood_sym(action_var, dist_info_vars)
kl = dist.kl_sym(old_dist_info_vars, dist_info_vars)

# formulate as a minimization problem
# The gradient of the surrogate objective is the policy gradient
if is_recurrent:
surr_obj = - tf.reduce_sum(logli * advantage_var * valid_var) / tf.reduce_sum(valid_var)
mean_kl = tf.reduce_sum(kl * valid_var) / tf.reduce_sum(valid_var)
surr_obj = -tf.reduce_sum(logli * advantage_var *
valid_var) / tf.reduce_sum(valid_var)
mean_kl = tf.reduce_sum(
kl * valid_var) / tf.reduce_sum(valid_var)
max_kl = tf.reduce_max(kl * valid_var)
else:
surr_obj = - tf.reduce_mean(logli * advantage_var)
surr_obj = -tf.reduce_mean(logli * advantage_var)
mean_kl = tf.reduce_mean(kl)
max_kl = tf.reduce_max(kl)

input_list = [obs_var, action_var, advantage_var] + state_info_vars_list
input_list = [obs_var, action_var, advantage_var
] + state_info_vars_list
if is_recurrent:
input_list.append(valid_var)

self.optimizer.update_opt(loss=surr_obj, target=self.policy, inputs=input_list)
self.optimizer.update_opt(
loss=surr_obj, target=self.policy, inputs=input_list)

f_kl = tensor_utils.compile_function(
inputs=input_list + old_dist_info_vars_list,
outputs=[mean_kl, max_kl],
)
self.opt_info = dict(
f_kl=f_kl,
)
self.opt_info = dict(f_kl=f_kl, )

@overrides
def optimize_policy(self, itr, samples_data):
logger.log("optimizing policy")
inputs = ext.extract(
samples_data,
"observations", "actions", "advantages"
)
inputs = ext.extract(samples_data, "observations", "actions",
"advantages")
agent_infos = samples_data["agent_infos"]
state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
inputs += tuple(state_info_list)
if self.policy.recurrent:
inputs += (samples_data["valids"],)
dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]
inputs += (samples_data["valids"], )
dist_info_list = [
agent_infos[k] for k in self.policy.distribution.dist_info_keys
]
loss_before = self.optimizer.loss(inputs)
self.optimizer.optimize(inputs)
loss_after = self.optimizer.loss(inputs)
logger.record_tabular("LossBefore", loss_before)
logger.record_tabular("LossAfter", loss_after)

mean_kl, max_kl = self.opt_info['f_kl'](*(list(inputs) + dist_info_list))
mean_kl, max_kl = self.opt_info['f_kl'](
*(list(inputs) + dist_info_list))
logger.record_tabular('MeanKL', mean_kl)
logger.record_tabular('MaxKL', max_kl)

Expand Down
7 changes: 2 additions & 5 deletions rllab/tf/distributions/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@




class Distribution(object):
@property
def dim(self):
Expand All @@ -19,7 +15,8 @@ def kl(self, old_dist_info, new_dist_info):
"""
raise NotImplementedError

def likelihood_ratio_sym(self, x_var, old_dist_info_vars, new_dist_info_vars):
def likelihood_ratio_sym(self, x_var, old_dist_info_vars,
new_dist_info_vars):
raise NotImplementedError

def entropy(self, dist_info):
Expand Down
Loading

0 comments on commit 49b387f

Please sign in to comment.