Skip to content

Commit

Permalink
Bound temperature (#142)
Browse files Browse the repository at this point in the history
* Bump up tensorflow requirements

* Use exp-transformed log alpha to guarantee its positivity

* Fix preprocessor serialization in BasePolicy

* Change the SAC loss averaging to match the old SAC code

* Convert SAC._alpha to tensor for diagnostics

* Fix SAC alpha restore

* Update tf versions
  • Loading branch information
hartikainen authored May 1, 2020
1 parent 0596f68 commit 05daa55
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 20 deletions.
2 changes: 1 addition & 1 deletion examples/development/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def _restore_algorithm(self, checkpoint_dir):
)

self.algorithm._alpha_optimizer.apply_gradients([(
tf.zeros_like(self.algorithm._alpha), self.algorithm._alpha
tf.zeros_like(self.algorithm._log_alpha), self.algorithm._log_alpha
)])
self.algorithm._policy_optimizer.apply_gradients([
(tf.zeros_like(variable), variable)
Expand Down
2 changes: 1 addition & 1 deletion examples/development/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_checkpoint_dict(self):
for initial_Q_weights, Q_weights in zip(initial_Qs_weights, Qs_weights):
assert_weights_not_equal(initial_Q_weights, Q_weights)

experiment_runner.algorithm._alpha.assign(5.0)
experiment_runner.algorithm._log_alpha.assign(tf.math.log(5.0))
expected_alpha_value = 5.0
self.assertEqual(
experiment_runner.algorithm._alpha.numpy(),
Expand Down
10 changes: 5 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fasteners==0.15
filelock==3.0.12
funcsigs==1.0.2
future==0.18.2
gast==0.2.2
gast>=0.3.2
gitdb2==2.0.6
GitPython==3.1.0
glfw==1.9.1
Expand Down Expand Up @@ -104,10 +104,10 @@ six==1.13.0
smmap2==2.0.5
tabulate==0.8.6
tensorboard==2.2.0
tensorflow==2.2.0rc2
tensorflow-addons==0.8.3
tensorflow-estimator==2.2.0rc0
tfp-nightly>=0.10.0.dev20200313
tensorflow==2.2.0rc4
tensorflow-addons==0.9.1
tensorflow-estimator==2.2.0
tensorflow-probability==0.10.0rc1
termcolor==1.1.0
tqdm==4.41.1
urllib3==1.24.3
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@
'scikit-video>=1.1.11',
'scipy>=1.4.1',
'tensorflow',
# 'tensorflow-probability',
'tfp-nightly',
'tensorflow-probability>=0.10.0rc0',
),
zip_safe=True,
license='MIT'
Expand Down
20 changes: 11 additions & 9 deletions softlearning/algorithms/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def __init__(
learning_rate=self._policy_lr,
name="policy_optimizer")

self._alpha = tf.Variable(tf.exp(0.0), name='alpha')
self._log_alpha = tf.Variable(0.0)
self._alpha = tfp.util.DeferredTensor(self._log_alpha, tf.exp)

self._alpha_optimizer = tf.optimizers.Adam(
self._alpha_lr, name='alpha_optimizer')
Expand Down Expand Up @@ -188,10 +189,11 @@ def _update_critic(self, batch):
for Q, optimizer in zip(self._Qs, self._Q_optimizers):
with tf.GradientTape() as tape:
Q_values = Q.values(observations, actions)
Q_losses = (
0.5 * tf.losses.MSE(y_true=Q_targets, y_pred=Q_values))
Q_losses = 0.5 * (
tf.losses.MSE(y_true=Q_targets, y_pred=Q_values))
Q_loss = tf.nn.compute_average_loss(Q_losses)

gradients = tape.gradient(Q_losses, Q.trainable_variables)
gradients = tape.gradient(Q_loss, Q.trainable_variables)
optimizer.apply_gradients(zip(gradients, Q.trainable_variables))
Qs_losses.append(Q_losses)
Qs_values.append(Q_values)
Expand All @@ -217,8 +219,8 @@ def _update_actor(self, batch):
Qs_log_targets = tuple(
Q.values(observations, actions) for Q in self._Qs)
Q_log_targets = tf.reduce_min(Qs_log_targets, axis=0)

policy_losses = self._alpha * log_pis - Q_log_targets
policy_loss = tf.nn.compute_average_loss(policy_losses)

tf.debugging.assert_shapes((
(actions, ('B', 'nA')),
Expand All @@ -227,7 +229,7 @@ def _update_actor(self, batch):
))

policy_gradients = tape.gradient(
policy_losses, self._policy.trainable_variables)
policy_loss, self._policy.trainable_variables)

self._policy_optimizer.apply_gradients(zip(
policy_gradients, self._policy.trainable_variables))
Expand All @@ -251,9 +253,9 @@ def _update_alpha(self, batch):
# large learning rate.
alpha_loss = tf.nn.compute_average_loss(alpha_losses)

alpha_gradients = tape.gradient(alpha_loss, [self._alpha])
alpha_gradients = tape.gradient(alpha_loss, [self._log_alpha])
self._alpha_optimizer.apply_gradients(zip(
alpha_gradients, [self._alpha]))
alpha_gradients, [self._log_alpha]))

return alpha_losses

Expand All @@ -276,7 +278,7 @@ def _do_updates(self, batch):
('Q_value-mean', tf.reduce_mean(Qs_values)),
('Q_loss-mean', tf.reduce_mean(Qs_losses)),
('policy_loss-mean', tf.reduce_mean(policy_losses)),
('alpha', self._alpha),
('alpha', tf.convert_to_tensor(self._alpha)),
('alpha_loss-mean', tf.reduce_mean(alpha_losses)),
))
return diagnostics
Expand Down
4 changes: 2 additions & 2 deletions softlearning/policies/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ def get_config(self):
'input_shapes': self._input_shapes,
'output_shape': self._output_shape,
'observation_keys': self._observation_keys,
# 'preprocessors': preprocessors.serialize(self._preprocessors),
'preprocessors': self._preprocessors,
'preprocessors': tree.map(
preprocessors_lib.serialize, self._preprocessors),
'name': self._name,
}
return config
Expand Down

0 comments on commit 05daa55

Please sign in to comment.