Skip to content

Commit

Permalink
Add classification losses for offline RL.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 600260970
  • Loading branch information
agarwl authored and psc-g committed Apr 3, 2024
1 parent 485ea99 commit 90f986e
Show file tree
Hide file tree
Showing 6 changed files with 528 additions and 30 deletions.
49 changes: 49 additions & 0 deletions dopamine/labs/offline_rl/jax/configs/jax_classy_cql.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Hyperparameters follow Hessel et al. (2018), except for sticky_actions,
# which was False (not using sticky actions) in the original paper.
import dopamine.jax.agents.full_rainbow.full_rainbow_agent
import dopamine.jax.agents.dqn.dqn_agent
import dopamine.discrete_domains.atari_lib
import dopamine.discrete_domains.run_experiment
import dopamine.labs.offline_rl.fixed_replay
import dopamine.labs.offline_rl.jax.networks
import dopamine.labs.offline_rl.jax.offline_rainbow_agent

JaxDQNAgent.gamma = 0.99
JaxDQNAgent.update_horizon = 1
JaxDQNAgent.min_replay_history = 20000 # agent steps
# update_period=1 is a sane default for offline RL.
JaxDQNAgent.update_period = 1
JaxDQNAgent.target_update_period = 2000 # agent steps
JaxDQNAgent.epsilon_eval = 0.001
JaxDQNAgent.epsilon_decay_period = 250000 # agent steps
JaxDQNAgent.optimizer = 'adam'
JaxDQNAgent.summary_writing_frequency = 2500

JaxFullRainbowAgent.dueling = False # Don't use duelling networks.
JaxFullRainbowAgent.double_dqn = True
JaxFullRainbowAgent.num_atoms = 51
JaxFullRainbowAgent.replay_scheme = 'uniform'
JaxFullRainbowAgent.vmax = 10.

OfflineClassyCQLAgent.td_coefficient = 1.0
OfflineClassyCQLAgent.bc_coefficient = 0.1

JaxFullRainbowAgent.network = @networks.ParameterizedRainbowNetwork

# Use parameters similar to that of C51.
create_optimizer.learning_rate = 6.25e-5
create_optimizer.eps = 0.0003125

atari_lib.create_atari_environment.game_name = 'Pong'
# Sticky actions with probability 0.25, as suggested by (Machado et al., 2017).
atari_lib.create_atari_environment.sticky_actions = True
create_runner.schedule = 'continuous_train'
create_agent.agent_name = 'classy_cql'
create_agent.debug_mode = True
Runner.num_iterations = 100
Runner.training_steps = 62_500 # agent steps
Runner.evaluation_steps = 125000 # agent steps
Runner.max_steps_per_episode = 27000 # agent steps

JaxFixedReplayBuffer.replay_capacity = 50000
JaxFixedReplayBuffer.batch_size = 32
22 changes: 14 additions & 8 deletions dopamine/labs/offline_rl/jax/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
from typing import Tuple

from absl import logging
from dopamine.discrete_domains import atari_lib
from dopamine.google.experiments.two_hot import losses as classy_transforms
from flax import linen as nn
import gin
import jax
import jax.numpy as jnp


NetworkType = collections.namedtuple('network', ['q_values', 'representation'])
ClassyNetworkType = collections.namedtuple(
'classy_network', ['q_values', 'logits', 'probabilities', 'representation'])


def preprocess_atari_inputs(x):
Expand Down Expand Up @@ -228,13 +230,14 @@ class ParameterizedRainbowNetwork(nn.Module):

num_actions: int
num_atoms: int
dueling: bool = True
dueling: bool = False
noisy: bool = False # No exploration in offline RL, kept for compatibility.
distributional: bool = True
inputs_preprocessed: bool = False
feature_dim: int = 512
use_impala_encoder: bool = False
nn_scale: int = 1
transform: classy_transforms.HistogramLoss | None = None

def setup(self):
if self.use_impala_encoder:
Expand All @@ -259,7 +262,7 @@ def __call__(
x = nn.Dense(
features=self.feature_dim * self.nn_scale, kernel_init=initializer
)(x)
x = nn.relu(x)
x = representation = nn.relu(x)

if self.dueling:
adv = nn.Dense(features=self.num_actions * self.num_atoms)(x)
Expand All @@ -271,9 +274,12 @@ def __call__(
x = nn.Dense(features=self.num_actions * self.num_atoms)(x)
logits = x.reshape((self.num_actions, self.num_atoms))

if self.distributional:
probabilities = nn.softmax(logits)
probabilities = nn.softmax(logits)
if self.transform is not None:
q_values = jax.vmap(self.transform.transform_from_probs)(probabilities)
else:
q_values = jnp.sum(support * probabilities, axis=1)
return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
q_values = jnp.sum(logits, axis=1) # Sum over all the num_atoms
return atari_lib.DQNNetworkType(q_values)
if self.distributional:
return ClassyNetworkType(
q_values, logits, probabilities, representation)
return NetworkType(q_values, representation)
Loading

0 comments on commit 90f986e

Please sign in to comment.