Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tackle Typing and Linting Errors #379

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
3f89961
Add ruff ignore F722
gileshd Sep 12, 2024
3ad1fa7
Remove unused imports
gileshd Sep 12, 2024
dc7f675
Convert strings with latex to raw-strings
gileshd Sep 13, 2024
20a1df7
Prepend space in uni-dim jaxtyping hints
gileshd Sep 12, 2024
ca9a75c
Fix jr.PRNGKey type hints
gileshd Sep 12, 2024
75205f8
Rename and change PRNGKey Type
gileshd Sep 12, 2024
42c5c22
Add IntScalar type
gileshd Sep 23, 2024
7da57ef
Minor arg and type changes in utils/utils.py
gileshd Sep 17, 2024
62df3ba
Update HMM[Parameter|Property]Set protocols
gileshd Sep 12, 2024
45134e8
Update type annotations in hmm base classes
gileshd Sep 23, 2024
0d4ca74
Update type annotations in hmm inference code.
gileshd Sep 18, 2024
cfd3bfc
Fix type annotations in hmm parallel inference
gileshd Sep 12, 2024
d0a0e0f
Add further type annotations to hmm transitions class
gileshd Oct 16, 2024
4b2c40b
Add further type annotations to hmm initial base class
gileshd Oct 16, 2024
4c84ddc
Add further type annotations to categorical hmm
gileshd Sep 18, 2024
7e04c9e
Add further type annotations to arhmm
gileshd Sep 20, 2024
b2fffa4
Add further type annotations to linreghmm
gileshd Sep 20, 2024
a45e61b
Add further type annotations to Bernoulli HMM
gileshd Sep 20, 2024
32db62a
Add further type annotations to Gamma HMM
gileshd Sep 20, 2024
81138f8
Add further type annotations to Gaussian HMMs
gileshd Sep 20, 2024
4e36520
Add further type annotations to gmhmms
gileshd Sep 20, 2024
4b7ee53
Add further type annotations to logreg hmm
gileshd Sep 21, 2024
e246606
Add further type annotations to multinomialhmm
gileshd Sep 21, 2024
e19748d
Add further type annotations to poisson hmm
gileshd Sep 23, 2024
c2098cf
Add further type annotations to categorical glm hmm
gileshd Oct 6, 2024
547c610
Fix LinearGaussianSSM.sample type hint
gileshd Sep 12, 2024
58c9127
Change type hints to jaxtyping in slds code
gileshd Sep 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions dynamax/generalized_gaussian_ssm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from jax import jacfwd, vmap, lax
import jax.numpy as jnp
from jax import lax
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
from jaxtyping import Array, Float
from typing import NamedTuple, Optional, Union, Callable

Expand Down Expand Up @@ -83,7 +82,7 @@ def compute_weights_and_sigmas(self, m, P):


def _predict(m, P, f, Q, u, g_ev, g_cov):
"""Predict next mean and covariance under an additive-noise Gaussian filter
r"""Predict next mean and covariance under an additive-noise Gaussian filter

p(x_{t+1}) = N(x_{t+1} | mu_pred, Sigma_pred)
where
Expand Down Expand Up @@ -117,7 +116,7 @@ def _predict(m, P, f, Q, u, g_ev, g_cov):


def _condition_on(m, P, y_cond_mean, y_cond_cov, u, y, g_ev, g_cov, num_iter, emission_dist):
"""Condition a Gaussian potential on a new observation with arbitrary
r"""Condition a Gaussian potential on a new observation with arbitrary
likelihood with given functions for conditional moments and make a
Gaussian approximation.
p(x_t | y_t, u_t, y_{1:t-1}, u_{1:t-1})
Expand Down Expand Up @@ -172,7 +171,7 @@ def _step(carry, _):


def _statistical_linear_regression(mu, Sigma, m, S, C):
"""Return moment-matching affine coefficients and approximation noise variance
r"""Return moment-matching affine coefficients and approximation noise variance
given joint moments.

g(x) \approx Ax + b + e where e ~ N(0, Omega)
Expand Down
20 changes: 10 additions & 10 deletions dynamax/generalized_gaussian_ssm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from dynamax.nonlinear_gaussian_ssm.models import FnStateToState, FnStateAndInputToState
from dynamax.nonlinear_gaussian_ssm.models import FnStateToEmission, FnStateAndInputToEmission

FnStateToEmission2 = Callable[[Float[Array, "state_dim"]], Float[Array, "emission_dim emission_dim"]]
FnStateAndInputToEmission2 = Callable[[Float[Array, "state_dim"], Float[Array, "input_dim"]], Float[Array, "emission_dim emission_dim"]]
FnStateToEmission2 = Callable[[Float[Array, " state_dim"]], Float[Array, "emission_dim emission_dim"]]
FnStateAndInputToEmission2 = Callable[[Float[Array, " state_dim"], Float[Array, " input_dim"]], Float[Array, "emission_dim emission_dim"]]

# emission distribution takes a mean vector and covariance matrix and returns a distribution
EmissionDistFn = Callable[ [Float[Array, "state_dim"], Float[Array, "state_dim state_dim"]], tfd.Distribution]
EmissionDistFn = Callable[ [Float[Array, " state_dim"], Float[Array, "state_dim state_dim"]], tfd.Distribution]


class ParamsGGSSM(NamedTuple):
Expand All @@ -42,7 +42,7 @@ class ParamsGGSSM(NamedTuple):

"""

initial_mean: Float[Array, "state_dim"]
initial_mean: Float[Array, " state_dim"]
initial_covariance: Float[Array, "state_dim state_dim"]
dynamics_function: Union[FnStateToState, FnStateAndInputToState]
dynamics_covariance: Float[Array, "state_dim state_dim"]
Expand Down Expand Up @@ -97,15 +97,15 @@ def covariates_shape(self):
def initial_distribution(
self,
params: ParamsGGSSM,
inputs: Optional[Float[Array, "input_dim"]]=None
inputs: Optional[Float[Array, " input_dim"]]=None
) -> tfd.Distribution:
return MVN(params.initial_mean, params.initial_covariance)

def transition_distribution(
self,
params: ParamsGGSSM,
state: Float[Array, "state_dim"],
inputs: Optional[Float[Array, "input_dim"]]=None
state: Float[Array, " state_dim"],
inputs: Optional[Float[Array, " input_dim"]]=None
) -> tfd.Distribution:
f = params.dynamics_function
if inputs is None:
Expand All @@ -117,8 +117,8 @@ def transition_distribution(
def emission_distribution(
self,
params: ParamsGGSSM,
state: Float[Array, "state_dim"],
inputs: Optional[Float[Array, "input_dim"]]=None
state: Float[Array, " state_dim"],
inputs: Optional[Float[Array, " input_dim"]]=None
) -> tfd.Distribution:
h = params.emission_mean_function
R = params.emission_cov_function
Expand All @@ -128,4 +128,4 @@ def emission_distribution(
else:
mean = h(state, inputs)
cov = R(state, inputs)
return params.emission_dist(mean, cov)
return params.emission_dist(mean, cov)
4 changes: 2 additions & 2 deletions dynamax/hidden_markov_model/demos/categorical_glm_hmm_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@
plt.figure()
plt.imshow(jnp.vstack((states[None, :], most_likely_states[None, :])),
aspect="auto", interpolation='none', cmap="Greys")
plt.yticks([0.0, 1.0], ["$z$", "$\hat{z}$"])
plt.yticks([0.0, 1.0], ["$z$", r"$\hat{z}$"])
plt.xlabel("time")
plt.xlim(0, 500)


print("true log prob: ", hmm.marginal_log_prob(true_params, emissions, inputs=inputs))
print("test log prob: ", test_hmm.marginal_log_prob(params, emissions, inputs=inputs))

plt.show()
plt.show()
140 changes: 75 additions & 65 deletions dynamax/hidden_markov_model/inference.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
from functools import partial
from typing import Callable, NamedTuple, Optional, Tuple, Union
import jax.numpy as jnp
import jax.random as jr
from jax import lax
from jax import vmap
from jax import jit
from functools import partial

from typing import Callable, Optional, Tuple, Union, NamedTuple
from jax import jit, lax, vmap
from jaxtyping import Int, Float, Array

from dynamax.types import Scalar, PRNGKey
from dynamax.types import IntScalar, Scalar

_get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x

def get_trans_mat(transition_matrix, transition_fn, t):
def get_trans_mat(
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]],
t: IntScalar
) -> Float[Array, "num_states num_states"]:
if transition_fn is not None:
return transition_fn(t)
else:
if transition_matrix.ndim == 3: # (T,K,K)
elif transition_matrix is not None:
if transition_matrix.ndim == 3: # (T-1,K,K)
return transition_matrix[t]
else:
return transition_matrix
else:
raise ValueError("Either `transition_matrix` or `transition_fn` must be specified.")

class HMMPosteriorFiltered(NamedTuple):
r"""Simple wrapper for properties of an HMM filtering posterior.
Expand Down Expand Up @@ -50,12 +54,12 @@ class HMMPosterior(NamedTuple):
filtered_probs: Float[Array, "num_timesteps num_states"]
predicted_probs: Float[Array, "num_timesteps num_states"]
smoothed_probs: Float[Array, "num_timesteps num_states"]
initial_probs: Float[Array, "num_states"]
trans_probs: Optional[Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]]] = None
initial_probs: Float[Array, " num_states"]
trans_probs: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]] = None


def _normalize(u, axis=0, eps=1e-15):
def _normalize(u: Array, axis=0, eps=1e-15):
"""Normalizes the values within the axis in a way that they sum up to 1.

Args:
Expand Down Expand Up @@ -97,11 +101,11 @@ def _predict(probs, A):

@partial(jit, static_argnames=["transition_fn"])
def hmm_filter(
initial_distribution: Float[Array, "num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
initial_distribution: Float[Array, " num_states"],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None
) -> HMMPosteriorFiltered:
r"""Forwards filtering

Expand Down Expand Up @@ -145,11 +149,11 @@ def _step(carry, t):

@partial(jit, static_argnames=["transition_fn"])
def hmm_backward_filter(
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None
) -> Tuple[Float, Float[Array, "num_timesteps num_states"]]:
transition_fn: Optional[Callable[[int], Float[Array, "num_states num_states"]]]= None
) -> Tuple[Scalar, Float[Array, "num_timesteps num_states"]]:
r"""Run the filter backwards in time. This is the second step of the forward-backward algorithm.

Transition matrix may be either 2D (if transition probabilities are fixed) or 3D
Expand Down Expand Up @@ -191,11 +195,11 @@ def _step(carry, t):

@partial(jit, static_argnames=["transition_fn"])
def hmm_two_filter_smoother(
initial_distribution: Float[Array, "num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
initial_distribution: Float[Array, " num_states"],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None,
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None,
compute_trans_probs: bool = True
) -> HMMPosterior:
r"""Computed the smoothed state probabilities using the two-filter
Expand Down Expand Up @@ -245,11 +249,11 @@ def hmm_two_filter_smoother(

@partial(jit, static_argnames=["transition_fn"])
def hmm_smoother(
initial_distribution: Float[Array, "num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
initial_distribution: Float[Array, " num_states"],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None,
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None,
compute_trans_probs: bool = True
) -> HMMPosterior:
r"""Computed the smoothed state probabilities using a general
Expand Down Expand Up @@ -325,12 +329,12 @@ def _step(carry, args):

@partial(jit, static_argnames=["transition_fn", "window_size"])
def hmm_fixed_lag_smoother(
initial_distribution: Float[Array, "num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
initial_distribution: Float[Array, " num_states"],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
window_size: Int,
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None
window_size: int,
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None
) -> HMMPosterior:
r"""Compute the smoothed state probabilities using the fixed-lag smoother.

Expand Down Expand Up @@ -439,12 +443,12 @@ def compute_posterior(filtered_probs, beta):

@partial(jit, static_argnames=["transition_fn"])
def hmm_posterior_mode(
initial_distribution: Float[Array, "num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
initial_distribution: Float[Array, " num_states"],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None
) -> Int[Array, "num_timesteps"]:
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None
) -> Int[Array, " num_timesteps"]:
r"""Compute the most likely state sequence. This is called the Viterbi algorithm.

Args:
Expand Down Expand Up @@ -486,13 +490,13 @@ def _forward_pass(state, best_next_state):

@partial(jit, static_argnames=["transition_fn"])
def hmm_posterior_sample(
rng: jr.PRNGKey,
initial_distribution: Float[Array, "num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
key: Array,
initial_distribution: Float[Array, " num_states"],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None
) -> Int[Array, "num_timesteps"]:
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None
) -> Tuple[Scalar, Int[Array, " num_timesteps"]]:
r"""Sample a latent sequence from the posterior.

Args:
Expand All @@ -515,7 +519,7 @@ def hmm_posterior_sample(
# Run the sampler backward in time
def _step(carry, args):
next_state = carry
t, rng, filtered_probs = args
t, subkey, filtered_probs = args

A = get_trans_mat(transition_matrix, transition_fn, t)

Expand All @@ -524,15 +528,15 @@ def _step(carry, args):
smoothed_probs /= smoothed_probs.sum()

# Sample current state
state = jr.choice(rng, a=num_states, p=smoothed_probs)
state = jr.choice(subkey, a=num_states, p=smoothed_probs)

return state, state

# Run the HMM smoother
rngs = jr.split(rng, num_timesteps)
last_state = jr.choice(rngs[-1], a=num_states, p=filtered_probs[-1])
keys = jr.split(key, num_timesteps)
last_state = jr.choice(keys[-1], a=num_states, p=filtered_probs[-1])
_, states = lax.scan(
_step, last_state, (jnp.arange(1, num_timesteps), rngs[:-1], filtered_probs[:-1]),
_step, last_state, (jnp.arange(1, num_timesteps), keys[:-1], filtered_probs[:-1]),
reverse=True
)

Expand All @@ -544,12 +548,13 @@ def _compute_sum_transition_probs(
transition_matrix: Float[Array, "num_states num_states"],
hmm_posterior: HMMPosterior) -> Float[Array, "num_states num_states"]:
"""Compute the transition probabilities from the HMM posterior messages.

Args:
transition_matrix (_type_): _description_
hmm_posterior (_type_): _description_
"""

def _step(carry, args):
def _step(carry, args: Tuple[Array, Array, Array, Int[Array, ""]]):
filtered_probs, smoothed_probs_next, predicted_probs_next, t = args

# Get parameters for time t
Expand Down Expand Up @@ -580,11 +585,13 @@ def _step(carry, args):


def _compute_all_transition_probs(
transition_matrix: Float[Array, "num_timesteps num_states num_states"],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
hmm_posterior: HMMPosterior,
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None
) -> Float[Array, "num_timesteps num_states num_states"]:
"""Compute the transition probabilities from the HMM posterior messages.

Args:
transition_matrix (_type_): _description_
hmm_posterior (_type_): _description_
Expand All @@ -598,20 +605,21 @@ def _compute_probs(t):
A = get_trans_mat(transition_matrix, transition_fn, t)
return jnp.einsum('i,ij,j->ij', filtered_probs[t], A, relative_probs_next[t])

transition_probs = vmap(_compute_probs)(jnp.arange(len(filtered_probs)-1))
transition_probs = vmap(_compute_probs)(jnp.arange(len(filtered_probs)))
return transition_probs


# TODO: Consider alternative annotation for return type:
# Float[Array, "*num_timesteps num_states num_states"] I think this would allow multiple prepended dims.
# Float[Array, "#num_timesteps num_states num_states"] this might accept (1, sd, sd) but not (sd, sd).
# TODO: This is a candidate for @overload however at present I think we would need to use
# `@beartype.typing.overload` and beartype is currently not a core dependency.
# Support for `typing.overload` might change in the future:
# https://github.com/beartype/beartype/issues/54
def compute_transition_probs(
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
hmm_posterior: HMMPosterior,
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None
) -> Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]]:
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None
) -> Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]:
r"""Compute the posterior marginal distributions $p(z_{t+1}, z_t \mid y_{1:T}, u_{1:T}, \theta)$.

Args:
Expand All @@ -622,8 +630,10 @@ def compute_transition_probs(
Returns:
array of smoothed transition probabilities.
"""
reduce_sum = transition_matrix is not None and transition_matrix.ndim == 2
if reduce_sum:
if transition_matrix is None and transition_fn is None:
raise ValueError("Either `transition_matrix` or `transition_fn` must be specified.")

if transition_matrix is not None and transition_matrix.ndim == 2:
return _compute_sum_transition_probs(transition_matrix, hmm_posterior)
else:
return _compute_all_transition_probs(transition_matrix, hmm_posterior, transition_fn=transition_fn)
1 change: 0 additions & 1 deletion dynamax/hidden_markov_model/inference_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
import itertools as it
import jax.numpy as jnp
import jax.random as jr
Expand Down
Loading
Loading