Skip to content

Commit

Permalink
reactor lin-tanh to sequential (#139)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Sep 8, 2023
1 parent 5ab435f commit 93ee78e
Showing 1 changed file with 22 additions and 17 deletions.
39 changes: 22 additions & 17 deletions src/stream_ml/pytorch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,54 +5,59 @@

from torch import nn

__all__ = ["lin_tanh"]
__all__ = ["sequential"]


def lin_tanh(
n_in: int = 1,
n_hidden: int = 50,
n_layers: int = 3,
n_out: int = 3,
def sequential(
data: int = 1,
layers: int = 3,
hidden_features: int = 50,
features: int = 3,
*,
dropout: float = 0.0,
activation: type[nn.Module] | None = None,
) -> nn.Sequential:
"""Linear tanh network.
Parameters
----------
n_in : int, optional
data : int, optional
Number of input features, by default 1
n_hidden : int, optional
Number of hidden units, by default 50.
n_layers : int, optional
layers : int, optional
Number of hidden layers, by default 3.
Must be >= 2.
n_out : int, optional
hidden_features : int, optional
Number of hidden units, by default 50.
features : int, optional
Number of output features, by default 3.
dropout : float, optional
Dropout probability, by default 0.0
activation : type[`torch.nn.Module`] | None, optional
Activation function. If `None` (default), uses `torch.nn.Tanh`.
Returns
-------
`torch.nn.Sequential`
"""
activation_func = nn.Tanh if activation is None else activation

def make_layer(n_in: int, n_hidden: int) -> tuple[nn.Module, ...]:
return (nn.Linear(n_in, n_hidden), nn.Tanh()) + (
def make_layer(data: int, hidden_features: int) -> tuple[nn.Module, ...]:
return (nn.Linear(data, hidden_features), activation_func()) + (
(nn.Dropout(p=dropout),) if dropout > 0 else ()
)

mid_layers = (
functools.reduce(
operator.add,
(make_layer(n_hidden, n_hidden) for _ in range(n_layers - 2)),
(make_layer(hidden_features, hidden_features) for _ in range(layers - 2)),
)
if n_layers >= 3 # noqa: PLR2004
if layers >= 3 # noqa: PLR2004
else ()
)

return nn.Sequential(
*make_layer(n_in, n_hidden),
*make_layer(data, hidden_features),
*mid_layers,
nn.Linear(n_hidden, n_out),
nn.Linear(hidden_features, features),
)

0 comments on commit 93ee78e

Please sign in to comment.