Skip to content

Commit

Permalink
High-level API: Establish a strong link between the actor and the
Browse files Browse the repository at this point in the history
distribution function (dist_fn) used in policies by creating
the distribution function in the actor factory which knows which
function is appropriate.

Consequently, remove the policy parameter 'dist_fn' from the high-level API
because it is determined automatically, eliminating the possibility
of misspecification by the user. [breaking change: code must not
specify the 'dist_fn' parameter, but persisted objects continue to work
as expected]

Implements #1194
  • Loading branch information
opcode81 committed Aug 7, 2024
1 parent fb0561a commit ee90cb5
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 60 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Release 1.1.0

### Api Extensions
### Changes/Improvements
- `evaluation`: New package for repeating the same experiment with multiple seeds and aggregating the results. #1074 #1141 #1183
- `data`:
- `Batch`:
Expand Down Expand Up @@ -107,6 +107,11 @@ continuous and discrete cases. #1032
- `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077
- `AtariEnvFactory` constructor (in examples, so not really breaking) now requires explicit train and test seeds. #1074
- `EnvFactoryRegistered` now requires an explicit `test_seed` in the constructor. #1074
- `highlevel`:
- The parameter `dist_fn` has been removed from the parameter objects (`PGParams`, `A2CParams`, `PPOParams`, `NPGParams`, `TRPOParams`).
The correct distribution is now determined automatically based on the actor factory being used, avoiding the possibility of
misspecification. Persisted configurations/policies continue to work as expected, but code must not specify the `dist_fn` parameter.
#1194 #1195


### Tests
Expand Down
16 changes: 15 additions & 1 deletion examples/atari/atari_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
IntermediateModule,
IntermediateModuleFactory,
)
from tianshou.highlevel.params.dist_fn import DistributionFunctionFactoryCategorical
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net.common import NetBase
from tianshou.utils.net.discrete import Actor, NoisyLinear

Expand Down Expand Up @@ -246,6 +248,8 @@ def forward(


class ActorFactoryAtariDQN(ActorFactory):
USE_SOFTMAX_OUTPUT = False

def __init__(
self,
scale_obs: bool = True,
Expand Down Expand Up @@ -274,7 +278,17 @@ def create_module(self, envs: Environments, device: TDevice) -> Actor:
)
if self.scale_obs:
net = scale_obs(net)
return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device)
return Actor(
net,
envs.get_action_shape(),
device=device,
softmax_output=self.USE_SOFTMAX_OUTPUT,
).to(device)

def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None:
return DistributionFunctionFactoryCategorical(
is_probs_input=self.USE_SOFTMAX_OUTPUT,
).create_dist_fn(envs)


class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory):
Expand Down
4 changes: 0 additions & 4 deletions examples/mujoco/mujoco_npg_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
ExperimentConfig,
NPGExperimentBuilder,
)
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactoryIndependentGaussians,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import NPGParams
from tianshou.utils import logging
Expand Down Expand Up @@ -78,7 +75,6 @@ def main(
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
if lr_decay
else None,
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
),
)
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
Expand Down
4 changes: 0 additions & 4 deletions examples/mujoco/mujoco_ppo_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
ExperimentConfig,
PPOExperimentBuilder,
)
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactoryIndependentGaussians,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.utils import logging
Expand Down Expand Up @@ -88,7 +85,6 @@ def main(
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
if lr_decay
else None,
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
),
)
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
Expand Down
4 changes: 0 additions & 4 deletions examples/mujoco/mujoco_ppo_hl_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
PPOExperimentBuilder,
)
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactoryIndependentGaussians,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.utils import logging
Expand Down Expand Up @@ -115,7 +112,6 @@ def main(
recompute_advantage=True,
lr=3e-4,
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config),
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
),
)
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
Expand Down
4 changes: 0 additions & 4 deletions examples/mujoco/mujoco_trpo_hl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
ExperimentConfig,
TRPOExperimentBuilder,
)
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactoryIndependentGaussians,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import TRPOParams
from tianshou.utils import logging
Expand Down Expand Up @@ -82,7 +79,6 @@ def main(
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
if lr_decay
else None,
dist_fn=DistributionFunctionFactoryIndependentGaussians(),
),
)
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
Expand Down
1 change: 1 addition & 0 deletions test/highlevel/test_experiment_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime
@pytest.mark.parametrize(
"builder_cls",
[
PGExperimentBuilder,
PPOExperimentBuilder,
A2CExperimentBuilder,
DQNExperimentBuilder,
Expand Down
4 changes: 4 additions & 0 deletions tianshou/highlevel/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,14 @@ def _create_policy(self, envs: Environments, device: TDevice) -> PGPolicy:
optim_factory=self.optim_factory,
),
)
dist_fn = self.actor_factory.create_dist_fn(envs)
assert dist_fn is not None
return PGPolicy(
actor=actor.module,
optim=actor.optim,
action_space=envs.get_action_space(),
observation_space=envs.get_observation_space(),
dist_fn=dist_fn,
**kwargs,
)

Expand Down Expand Up @@ -333,6 +336,7 @@ def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]:
kwargs["critic"] = actor_critic.critic
kwargs["optim"] = actor_critic.optim
kwargs["action_space"] = envs.get_action_space()
kwargs["dist_fn"] = self.actor_factory.create_dist_fn(envs)
return kwargs

def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
Expand Down
42 changes: 38 additions & 4 deletions tianshou/highlevel/module/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
)
from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactoryCategorical,
DistributionFunctionFactoryIndependentGaussians,
)
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import BaseActor, ModuleType, Net
from tianshou.utils.string import ToStringMixin
Expand Down Expand Up @@ -47,6 +52,14 @@ class ActorFactory(ModuleFactory, ToStringMixin, ABC):
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
pass

@abstractmethod
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None:
"""
:param envs: the environments
:return: the distribution function, which converts the actor's output into a distribution, or None
if the actor does not output distribution parameters
"""

def create_module_opt(
self,
envs: Environments,
Expand All @@ -70,7 +83,7 @@ def create_module_opt(
def _init_linear(actor: torch.nn.Module) -> None:
"""Initializes linear layers of an actor module using default mechanisms.
:param module: the actor module.
:param actor: the actor module.
"""
init_linear_orthogonal(actor)
if hasattr(actor, "mu"):
Expand Down Expand Up @@ -104,7 +117,7 @@ def __init__(
self.hidden_activation = hidden_activation
self.discrete_softmax = discrete_softmax

def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
def _create_factory(self, envs: Environments) -> ActorFactory:
env_type = envs.get_type()
factory: ActorFactoryContinuousDeterministicNet | ActorFactoryContinuousGaussianNet | ActorFactoryDiscreteNet
if env_type == EnvType.CONTINUOUS:
Expand All @@ -125,15 +138,22 @@ def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
raise ValueError("Continuous action spaces are not supported by the algorithm")
case _:
raise ValueError(self.continuous_actor_type)
return factory.create_module(envs, device)
elif env_type == EnvType.DISCRETE:
factory = ActorFactoryDiscreteNet(
self.DEFAULT_HIDDEN_SIZES,
softmax_output=self.discrete_softmax,
)
return factory.create_module(envs, device)
else:
raise ValueError(f"{env_type} not supported")
return factory

def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
factory = self._create_factory(envs)
return factory.create_module(envs, device)

def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None:
factory = self._create_factory(envs)
return factory.create_dist_fn(envs)


class ActorFactoryContinuous(ActorFactory, ABC):
Expand All @@ -159,6 +179,9 @@ def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
device=device,
).to(device)

def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None:
return None


class ActorFactoryContinuousGaussianNet(ActorFactoryContinuous):
def __init__(
Expand Down Expand Up @@ -202,6 +225,9 @@ def create_module(self, envs: Environments, device: TDevice) -> BaseActor:

return actor

def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None:
return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs)


class ActorFactoryDiscreteNet(ActorFactory):
def __init__(
Expand Down Expand Up @@ -229,6 +255,11 @@ def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
softmax_output=self.softmax_output,
).to(device)

def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None:
return DistributionFunctionFactoryCategorical(
is_probs_input=self.softmax_output,
).create_dist_fn(envs)


class ActorFactoryTransientStorageDecorator(ActorFactory):
"""Wraps an actor factory, storing the most recently created actor instance such that it can be retrieved."""
Expand All @@ -254,6 +285,9 @@ def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.M
self._actor_future.actor = module
return module

def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont | None:
return self.actor_factory.create_dist_fn(envs)


class IntermediateModuleFactoryFromActorFactory(IntermediateModuleFactory):
def __init__(self, actor_factory: ActorFactory):
Expand Down
36 changes: 21 additions & 15 deletions tianshou/highlevel/params/dist_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.env import Environments
from tianshou.policy.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont
from tianshou.utils.string import ToStringMixin

Expand All @@ -20,32 +20,38 @@ def create_dist_fn(


class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
def __init__(self, is_probs_input: bool = True):
"""
:param is_probs_input: If True, the distribution function shall create a categorical distribution from a
tensor containing probabilities; otherwise the tensor is assumed to contain logits.
"""
self.is_probs_input = is_probs_input

def create_dist_fn(self, envs: Environments) -> TDistFnDiscrete:
envs.get_type().assert_discrete(self)
return self._dist_fn
if self.is_probs_input:
return self._dist_fn_probs
else:
return self._dist_fn

# NOTE: Do not move/rename because a reference to the function can appear in persisted policies
@staticmethod
def _dist_fn(logits: torch.Tensor) -> torch.distributions.Categorical:
return torch.distributions.Categorical(logits=logits)

# NOTE: Do not move/rename because a reference to the function can appear in persisted policies
@staticmethod
def _dist_fn(p: torch.Tensor) -> torch.distributions.Categorical:
return torch.distributions.Categorical(logits=p)
def _dist_fn_probs(probs: torch.Tensor) -> torch.distributions.Categorical:
return torch.distributions.Categorical(probs=probs)


class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
envs.get_type().assert_continuous(self)
return self._dist_fn

# NOTE: Do not move/rename because a reference to the function can appear in persisted policies
@staticmethod
def _dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.distributions.Distribution:
loc, scale = loc_scale
return torch.distributions.Independent(torch.distributions.Normal(loc, scale), 1)


class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
match envs.get_type():
case EnvType.DISCRETE:
return DistributionFunctionFactoryCategorical().create_dist_fn(envs)
case EnvType.CONTINUOUS:
return DistributionFunctionFactoryIndependentGaussians().create_dist_fn(envs)
case _:
raise ValueError(envs.get_type())
27 changes: 4 additions & 23 deletions tianshou/highlevel/params/policy_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,11 @@
from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.highlevel.params.alpha import AutoAlphaFactory
from tianshou.highlevel.params.dist_fn import (
DistributionFunctionFactory,
DistributionFunctionFactoryDefault,
)
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.noise import NoiseFactory
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils import MultipleLRSchedulers
from tianshou.utils.pickle import setstate
from tianshou.utils.string import ToStringMixin


Expand Down Expand Up @@ -209,15 +205,6 @@ def change_value(self, value: Any, data: ParamTransformerData) -> Any:
return value


class ParamTransformerDistributionFunction(ParamTransformerChangeValue):
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
if value == "default":
value = DistributionFunctionFactoryDefault().create_dist_fn(data.envs)
elif isinstance(value, DistributionFunctionFactory):
value = value.create_dist_fn(data.envs)
return value


class ParamTransformerActionScaling(ParamTransformerChangeValue):
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
if value == "default":
Expand Down Expand Up @@ -322,20 +309,14 @@ class PGParams(Params, ParamsMixinActionScaling, ParamsMixinLearningRateWithSche
whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation.
Does not affect training.
"""
dist_fn: TDistFnDiscrOrCont | DistributionFunctionFactory | Literal["default"] = "default"
"""
This can either be a function which maps the model output to a torch distribution or a
factory for the creation of such a function.
When set to "default", a factory which creates Gaussian distributions from mean and standard
deviation will be used for the continuous case and which creates categorical distributions
for the discrete case (see :class:`DistributionFunctionFactoryDefault`)
"""

def __setstate__(self, state: dict[str, Any]) -> None:
setstate(PGParams, self, state, removed_properties=["dist_fn"])

def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.extend(ParamsMixinActionScaling._get_param_transformers(self))
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
transformers.append(ParamTransformerDistributionFunction("dist_fn"))
return transformers


Expand Down

0 comments on commit ee90cb5

Please sign in to comment.