diff --git a/dynamax/generalized_gaussian_ssm/inference.py b/dynamax/generalized_gaussian_ssm/inference.py index fa9af8fa..ca58bde5 100644 --- a/dynamax/generalized_gaussian_ssm/inference.py +++ b/dynamax/generalized_gaussian_ssm/inference.py @@ -3,7 +3,6 @@ from jax import jacfwd, vmap, lax import jax.numpy as jnp from jax import lax -from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN from jaxtyping import Array, Float from typing import NamedTuple, Optional, Union, Callable @@ -83,7 +82,7 @@ def compute_weights_and_sigmas(self, m, P): def _predict(m, P, f, Q, u, g_ev, g_cov): - """Predict next mean and covariance under an additive-noise Gaussian filter + r"""Predict next mean and covariance under an additive-noise Gaussian filter p(x_{t+1}) = N(x_{t+1} | mu_pred, Sigma_pred) where @@ -117,7 +116,7 @@ def _predict(m, P, f, Q, u, g_ev, g_cov): def _condition_on(m, P, y_cond_mean, y_cond_cov, u, y, g_ev, g_cov, num_iter, emission_dist): - """Condition a Gaussian potential on a new observation with arbitrary + r"""Condition a Gaussian potential on a new observation with arbitrary likelihood with given functions for conditional moments and make a Gaussian approximation. p(x_t | y_t, u_t, y_{1:t-1}, u_{1:t-1}) @@ -172,7 +171,7 @@ def _step(carry, _): def _statistical_linear_regression(mu, Sigma, m, S, C): - """Return moment-matching affine coefficients and approximation noise variance + r"""Return moment-matching affine coefficients and approximation noise variance given joint moments. g(x) \approx Ax + b + e where e ~ N(0, Omega) diff --git a/dynamax/generalized_gaussian_ssm/models.py b/dynamax/generalized_gaussian_ssm/models.py index dbbba7c8..7f387cf9 100644 --- a/dynamax/generalized_gaussian_ssm/models.py +++ b/dynamax/generalized_gaussian_ssm/models.py @@ -11,11 +11,11 @@ from dynamax.nonlinear_gaussian_ssm.models import FnStateToState, FnStateAndInputToState from dynamax.nonlinear_gaussian_ssm.models import FnStateToEmission, FnStateAndInputToEmission -FnStateToEmission2 = Callable[[Float[Array, "state_dim"]], Float[Array, "emission_dim emission_dim"]] -FnStateAndInputToEmission2 = Callable[[Float[Array, "state_dim"], Float[Array, "input_dim"]], Float[Array, "emission_dim emission_dim"]] +FnStateToEmission2 = Callable[[Float[Array, " state_dim"]], Float[Array, "emission_dim emission_dim"]] +FnStateAndInputToEmission2 = Callable[[Float[Array, " state_dim"], Float[Array, " input_dim"]], Float[Array, "emission_dim emission_dim"]] # emission distribution takes a mean vector and covariance matrix and returns a distribution -EmissionDistFn = Callable[ [Float[Array, "state_dim"], Float[Array, "state_dim state_dim"]], tfd.Distribution] +EmissionDistFn = Callable[ [Float[Array, " state_dim"], Float[Array, "state_dim state_dim"]], tfd.Distribution] class ParamsGGSSM(NamedTuple): @@ -42,7 +42,7 @@ class ParamsGGSSM(NamedTuple): """ - initial_mean: Float[Array, "state_dim"] + initial_mean: Float[Array, " state_dim"] initial_covariance: Float[Array, "state_dim state_dim"] dynamics_function: Union[FnStateToState, FnStateAndInputToState] dynamics_covariance: Float[Array, "state_dim state_dim"] @@ -97,15 +97,15 @@ def covariates_shape(self): def initial_distribution( self, params: ParamsGGSSM, - inputs: Optional[Float[Array, "input_dim"]]=None + inputs: Optional[Float[Array, " input_dim"]]=None ) -> tfd.Distribution: return MVN(params.initial_mean, params.initial_covariance) def transition_distribution( self, params: ParamsGGSSM, - state: Float[Array, "state_dim"], - inputs: Optional[Float[Array, "input_dim"]]=None + state: Float[Array, " state_dim"], + inputs: Optional[Float[Array, " input_dim"]]=None ) -> tfd.Distribution: f = params.dynamics_function if inputs is None: @@ -117,8 +117,8 @@ def transition_distribution( def emission_distribution( self, params: ParamsGGSSM, - state: Float[Array, "state_dim"], - inputs: Optional[Float[Array, "input_dim"]]=None + state: Float[Array, " state_dim"], + inputs: Optional[Float[Array, " input_dim"]]=None ) -> tfd.Distribution: h = params.emission_mean_function R = params.emission_cov_function @@ -128,4 +128,4 @@ def emission_distribution( else: mean = h(state, inputs) cov = R(state, inputs) - return params.emission_dist(mean, cov) \ No newline at end of file + return params.emission_dist(mean, cov) diff --git a/dynamax/hidden_markov_model/demos/categorical_glm_hmm_demo.py b/dynamax/hidden_markov_model/demos/categorical_glm_hmm_demo.py index 365559fe..bc90faaa 100644 --- a/dynamax/hidden_markov_model/demos/categorical_glm_hmm_demo.py +++ b/dynamax/hidden_markov_model/demos/categorical_glm_hmm_demo.py @@ -58,7 +58,7 @@ plt.figure() plt.imshow(jnp.vstack((states[None, :], most_likely_states[None, :])), aspect="auto", interpolation='none', cmap="Greys") - plt.yticks([0.0, 1.0], ["$z$", "$\hat{z}$"]) + plt.yticks([0.0, 1.0], ["$z$", r"$\hat{z}$"]) plt.xlabel("time") plt.xlim(0, 500) @@ -66,4 +66,4 @@ print("true log prob: ", hmm.marginal_log_prob(true_params, emissions, inputs=inputs)) print("test log prob: ", test_hmm.marginal_log_prob(params, emissions, inputs=inputs)) - plt.show() \ No newline at end of file + plt.show() diff --git a/dynamax/hidden_markov_model/inference.py b/dynamax/hidden_markov_model/inference.py index 338a2ee8..7b1f914b 100644 --- a/dynamax/hidden_markov_model/inference.py +++ b/dynamax/hidden_markov_model/inference.py @@ -1,25 +1,29 @@ +from functools import partial +from typing import Callable, NamedTuple, Optional, Tuple, Union import jax.numpy as jnp import jax.random as jr -from jax import lax -from jax import vmap -from jax import jit -from functools import partial - -from typing import Callable, Optional, Tuple, Union, NamedTuple +from jax import jit, lax, vmap from jaxtyping import Int, Float, Array -from dynamax.types import Scalar, PRNGKey +from dynamax.types import IntScalar, Scalar _get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x -def get_trans_mat(transition_matrix, transition_fn, t): +def get_trans_mat( + transition_matrix: Optional[Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]], + transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]], + t: IntScalar +) -> Float[Array, "num_states num_states"]: if transition_fn is not None: return transition_fn(t) - else: - if transition_matrix.ndim == 3: # (T,K,K) + elif transition_matrix is not None: + if transition_matrix.ndim == 3: # (T-1,K,K) return transition_matrix[t] else: return transition_matrix + else: + raise ValueError("Either `transition_matrix` or `transition_fn` must be specified.") class HMMPosteriorFiltered(NamedTuple): r"""Simple wrapper for properties of an HMM filtering posterior. @@ -50,12 +54,12 @@ class HMMPosterior(NamedTuple): filtered_probs: Float[Array, "num_timesteps num_states"] predicted_probs: Float[Array, "num_timesteps num_states"] smoothed_probs: Float[Array, "num_timesteps num_states"] - initial_probs: Float[Array, "num_states"] - trans_probs: Optional[Union[Float[Array, "num_timesteps num_states num_states"], - Float[Array, "num_states num_states"]]] = None + initial_probs: Float[Array, " num_states"] + trans_probs: Optional[Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]] = None -def _normalize(u, axis=0, eps=1e-15): +def _normalize(u: Array, axis=0, eps=1e-15): """Normalizes the values within the axis in a way that they sum up to 1. Args: @@ -97,11 +101,11 @@ def _predict(probs, A): @partial(jit, static_argnames=["transition_fn"]) def hmm_filter( - initial_distribution: Float[Array, "num_states"], - transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"], - Float[Array, "num_states num_states"]], + initial_distribution: Float[Array, " num_states"], + transition_matrix: Optional[Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]], log_likelihoods: Float[Array, "num_timesteps num_states"], - transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None + transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None ) -> HMMPosteriorFiltered: r"""Forwards filtering @@ -145,11 +149,11 @@ def _step(carry, t): @partial(jit, static_argnames=["transition_fn"]) def hmm_backward_filter( - transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"], - Float[Array, "num_states num_states"]], + transition_matrix: Optional[Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]], log_likelihoods: Float[Array, "num_timesteps num_states"], - transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None -) -> Tuple[Float, Float[Array, "num_timesteps num_states"]]: + transition_fn: Optional[Callable[[int], Float[Array, "num_states num_states"]]]= None +) -> Tuple[Scalar, Float[Array, "num_timesteps num_states"]]: r"""Run the filter backwards in time. This is the second step of the forward-backward algorithm. Transition matrix may be either 2D (if transition probabilities are fixed) or 3D @@ -191,11 +195,11 @@ def _step(carry, t): @partial(jit, static_argnames=["transition_fn"]) def hmm_two_filter_smoother( - initial_distribution: Float[Array, "num_states"], - transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"], - Float[Array, "num_states num_states"]], + initial_distribution: Float[Array, " num_states"], + transition_matrix: Optional[Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]], log_likelihoods: Float[Array, "num_timesteps num_states"], - transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None, + transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None, compute_trans_probs: bool = True ) -> HMMPosterior: r"""Computed the smoothed state probabilities using the two-filter @@ -245,11 +249,11 @@ def hmm_two_filter_smoother( @partial(jit, static_argnames=["transition_fn"]) def hmm_smoother( - initial_distribution: Float[Array, "num_states"], - transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"], - Float[Array, "num_states num_states"]], + initial_distribution: Float[Array, " num_states"], + transition_matrix: Optional[Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]], log_likelihoods: Float[Array, "num_timesteps num_states"], - transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None, + transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None, compute_trans_probs: bool = True ) -> HMMPosterior: r"""Computed the smoothed state probabilities using a general @@ -325,12 +329,12 @@ def _step(carry, args): @partial(jit, static_argnames=["transition_fn", "window_size"]) def hmm_fixed_lag_smoother( - initial_distribution: Float[Array, "num_states"], - transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"], - Float[Array, "num_states num_states"]], + initial_distribution: Float[Array, " num_states"], + transition_matrix: Optional[Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]], log_likelihoods: Float[Array, "num_timesteps num_states"], - window_size: Int, - transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None + window_size: int, + transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None ) -> HMMPosterior: r"""Compute the smoothed state probabilities using the fixed-lag smoother. @@ -439,12 +443,12 @@ def compute_posterior(filtered_probs, beta): @partial(jit, static_argnames=["transition_fn"]) def hmm_posterior_mode( - initial_distribution: Float[Array, "num_states"], - transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"], - Float[Array, "num_states num_states"]], + initial_distribution: Float[Array, " num_states"], + transition_matrix: Optional[Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]], log_likelihoods: Float[Array, "num_timesteps num_states"], - transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None -) -> Int[Array, "num_timesteps"]: + transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None +) -> Int[Array, " num_timesteps"]: r"""Compute the most likely state sequence. This is called the Viterbi algorithm. Args: @@ -486,13 +490,13 @@ def _forward_pass(state, best_next_state): @partial(jit, static_argnames=["transition_fn"]) def hmm_posterior_sample( - rng: jr.PRNGKey, - initial_distribution: Float[Array, "num_states"], - transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"], - Float[Array, "num_states num_states"]], + key: Array, + initial_distribution: Float[Array, " num_states"], + transition_matrix: Optional[Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]], log_likelihoods: Float[Array, "num_timesteps num_states"], - transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None -) -> Int[Array, "num_timesteps"]: + transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None +) -> Tuple[Scalar, Int[Array, " num_timesteps"]]: r"""Sample a latent sequence from the posterior. Args: @@ -515,7 +519,7 @@ def hmm_posterior_sample( # Run the sampler backward in time def _step(carry, args): next_state = carry - t, rng, filtered_probs = args + t, subkey, filtered_probs = args A = get_trans_mat(transition_matrix, transition_fn, t) @@ -524,15 +528,15 @@ def _step(carry, args): smoothed_probs /= smoothed_probs.sum() # Sample current state - state = jr.choice(rng, a=num_states, p=smoothed_probs) + state = jr.choice(subkey, a=num_states, p=smoothed_probs) return state, state # Run the HMM smoother - rngs = jr.split(rng, num_timesteps) - last_state = jr.choice(rngs[-1], a=num_states, p=filtered_probs[-1]) + keys = jr.split(key, num_timesteps) + last_state = jr.choice(keys[-1], a=num_states, p=filtered_probs[-1]) _, states = lax.scan( - _step, last_state, (jnp.arange(1, num_timesteps), rngs[:-1], filtered_probs[:-1]), + _step, last_state, (jnp.arange(1, num_timesteps), keys[:-1], filtered_probs[:-1]), reverse=True ) @@ -544,12 +548,13 @@ def _compute_sum_transition_probs( transition_matrix: Float[Array, "num_states num_states"], hmm_posterior: HMMPosterior) -> Float[Array, "num_states num_states"]: """Compute the transition probabilities from the HMM posterior messages. + Args: transition_matrix (_type_): _description_ hmm_posterior (_type_): _description_ """ - def _step(carry, args): + def _step(carry, args: Tuple[Array, Array, Array, Int[Array, ""]]): filtered_probs, smoothed_probs_next, predicted_probs_next, t = args # Get parameters for time t @@ -580,11 +585,13 @@ def _step(carry, args): def _compute_all_transition_probs( - transition_matrix: Float[Array, "num_timesteps num_states num_states"], + transition_matrix: Optional[Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]], hmm_posterior: HMMPosterior, - transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None + transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None ) -> Float[Array, "num_timesteps num_states num_states"]: """Compute the transition probabilities from the HMM posterior messages. + Args: transition_matrix (_type_): _description_ hmm_posterior (_type_): _description_ @@ -598,20 +605,21 @@ def _compute_probs(t): A = get_trans_mat(transition_matrix, transition_fn, t) return jnp.einsum('i,ij,j->ij', filtered_probs[t], A, relative_probs_next[t]) - transition_probs = vmap(_compute_probs)(jnp.arange(len(filtered_probs)-1)) + transition_probs = vmap(_compute_probs)(jnp.arange(len(filtered_probs))) return transition_probs -# TODO: Consider alternative annotation for return type: -# Float[Array, "*num_timesteps num_states num_states"] I think this would allow multiple prepended dims. -# Float[Array, "#num_timesteps num_states num_states"] this might accept (1, sd, sd) but not (sd, sd). +# TODO: This is a candidate for @overload however at present I think we would need to use +# `@beartype.typing.overload` and beartype is currently not a core dependency. +# Support for `typing.overload` might change in the future: +# https://github.com/beartype/beartype/issues/54 def compute_transition_probs( - transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"], - Float[Array, "num_states num_states"]], + transition_matrix: Optional[Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]], hmm_posterior: HMMPosterior, - transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None -) -> Union[Float[Array, "num_timesteps num_states num_states"], - Float[Array, "num_states num_states"]]: + transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None +) -> Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]: r"""Compute the posterior marginal distributions $p(z_{t+1}, z_t \mid y_{1:T}, u_{1:T}, \theta)$. Args: @@ -622,8 +630,10 @@ def compute_transition_probs( Returns: array of smoothed transition probabilities. """ - reduce_sum = transition_matrix is not None and transition_matrix.ndim == 2 - if reduce_sum: + if transition_matrix is None and transition_fn is None: + raise ValueError("Either `transition_matrix` or `transition_fn` must be specified.") + + if transition_matrix is not None and transition_matrix.ndim == 2: return _compute_sum_transition_probs(transition_matrix, hmm_posterior) else: return _compute_all_transition_probs(transition_matrix, hmm_posterior, transition_fn=transition_fn) diff --git a/dynamax/hidden_markov_model/inference_test.py b/dynamax/hidden_markov_model/inference_test.py index 9a4babfe..6f814279 100644 --- a/dynamax/hidden_markov_model/inference_test.py +++ b/dynamax/hidden_markov_model/inference_test.py @@ -1,4 +1,3 @@ -import pytest import itertools as it import jax.numpy as jnp import jax.random as jr diff --git a/dynamax/hidden_markov_model/models/abstractions.py b/dynamax/hidden_markov_model/models/abstractions.py index 739b747d..502a299b 100644 --- a/dynamax/hidden_markov_model/models/abstractions.py +++ b/dynamax/hidden_markov_model/models/abstractions.py @@ -1,9 +1,11 @@ from abc import abstractmethod, ABC +from typing import Any, Optional, Tuple, runtime_checkable, Union +from typing_extensions import Protocol from dynamax.ssm import SSM -from dynamax.types import Scalar +from dynamax.types import IntScalar, Scalar from dynamax.parameters import to_unconstrained, from_unconstrained from dynamax.parameters import ParameterSet, PropertySet -from dynamax.hidden_markov_model.inference import HMMPosterior, HMMPosteriorFiltered +from dynamax.hidden_markov_model.inference import HMMPosterior from dynamax.hidden_markov_model.inference import hmm_filter from dynamax.hidden_markov_model.inference import hmm_posterior_mode from dynamax.hidden_markov_model.inference import hmm_smoother @@ -11,16 +13,15 @@ from dynamax.utils.optimize import run_gradient_descent from dynamax.utils.utils import pytree_slice import jax.numpy as jnp -import jax.random as jr from jax import vmap from jax.tree_util import tree_map -from jaxtyping import Float, Array, PyTree +from jaxtyping import Float, Array, PyTree, Real import optax from tensorflow_probability.substrates.jax import distributions as tfd -from typing import Any, Optional, Tuple -from typing_extensions import Protocol + +@runtime_checkable class HMMParameterSet(Protocol): """Container for HMM parameters. @@ -28,11 +29,20 @@ class HMMParameterSet(Protocol): :param transitions: (ParameterSet) transition distribution parameters :param emissions: (ParameterSet) emission distribution parameters """ - initial: ParameterSet - transitions: ParameterSet - emissions: ParameterSet + @property + def initial(self) -> ParameterSet: + pass + + @property + def transitions(self) -> ParameterSet: + pass + + @property + def emissions(self) -> ParameterSet: + pass +@runtime_checkable class HMMPropertySet(Protocol): """Container for properties of HMM parameter properties. @@ -40,10 +50,17 @@ class HMMPropertySet(Protocol): :param transitions: (PropertySet) transition distribution properties :param emissions: (PropertySet) emission distribution properties """ - initial: PropertySet - transitions: PropertySet - emissions: PropertySet + @property + def initial(self) -> PropertySet: + pass + @property + def transitions(self) -> PropertySet: + pass + + @property + def emissions(self) -> PropertySet: + pass class HMMInitialState(ABC): @@ -59,7 +76,7 @@ def __init__(self, @abstractmethod def distribution(self, params: ParameterSet, - inputs: Optional[Float[Array, "input_dim"]]=None + inputs: Optional[Float[Array, " input_dim"]]=None ) -> tfd.Distribution: """Return a distribution over the initial latent state @@ -71,7 +88,7 @@ def distribution(self, @abstractmethod def initialize(self, - key: jr.PRNGKey=None, + key: Optional[Array]=None, method: str="prior", **kwargs ) -> Tuple[ParameterSet, PropertySet]: @@ -96,14 +113,14 @@ def log_prior(self, params: ParameterSet) -> Scalar: """ raise NotImplementedError - def _compute_initial_probs(self, params, inputs=None): - return self.initial_distribution(params, inputs).probs_parameter() + def _compute_initial_probs(self, params, inputs:Optional[Array] = None): + return self.distribution(params, inputs).probs_parameter() def collect_suff_stats(self, params: ParameterSet, posterior: HMMPosterior, inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None - ) -> PyTree: + ) -> Tuple[Float[Array, " num_states"], Optional[Float[Array, " input_dim"]]]: """Collect sufficient statistics for updating the initial distribution parameters. Args: @@ -135,7 +152,7 @@ def m_step(self, batch_stats: PyTree, m_step_state: Any, scale: float=1.0 - ) -> ParameterSet: + ) -> Tuple[ParameterSet, Any]: """Perform an M-step on the initial distribution parameters. Args: @@ -194,8 +211,8 @@ def __init__(self, @abstractmethod def distribution(self, params: ParameterSet, - state: int, - inputs: Optional[Float[Array, "input_dim"]]=None + state: IntScalar, + inputs: Optional[Float[Array, " input_dim"]]=None ) -> tfd.Distribution: """Return a distribution over the next latent state @@ -212,7 +229,7 @@ def distribution(self, @abstractmethod def initialize(self, - key: jr.PRNGKey=None, + key: Optional[Array]=None, method: str="prior", **kwargs ) -> Tuple[ParameterSet, PropertySet]: @@ -238,7 +255,7 @@ def log_prior(self, params: ParameterSet) -> Scalar: """ raise NotImplementedError - def _compute_transition_matrices(self, params, inputs=None): + def _compute_transition_matrices(self, params, inputs:Optional[Array] = None): if inputs is not None: f = lambda inpt: \ vmap(lambda state: \ @@ -254,7 +271,7 @@ def collect_suff_stats(self, params: ParameterSet, posterior: HMMPosterior, inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None - ) -> PyTree: + ) -> Tuple[Float[Array, "..."], Optional[Float[Array, "num_timesteps-1 input_dim"]]]: """Collect sufficient statistics for updating the transition distribution parameters. Args: @@ -284,7 +301,7 @@ def m_step(self, batch_stats: PyTree, m_step_state: Any, scale: float=1.0 - ) -> ParameterSet: + ) -> Tuple[ParameterSet, Any]: """Perform an M-step on the transition distribution parameters. Args: @@ -350,8 +367,8 @@ def emission_shape(self) -> Tuple[int]: @abstractmethod def distribution(self, params: ParameterSet, - state: int, - inputs: Optional[Float[Array, "input_dim"]]=None + state: IntScalar, + inputs: Optional[Float[Array, " input_dim"]]=None ) -> tfd.Distribution: """Return a distribution over the emission @@ -368,7 +385,7 @@ def distribution(self, @abstractmethod def initialize(self, - key: jr.PRNGKey=None, + key: Optional[Array]=None, method: str="prior", **kwargs ) -> Tuple[ParameterSet, PropertySet]: @@ -394,7 +411,7 @@ def log_prior(self, params: ParameterSet) -> Scalar: """ raise NotImplementedError - def _compute_conditional_logliks(self, params, emissions, inputs=None): + def _compute_conditional_logliks(self, params, emissions, inputs:Optional[Array] = None): # Compute the log probability for each time step by # performing a nested vmap over emission time steps and states. f = lambda emission, inpt: \ @@ -405,9 +422,12 @@ def _compute_conditional_logliks(self, params, emissions, inputs=None): def collect_suff_stats(self, params: ParameterSet, posterior: HMMPosterior, - emissions: Float[Array, "num_timesteps emission_dim"], + emissions: Union[Real[Array, "num_timesteps emission_dim"], + Real[Array, " num_timesteps"]], inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None - ) -> PyTree: + ) -> Tuple[Float[Array, "num_timesteps num_states"], + Union[Real[Array, "num_timesteps emission_dim"], Real[Array, " num_timesteps"]], + Optional[Float[Array, "num_timesteps input_dim"]]]: """Collect sufficient statistics for updating the emission distribution parameters. Args: @@ -438,7 +458,7 @@ def m_step(self, batch_stats: PyTree, m_step_state: Any, scale: float=1.0 - ) -> ParameterSet: + ) -> Tuple[ParameterSet, Any]: """Perform an M-step on the emission distribution parameters. Args: @@ -485,7 +505,7 @@ def _single_expected_log_like(stats): class HMM(SSM): - """Abstract base class of Hidden Markov Models (HMMs). + r"""Abstract base class of Hidden Markov Models (HMMs). The model is defined as follows @@ -532,43 +552,48 @@ def __init__(self, def emission_shape(self): return self.emission_component.emission_shape - def initial_distribution(self, params, inputs=None): + def initial_distribution(self, params: HMMParameterSet, inputs:Optional[Array] = None) -> tfd.Distribution: return self.initial_component.distribution(params.initial, inputs=inputs) - def transition_distribution(self, params, state, inputs=None): + def transition_distribution(self, params: HMMParameterSet, state: IntScalar, inputs:Optional[Array] = None) -> tfd.Distribution: return self.transition_component.distribution(params.transitions, state, inputs=inputs) - def emission_distribution(self, params, state, inputs=None): + def emission_distribution(self, params: HMMParameterSet, state: IntScalar, inputs:Optional[Array] = None): return self.emission_component.distribution(params.emissions, state, inputs=inputs) - def log_prior(self, params): + def log_prior(self, params: HMMParameterSet) -> Scalar: lp = self.initial_component.log_prior(params.initial) lp += self.transition_component.log_prior(params.transitions) lp += self.emission_component.log_prior(params.emissions) return lp # The inference functions all need the same arguments - def _inference_args(self, params, emissions, inputs): + def _inference_args(self, params: HMMParameterSet, emissions: Array, inputs: Optional[Array]): return (self.initial_component._compute_initial_probs(params.initial, inputs), self.transition_component._compute_transition_matrices(params.transitions, inputs), self.emission_component._compute_conditional_logliks(params.emissions, emissions, inputs)) # Convenience wrappers for the inference code - def marginal_log_prob(self, params, emissions, inputs=None): + def marginal_log_prob(self, params: HMMParameterSet, emissions: Array, inputs: Optional[Array]=None): post = hmm_filter(*self._inference_args(params, emissions, inputs)) return post.marginal_loglik - def most_likely_states(self, params, emissions, inputs=None): + def most_likely_states(self, params: HMMParameterSet, emissions: Array, inputs: Optional[Array]=None): return hmm_posterior_mode(*self._inference_args(params, emissions, inputs)) - def filter(self, params, emissions, inputs=None): + def filter(self, params: HMMParameterSet, emissions: Array, inputs: Optional[Array]=None): return hmm_filter(*self._inference_args(params, emissions, inputs)) - def smoother(self, params, emissions, inputs=None): + def smoother(self, params: HMMParameterSet, emissions: Array, inputs: Optional[Array]=None): return hmm_smoother(*self._inference_args(params, emissions, inputs)) # Expectation-maximization (EM) code - def e_step(self, params, emissions, inputs=None): + def e_step( + self, + params: HMMParameterSet, + emissions: Array, + inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None + ) -> Tuple[PyTree, Scalar]: """The E-step computes expected sufficient statistics under the posterior. In the generic case, we simply return the posterior itself. """ @@ -580,7 +605,7 @@ def e_step(self, params, emissions, inputs=None): emission_stats = self.emission_component.collect_suff_stats(params.emissions, posterior, emissions, inputs) return (initial_stats, transition_stats, emission_stats), posterior.marginal_loglik - def initialize_m_step_state(self, params, props): + def initialize_m_step_state(self, params: HMMParameterSet, props: HMMPropertySet): """Initialize any required state for the M step. For example, this might include the optimizer state for Adam. @@ -590,7 +615,13 @@ def initialize_m_step_state(self, params, props): emissions_m_step_state = self.emission_component.initialize_m_step_state(params.emissions, props.emissions) return initial_m_step_state, transitions_m_step_state, emissions_m_step_state - def m_step(self, params, props, batch_stats, m_step_state): + def m_step( + self, + params: HMMParameterSet, + props: HMMPropertySet, + batch_stats: PyTree, + m_step_state: Any + ) -> Tuple[HMMParameterSet, Any]: batch_initial_stats, batch_transition_stats, batch_emission_stats = batch_stats initial_m_step_state, transitions_m_step_state, emissions_m_step_state = m_step_state diff --git a/dynamax/hidden_markov_model/models/arhmm.py b/dynamax/hidden_markov_model/models/arhmm.py index 2eff832b..b7d1fa61 100644 --- a/dynamax/hidden_markov_model/models/arhmm.py +++ b/dynamax/hidden_markov_model/models/arhmm.py @@ -1,8 +1,10 @@ +from typing import NamedTuple, Optional, Tuple, Union import jax.numpy as jnp import jax.random as jr from jax import lax from jax.tree_util import tree_map -from jaxtyping import Float, Array +from jaxtyping import Int, Float, Array + from dynamax.hidden_markov_model.models.abstractions import HMM, HMMParameterSet, HMMPropertySet from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions @@ -11,7 +13,6 @@ from dynamax.types import Scalar from dynamax.utils.bijectors import RealToPSDBijector from tensorflow_probability.substrates import jax as tfp -from typing import NamedTuple, Optional, Tuple, Union tfd = tfp.distributions tfb = tfp.bijectors @@ -25,21 +26,22 @@ class ParamsLinearAutoregressiveHMM(NamedTuple): class LinearAutoregressiveHMMEmissions(LinearRegressionHMMEmissions): def __init__(self, - num_states, - emission_dim, - num_lags=1): + num_states: int, + emission_dim: int, + num_lags: int=1): self.num_lags = num_lags self.emission_dim = emission_dim input_dim = num_lags * emission_dim super().__init__(num_states, input_dim, emission_dim) def initialize(self, - key=jr.PRNGKey(0), - method="prior", - emission_weights=None, - emission_biases=None, - emission_covariances=None, - emissions=None): + key: Array=jr.PRNGKey(0), + method: str="prior", + emission_weights: Optional[Float[Array, "num_states emission_dim input_dim"]]=None, + emission_biases: Optional[Float[Array, "num_states emission_dim"]]=None, + emission_covariances: Optional[Float[Array, "num_states emission_dim emission_dim"]]=None, + emissions: Optional[Float[Array, "num_timesteps emission_dim"]]=None + ) -> Tuple[ParamsLinearRegressionHMMEmissions, ParamsLinearRegressionHMMEmissions]: if method.lower() == "kmeans": assert emissions is not None, "Need emissions to initialize the model with K-Means!" from sklearn.cluster import KMeans @@ -107,8 +109,8 @@ def __init__(self, num_states: int, emission_dim: int, num_lags: int=1, - initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, - transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, + initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, + transition_matrix_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, transition_matrix_stickiness: Scalar=0.0): self.emission_dim = emission_dim self.num_lags = num_lags @@ -125,9 +127,9 @@ def inputs_shape(self): return (self.num_lags * self.emission_dim,) def initialize(self, - key: jr.PRNGKey=jr.PRNGKey(0), + key: Array=jr.PRNGKey(0), method: str="prior", - initial_probs: Optional[Float[Array, "num_states"]]=None, + initial_probs: Optional[Float[Array, " num_states"]]=None, transition_matrix: Optional[Float[Array, "num_states num_states"]]=None, emission_weights: Optional[Float[Array, "num_states emission_dim emission_dim_times_num_lags"]]=None, emission_biases: Optional[Float[Array, "num_states emission_dim"]]=None, @@ -163,10 +165,10 @@ def initialize(self, def sample(self, params: HMMParameterSet, - key: jr.PRNGKey, + key: Array, num_timesteps: int, prev_emissions: Optional[Float[Array, "num_lags emission_dim"]]=None, - ) -> Tuple[Float[Array, "num_timesteps state_dim"], Float[Array, "num_timesteps emission_dim"]]: + ) -> Tuple[Int[Array, " num_timesteps"], Float[Array, "num_timesteps emission_dim"]]: r"""Sample states $z_{1:T}$ and emissions $y_{1:T}$ given parameters $\theta$. Args: @@ -211,7 +213,7 @@ def _step(carry, key): def compute_inputs(self, emissions: Float[Array, "num_timesteps emission_dim"], prev_emissions: Optional[Float[Array, "num_lags emission_dim"]]=None - ) -> Float[Array, "num_timesteps emission_dim_times_num_lags"]: + ) -> Float[Array, "num_timesteps {self.num_lags}*{self.emission_dim}"]: r"""Helper function to compute the matrix of lagged emissions. Args: diff --git a/dynamax/hidden_markov_model/models/bernoulli_hmm.py b/dynamax/hidden_markov_model/models/bernoulli_hmm.py index ab6bc8c3..4e7d29ec 100644 --- a/dynamax/hidden_markov_model/models/bernoulli_hmm.py +++ b/dynamax/hidden_markov_model/models/bernoulli_hmm.py @@ -17,7 +17,7 @@ class ParamsBernoulliHMMEmissions(NamedTuple): - probs: Union[Float[Array, "emission_dim"], ParameterProperties] + probs: Union[Float[Array, " emission_dim"], ParameterProperties] class ParamsBernoulliHMM(NamedTuple): @@ -28,11 +28,13 @@ class ParamsBernoulliHMM(NamedTuple): class BernoulliHMMEmissions(HMMEmissions): - def __init__(self, - num_states, - emission_dim, - emission_prior_concentration1=1.1, - emission_prior_concentration0=1.1): + def __init__( + self, + num_states: int, + emission_dim: int, + emission_prior_concentration1: Scalar = 1.1, + emission_prior_concentration0: Scalar = 1.1, + ): """_summary_ Args: emission_probs (_type_): _description_ @@ -43,22 +45,26 @@ def __init__(self, self.emission_prior_concentration1 = emission_prior_concentration1 @property - def emission_shape(self): + def emission_shape(self) -> Tuple[int]: return (self.emission_dim,) - def distribution(self, params, state, inputs=None): + def distribution(self, params, state, inputs=None) -> tfd.Distribution: # This model assumes the emissions are a vector of conditionally independent # Bernoulli observations. The `reinterpreted_batch_ndims` argument tells # `tfd.Independent` that only the last dimension should be considered a "batch" # of conditionally independent observations. return tfd.Independent(tfd.Bernoulli(probs=params.probs[state]), reinterpreted_batch_ndims=1) - def log_prior(self, params): - prior = tfd.Beta(self.emission_prior_concentration1, - self.emission_prior_concentration0) + def log_prior(self, params) -> Float[Array, ""]: + prior = tfd.Beta(self.emission_prior_concentration1, self.emission_prior_concentration0) return prior.log_prob(params.probs).sum() - def initialize(self, key=jr.PRNGKey(0), method="prior", emission_probs=None): + def initialize( + self, + key: Array = jr.PRNGKey(0), + method="prior", + emission_probs: Optional[Float[Array, "num_states emission_dim"]] = None, + ) -> Tuple[ParamsBernoulliHMMEmissions, ParamsBernoulliHMMEmissions]: if emission_probs is None: if method.lower() == "prior": prior = tfd.Beta(self.emission_prior_concentration1, self.emission_prior_concentration0) @@ -90,8 +96,8 @@ def m_step(self, params, props, batch_stats, m_step_state): if props.probs.trainable: sum_x, sum_1mx = pytree_sum(batch_stats, axis=0) probs = tfd.Beta( - self.emission_prior_concentration1 + sum_x, - self.emission_prior_concentration0 + sum_1mx).mode() + self.emission_prior_concentration1 + sum_x, self.emission_prior_concentration0 + sum_1mx + ).mode() params = params._replace(probs=probs) return params, m_step_state @@ -117,25 +123,37 @@ class BernoulliHMM(HMM): :param emission_prior_concentration1: $\gamma_1$ """ - def __init__(self, num_states: int, - emission_dim: int, - initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, - transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, - transition_matrix_stickiness: Scalar=0.0, - emission_prior_concentration0: Scalar=1.1, - emission_prior_concentration1: Scalar=1.1): + + def __init__( + self, + num_states: int, + emission_dim: int, + initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]] = 1.1, + transition_matrix_concentration: Union[Scalar, Float[Array, " num_states"]] = 1.1, + transition_matrix_stickiness: Scalar = 0.0, + emission_prior_concentration0: Scalar = 1.1, + emission_prior_concentration1: Scalar = 1.1, + ): self.emission_dim = emission_dim initial_component = StandardHMMInitialState(num_states, initial_probs_concentration=initial_probs_concentration) - transition_component = StandardHMMTransitions(num_states, concentration=transition_matrix_concentration, stickiness=transition_matrix_stickiness) - emission_component = BernoulliHMMEmissions(num_states, emission_dim, emission_prior_concentration0=emission_prior_concentration0, emission_prior_concentration1=emission_prior_concentration1) + transition_component = StandardHMMTransitions( + num_states, concentration=transition_matrix_concentration, stickiness=transition_matrix_stickiness + ) + emission_component = BernoulliHMMEmissions( + num_states, + emission_dim, + emission_prior_concentration0=emission_prior_concentration0, + emission_prior_concentration1=emission_prior_concentration1, + ) super().__init__(num_states, initial_component, transition_component, emission_component) - def initialize(self, - key: jr.PRNGKey=jr.PRNGKey(0), - method: str="prior", - initial_probs: Optional[Float[Array, "num_states"]]=None, - transition_matrix: Optional[Float[Array, "num_states num_states"]]=None, - emission_probs: Optional[Float[Array, "num_states emission_dim"]]=None + def initialize( + self, + key: Array = jr.PRNGKey(0), + method: str = "prior", + initial_probs: Optional[Float[Array, " num_states"]] = None, + transition_matrix: Optional[Float[Array, "num_states num_states"]] = None, + emission_probs: Optional[Float[Array, "num_states emission_dim"]] = None, ) -> Tuple[ParameterSet, PropertySet]: """Initialize the model parameters and their corresponding properties. @@ -155,9 +173,15 @@ def initialize(self, Returns: Model parameters and their properties. """ - key1, key2, key3 = jr.split(key , 3) + key1, key2, key3 = jr.split(key, 3) params, props = dict(), dict() - params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs) - params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix) - params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_probs=emission_probs) + params["initial"], props["initial"] = self.initial_component.initialize( + key1, method=method, initial_probs=initial_probs + ) + params["transitions"], props["transitions"] = self.transition_component.initialize( + key2, method=method, transition_matrix=transition_matrix + ) + params["emissions"], props["emissions"] = self.emission_component.initialize( + key3, method=method, emission_probs=emission_probs + ) return ParamsBernoulliHMM(**params), ParamsBernoulliHMM(**props) diff --git a/dynamax/hidden_markov_model/models/categorical_glm_hmm.py b/dynamax/hidden_markov_model/models/categorical_glm_hmm.py index 71364e37..c1127a89 100644 --- a/dynamax/hidden_markov_model/models/categorical_glm_hmm.py +++ b/dynamax/hidden_markov_model/models/categorical_glm_hmm.py @@ -1,4 +1,3 @@ -import jax.numpy as jnp import jax.random as jr import tensorflow_probability.substrates.jax.distributions as tfd from jaxtyping import Float, Array @@ -6,7 +5,7 @@ from dynamax.hidden_markov_model.models.abstractions import HMM, HMMEmissions, HMMParameterSet, HMMPropertySet from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions -from dynamax.types import Scalar +from dynamax.types import IntScalar, Scalar import optax from typing import NamedTuple, Optional, Tuple, Union @@ -25,9 +24,9 @@ class ParamsCategoricalRegressionHMM(NamedTuple): class CategoricalRegressionHMMEmissions(HMMEmissions): def __init__(self, - num_states, - num_classes, - input_dim, + num_states: int, + num_classes: int, + input_dim: int, m_step_optimizer=optax.adam(1e-2), m_step_num_iters=50): """_summary_ @@ -51,7 +50,13 @@ def inputs_shape(self): def log_prior(self, params): return 0.0 - def initialize(self, key=jr.PRNGKey(0), method="prior", emission_weights=None, emission_biases=None): + def initialize( + self, + key: Array=jr.PRNGKey(0), + method: str="prior", + emission_weights: Optional[Float[Array, "num_states num_classes input_dim"]]=None, + emission_biases: Optional[Float[Array, "num_states num_classes"]]=None, + ): """Initialize the model parameters and their corresponding properties. You can either specify parameters manually via the keyword arguments, or you can have @@ -89,7 +94,11 @@ def initialize(self, key=jr.PRNGKey(0), method="prior", emission_weights=None, e biases=ParameterProperties()) return params, props - def distribution(self, params, state, inputs=None): + def distribution( + self, + params: ParamsCategoricalRegressionHMMEmissions, + state: IntScalar, + inputs: Float[Array, " input_dim"]): logits = params.weights[state] @ inputs + params.biases[state] return tfd.Categorical(logits=logits) @@ -121,8 +130,8 @@ def __init__(self, num_states: int, num_classes: int, input_dim: int, - initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, - transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, + initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, + transition_matrix_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, transition_matrix_stickiness: Scalar=0.0, m_step_optimizer: optax.GradientTransformation=optax.adam(1e-2), m_step_num_iters: int=50): @@ -137,9 +146,9 @@ def inputs_shape(self): return (self.input_dim,) def initialize(self, - key: jr.PRNGKey=jr.PRNGKey(0), + key: Array=jr.PRNGKey(0), method: str="prior", - initial_probs: Optional[Float[Array, "num_states"]]=None, + initial_probs: Optional[Float[Array, " num_states"]]=None, transition_matrix: Optional[Float[Array, "num_states num_states"]]=None, emission_weights: Optional[Float[Array, "num_states num_classes input_dim"]]=None, emission_biases: Optional[Float[Array, "num_states num_classes"]]=None, diff --git a/dynamax/hidden_markov_model/models/categorical_hmm.py b/dynamax/hidden_markov_model/models/categorical_hmm.py index 7dfd341a..f6fba693 100644 --- a/dynamax/hidden_markov_model/models/categorical_hmm.py +++ b/dynamax/hidden_markov_model/models/categorical_hmm.py @@ -13,7 +13,7 @@ from dynamax.hidden_markov_model.models.transitions import ParamsStandardHMMTransitions from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions from dynamax.parameters import ParameterProperties, ParameterSet, PropertySet -from dynamax.types import Scalar +from dynamax.types import IntScalar, Scalar from dynamax.utils.utils import pytree_sum @@ -30,10 +30,10 @@ class ParamsCategoricalHMM(NamedTuple): class CategoricalHMMEmissions(HMMEmissions): def __init__(self, - num_states, - emission_dim, - num_classes, - emission_prior_concentration=1.1): + num_states: int, + emission_dim: int, + num_classes: int, + emission_prior_concentration: Union[Scalar, Float[Array, " num_classes"]]=1.1): """_summary_ Args: @@ -45,18 +45,22 @@ def __init__(self, self.emission_prior_concentration = emission_prior_concentration * jnp.ones(num_classes) @property - def emission_shape(self): + def emission_shape(self) -> Tuple[int]: return (self.emission_dim,) - def distribution(self, params, state, inputs=None): + def distribution(self, params: ParamsCategoricalHMMEmissions, state: IntScalar, inputs=None) -> tfd.Distribution: return tfd.Independent( tfd.Categorical(probs=params.probs[state]), reinterpreted_batch_ndims=1) - def log_prior(self, params): + def log_prior(self, params: ParamsCategoricalHMMEmissions) -> Scalar: return tfd.Dirichlet(self.emission_prior_concentration).log_prob(params.probs).sum() - def initialize(self, key=jr.PRNGKey(0), method="prior", emission_probs=None): + def initialize(self, + key:Optional[Array]=jr.PRNGKey(0), + method="prior", + emission_probs:Optional[Float[Array, "num_states emission_dim num_classes"]]=None + ) -> Tuple[ParamsCategoricalHMMEmissions, ParamsCategoricalHMMEmissions]: """Initialize the model parameters and their corresponding properties. You can either specify parameters manually via the keyword arguments, or you can have @@ -77,6 +81,8 @@ def initialize(self, key=jr.PRNGKey(0), method="prior", emission_probs=None): # Initialize the emission probabilities if emission_probs is None: if method.lower() == "prior": + if key is None: + raise ValueError("key must not be None when emission_probs is None") prior = tfd.Dirichlet(self.emission_prior_concentration) emission_probs = prior.sample(seed=key, sample_shape=(self.num_states, self.emission_dim)) elif method.lower() == "kmeans": @@ -134,8 +140,8 @@ class CategoricalHMM(HMM): def __init__(self, num_states: int, emission_dim: int, num_classes: int, - initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, - transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, + initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, + transition_matrix_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, transition_matrix_stickiness: Scalar=0.0, emission_prior_concentration=1.1): self.emission_dim = emission_dim @@ -145,9 +151,9 @@ def __init__(self, num_states: int, super().__init__(num_states, initial_component, transition_component, emission_component) def initialize(self, - key: jr.PRNGKey=jr.PRNGKey(0), + key: Array=jr.PRNGKey(0), method: str="prior", - initial_probs: Optional[Float[Array, "num_states"]]=None, + initial_probs: Optional[Float[Array, " num_states"]]=None, transition_matrix: Optional[Float[Array, "num_states num_states"]]=None, emission_probs: Optional[Float[Array, "num_states emission_dim num_classes"]]=None ) -> Tuple[ParameterSet, PropertySet]: diff --git a/dynamax/hidden_markov_model/models/gamma_hmm.py b/dynamax/hidden_markov_model/models/gamma_hmm.py index 2efcdd86..45825f34 100644 --- a/dynamax/hidden_markov_model/models/gamma_hmm.py +++ b/dynamax/hidden_markov_model/models/gamma_hmm.py @@ -13,32 +13,38 @@ class ParamsGammaHMMEmissions(NamedTuple): - concentration: Union[Float[Array, "state_dim"], ParameterProperties] - rate: Union[Float[Array, "state_dim"], ParameterProperties] + concentration: Union[Float[Array, " state_dim"], ParameterProperties] + rate: Union[Float[Array, " state_dim"], ParameterProperties] class GammaHMMEmissions(HMMEmissions): - def __init__(self, - num_states, - m_step_optimizer=optax.adam(1e-2), - m_step_num_iters=50): + def __init__( + self, + num_states: int, + m_step_optimizer: optax.GradientTransformation = optax.adam(1e-2), + m_step_num_iters: int = 50, + ): super().__init__(m_step_optimizer=m_step_optimizer, m_step_num_iters=m_step_num_iters) self.num_states = num_states @property - def emission_shape(self): + def emission_shape(self) -> Tuple: return () - def initialize(self, - key=jr.PRNGKey(0), - method="prior", - emission_concentrations=None, - emission_rates=None, - emissions=None): + def initialize( + self, + key: Array = jr.PRNGKey(0), + method="prior", + emission_concentrations: Optional[Float[Array, " num_states"]] = None, + emission_rates: Optional[Float[Array, " num_states"]] = None, + emissions: Optional[Float[Array, " num_timesteps"]] = None, + # ) -> Tuple[ParamsGammaHMMEmissions, ParamsGammaHMMEmissions]: + ) -> Tuple[ParamsGammaHMMEmissions, ParamsGammaHMMEmissions]: if method.lower() == "kmeans": assert emissions is not None, "Need emissions to initialize the model with K-Means!" from sklearn.cluster import KMeans + key, subkey = jr.split(key) # Create a random seed for SKLearn. sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value. km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, 1)) @@ -57,18 +63,19 @@ def initialize(self, default = lambda x, x0: x if x is not None else x0 params = ParamsGammaHMMEmissions( concentration=default(emission_concentrations, _emission_concentrations), - rate=default(emission_rates, _emission_rates)) + rate=default(emission_rates, _emission_rates), + ) props = ParamsGammaHMMEmissions( concentration=ParameterProperties(constrainer=tfb.Softplus()), - rate=ParameterProperties(constrainer=tfb.Softplus())) + rate=ParameterProperties(constrainer=tfb.Softplus()), + ) return params, props - def log_prior(self, params): + def log_prior(self, params) -> float: return 0.0 - def distribution(self, params, state, inputs=None): - return tfd.Gamma(concentration=params.concentration[state], - rate=params.rate[state]) + def distribution(self, params: ParamsGammaHMMEmissions, state, inputs=None) -> tfd.Distribution: + return tfd.Gamma(concentration=params.concentration[state], rate=params.rate[state]) class ParamsGammaHMM(NamedTuple): @@ -96,27 +103,35 @@ class GammaHMM(HMM): :param m_step_num_iters: number of optimizer steps per M-step. """ - def __init__(self, - num_states: int, - initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, - transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, - transition_matrix_stickiness: Scalar=0.0, - m_step_optimizer: optax.GradientTransformation=optax.adam(1e-2), - m_step_num_iters: int=50): + + def __init__( + self, + num_states: int, + initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]] = 1.1, + transition_matrix_concentration: Union[Scalar, Float[Array, " num_states"]] = 1.1, + transition_matrix_stickiness: Scalar = 0.0, + m_step_optimizer: optax.GradientTransformation = optax.adam(1e-2), + m_step_num_iters: int = 50, + ): initial_component = StandardHMMInitialState(num_states, initial_probs_concentration=initial_probs_concentration) - transition_component = StandardHMMTransitions(num_states, concentration=transition_matrix_concentration, stickiness=transition_matrix_stickiness) - emission_component = GammaHMMEmissions(num_states, m_step_optimizer=m_step_optimizer, m_step_num_iters=m_step_num_iters) + transition_component = StandardHMMTransitions( + num_states, concentration=transition_matrix_concentration, stickiness=transition_matrix_stickiness + ) + emission_component = GammaHMMEmissions( + num_states, m_step_optimizer=m_step_optimizer, m_step_num_iters=m_step_num_iters + ) super().__init__(num_states, initial_component, transition_component, emission_component) - def initialize(self, - key: jr.PRNGKey=jr.PRNGKey(0), - method: str="prior", - initial_probs: Optional[Float[Array, "num_states"]]=None, - transition_matrix: Optional[Float[Array, "num_states num_states"]]=None, - emission_concentrations: Optional[Float[Array, "num_states"]]=None, - emission_rates: Optional[Float[Array, "num_states"]]=None, - emissions: Optional[Float[Array, "num_timesteps"]]=None, - ) -> Tuple[HMMParameterSet, HMMPropertySet]: + def initialize( + self, + key: Array = jr.PRNGKey(0), + method: str = "prior", + initial_probs: Optional[Float[Array, " num_states"]] = None, + transition_matrix: Optional[Float[Array, "num_states num_states"]] = None, + emission_concentrations: Optional[Float[Array, " num_states"]] = None, + emission_rates: Optional[Float[Array, " num_states"]] = None, + emissions: Optional[Float[Array, " num_timesteps"]] = None, + ) -> Tuple[HMMParameterSet, HMMPropertySet]: """Initialize the model parameters and their corresponding properties. You can either specify parameters manually via the keyword arguments, or you can have @@ -136,9 +151,19 @@ def initialize(self, Model parameters and their properties. """ - key1, key2, key3 = jr.split(key , 3) + key1, key2, key3 = jr.split(key, 3) params, props = dict(), dict() - params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs) - params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix) - params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_concentrations=emission_concentrations, emission_rates=emission_rates, emissions=emissions) + params["initial"], props["initial"] = self.initial_component.initialize( + key1, method=method, initial_probs=initial_probs + ) + params["transitions"], props["transitions"] = self.transition_component.initialize( + key2, method=method, transition_matrix=transition_matrix + ) + params["emissions"], props["emissions"] = self.emission_component.initialize( + key3, + method=method, + emission_concentrations=emission_concentrations, + emission_rates=emission_rates, + emissions=emissions, + ) return ParamsGammaHMM(**params), ParamsGammaHMM(**props) diff --git a/dynamax/hidden_markov_model/models/gaussian_hmm.py b/dynamax/hidden_markov_model/models/gaussian_hmm.py index c1904878..61205090 100644 --- a/dynamax/hidden_markov_model/models/gaussian_hmm.py +++ b/dynamax/hidden_markov_model/models/gaussian_hmm.py @@ -1,15 +1,18 @@ +from typing import Any, Dict, NamedTuple, Optional, Tuple, Union +from jax import vmap import jax.numpy as jnp import jax.random as jr import tensorflow_probability.substrates.jax.bijectors as tfb import tensorflow_probability.substrates.jax.distributions as tfd -from jax import vmap from jaxtyping import Float, Array import optax + from dynamax.parameters import ParameterProperties +from dynamax.hidden_markov_model.inference import HMMPosterior from dynamax.hidden_markov_model.models.abstractions import HMM, HMMEmissions, HMMParameterSet, HMMPropertySet from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions -from dynamax.types import Scalar +from dynamax.types import IntScalar, Scalar from dynamax.utils.distributions import InverseWishart from dynamax.utils.distributions import NormalInverseGamma from dynamax.utils.distributions import NormalInverseWishart @@ -17,7 +20,6 @@ from dynamax.utils.distributions import niw_posterior_update from dynamax.utils.bijectors import RealToPSDBijector from dynamax.utils.utils import pytree_sum -from typing import NamedTuple, Optional, Tuple, Union class ParamsGaussianHMMEmissions(NamedTuple): @@ -28,19 +30,21 @@ class ParamsGaussianHMMEmissions(NamedTuple): class GaussianHMMEmissions(HMMEmissions): def __init__(self, - num_states, - emission_dim, - emission_prior_mean=0.0, - emission_prior_concentration=1e-4, - emission_prior_scale=1e-4, - emission_prior_extra_df=0.1): - """_summary_ + num_states: int, + emission_dim: int, + emission_prior_mean: Union[Scalar, Float[Array, " emission_dim"]] = 0.0, + emission_prior_concentration: Scalar = 1e-4, + emission_prior_scale: Union[Scalar, Float[Array, "emission_dim emission_dim"]] = 1e-4, + emission_prior_extra_df: Scalar = 0.1): + """Initialize GaussianHMMEmissions. Args: - initial_probabilities (_type_): _description_ - transition_matrix (_type_): _description_ - emission_means (_type_): _description_ - emission_covariance_matrices (_type_): _description_ + num_states: number of discrete states + emission_dim: dimension of the emission vector + emission_prior_mean: prior mean for emissions + emission_prior_concentration: concentration parameter for the prior + emission_prior_scale: scale matrix for the prior + emission_prior_extra_df: extra degrees of freedom for the prior """ self.num_states = num_states self.emission_dim = emission_dim @@ -51,23 +55,42 @@ def __init__(self, self.emission_prior_df = emission_dim + emission_prior_extra_df @property - def emission_shape(self): + def emission_shape(self) -> Tuple[int]: return (self.emission_dim,) - def distribution(self, params, state, inputs=None): + def distribution( + self, + params: ParamsGaussianHMMEmissions, + state: IntScalar, + inputs: Optional[Array] = None + ) -> tfd.Distribution: return tfd.MultivariateNormalFullCovariance( params.means[state], params.covs[state]) - def log_prior(self, params): + def log_prior(self, params: ParamsGaussianHMMEmissions) -> Float[Array, ""]: return NormalInverseWishart(self.emission_prior_mean, self.emission_prior_conc, self.emission_prior_df, self.emission_prior_scale).log_prob( (params.covs, params.means)).sum() - def initialize(self, key=jr.PRNGKey(0), - method="prior", - emission_means=None, - emission_covariances=None, - emissions=None): + def initialize(self, + key: Array = jr.PRNGKey(0), + method: str = "prior", + emission_means: Optional[Float[Array, "num_states emission_dim"]] = None, + emission_covariances: Optional[Float[Array, "num_states emission_dim emission_dim"]] = None, + emissions: Optional[Float[Array, "num_timesteps emission_dim"]] = None + ) -> Tuple[ParamsGaussianHMMEmissions, ParamsGaussianHMMEmissions]: + """Initialize the model parameters and their corresponding properties. + + Args: + key: random number generator for unspecified parameters. Must not be None if there are any unspecified parameters. + method: method for initializing unspecified parameters. Both "prior" and "kmeans" are supported. + emission_means: manually specified emission means. + emission_covariances: manually specified emission covariances. + emissions: emissions for initializing the parameters with kmeans. + + Returns: + Tuple of (params, props) where params are the initialized parameters and props are their properties. + """ if method.lower() == "kmeans": assert emissions is not None, "Need emissions to initialize the model with K-Means!" from sklearn.cluster import KMeans @@ -97,7 +120,13 @@ def initialize(self, key=jr.PRNGKey(0), covs=ParameterProperties(constrainer=RealToPSDBijector())) return params, props - def collect_suff_stats(self, params, posterior, emissions, inputs=None): + def collect_suff_stats( + self, + params: ParamsGaussianHMMEmissions, + posterior: HMMPosterior, + emissions: Float[Array, "num_timesteps emission_dim"], + inputs: Optional[Array] = None + ) -> Dict[str, Float[Array, "..."]]: expected_states = posterior.smoothed_probs return dict( sum_w=jnp.einsum("tk->k", expected_states), @@ -105,10 +134,16 @@ def collect_suff_stats(self, params, posterior, emissions, inputs=None): sum_xxT=jnp.einsum("tk,ti,tj->kij", expected_states, emissions, emissions) ) - def initialize_m_step_state(self, params, props): + def initialize_m_step_state(self, params: ParamsGaussianHMMEmissions, props: ParamsGaussianHMMEmissions) -> None: return None - def m_step(self, params, props, batch_stats, m_step_state): + def m_step( + self, + params: ParamsGaussianHMMEmissions, + props: ParamsGaussianHMMEmissions, + batch_stats: Dict[str, Float[Array, "..."]], + m_step_state: Any + ) -> Tuple[ParamsGaussianHMMEmissions, Any]: if props.covs.trainable and props.means.trainable: niw_prior = NormalInverseWishart(loc=self.emission_prior_mean, mean_concentration=self.emission_prior_conc, @@ -141,17 +176,18 @@ class ParamsDiagonalGaussianHMMEmissions(NamedTuple): class DiagonalGaussianHMMEmissions(HMMEmissions): def __init__(self, - num_states, - emission_dim, - emission_prior_mean=0.0, - emission_prior_mean_concentration=1e-4, - emission_prior_concentration=0.1, - emission_prior_scale=0.1): + num_states: int, + emission_dim: int, + emission_prior_mean: Union[Scalar, Float[Array, " emission_dim"]] = 0.0, + emission_prior_mean_concentration: Scalar = 1e-4, + emission_prior_concentration: Union[Scalar, Float[Array, " emission_dim"]] = 0.1, + emission_prior_scale: Scalar = 0.1): self.num_states = num_states self.emission_dim = emission_dim self.emission_prior_mean = emission_prior_mean * jnp.ones(emission_dim) self.emission_prior_mean_conc = emission_prior_mean_concentration + # TODO: Problem here if prior conc is Array shape ()? self.emission_prior_conc = emission_prior_concentration * jnp.ones(emission_dim) \ if isinstance(emission_prior_concentration, float) else emission_prior_concentration self.emission_prior_scale = emission_prior_scale @@ -160,11 +196,13 @@ def __init__(self, def emission_shape(self): return (self.emission_dim,) - def initialize(self, key=jr.PRNGKey(0), - method="prior", - emission_means=None, - emission_scale_diags=None, - emissions=None): + def initialize(self, + key: Array = jr.PRNGKey(0), + method: str = "prior", + emission_means: Optional[Float[Array, "num_states emission_dim"]] = None, + emission_scale_diags: Optional[Float[Array, "num_states emission_dim"]] = None, + emissions: Optional[Float[Array, "num_timesteps emission_dim"]] = None + ) -> Tuple[ParamsDiagonalGaussianHMMEmissions, ParamsDiagonalGaussianHMMEmissions]: if method.lower() == "kmeans": assert emissions is not None, "Need emissions to initialize the model with K-Means!" @@ -235,20 +273,20 @@ def _single_m_step(stats): class ParamsSphericalGaussianHMMEmissions(NamedTuple): means: Union[Float[Array, "state_dim emission_dim"], ParameterProperties] - scales: Union[Float[Array, "state_dim"], ParameterProperties] + scales: Union[Float[Array, " state_dim"], ParameterProperties] class SphericalGaussianHMMEmissions(HMMEmissions): def __init__(self, - num_states, - emission_dim, - emission_prior_mean=0.0, - emission_prior_mean_covariance=1.0, - emission_var_concentration=1.1, - emission_var_rate=1.1, - m_step_optimizer=optax.adam(1e-2), - m_step_num_iters=50): + num_states: int, + emission_dim: int, + emission_prior_mean: Union[Scalar, Float[Array, " emission_dim"]] = 0.0, + emission_prior_mean_covariance: Union[Scalar, Float[Array, "emission_dim emission_dim"]] = 1.0, + emission_var_concentration: Scalar = 1.1, + emission_var_rate: Scalar = 1.1, + m_step_optimizer: optax.GradientTransformation = optax.adam(1e-2), + m_step_num_iters: int = 50): super().__init__(m_step_optimizer=m_step_optimizer, m_step_num_iters=m_step_num_iters) self.num_states = num_states self.emission_dim = emission_dim @@ -260,14 +298,16 @@ def __init__(self, self.emission_var_rate = emission_var_rate @property - def emission_shape(self): + def emission_shape(self) -> Tuple[int]: return (self.emission_dim,) - def initialize(self, key=jr.PRNGKey(0), - method="prior", - emission_means=None, - emission_scales=None, - emissions=None): + def initialize(self, + key: Array = jr.PRNGKey(0), + method: str = "prior", + emission_means: Optional[Float[Array, "num_states emission_dim"]] = None, + emission_scales: Optional[Float[Array, " num_states"]] = None, + emissions: Optional[Float[Array, "num_timesteps emission_dim"]] = None + ) -> Tuple[ParamsSphericalGaussianHMMEmissions, ParamsSphericalGaussianHMMEmissions]: """Initialize the model parameters and their corresponding properties. You can either specify parameters manually via the keyword arguments, or you can have @@ -319,12 +359,16 @@ def initialize(self, key=jr.PRNGKey(0), scales=ParameterProperties(constrainer=tfb.Softplus())) return params, props - def distribution(self, params, state, inputs=None): - dim = self.emission_dim + def distribution( + self, + params: ParamsSphericalGaussianHMMEmissions, + state: IntScalar, + inputs: Optional[Array] = None + ) -> tfd.Distribution: return tfd.MultivariateNormalDiag(params.means[state], - params.scales[state] * jnp.ones((dim,))) + params.scales[state] * jnp.ones((self.emission_shape))) - def log_prior(self, params): + def log_prior(self, params: ParamsSphericalGaussianHMMEmissions) -> Float[Array, ""]: lp = tfd.MultivariateNormalFullCovariance( self.emission_prior_mean, self.emission_prior_mean_cov)\ .log_prob(params.means).sum() @@ -341,17 +385,21 @@ class ParamsSharedCovarianceGaussianHMMEmissions(NamedTuple): class SharedCovarianceGaussianHMMEmissions(HMMEmissions): def __init__(self, - num_states, - emission_dim, - emission_prior_mean=0.0, - emission_prior_concentration=1e-4, - emission_prior_scale=1e-4, - emission_prior_extra_df=0.1): - """_summary_ + num_states: int, + emission_dim: int, + emission_prior_mean: Union[Scalar, Float[Array, " emission_dim"]] = 0.0, + emission_prior_concentration: Scalar = 1e-4, + emission_prior_scale: Union[Scalar, Float[Array, "emission_dim emission_dim"]] = 1e-4, + emission_prior_extra_df: Scalar = 0.1): + """Initialize SharedCovarianceGaussianHMMEmissions. Args: - emission_means (_type_): _description_ - emission_covariance_matrix (_type_): _description_ + num_states: number of discrete states + emission_dim: dimension of the emission vector + emission_prior_mean: prior mean for emissions + emission_prior_concentration: concentration parameter for the prior + emission_prior_scale: scale matrix for the prior + emission_prior_extra_df: extra degrees of freedom for the prior """ self.num_states = num_states self.emission_dim = emission_dim @@ -362,14 +410,16 @@ def __init__(self, self.emission_prior_df = emission_dim + emission_prior_extra_df @property - def emission_shape(self): + def emission_shape(self) -> Tuple[int]: return (self.emission_dim,) - def initialize(self, key=jr.PRNGKey(0), - method="prior", - emission_means=None, - emission_covariance=None, - emissions=None): + def initialize(self, + key: Array = jr.PRNGKey(0), + method: str = "prior", + emission_means: Optional[Float[Array, "num_states emission_dim"]] = None, + emission_covariance: Optional[Float[Array, "emission_dim emission_dim"]] = None, + emissions: Optional[Float[Array, "num_timesteps emission_dim"]] = None + ) -> Tuple[ParamsSharedCovarianceGaussianHMMEmissions, ParamsSharedCovarianceGaussianHMMEmissions]: """Initialize the model parameters and their corresponding properties. You can either specify parameters manually via the keyword arguments, or you can have @@ -420,11 +470,16 @@ def initialize(self, key=jr.PRNGKey(0), cov=ParameterProperties(constrainer=RealToPSDBijector())) return params, props - def distribution(self, params, state, inputs=None): + def distribution( + self, + params: ParamsSharedCovarianceGaussianHMMEmissions, + state: IntScalar, + inputs: Optional[Array] = None + ) -> tfd.Distribution: return tfd.MultivariateNormalFullCovariance( params.means[state], params.cov) - def log_prior(self, params): + def log_prior(self, params: ParamsSharedCovarianceGaussianHMMEmissions) -> Float[Array, ""]: mus = params.means Sigma = params.cov mu0 = self.emission_prior_mean @@ -436,7 +491,13 @@ def log_prior(self, params): lp += tfd.MultivariateNormalFullCovariance(mu0, Sigma / kappa0).log_prob(mus).sum() return lp - def collect_suff_stats(self, params, posterior, emissions, inputs=None): + def collect_suff_stats( + self, + params: ParamsSharedCovarianceGaussianHMMEmissions, + posterior: HMMPosterior, + emissions: Float[Array, "num_timesteps emission_dim"], + inputs: Optional[Array] = None + ) -> Dict[str, Float[Array, "..."] | int]: expected_states = posterior.smoothed_probs sum_w = jnp.einsum("tk->k", expected_states) sum_x = jnp.einsum("tk,ti->ki", expected_states, emissions) @@ -445,10 +506,20 @@ def collect_suff_stats(self, params, posterior, emissions, inputs=None): stats = dict(sum_w=sum_w, sum_x=sum_x, sum_xxT=sum_xxT, sum_T=sum_T) return stats - def initialize_m_step_state(self, params, props): + def initialize_m_step_state( + self, + params: ParamsSharedCovarianceGaussianHMMEmissions, + props: ParamsSharedCovarianceGaussianHMMEmissions + ) -> None: return None - def m_step(self, params, props, batch_stats, m_step_state): + def m_step( + self, + params: ParamsSharedCovarianceGaussianHMMEmissions, + props: ParamsSharedCovarianceGaussianHMMEmissions, + batch_stats: Dict[str, Array], + m_step_state: Any + ) -> Tuple[ParamsSharedCovarianceGaussianHMMEmissions, Any]: mu0 = self.emission_prior_mean kappa0 = self.emission_prior_conc Psi0 = self.emission_prior_scale @@ -473,11 +544,14 @@ class ParamsLowRankGaussianHMMEmissions(NamedTuple): class LowRankGaussianHMMEmissions(HMMEmissions): - def __init__(self, num_states, emission_dim, emission_rank, - emission_diag_factor_concentration=1.1, - emission_diag_factor_rate=1.1, - m_step_optimizer=optax.adam(1e-2), - m_step_num_iters=50): + def __init__(self, + num_states: int, + emission_dim: int, + emission_rank: int, + emission_diag_factor_concentration: Scalar = 1.1, + emission_diag_factor_rate: Scalar = 1.1, + m_step_optimizer: optax.GradientTransformation = optax.adam(1e-2), + m_step_num_iters: int = 50): super().__init__(m_step_optimizer=m_step_optimizer, m_step_num_iters=m_step_num_iters) self.num_states = num_states self.emission_dim = emission_dim @@ -485,12 +559,14 @@ def __init__(self, num_states, emission_dim, emission_rank, self.emission_diag_factor_conc = emission_diag_factor_concentration self.emission_diag_factor_rate = emission_diag_factor_rate - def initialize(self, key=jr.PRNGKey(0), - method="prior", - emission_means=None, - emission_cov_diag_factors=None, - emission_cov_low_rank_factors=None, - emissions=None): + def initialize(self, + key: Array = jr.PRNGKey(0), + method: str = "prior", + emission_means: Optional[Float[Array, "num_states emission_dim"]] = None, + emission_cov_diag_factors: Optional[Float[Array, "num_states emission_dim"]] = None, + emission_cov_low_rank_factors: Optional[Float[Array, "num_states emission_dim emission_rank"]] = None, + emissions: Optional[Float[Array, "num_timesteps emission_dim"]] = None + ) -> Tuple[ParamsLowRankGaussianHMMEmissions, ParamsLowRankGaussianHMMEmissions]: """Initialize the model parameters and their corresponding properties. You can either specify parameters manually via the keyword arguments, or you can have @@ -546,17 +622,22 @@ def initialize(self, key=jr.PRNGKey(0), return params, props @property - def emission_shape(self): + def emission_shape(self) -> Tuple[int]: return (self.emission_dim,) - def distribution(self, params, state, inputs=None): + def distribution( + self, + params: ParamsLowRankGaussianHMMEmissions, + state: IntScalar, + inputs: Optional[Array] = None + ) -> tfd.Distribution: return tfd.MultivariateNormalDiagPlusLowRankCovariance( params.means[state], params.cov_diag_factors[state], params.cov_low_rank_factors[state] ) - def log_prior(self, params): + def log_prior(self, params: ParamsLowRankGaussianHMMEmissions) -> Float[Array, ""]: lp = tfd.Gamma(self.emission_diag_factor_conc, self.emission_diag_factor_rate)\ .log_prob(params.cov_diag_factors).sum() return lp @@ -598,10 +679,10 @@ class GaussianHMM(HMM): """ def __init__(self, num_states: int, emission_dim: int, - initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, - transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, + initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, + transition_matrix_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, transition_matrix_stickiness: Scalar=0.0, - emission_prior_mean: Union[Scalar, Float[Array, "emission_dim"]]=0.0, + emission_prior_mean: Union[Scalar, Float[Array, " emission_dim"]]=0.0, emission_prior_concentration: Scalar=1e-4, emission_prior_scale: Union[Scalar, Float[Array, "emission_dim emission_dim"]]=1e-4, emission_prior_extra_df: Scalar=0.1): @@ -617,9 +698,9 @@ def __init__(self, num_states: int, super().__init__(num_states, initial_component, transition_component, emission_component) def initialize(self, - key: jr.PRNGKey=jr.PRNGKey(0), + key: Array=jr.PRNGKey(0), method: str="prior", - initial_probs: Optional[Float[Array, "num_states"]]=None, + initial_probs: Optional[Float[Array, " num_states"]]=None, transition_matrix: Optional[Float[Array, "num_states num_states"]]=None, emission_means: Optional[Float[Array, "num_states emission_dim"]]=None, emission_covariances: Optional[Float[Array, "num_states emission_dim emission_dim"]]=None, @@ -646,9 +727,15 @@ def initialize(self, """ key1, key2, key3 = jr.split(key , 3) params, props = dict(), dict() - params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs) - params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix) - params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_means=emission_means, emission_covariances=emission_covariances, emissions=emissions) + params["initial"], props["initial"] = self.initial_component.initialize( + key1, method=method, initial_probs=initial_probs + ) + params["transitions"], props["transitions"] = self.transition_component.initialize( + key2, method=method, transition_matrix=transition_matrix + ) + params["emissions"], props["emissions"] = self.emission_component.initialize( + key3, method=method, emission_means=emission_means, emission_covariances=emission_covariances, emissions=emissions + ) return ParamsGaussianHMM(**params), ParamsGaussianHMM(**props) @@ -691,17 +778,21 @@ class DiagonalGaussianHMM(HMM): """ def __init__(self, num_states: int, emission_dim: int, - initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, - transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, + initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, + transition_matrix_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, transition_matrix_stickiness: Scalar=0.0, - emission_prior_mean: Union[Scalar, Float[Array, "emission_dim"]]=0.0, - emission_prior_mean_concentration: Union[Scalar, Float[Array, "emission_dim"]]=1e-4, + emission_prior_mean: Union[Scalar, Float[Array, " emission_dim"]]=0.0, + emission_prior_mean_concentration: Union[Scalar, Float[Array, " emission_dim"]]=1e-4, emission_prior_concentration: Scalar=0.1, emission_prior_scale: Scalar=0.1): self.emission_dim = emission_dim - initial_component = StandardHMMInitialState(num_states, initial_probs_concentration=initial_probs_concentration) - transition_component = StandardHMMTransitions(num_states, concentration=transition_matrix_concentration, stickiness=transition_matrix_stickiness) + initial_component = StandardHMMInitialState( + num_states, initial_probs_concentration=initial_probs_concentration + ) + transition_component = StandardHMMTransitions( + num_states, concentration=transition_matrix_concentration, stickiness=transition_matrix_stickiness + ) emission_component = DiagonalGaussianHMMEmissions( num_states, emission_dim, emission_prior_mean=emission_prior_mean, @@ -711,15 +802,15 @@ def __init__(self, num_states: int, super().__init__(num_states, initial_component, transition_component, emission_component) - def initialize(self, key: jr.PRNGKey=jr.PRNGKey(0), + def initialize(self, key: Array=jr.PRNGKey(0), method: str="prior", - initial_probs: Optional[Float[Array, "num_states"]]=None, + initial_probs: Optional[Float[Array, " num_states"]]=None, transition_matrix: Optional[Float[Array, "num_states num_states"]]=None, emission_means: Optional[Float[Array, "num_states emission_dim"]]=None, emission_scale_diags: Optional[Float[Array, "num_states emission_dim"]]=None, emissions: Optional[Float[Array, "num_timesteps emission_dim"]]=None ) -> Tuple[HMMParameterSet, HMMPropertySet]: - """Initialize the model parameters and their corresponding properties. + r"""Initialize the model parameters and their corresponding properties. You can either specify parameters manually via the keyword arguments, or you can have them set automatically. If any parameters are not specified, you must supply a PRNGKey. @@ -739,9 +830,15 @@ def initialize(self, key: jr.PRNGKey=jr.PRNGKey(0), """ key1, key2, key3 = jr.split(key , 3) params, props = dict(), dict() - params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs) - params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix) - params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_means=emission_means, emission_scale_diags=emission_scale_diags, emissions=emissions) + params["initial"], props["initial"] = self.initial_component.initialize( + key1, method=method, initial_probs=initial_probs + ) + params["transitions"], props["transitions"] = self.transition_component.initialize( + key2, method=method, transition_matrix=transition_matrix + ) + params["emissions"], props["emissions"] = self.emission_component.initialize( + key3, method=method, emission_means=emission_means, emission_scale_diags=emission_scale_diags, emissions=emissions + ) return ParamsDiagonalGaussianHMM(**params), ParamsDiagonalGaussianHMM(**props) @@ -786,10 +883,10 @@ class SphericalGaussianHMM(HMM): """ def __init__(self, num_states: int, emission_dim: int, - initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, - transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, + initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, + transition_matrix_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, transition_matrix_stickiness: Scalar=0.0, - emission_prior_mean: Union[Scalar, Float[Array, "emission_dim"]]=0.0, + emission_prior_mean: Union[Scalar, Float[Array, " emission_dim"]]=0.0, emission_prior_mean_covariance: Union[Scalar, Float[Array, "emission_dim emission_dim"]]=1.0, emission_var_concentration: Scalar=1.1, emission_var_rate: Scalar=1.1, @@ -809,12 +906,12 @@ def __init__(self, num_states: int, super().__init__(num_states, initial_component, transition_component, emission_component) - def initialize(self, key: jr.PRNGKey=jr.PRNGKey(0), + def initialize(self, key: Array=jr.PRNGKey(0), method: str="prior", - initial_probs: Optional[Float[Array, "num_states"]]=None, + initial_probs: Optional[Float[Array, " num_states"]]=None, transition_matrix: Optional[Float[Array, "num_states num_states"]]=None, emission_means: Optional[Float[Array, "num_states emission_dim"]]=None, - emission_scales: Optional[Float[Array, "num_states"]]=None, + emission_scales: Optional[Float[Array, " num_states"]]=None, emissions: Optional[Float[Array, "num_timesteps emission_dim"]]=None ) -> Tuple[HMMParameterSet, HMMPropertySet]: """Initialize the model parameters and their corresponding properties. @@ -838,9 +935,15 @@ def initialize(self, key: jr.PRNGKey=jr.PRNGKey(0), """ key1, key2, key3 = jr.split(key , 3) params, props = dict(), dict() - params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs) - params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix) - params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_means=emission_means, emission_scales=emission_scales, emissions=emissions) + params["initial"], props["initial"] = self.initial_component.initialize( + key1, method=method, initial_probs=initial_probs + ) + params["transitions"], props["transitions"] = self.transition_component.initialize( + key2, method=method, transition_matrix=transition_matrix + ) + params["emissions"], props["emissions"] = self.emission_component.initialize( + key3, method=method, emission_means=emission_means, emission_scales=emission_scales, emissions=emissions + ) return ParamsSphericalGaussianHMM(**params), ParamsSphericalGaussianHMM(**props) @@ -880,10 +983,10 @@ class SharedCovarianceGaussianHMM(HMM): """ def __init__(self, num_states: int, emission_dim: int, - initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, - transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, + initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, + transition_matrix_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, transition_matrix_stickiness: Scalar=0.0, - emission_prior_mean: Union[Scalar, Float[Array, "emission_dim"]]=0.0, + emission_prior_mean: Union[Scalar, Float[Array, " emission_dim"]]=0.0, emission_prior_concentration: Scalar=1e-4, emission_prior_scale: Scalar=1e-4, emission_prior_extra_df: Scalar=0.1): @@ -899,9 +1002,9 @@ def __init__(self, num_states: int, emission_prior_extra_df=emission_prior_extra_df) super().__init__(num_states, initial_component, transition_component, emission_component) - def initialize(self, key: jr.PRNGKey=jr.PRNGKey(0), + def initialize(self, key: Array=jr.PRNGKey(0), method: str="prior", - initial_probs: Optional[Float[Array, "num_states"]]=None, + initial_probs: Optional[Float[Array, " num_states"]]=None, transition_matrix: Optional[Float[Array, "num_states num_states"]]=None, emission_means: Optional[Float[Array, "num_states emission_dim"]]=None, emission_covariance: Optional[Float[Array, "emission_dim emission_dim"]]=None, @@ -927,9 +1030,15 @@ def initialize(self, key: jr.PRNGKey=jr.PRNGKey(0), """ key1, key2, key3 = jr.split(key , 3) params, props = dict(), dict() - params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs) - params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix) - params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_means=emission_means, emission_covariance=emission_covariance, emissions=emissions) + params["initial"], props["initial"] = self.initial_component.initialize( + key1, method=method, initial_probs=initial_probs + ) + params["transitions"], props["transitions"] = self.transition_component.initialize( + key2, method=method, transition_matrix=transition_matrix + ) + params["emissions"], props["emissions"] = self.emission_component.initialize( + key3, method=method, emission_means=emission_means, emission_covariance=emission_covariance, emissions=emissions + ) return ParamsSharedCovarianceGaussianHMM(**params), ParamsSharedCovarianceGaussianHMM(**props) @@ -976,8 +1085,8 @@ class LowRankGaussianHMM(HMM): def __init__(self, num_states: int, emission_dim: int, emission_rank: int, - initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, - transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, + initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, + transition_matrix_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, transition_matrix_stickiness: Scalar=0.0, emission_diag_factor_concentration: Scalar=1.1, emission_diag_factor_rate: Scalar=1.1, @@ -995,9 +1104,9 @@ def __init__(self, num_states: int, m_step_num_iters=m_step_num_iters) super().__init__(num_states, initial_component, transition_component, emission_component) - def initialize(self, key: jr.PRNGKey=jr.PRNGKey(0), + def initialize(self, key: Array=jr.PRNGKey(0), method: str="prior", - initial_probs: Optional[Float[Array, "num_states"]]=None, + initial_probs: Optional[Float[Array, " num_states"]]=None, transition_matrix: Optional[Float[Array, "num_states num_states"]]=None, emission_means: Optional[Float[Array, "num_states emission_dim"]]=None, emission_cov_diag_factors: Optional[Float[Array, "num_states emission_dim"]]=None, @@ -1025,7 +1134,13 @@ def initialize(self, key: jr.PRNGKey=jr.PRNGKey(0), """ key1, key2, key3 = jr.split(key , 3) params, props = dict(), dict() - params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs) - params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix) - params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_means=emission_means, emission_cov_diag_factors=emission_cov_diag_factors, emission_cov_low_rank_factors=emission_cov_low_rank_factors, emissions=emissions) + params["initial"], props["initial"] = self.initial_component.initialize( + key1, method=method, initial_probs=initial_probs + ) + params["transitions"], props["transitions"] = self.transition_component.initialize( + key2, method=method, transition_matrix=transition_matrix + ) + params["emissions"], props["emissions"] = self.emission_component.initialize( + key3, method=method, emission_means=emission_means, emission_cov_diag_factors=emission_cov_diag_factors, emission_cov_low_rank_factors=emission_cov_low_rank_factors, emissions=emissions + ) return ParamsLowRankGaussianHMM(**params), ParamsLowRankGaussianHMM(**props) diff --git a/dynamax/hidden_markov_model/models/gmm_hmm.py b/dynamax/hidden_markov_model/models/gmm_hmm.py index 8b7e778c..3d41ace8 100644 --- a/dynamax/hidden_markov_model/models/gmm_hmm.py +++ b/dynamax/hidden_markov_model/models/gmm_hmm.py @@ -1,3 +1,4 @@ +from typing import Any, Dict, NamedTuple, Optional, Tuple, Union import jax.numpy as jnp import jax.random as jr import tensorflow_probability.substrates.jax.bijectors as tfb @@ -10,16 +11,15 @@ from dynamax.utils.distributions import NormalInverseWishart from dynamax.utils.distributions import nig_posterior_update from dynamax.utils.distributions import niw_posterior_update +from dynamax.hidden_markov_model.inference import HMMPosterior from dynamax.hidden_markov_model.models.abstractions import HMM, HMMEmissions, HMMParameterSet, HMMPropertySet from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions from dynamax.utils.bijectors import RealToPSDBijector from dynamax.utils.utils import pytree_sum -from dynamax.types import Scalar -from typing import NamedTuple, Optional, Tuple, Union +from dynamax.types import IntScalar, Scalar -# Types class ParamsGaussianMixtureHMMEmissions(NamedTuple): weights: Union[Float[Array, "state_dim num_components"], ParameterProperties] means: Union[Float[Array, "state_dim num_components emission_dim"], ParameterProperties] @@ -48,14 +48,14 @@ class ParamsDiagonalGaussianMixtureHMM(NamedTuple): class GaussianMixtureHMMEmissions(HMMEmissions): def __init__(self, - num_states, - num_components, - emission_dim, - emission_weights_concentration=1.1, - emission_prior_mean=0., - emission_prior_mean_concentration=1e-4, - emission_prior_extra_df=1e-4, - emission_prior_scale=0.1): + num_states: int, + num_components: int, + emission_dim: int, + emission_weights_concentration: Union[Scalar, Float[Array, " num_components"]]=1.1, + emission_prior_mean: Union[Scalar, Float[Array, " emission_dim"]]=0., + emission_prior_mean_concentration: Scalar=1e-4, + emission_prior_extra_df: Scalar=1e-4, + emission_prior_scale: Union[Scalar, Float[Array, "emission_dim emission_dim"]]=0.1): self.num_states = num_states self.num_components = num_components self.emission_dim = emission_dim @@ -69,12 +69,14 @@ def __init__(self, def emission_shape(self): return (self.emission_dim,) - def initialize(self, key=jr.PRNGKey(0), - method="prior", - emission_weights=None, - emission_means=None, - emission_covariances=None, - emissions=None): + def initialize(self, + key: Array=jr.PRNGKey(0), + method: str="prior", + emission_weights: Optional[Float[Array, "num_states num_components"]]=None, + emission_means: Optional[Float[Array, "num_states num_components emission_dim"]]=None, + emission_covariances: Optional[Float[Array, "num_states num_components emission_dim emission_dim"]]=None, + emissions: Optional[Float[Array, "num_timesteps emission_dim"]]=None + ) -> Tuple[ParamsGaussianMixtureHMMEmissions, ParamsGaussianMixtureHMMEmissions]: if method.lower() == "kmeans": assert emissions is not None, "Need emissions to initialize the model with K-Means!" from sklearn.cluster import KMeans @@ -111,13 +113,17 @@ def initialize(self, key=jr.PRNGKey(0), covs=ParameterProperties(constrainer=RealToPSDBijector())) return params, props - def distribution(self, params, state, inputs=None): + def distribution(self, + params: ParamsGaussianMixtureHMMEmissions, + state: IntScalar, + inputs: Optional[Array] = None + ) -> tfd.Distribution: return tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(probs=params.weights[state]), components_distribution=tfd.MultivariateNormalFullCovariance( loc=params.means[state], covariance_matrix=params.covs[state])) - def log_prior(self, params): + def log_prior(self, params:ParamsGaussianMixtureHMMEmissions) -> Float[Array, ""]: lp = tfd.Dirichlet(self.emission_weights_concentration).log_prob( params.weights).sum() lp += NormalInverseWishart(self.emission_prior_mean, self.emission_prior_mean_concentration, @@ -125,7 +131,12 @@ def log_prior(self, params): (params.covs, params.means)).sum() return lp - def collect_suff_stats(self, params, posterior, emissions, inputs=None): + def collect_suff_stats(self, + params: ParamsGaussianMixtureHMMEmissions, + posterior: HMMPosterior, + emissions: Float[Array, "num_timesteps emission_dim"], + inputs: Optional[Array] = None + ) -> Dict[str, Float[Array, "..."]]: def prob_fn(x): logprobs = vmap(lambda mus, sigmas, weights: tfd.MultivariateNormalFullCovariance( loc=mus, covariance_matrix=sigmas).log_prob(x) + jnp.log(weights))( @@ -141,10 +152,20 @@ def prob_fn(x): N = weights.sum(axis=0) return dict(N=N, Sx=Sx, SxxT=SxxT) - def initialize_m_step_state(self, params, props): + def initialize_m_step_state( + self, + params: ParamsGaussianMixtureHMMEmissions, + props: ParamsGaussianMixtureHMMEmissions + ) -> None: return None - def m_step(self, params, props, batch_stats, m_step_state): + def m_step( + self, + params: ParamsGaussianMixtureHMMEmissions, + props: ParamsGaussianMixtureHMMEmissions, + batch_stats: Dict[str, Float[Array, "..."]], + m_step_state: Any + ) -> Tuple[ParamsGaussianMixtureHMMEmissions, Any]: assert props.weights.trainable, "GaussianMixtureHMM.fit_em() does not support fitting a subset of parameters" assert props.means.trainable, "GaussianMixtureHMM.fit_em() does not support fitting a subset of parameters" assert props.covs.trainable, "GaussianMixtureHMM.fit_em() does not support fitting a subset of parameters" @@ -207,11 +228,11 @@ def __init__(self, num_states: int, num_components: int, emission_dim: int, - initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, - transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, + initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, + transition_matrix_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, transition_matrix_stickiness: Scalar=0.0, - emission_weights_concentration: Union[Scalar, Float[Array, "num_components"]]=1.1, - emission_prior_mean: Union[Scalar, Float[Array, "emission_dim"]]=0.0, + emission_weights_concentration: Union[Scalar, Float[Array, " num_components"]]=1.1, + emission_prior_mean: Union[Scalar, Float[Array, " emission_dim"]]=0.0, emission_prior_mean_concentration: Scalar=1e-4, emission_prior_extra_df: Scalar=1e-4, emission_prior_scale: Union[Scalar, Float[Array, "emission_dim emission_dim"]]=1e-4): @@ -229,9 +250,9 @@ def __init__(self, super().__init__(num_states, initial_component, transition_component, emission_component) def initialize(self, - key: jr.PRNGKey=jr.PRNGKey(0), + key: Array=jr.PRNGKey(0), method: str="prior", - initial_probs: Optional[Float[Array, "num_states"]]=None, + initial_probs: Optional[Float[Array, " num_states"]]=None, transition_matrix: Optional[Float[Array, "num_states num_states"]]=None, emission_weights: Optional[Float[Array, "num_states num_components"]]=None, emission_means: Optional[Float[Array, "num_states num_components emission_dim"]]=None, @@ -268,14 +289,14 @@ def initialize(self, class DiagonalGaussianMixtureHMMEmissions(HMMEmissions): def __init__(self, - num_states, - num_components, - emission_dim, - emission_weights_concentration=1.1, - emission_prior_mean=0., - emission_prior_mean_concentration=1e-4, - emission_prior_shape=1., - emission_prior_scale=1.): + num_states: int, + num_components: int, + emission_dim: int, + emission_weights_concentration: Union[Scalar, Float[Array, " num_components"]]=1.1, + emission_prior_mean: Union[Scalar, Float[Array, " emission_dim"]]=0., + emission_prior_mean_concentration: Scalar=1e-4, + emission_prior_shape: Scalar=1., + emission_prior_scale: Union[Scalar, Float[Array, " emission_dim"]]=1.): self.num_states = num_states self.num_components = num_components self.emission_dim = emission_dim @@ -288,15 +309,17 @@ def __init__(self, self.emission_prior_scale = emission_prior_scale @property - def emission_shape(self): + def emission_shape(self) -> Tuple[int]: return (self.emission_dim,) - def initialize(self, key=jr.PRNGKey(0), - method="prior", - emission_weights=None, - emission_means=None, - emission_scale_diags=None, - emissions=None): + def initialize(self, + key: Array=jr.PRNGKey(0), + method: str="prior", + emission_weights: Optional[Float[Array, "num_states num_components"]]=None, + emission_means: Optional[Float[Array, "num_states num_components emission_dim"]]=None, + emission_scale_diags: Optional[Float[Array, "num_states num_components emission_dim"]]=None, + emissions: Optional[Float[Array, "num_timesteps emission_dim"]]=None + ) -> Tuple[ParamsDiagonalGaussianMixtureHMMEmissions, ParamsDiagonalGaussianMixtureHMMEmissions]: if method.lower() == "kmeans": assert emissions is not None, "Need emissions to initialize the model with K-Means!" from sklearn.cluster import KMeans @@ -333,14 +356,18 @@ def initialize(self, key=jr.PRNGKey(0), scale_diags=ParameterProperties(constrainer=tfb.Softplus())) return params, props - def distribution(self, params, state, inputs=None): + def distribution(self, + params: ParamsDiagonalGaussianMixtureHMMEmissions, + state: IntScalar, + inputs: Optional[Array] = None + ) -> tfd.Distribution: return tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(probs=params.weights[state]), components_distribution=tfd.MultivariateNormalDiag( loc=params.means[state], scale_diag=params.scale_diags[state])) - def log_prior(self, params): + def log_prior(self, params: ParamsDiagonalGaussianMixtureHMMEmissions) -> Float[Array, ""]: lp = tfd.Dirichlet(self.emission_weights_concentration).log_prob( params.weights).sum() lp += NormalInverseGamma(self.emission_prior_mean, self.emission_prior_mean_concentration, @@ -349,7 +376,12 @@ def log_prior(self, params): return lp # Expectation-maximization (EM) code - def collect_suff_stats(self, params, posterior, emissions, inputs=None): + def collect_suff_stats(self, + params: ParamsDiagonalGaussianMixtureHMMEmissions, + posterior: HMMPosterior, + emissions: Float[Array, "num_timesteps emission_dim"], + inputs: Optional[Array] = None + ) -> Dict[str, Float[Array, "..."]]: # Evaluate the posterior probability of each discrete class def prob_fn(x): logprobs = vmap(lambda mus, sigmas, weights: tfd.MultivariateNormalDiag( @@ -367,10 +399,18 @@ def prob_fn(x): N = weights.sum(axis=0) return dict(N=N, Sx=Sx, Sxsq=Sxsq) - def initialize_m_step_state(self, params, props): + def initialize_m_step_state(self, + params: ParamsDiagonalGaussianMixtureHMMEmissions, + props: ParamsDiagonalGaussianMixtureHMMEmissions + ) -> None: return None - def m_step(self, params, props, batch_stats, m_step_state): + def m_step(self, + params: ParamsDiagonalGaussianMixtureHMMEmissions, + props: ParamsDiagonalGaussianMixtureHMMEmissions, + batch_stats: Dict[str, Float[Array, "..."]], + m_step_state: None + ) -> Tuple[ParamsDiagonalGaussianMixtureHMMEmissions, None]: assert props.weights.trainable, "GaussianMixtureDiagHMM.fit_em() does not support fitting a subset of parameters" assert props.means.trainable, "GaussianMixtureDiagHMM.fit_em() does not support fitting a subset of parameters" assert props.scale_diags.trainable, "GaussianMixtureDiagHMM.fit_em() does not support fitting a subset of parameters" @@ -440,11 +480,11 @@ def __init__(self, num_states: int, num_components: int, emission_dim: int, - initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, - transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, + initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, + transition_matrix_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, transition_matrix_stickiness: Scalar=0.0, - emission_weights_concentration: Union[Scalar, Float[Array, "num_components"]]=1.1, - emission_prior_mean: Union[Scalar, Float[Array, "emission_dim"]]=0.0, + emission_weights_concentration: Union[Scalar, Float[Array, " num_components"]]=1.1, + emission_prior_mean: Union[Scalar, Float[Array, " emission_dim"]]=0.0, emission_prior_mean_concentration: Scalar=1e-4, emission_prior_shape: Scalar=1., emission_prior_scale: Scalar=1.): @@ -463,9 +503,9 @@ def __init__(self, def initialize(self, - key: jr.PRNGKey=jr.PRNGKey(0), + key: Array=jr.PRNGKey(0), method: str="prior", - initial_probs: Optional[Float[Array, "num_states"]]=None, + initial_probs: Optional[Float[Array, " num_states"]]=None, transition_matrix: Optional[Float[Array, "num_states num_states"]]=None, emission_weights: Optional[Float[Array, "num_states num_components"]]=None, emission_means: Optional[Float[Array, "num_states num_components emission_dim"]]=None, diff --git a/dynamax/hidden_markov_model/models/initial.py b/dynamax/hidden_markov_model/models/initial.py index 20897413..06d8fbca 100644 --- a/dynamax/hidden_markov_model/models/initial.py +++ b/dynamax/hidden_markov_model/models/initial.py @@ -1,23 +1,25 @@ -from dynamax.hidden_markov_model.models.abstractions import HMMInitialState -from dynamax.parameters import ParameterProperties +from typing import Any, cast, NamedTuple, Optional, Tuple, Union import jax.numpy as jnp import jax.random as jr from jaxtyping import Float, Array import tensorflow_probability.substrates.jax.distributions as tfd import tensorflow_probability.substrates.jax.bijectors as tfb -from typing import NamedTuple, Union +from dynamax.hidden_markov_model.inference import HMMPosterior +from dynamax.hidden_markov_model.models.abstractions import HMMInitialState +from dynamax.parameters import ParameterProperties +from dynamax.types import Scalar class ParamsStandardHMMInitialState(NamedTuple): - probs: Union[Float[Array, "state_dim"], ParameterProperties] + probs: Union[Float[Array, " state_dim"], ParameterProperties] class StandardHMMInitialState(HMMInitialState): """Abstract class for HMM initial distributions. """ def __init__(self, - num_states, - initial_probs_concentration=1.1): + num_states: int, + initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1): """ Args: initial_probabilities[k]: prob(hidden(1)=k) @@ -25,10 +27,15 @@ def __init__(self, self.num_states = num_states self.initial_probs_concentration = initial_probs_concentration * jnp.ones(num_states) - def distribution(self, params, inputs=None): + def distribution(self, params: ParamsStandardHMMInitialState, inputs=None) -> tfd.Distribution: return tfd.Categorical(probs=params.probs) - def initialize(self, key=None, method="prior", initial_probs=None): + def initialize( + self, + key: Optional[Array]=None, + method="prior", + initial_probs: Optional[Float[Array, " num_states"]]=None + ) -> Tuple[ParamsStandardHMMInitialState, ParamsStandardHMMInitialState]: """Initialize the model parameters and their corresponding properties. Args: @@ -41,27 +48,38 @@ def initialize(self, key=None, method="prior", initial_probs=None): """ # Initialize the initial probabilities if initial_probs is None: - this_key, key = jr.split(key) - initial_probs = tfd.Dirichlet(self.initial_probs_concentration).sample(seed=this_key) + if key is None: + raise ValueError("key must be provided if initial_probs is not provided.") + else: + this_key, key = jr.split(key) + initial_probs = tfd.Dirichlet(self.initial_probs_concentration).sample(seed=this_key) # Package the results into dictionaries params = ParamsStandardHMMInitialState(probs=initial_probs) props = ParamsStandardHMMInitialState(probs=ParameterProperties(constrainer=tfb.SoftmaxCentered())) return params, props - def log_prior(self, params): + def log_prior(self, params: ParamsStandardHMMInitialState) -> Scalar: return tfd.Dirichlet(self.initial_probs_concentration).log_prob(params.probs) - def _compute_initial_probs(self, params, inputs=None): + def _compute_initial_probs( + self, params: ParamsStandardHMMInitialState, inputs=None + ) -> Float[Array, " num_states"]: return params.probs - def collect_suff_stats(self, params, posterior, inputs=None): + def collect_suff_stats(self, params, posterior: HMMPosterior, inputs=None) -> Float[Array, " num_states"]: return posterior.smoothed_probs[0] - def initialize_m_step_state(self, params, props): + def initialize_m_step_state(self, params, props) -> None: return None - def m_step(self, params, props, batch_stats, m_step_state): + def m_step( + self, + params: ParamsStandardHMMInitialState, + props: ParamsStandardHMMInitialState, + batch_stats: Float[Array, "batch num_states"], + m_step_state: Any + ) -> Tuple[ParamsStandardHMMInitialState, Any]: if props.probs.trainable: if self.num_states == 1: probs = jnp.array([1.0]) diff --git a/dynamax/hidden_markov_model/models/linreg_hmm.py b/dynamax/hidden_markov_model/models/linreg_hmm.py index df63c1bd..1f98bc31 100644 --- a/dynamax/hidden_markov_model/models/linreg_hmm.py +++ b/dynamax/hidden_markov_model/models/linreg_hmm.py @@ -1,7 +1,11 @@ +from typing import Any, Dict, NamedTuple, Optional, Tuple, Union import jax.numpy as jnp import jax.random as jr from jax import vmap -from jaxtyping import Float, Array +from jaxtyping import Array, Float, Int +from tensorflow_probability.substrates import jax as tfp + +from dynamax.hidden_markov_model.inference import HMMPosterior from dynamax.hidden_markov_model.models.abstractions import HMM, HMMEmissions, HMMParameterSet, HMMPropertySet from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions @@ -9,8 +13,6 @@ from dynamax.types import Scalar from dynamax.utils.utils import pytree_sum from dynamax.utils.bijectors import RealToPSDBijector -from tensorflow_probability.substrates import jax as tfp -from typing import NamedTuple, Optional, Tuple, Union tfd = tfp.distributions tfb = tfp.bijectors @@ -29,9 +31,9 @@ class ParamsLinearRegressionHMM(NamedTuple): class LinearRegressionHMMEmissions(HMMEmissions): def __init__(self, - num_states, - input_dim, - emission_dim): + num_states: int, + input_dim: int, + emission_dim: int): """_summary_ Args: @@ -46,16 +48,17 @@ def __init__(self, self.emission_dim = emission_dim @property - def emission_shape(self): + def emission_shape(self) -> Tuple[int]: return (self.emission_dim,) def initialize(self, - key=jr.PRNGKey(0), - method="prior", - emission_weights=None, - emission_biases=None, - emission_covariances=None, - emissions=None): + key: Array=jr.PRNGKey(0), + method: str="prior", + emission_weights: Optional[Float[Array, "num_states emission_dim input_dim"]]=None, + emission_biases: Optional[Float[Array, "num_states emission_dim"]]=None, + emission_covariances: Optional[Float[Array, "num_states emission_dim emission_dim"]]=None, + emissions: Optional[Float[Array, "num_timesteps emission_dim"]]=None + ) -> Tuple[ParamsLinearRegressionHMMEmissions, ParamsLinearRegressionHMMEmissions]: if method.lower() == "kmeans": assert emissions is not None, "Need emissions to initialize the model with K-Means!" from sklearn.cluster import KMeans @@ -87,16 +90,27 @@ def initialize(self, covs=ParameterProperties(constrainer=RealToPSDBijector())) return params, props - def distribution(self, params, state, inputs): + def distribution( + self, + params: ParamsLinearRegressionHMMEmissions, + state: Union[int, Int[Array, ""]], + inputs: Float[Array, " input_dim"] + ): prediction = params.weights[state] @ inputs prediction += params.biases[state] return tfd.MultivariateNormalFullCovariance(prediction, params.covs[state]) - def log_prior(self, params): + def log_prior(self, params: ParamsLinearRegressionHMMEmissions) -> float: return 0.0 # Expectation-maximization (EM) code - def collect_suff_stats(self, params, posterior, emissions, inputs=None): + def collect_suff_stats( + self, + params: ParamsLinearRegressionHMMEmissions, + posterior: HMMPosterior, + emissions: Float[Array, "num_timesteps emission_dim"], + inputs: Optional[Float[Array, "num_timesteps input_dim"]] = None + ) -> Dict[str, Float[Array, "..."]]: expected_states = posterior.smoothed_probs sum_w = jnp.einsum("tk->k", expected_states) sum_x = jnp.einsum("tk,ti->ki", expected_states, inputs) @@ -106,10 +120,16 @@ def collect_suff_stats(self, params, posterior, emissions, inputs=None): sum_yyT = jnp.einsum("tk,ti,tj->kij", expected_states, emissions, emissions) return dict(sum_w=sum_w, sum_x=sum_x, sum_y=sum_y, sum_xxT=sum_xxT, sum_xyT=sum_xyT, sum_yyT=sum_yyT) - def initialize_m_step_state(self, params, props): + def initialize_m_step_state(self, params, props) -> None: return None - def m_step(self, params, props, batch_stats, m_step_state): + def m_step( + self, + params: ParamsLinearRegressionHMMEmissions, + props: ParamsLinearRegressionHMMEmissions, + batch_stats: Dict[str, Float[Array, "..."]], + m_step_state: Any + ) -> Tuple[ParamsLinearRegressionHMMEmissions, Any]: def _single_m_step(stats): sum_w = stats['sum_w'] sum_x = stats['sum_x'] @@ -169,8 +189,8 @@ def __init__(self, num_states: int, input_dim: int, emission_dim: int, - initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, - transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, + initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, + transition_matrix_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, transition_matrix_stickiness: Scalar=0.0): self.emission_dim = emission_dim self.input_dim = input_dim @@ -184,9 +204,9 @@ def inputs_shape(self): return (self.input_dim,) def initialize(self, - key: jr.PRNGKey=jr.PRNGKey(0), + key: Array=jr.PRNGKey(0), method: str="prior", - initial_probs: Optional[Float[Array, "num_states"]]=None, + initial_probs: Optional[Float[Array, " num_states"]]=None, transition_matrix: Optional[Float[Array, "num_states num_states"]]=None, emission_weights: Optional[Float[Array, "num_states emission_dim input_dim"]]=None, emission_biases: Optional[Float[Array, "num_states emission_dim"]]=None, diff --git a/dynamax/hidden_markov_model/models/logreg_hmm.py b/dynamax/hidden_markov_model/models/logreg_hmm.py index 2da4dd84..a5017fa5 100644 --- a/dynamax/hidden_markov_model/models/logreg_hmm.py +++ b/dynamax/hidden_markov_model/models/logreg_hmm.py @@ -1,49 +1,51 @@ +from typing import NamedTuple, Optional, Tuple, Union import jax.numpy as jnp import jax.random as jr +from jaxtyping import Float, Array +import optax import tensorflow_probability.substrates.jax.distributions as tfd import tensorflow_probability.substrates.jax.bijectors as tfb -from jaxtyping import Float, Array + from dynamax.parameters import ParameterProperties from dynamax.hidden_markov_model.models.abstractions import HMM, HMMEmissions, HMMParameterSet, HMMPropertySet from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions -from dynamax.types import Scalar -import optax -from typing import NamedTuple, Optional, Tuple, Union +from dynamax.types import IntScalar, Scalar class ParamsLogisticRegressionHMMEmissions(NamedTuple): weights: Union[Float[Array, "state_dim input_dim"], ParameterProperties] - biases: Union[Float[Array, "state_dim"], ParameterProperties] + biases: Union[Float[Array, " state_dim"], ParameterProperties] class LogisticRegressionHMMEmissions(HMMEmissions): def __init__(self, - num_states, - input_dim, - emission_matrices_scale=1e8, - m_step_optimizer=optax.adam(1e-2), - m_step_num_iters=50): + num_states: int, + input_dim: int, + emission_matrices_scale: Scalar = 1e8, + m_step_optimizer: optax.GradientTransformation = optax.adam(1e-2), + m_step_num_iters: int = 50): super().__init__(m_step_optimizer=m_step_optimizer, m_step_num_iters=m_step_num_iters) self.num_states = num_states self.input_dim = input_dim self.emission_weights_scale = emission_matrices_scale @property - def emission_shape(self): + def emission_shape(self) -> Tuple: return () @property - def inputs_shape(self): + def inputs_shape(self) -> Tuple[int]: return (self.input_dim,) def initialize(self, - key=jr.PRNGKey(0), - method="prior", - emission_weights=None, - emission_biases=None, - emissions=None, - inputs=None): + key: Array = jr.PRNGKey(0), + method: str = "prior", + emission_weights: Optional[Float[Array, "num_states input_dim"]] = None, + emission_biases: Optional[Float[Array, " num_states"]] = None, + emissions: Optional[Float[Array, " num_timesteps"]] = None, + inputs: Optional[Float[Array, "num_timesteps input_dim"]] = None + ) -> Tuple[ParamsLogisticRegressionHMMEmissions, ParamsLogisticRegressionHMMEmissions]: if method.lower() == "kmeans": assert emissions is not None, "Need emissions to initialize the model with K-Means!" @@ -78,10 +80,15 @@ def initialize(self, biases=ParameterProperties()) return params, props - def log_prior(self, params): + def log_prior(self, params: ParamsLogisticRegressionHMMEmissions) -> Float[Array, ""]: return tfd.Normal(0, self.emission_weights_scale).log_prob(params.weights).sum() - def distribution(self, params, state, inputs): + def distribution( + self, + params: ParamsLogisticRegressionHMMEmissions, + state: IntScalar, + inputs: Float[Array, "input_dim"] + ) -> tfd.Distribution: logits = params.weights[state] @ inputs + params.biases[state] return tfd.Bernoulli(logits=logits) @@ -121,8 +128,8 @@ class LogisticRegressionHMM(HMM): def __init__(self, num_states: int, input_dim: int, - initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, - transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, + initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, + transition_matrix_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, transition_matrix_stickiness: Scalar=0.0, emission_matrices_scale: Scalar=1e8, m_step_optimizer: optax.GradientTransformation=optax.adam(1e-2), @@ -134,17 +141,17 @@ def __init__(self, super().__init__(num_states, initial_component, transition_component, emission_component) @property - def inputs_shape(self): + def inputs_shape(self) -> Tuple[int, ...]: return (self.inputs_dim,) def initialize(self, - key: jr.PRNGKey=jr.PRNGKey(0), + key: Array=jr.PRNGKey(0), method: str="prior", - initial_probs: Optional[Float[Array, "num_states"]]=None, + initial_probs: Optional[Float[Array, " num_states"]]=None, transition_matrix: Optional[Float[Array, "num_states num_states"]]=None, emission_weights: Optional[Float[Array, "num_states input_dim"]]=None, - emission_biases: Optional[Float[Array, "num_states"]]=None, - emissions: Optional[Float[Array, "num_timesteps"]]=None, + emission_biases: Optional[Float[Array, " num_states"]]=None, + emissions: Optional[Float[Array, " num_timesteps"]]=None, inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None, ) -> Tuple[HMMParameterSet, HMMPropertySet]: """Initialize the model parameters and their corresponding properties. diff --git a/dynamax/hidden_markov_model/models/multinomial_hmm.py b/dynamax/hidden_markov_model/models/multinomial_hmm.py index 326c0e98..3afe2921 100644 --- a/dynamax/hidden_markov_model/models/multinomial_hmm.py +++ b/dynamax/hidden_markov_model/models/multinomial_hmm.py @@ -1,18 +1,19 @@ -from typing import NamedTuple, Optional, Tuple, Union +from typing import Any, Dict, NamedTuple, Optional, Tuple, Union import jax.numpy as jnp import jax.random as jr import tensorflow_probability.substrates.jax.bijectors as tfb import tensorflow_probability.substrates.jax.distributions as tfd -from jaxtyping import Array, Float +from jaxtyping import Array, Float, Real +from dynamax.hidden_markov_model.inference import HMMPosterior from dynamax.hidden_markov_model.models.abstractions import HMM, HMMEmissions from dynamax.hidden_markov_model.models.initial import ParamsStandardHMMInitialState from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState from dynamax.hidden_markov_model.models.transitions import ParamsStandardHMMTransitions from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions from dynamax.parameters import ParameterProperties, ParameterSet, PropertySet -from dynamax.types import Scalar +from dynamax.types import IntScalar, Scalar from dynamax.utils.utils import pytree_sum @@ -23,11 +24,11 @@ class ParamsMultinomialHMMEmissions(NamedTuple): class MultinomialHMMEmissions(HMMEmissions): def __init__(self, - num_states, - emission_dim, - num_classes, - num_trials, - emission_prior_concentration=1.1): + num_states: int, + emission_dim: int, + num_classes: int, + num_trials: int, + emission_prior_concentration: Union[Scalar, Float[Array, " num_classes"]] = 1.1): self.num_states = num_states self.emission_dim = emission_dim self.num_classes = num_classes @@ -38,7 +39,11 @@ def __init__(self, def emission_shape(self): return (self.emission_dim, self.num_classes) - def initialize(self, key=jr.PRNGKey(0), method="prior", emission_probs=None): + def initialize(self, + key: Array = jr.PRNGKey(0), + method: str = "prior", + emission_probs: Optional[Float[Array, "num_states emission_dim num_classes"]] = None + ) -> Tuple[ParamsMultinomialHMMEmissions, ParamsMultinomialHMMEmissions]: # Initialize the emission probabilities if emission_probs is None: if method.lower() == "prior": @@ -58,26 +63,44 @@ def initialize(self, key=jr.PRNGKey(0), method="prior", emission_probs=None): props = ParamsMultinomialHMMEmissions(probs=ParameterProperties(constrainer=tfb.SoftmaxCentered())) return params, props - def distribution(self, params, state, inputs=None): + def distribution( + self, + params: ParamsMultinomialHMMEmissions, + state: IntScalar, + inputs: Optional[Array] = None + ) -> tfd.Distribution: return tfd.Independent( tfd.Multinomial(self.num_trials, probs=params.probs[state]), reinterpreted_batch_ndims=1) - def log_prior(self, params): + def log_prior(self, params: ParamsMultinomialHMMEmissions) -> Float[Array, ""]: return tfd.Dirichlet(self.emission_prior_concentration).log_prob(params.probs).sum() - def collect_suff_stats(self, params, posterior, emissions, inputs=None): + def collect_suff_stats( + self, + params: ParamsMultinomialHMMEmissions, + posterior: HMMPosterior, + emissions: Real[Array, "num_timesteps emission_dim num_classes"], + inputs: Optional[Array] = None + ) -> Dict[str, Float[Array, "num_states emission_dim num_classes"]]: expected_states = posterior.smoothed_probs return dict(sum_x=jnp.einsum("tk, tdi->kdi", expected_states, emissions)) - def initialize_m_step_state(self, params, props): + def initialize_m_step_state(self, params, props) -> None: return None - def m_step(self, params, props, batch_stats, m_step_state): + def m_step( + self, + params: ParamsMultinomialHMMEmissions, + props: ParamsMultinomialHMMEmissions, + batch_stats: Dict[str, Float[Array, "batch_dim num_states emission_dim num_classes"]], + m_step_state: Any + ) -> Tuple[ParamsMultinomialHMMEmissions, Any]: if props.probs.trainable: emission_stats = pytree_sum(batch_stats, axis=0) probs = tfd.Dirichlet( - self.emission_prior_concentration + emission_stats['sum_x']).mode() + self.emission_prior_concentration + emission_stats['sum_x'] + ).mode() params = params._replace(probs=probs) return params, m_step_state @@ -111,14 +134,14 @@ class MultinomialHMM(HMM): """ def __init__(self, - num_states, - emission_dim, - num_classes, - num_trials, - initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, - transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, + num_states: int, + emission_dim: int, + num_classes: int, + num_trials: int, + initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, + transition_matrix_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, transition_matrix_stickiness: Scalar=0.0, - emission_prior_concentration: Scalar=1.1): + emission_prior_concentration: Union[Scalar, Float[Array, " num_classes"]]=1.1): self.emission_dim = emission_dim self.num_classes = num_classes self.num_trials = num_trials @@ -129,7 +152,7 @@ def __init__(self, def initialize(self, key=jr.PRNGKey(0), method="prior", - initial_probs: Optional[Float[Array, "num_states"]]=None, + initial_probs: Optional[Float[Array, " num_states"]]=None, transition_matrix: Optional[Float[Array, "num_states num_states"]]=None, emission_probs: Optional[Float[Array, "num_states emission_dim num_classes"]]=None ) -> Tuple[ParameterSet, PropertySet]: diff --git a/dynamax/hidden_markov_model/models/poisson_hmm.py b/dynamax/hidden_markov_model/models/poisson_hmm.py index 198bf128..b5c79203 100644 --- a/dynamax/hidden_markov_model/models/poisson_hmm.py +++ b/dynamax/hidden_markov_model/models/poisson_hmm.py @@ -1,4 +1,4 @@ -from typing import NamedTuple, Optional, Tuple, Union +from typing import Any, Dict, NamedTuple, Optional, Tuple, Union import jax.numpy as jnp import jax.random as jr @@ -6,13 +6,14 @@ import tensorflow_probability.substrates.jax.distributions as tfd from jaxtyping import Array, Float +from dynamax.hidden_markov_model.inference import HMMPosterior from dynamax.hidden_markov_model.models.abstractions import HMM, HMMEmissions from dynamax.hidden_markov_model.models.initial import ParamsStandardHMMInitialState from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState from dynamax.hidden_markov_model.models.transitions import ParamsStandardHMMTransitions from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions from dynamax.parameters import ParameterProperties, ParameterSet, PropertySet -from dynamax.types import Scalar +from dynamax.types import IntScalar, Scalar from dynamax.utils.utils import pytree_sum @@ -23,10 +24,10 @@ class ParamsPoissonHMMEmissions(NamedTuple): class PoissonHMMEmissions(HMMEmissions): def __init__(self, - num_states, - emission_dim, - emission_prior_concentration=1.1, - emission_prior_rate=0.1): + num_states: int, + emission_dim: int, + emission_prior_concentration: Scalar = 1.1, + emission_prior_rate: Scalar = 0.1): """_summary_ Args: @@ -40,12 +41,13 @@ def __init__(self, self.emission_prior_rate = emission_prior_rate @property - def emission_shape(self): + def emission_shape(self) -> Tuple[int]: return (self.emission_dim,) - def initialize(self, key=jr.PRNGKey(0), - method="prior", - emission_rates=None): + def initialize(self, key: Array=jr.PRNGKey(0), + method: str = "prior", + emission_rates: Optional[Float[Array, "num_states emission_dim"]] = None + ) -> Tuple[ParamsPoissonHMMEmissions, ParamsPoissonHMMEmissions]: # Initialize the emission probabilities if emission_rates is None: if method.lower() == "prior": @@ -64,24 +66,41 @@ def initialize(self, key=jr.PRNGKey(0), props = ParamsPoissonHMMEmissions(rates=ParameterProperties(constrainer=tfb.Softplus())) return params, props - def distribution(self, params, state, inputs=None): + def distribution( + self, + params: ParamsPoissonHMMEmissions, + state: IntScalar, + inputs: Optional[Array] = None + ) -> tfd.Distribution: return tfd.Independent(tfd.Poisson(rate=params.rates[state]), reinterpreted_batch_ndims=1) - def log_prior(self, params): + def log_prior(self, params: ParamsPoissonHMMEmissions) -> Float[Array, ""]: prior = tfd.Gamma(self.emission_prior_concentration, self.emission_prior_rate) return prior.log_prob(params.rates).sum() - def collect_suff_stats(self, params, posterior, emissions, inputs=None): + def collect_suff_stats( + self, + params: ParamsPoissonHMMEmissions, + posterior: HMMPosterior, + emissions: Float[Array, "num_timesteps emission_dim"], + inputs: Optional[Array] = None + ) -> Dict[str, Float[Array, "..."]]: expected_states = posterior.smoothed_probs sum_w = jnp.einsum("tk->k", expected_states)[:, None] sum_x = jnp.einsum("tk, ti->ki", expected_states, emissions) return dict(sum_w=sum_w, sum_x=sum_x) - def initialize_m_step_state(self, params, props): + def initialize_m_step_state(self, params: ParamsPoissonHMMEmissions, props: ParamsPoissonHMMEmissions) -> None: return None - def m_step(self, params, props, batch_stats, m_step_state): + def m_step( + self, + params: ParamsPoissonHMMEmissions, + props: ParamsPoissonHMMEmissions, + batch_stats: Dict[str, Float[Array, "..."]], + m_step_state: Any + ) -> Tuple[ParamsPoissonHMMEmissions, Any]: if props.rates.trainable: emission_stats = pytree_sum(batch_stats, axis=0) post_concentration = self.emission_prior_concentration + emission_stats['sum_x'] @@ -121,8 +140,8 @@ class PoissonHMM(HMM): def __init__(self, num_states: int, emission_dim: int, - initial_probs_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, - transition_matrix_concentration: Union[Scalar, Float[Array, "num_states"]]=1.1, + initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, + transition_matrix_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, transition_matrix_stickiness: Scalar=0.0, emission_prior_concentration: Scalar=1.1, emission_prior_rate: Scalar=0.1): @@ -132,9 +151,9 @@ def __init__(self, emission_component = PoissonHMMEmissions(num_states, emission_dim, emission_prior_concentration=emission_prior_concentration, emission_prior_rate=emission_prior_rate) super().__init__(num_states, initial_component, transition_component, emission_component) - def initialize(self, key=jr.PRNGKey(0), + def initialize(self, key: Array=jr.PRNGKey(0), method="prior", - initial_probs: Optional[Float[Array, "num_states"]]=None, + initial_probs: Optional[Float[Array, " num_states"]]=None, transition_matrix: Optional[Float[Array, "num_states num_states"]]=None, emission_rates: Optional[Float[Array, "num_states emission_dim"]]=None ) -> Tuple[ParameterSet, PropertySet]: diff --git a/dynamax/hidden_markov_model/models/transitions.py b/dynamax/hidden_markov_model/models/transitions.py index 59461f4d..cdb8a6ea 100644 --- a/dynamax/hidden_markov_model/models/transitions.py +++ b/dynamax/hidden_markov_model/models/transitions.py @@ -1,11 +1,13 @@ +from typing import Any, cast, NamedTuple, Optional, Tuple, Union import jax.numpy as jnp -import jax.random as jr +from jaxtyping import Float, Array import tensorflow_probability.substrates.jax.distributions as tfd import tensorflow_probability.substrates.jax.bijectors as tfb + from dynamax.hidden_markov_model.models.abstractions import HMMTransitions +from dynamax.hidden_markov_model.inference import HMMPosterior from dynamax.parameters import ParameterProperties -from jaxtyping import Float, Array -from typing import NamedTuple, Union +from dynamax.types import IntScalar, Scalar class ParamsStandardHMMTransitions(NamedTuple): @@ -29,7 +31,12 @@ class StandardHMMTransitions(HMMTransitions): """ - def __init__(self, num_states, concentration=1.1, stickiness=0.0): + def __init__( + self, + num_states: int, + concentration: Union[Scalar, Float[Array, "num_states num_states"]]=1.1, + stickiness: Union[Scalar, Float[Array, " num_states"]]=0.0 + ): """ Args: transition_matrix[j,k]: prob(hidden(t) = k | hidden(t-1)j) @@ -39,10 +46,15 @@ def __init__(self, num_states, concentration=1.1, stickiness=0.0): concentration * jnp.ones((num_states, num_states)) + \ stickiness * jnp.eye(num_states) - def distribution(self, params, state, inputs=None): + def distribution(self, params: ParamsStandardHMMTransitions, state: IntScalar, inputs=None): return tfd.Categorical(probs=params.transition_matrix[state]) - def initialize(self, key=None, method="prior", transition_matrix=None): + def initialize( + self, + key: Optional[Array]=None, + method="prior", + transition_matrix: Optional[Float[Array, "num_states num_states"]]=None + ) -> Tuple[ParamsStandardHMMTransitions, ParamsStandardHMMTransitions]: """Initialize the model parameters and their corresponding properties. Args: @@ -54,27 +66,44 @@ def initialize(self, key=None, method="prior", transition_matrix=None): _type_: _description_ """ if transition_matrix is None: - this_key, key = jr.split(key) - transition_matrix = tfd.Dirichlet(self.concentration).sample(seed=this_key) + if key is None: + raise ValueError("key must be provided if transition_matrix is not provided.") + else: + transition_matrix_sample = tfd.Dirichlet(self.concentration).sample(seed=key) + transition_matrix = cast(Float[Array, "num_states num_states"], transition_matrix_sample) # Package the results into dictionaries params = ParamsStandardHMMTransitions(transition_matrix=transition_matrix) props = ParamsStandardHMMTransitions(transition_matrix=ParameterProperties(constrainer=tfb.SoftmaxCentered())) return params, props - def log_prior(self, params): + def log_prior(self, params: ParamsStandardHMMTransitions) -> Scalar: return tfd.Dirichlet(self.concentration).log_prob(params.transition_matrix).sum() - def _compute_transition_matrices(self, params, inputs=None): + def _compute_transition_matrices( + self, params: ParamsStandardHMMTransitions, inputs=None + ) -> Float[Array, "num_states num_states"]: return params.transition_matrix - def collect_suff_stats(self, params, posterior, inputs=None): + def collect_suff_stats( + self, + params, + posterior: HMMPosterior, + inputs=None + ) -> Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]: return posterior.trans_probs def initialize_m_step_state(self, params, props): return None - def m_step(self, params, props, batch_stats, m_step_state): + def m_step( + self, + params: ParamsStandardHMMTransitions, + props: ParamsStandardHMMTransitions, + batch_stats: Float[Array, "batch num_states num_states"], + m_step_state: Any + ) -> Tuple[ParamsStandardHMMTransitions, Any]: if props.transition_matrix.trainable: if self.num_states == 1: transition_matrix = jnp.array([[1.0]]) diff --git a/dynamax/hidden_markov_model/parallel_inference.py b/dynamax/hidden_markov_model/parallel_inference.py index 37fa7fb2..927e3030 100644 --- a/dynamax/hidden_markov_model/parallel_inference.py +++ b/dynamax/hidden_markov_model/parallel_inference.py @@ -2,17 +2,17 @@ import jax.random as jr from jax import lax, vmap, value_and_grad from jaxtyping import Array, Float, Int -from typing import NamedTuple, Union -from functools import partial +from typing import NamedTuple, Tuple from dynamax.hidden_markov_model.inference import HMMPosterior, HMMPosteriorFiltered +from dynamax.types import Scalar #---------------------------------------------------------------------------# # Filtering # #---------------------------------------------------------------------------# class FilterMessage(NamedTuple): - """Filtering associative scan elements. + r"""Filtering associative scan elements. Attributes: A: $p(z_j \mid z_i)$ @@ -30,7 +30,7 @@ def _condition_on(A, ll, axis=-1): return A_cond, jnp.log(norm) + ll_max -def hmm_filter(initial_probs: Float[Array, "num_states"], +def hmm_filter(initial_probs: Float[Array, " num_states"], transition_matrix: Float[Array, "num_states num_states"], log_likelihoods: Float[Array, "num_timesteps num_states"] ) -> HMMPosteriorFiltered: @@ -89,10 +89,10 @@ def marginalize(m_ij, m_jk): #---------------------------------------------------------------------------# -def hmm_smoother(initial_probs: Float[Array, "num_states"], +def hmm_smoother(initial_probs: Float[Array, " num_states"], transition_matrix: Float[Array, "num_states num_states"], log_likelihoods: Float[Array, "num_timesteps num_states"] -) -> HMMPosteriorFiltered: +) -> HMMPosterior: r"""Parallel implementation of HMM smoothing with `jax.lax.associative_scan`. **Notes:** @@ -132,43 +132,43 @@ def log_normalizer(log_initial_probs, log_transition_matrix, log_likelihoods): #---------------------------------------------------------------------------# # Sampling # #---------------------------------------------------------------------------# -"""Associative scan elements $E_ij$ are vectors specifying a sample:: +r"""Associative scan elements $E_ij$ are vectors specifying a sample:: $z_j ~ p(z_j \mid z_i)$ for each possible value of $z_i$. """ -def _initialize_sampling_messages(rng, transition_matrix, filtered_probs): +def _initialize_sampling_messages(key, transition_matrix, filtered_probs): """Preprocess filtering output to construct input for sampling assocative scan.""" T, K = filtered_probs.shape - rngs = jr.split(rng, T) + keys = jr.split(key, T) - def _last_message(rng, probs): - state = jr.choice(rng, K, p=probs) + def _last_message(key, probs): + state = jr.choice(key, K, p=probs) return jnp.repeat(state, K) @vmap - def _generic_message(rng, probs): + def _generic_message(key, probs): smoothed_probs = probs * transition_matrix.T smoothed_probs = smoothed_probs / smoothed_probs.sum(1).reshape(K,1) - return vmap(lambda p: jr.choice(rng, K, p=p))(smoothed_probs) + return vmap(lambda p: jr.choice(key, K, p=p))(smoothed_probs) - En = _last_message(rngs[-1], filtered_probs[-1]) - Et = _generic_message(rngs[:-1], filtered_probs[:-1]) + En = _last_message(keys[-1], filtered_probs[-1]) + Et = _generic_message(keys[:-1], filtered_probs[:-1]) return jnp.concatenate([Et, En[None]]) -def hmm_posterior_sample(rng: jr.PRNGKey, - initial_distribution: Float[Array, "num_states"], +def hmm_posterior_sample(key: Array, + initial_distribution: Float[Array, " num_states"], transition_matrix: Float[Array, "num_states num_states"], log_likelihoods: Float[Array, "num_timesteps num_states"] -) -> Int[Array, "num_timesteps"]: +) -> Tuple[Scalar, Int[Array, " num_timesteps"]]: r"""Sample a sequence of hidden states from the posterior. Args: - rng: random number generator + key: random number generator initial_distribution: $p(z_1 \mid u_1, \theta)$ transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$ log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$. @@ -177,8 +177,6 @@ def hmm_posterior_sample(rng: jr.PRNGKey, log_normalizer: $\log P(y_{1:T} \mid u_{1:T}, \theta)$ states: sequence of hidden states $z_{1:T}$ """ - T, K = log_likelihoods.shape - # Run the HMM filter post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods) log_normalizer = post.marginal_loglik @@ -188,7 +186,7 @@ def hmm_posterior_sample(rng: jr.PRNGKey, def _operator(E_jk, E_ij): return jnp.take(E_ij, E_jk) - initial_messages = _initialize_sampling_messages(rng, transition_matrix, filtered_probs) + initial_messages = _initialize_sampling_messages(key, transition_matrix, filtered_probs) final_messages = lax.associative_scan(_operator, initial_messages, reverse=True) states = final_messages[:,0] return log_normalizer, states diff --git a/dynamax/linear_gaussian_ssm/demos/kf_linreg_jax_vs_pt.py b/dynamax/linear_gaussian_ssm/demos/kf_linreg_jax_vs_pt.py index 85565928..88dfcc11 100644 --- a/dynamax/linear_gaussian_ssm/demos/kf_linreg_jax_vs_pt.py +++ b/dynamax/linear_gaussian_ssm/demos/kf_linreg_jax_vs_pt.py @@ -1,29 +1,16 @@ import numpy as np -from jaxtyping import Float, Array -from typing import Callable, NamedTuple, Union, Tuple, Any -from functools import partial -import chex -import optax import jax import jax.numpy as jnp import jax.random as jr -from jax import lax, jacfwd, vmap, grad, jit -from jax.tree_util import tree_map, tree_reduce -from jax.flatten_util import ravel_pytree +from jax import lax import jax.numpy as jnp import jax.random as jr from jax import lax import time import platform -import matplotlib.pyplot as plt -import matplotlib.cm as cm -from dataclasses import dataclass -from itertools import cycle -import tensorflow as tf -import tensorflow_probability as tfp from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN import torch diff --git a/dynamax/linear_gaussian_ssm/inference.py b/dynamax/linear_gaussian_ssm/inference.py index 1588ad05..4ff4733b 100644 --- a/dynamax/linear_gaussian_ssm/inference.py +++ b/dynamax/linear_gaussian_ssm/inference.py @@ -14,7 +14,7 @@ from typing import NamedTuple, Optional, Union, Tuple from dynamax.utils.utils import psd_solve, symmetrize from dynamax.parameters import ParameterProperties -from dynamax.types import PRNGKey, Scalar +from dynamax.types import PRNGKeyT, Scalar class ParamsLGSSMInitial(NamedTuple): r"""Parameters of the initial distribution @@ -27,9 +27,9 @@ class ParamsLGSSMInitial(NamedTuple): :param cov: $Q_1$ """ - mean: Union[Float[Array, "state_dim"], ParameterProperties] + mean: Union[Float[Array, " state_dim"], ParameterProperties] # unconstrained parameters are stored as a vector. - cov: Union[Float[Array, "state_dim state_dim"], Float[Array, "state_dim_triu"], ParameterProperties] + cov: Union[Float[Array, "state_dim state_dim"], Float[Array, " state_dim_triu"], ParameterProperties] class ParamsLGSSMDynamics(NamedTuple): @@ -50,7 +50,7 @@ class ParamsLGSSMDynamics(NamedTuple): Float[Array, "ntime state_dim state_dim"]] bias: Union[ParameterProperties, - Float[Array, "state_dim"], + Float[Array, " state_dim"], Float[Array, "ntime state_dim"]] input_weights: Union[ParameterProperties, @@ -60,7 +60,7 @@ class ParamsLGSSMDynamics(NamedTuple): cov: Union[ParameterProperties, Float[Array, "state_dim state_dim"], Float[Array, "ntime state_dim state_dim"], - Float[Array, "state_dim_triu"]] + Float[Array, " state_dim_triu"]] class ParamsLGSSMEmissions(NamedTuple): @@ -81,7 +81,7 @@ class ParamsLGSSMEmissions(NamedTuple): Float[Array, "ntime emission_dim state_dim"]] bias: Union[ParameterProperties, - Float[Array, "emission_dim"], + Float[Array, " emission_dim"], Float[Array, "ntime emission_dim"]] input_weights: Union[ParameterProperties, @@ -91,9 +91,9 @@ class ParamsLGSSMEmissions(NamedTuple): cov: Union[ParameterProperties, Float[Array, "emission_dim emission_dim"], Float[Array, "ntime emission_dim emission_dim"], - Float[Array, "emission_dim"], + Float[Array, " emission_dim"], Float[Array, "ntime emission_dim"], - Float[Array, "emission_dim_triu"]] + Float[Array, " emission_dim_triu"]] class ParamsLGSSM(NamedTuple): @@ -117,7 +117,7 @@ class PosteriorGSSMFiltered(NamedTuple): :param filtered_covariances: array of filtered covariances $\mathrm{Cov}[z_t \mid y_{1:t}, u_{1:t}]$ """ - marginal_loglik: Union[Scalar, Float[Array, "ntime"]] + marginal_loglik: Union[Scalar, Float[Array, " ntime"]] filtered_means: Optional[Float[Array, "ntime state_dim"]] = None filtered_covariances: Optional[Float[Array, "ntime state_dim state_dim"]] = None predicted_means: Optional[Float[Array, "ntime state_dim"]] = None @@ -363,7 +363,7 @@ def wrapper(*args, **kwargs): def lgssm_joint_sample( params: ParamsLGSSM, - key: PRNGKey, + key: PRNGKeyT, num_timesteps: int, inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None )-> Tuple[Float[Array, "num_timesteps state_dim"], @@ -559,7 +559,7 @@ def _step(carry, args): def lgssm_posterior_sample( - key: PRNGKey, + key: PRNGKeyT, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]]=None, diff --git a/dynamax/linear_gaussian_ssm/info_inference.py b/dynamax/linear_gaussian_ssm/info_inference.py index 99ca5125..b9018bc9 100644 --- a/dynamax/linear_gaussian_ssm/info_inference.py +++ b/dynamax/linear_gaussian_ssm/info_inference.py @@ -9,7 +9,7 @@ class ParamsLGSSMInfo(NamedTuple): """Lightweight container for passing LGSSM parameters in information form to inference algorithms.""" - initial_mean: Float[Array, "state_dim"] + initial_mean: Float[Array, " state_dim"] dynamics_weights: Float[Array, "state_dim state_dim"] emission_weights: Float[Array, "emission_dim state_dim"] @@ -19,9 +19,9 @@ class ParamsLGSSMInfo(NamedTuple): # Optional parameters (None means zeros) dynamics_input_weights: Optional[Float[Array, "input_dim state_dim"]] = None - dynamics_bias: Optional[Float[Array, "state_dim"]] = None + dynamics_bias: Optional[Float[Array, " state_dim"]] = None emission_input_weights: Optional[Float[Array, "input_dim emission_dim"]] = None - emission_bias: Optional[Float[Array, "emission_dim"]] = None + emission_bias: Optional[Float[Array, " emission_dim"]] = None class PosteriorGSSMInfoFiltered(NamedTuple): diff --git a/dynamax/linear_gaussian_ssm/models.py b/dynamax/linear_gaussian_ssm/models.py index 453de651..4fe22d24 100644 --- a/dynamax/linear_gaussian_ssm/models.py +++ b/dynamax/linear_gaussian_ssm/models.py @@ -4,24 +4,25 @@ import jax.numpy as jnp import jax.random as jr from jax.tree_util import tree_map -from jaxtyping import Array, Float, PyTree +from jaxtyping import Array, Float import tensorflow_probability.substrates.jax.distributions as tfd from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN -from typing import Any, Optional, Tuple, Union -from typing_extensions import Protocol +from typing import Any, Optional, Tuple, Union, runtime_checkable +from typing_extensions import Protocol from dynamax.ssm import SSM from dynamax.linear_gaussian_ssm.inference import lgssm_joint_sample, lgssm_filter, lgssm_smoother, lgssm_posterior_sample from dynamax.linear_gaussian_ssm.inference import ParamsLGSSM, ParamsLGSSMInitial, ParamsLGSSMDynamics, ParamsLGSSMEmissions from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered, PosteriorGSSMSmoothed -from dynamax.parameters import ParameterProperties, ParameterSet -from dynamax.types import PRNGKey, Scalar +from dynamax.parameters import ParameterProperties +from dynamax.types import PRNGKeyT, Scalar from dynamax.utils.bijectors import RealToPSDBijector from dynamax.utils.distributions import MatrixNormalInverseWishart as MNIW from dynamax.utils.distributions import NormalInverseWishart as NIW from dynamax.utils.distributions import mniw_posterior_update, niw_posterior_update from dynamax.utils.utils import pytree_stack, psd_solve +@runtime_checkable class SuffStatsLGSSM(Protocol): """A :class:`NamedTuple` with sufficient statistics for LGSSM parameter estimation.""" pass @@ -87,8 +88,8 @@ def inputs_shape(self): def initialize( self, - key: PRNGKey =jr.PRNGKey(0), - initial_mean: Optional[Float[Array, "state_dim"]]=None, + key: PRNGKeyT =jr.PRNGKey(0), + initial_mean: Optional[Float[Array, " state_dim"]]=None, initial_covariance=None, dynamics_weights=None, dynamics_bias=None, @@ -178,7 +179,7 @@ def initial_distribution( def transition_distribution( self, params: ParamsLGSSM, - state: Float[Array, "state_dim"], + state: Float[Array, " state_dim"], inputs: Optional[Float[Array, "ntime input_dim"]]=None ) -> tfd.Distribution: inputs = inputs if inputs is not None else jnp.zeros(self.input_dim) @@ -190,22 +191,22 @@ def transition_distribution( def emission_distribution( self, params: ParamsLGSSM, - state: Float[Array, "state_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + state: Float[Array, " state_dim"], + inputs: Optional[Float[Array, "ntime input_dim"]] = None ) -> tfd.Distribution: inputs = inputs if inputs is not None else jnp.zeros(self.input_dim) mean = params.emissions.weights @ state + params.emissions.input_weights @ inputs if self.has_emissions_bias: mean += params.emissions.bias return MVN(mean, params.emissions.cov) - + def sample( self, params: ParamsLGSSM, - key: PRNGKey, + key: PRNGKeyT, num_timesteps: int, - inputs: Optional[Float[Array, "ntime input_dim"]] = None - ) -> PosteriorGSSMFiltered: + inputs: Optional[Float[Array, "num_timesteps input_dim"]] = None, + ) -> Tuple[Float[Array, "num_timesteps state_dim"], Float[Array, "num_timesteps emission_dim"]]: return lgssm_joint_sample(params, key, num_timesteps, inputs) def marginal_log_prob( @@ -235,7 +236,7 @@ def smoother( def posterior_sample( self, - key: PRNGKey, + key: PRNGKeyT, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]]=None @@ -500,7 +501,7 @@ def m_step( def fit_blocked_gibbs( self, - key: PRNGKey, + key: PRNGKeyT, initial_params: ParamsLGSSM, sample_size: int, emissions: Float[Array, "nbatch ntime emission_dim"], diff --git a/dynamax/linear_gaussian_ssm/models_test.py b/dynamax/linear_gaussian_ssm/models_test.py index c4394858..fefe8b9e 100644 --- a/dynamax/linear_gaussian_ssm/models_test.py +++ b/dynamax/linear_gaussian_ssm/models_test.py @@ -1,5 +1,4 @@ import pytest -from datetime import datetime import jax.random as jr from dynamax.linear_gaussian_ssm import LinearGaussianSSM from dynamax.linear_gaussian_ssm import LinearGaussianConjugateSSM diff --git a/dynamax/linear_gaussian_ssm/parallel_inference.py b/dynamax/linear_gaussian_ssm/parallel_inference.py index 2e50e8d6..a212f3b2 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference.py @@ -34,7 +34,7 @@ from jax import vmap, lax from jaxtyping import Array, Float from typing import NamedTuple -from dynamax.types import PRNGKey +from dynamax.types import PRNGKeyT from functools import partial import warnings @@ -148,7 +148,7 @@ class FilterMessage(NamedTuple): C: Float[Array, "ntime state_dim state_dim"] J: Float[Array, "ntime state_dim state_dim"] eta: Float[Array, "ntime state_dim"] - logZ: Float[Array, "ntime"] + logZ: Float[Array, " ntime"] def _initialize_filtering_messages(params, emissions): @@ -354,7 +354,7 @@ def _initialize_sampling_messages(key, params, filtered_means, filtered_covarian def lgssm_posterior_sample( - key: PRNGKey, + key: PRNGKeyT, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"] ) -> Float[Array, "ntime state_dim"]: @@ -379,4 +379,4 @@ def _operator(elem1, elem2): initial_messages = _initialize_sampling_messages(key, params, filtered_means, filtered_covs) _, samples = lax.associative_scan(_operator, initial_messages, reverse=True) - return samples \ No newline at end of file + return samples diff --git a/dynamax/nonlinear_gaussian_ssm/inference_ekf.py b/dynamax/nonlinear_gaussian_ssm/inference_ekf.py index d936ed45..2577a610 100644 --- a/dynamax/nonlinear_gaussian_ssm/inference_ekf.py +++ b/dynamax/nonlinear_gaussian_ssm/inference_ekf.py @@ -9,7 +9,7 @@ from dynamax.utils.utils import psd_solve, symmetrize from dynamax.nonlinear_gaussian_ssm.models import ParamsNLGSSM from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered, PosteriorGSSMSmoothed -from dynamax.types import PRNGKey +from dynamax.types import PRNGKeyT # Helper functions _get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x @@ -258,7 +258,7 @@ def _step(carry, args): def extended_kalman_posterior_sample( - key: PRNGKey, + key: PRNGKeyT, params: ParamsNLGSSM, emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]] = None diff --git a/dynamax/nonlinear_gaussian_ssm/inference_test_utils.py b/dynamax/nonlinear_gaussian_ssm/inference_test_utils.py index ef6794ab..52af8331 100644 --- a/dynamax/nonlinear_gaussian_ssm/inference_test_utils.py +++ b/dynamax/nonlinear_gaussian_ssm/inference_test_utils.py @@ -15,7 +15,7 @@ from dynamax.parameters import ParameterProperties from dynamax.ssm import SSM from dynamax.utils.bijectors import RealToPSDBijector -from dynamax.types import PRNGKey +from dynamax.types import PRNGKeyT tfd = tfp.distributions @@ -37,7 +37,7 @@ def lgssm_to_nlgssm(params: ParamsLGSSM) -> ParamsNLGSSM: def random_lgssm_args( - key: Union[int, PRNGKey] = 0, + key: Union[int, PRNGKeyT] = 0, num_timesteps: int = 15, state_dim: int = 4, emission_dim: int = 2 diff --git a/dynamax/nonlinear_gaussian_ssm/inference_ukf_test.py b/dynamax/nonlinear_gaussian_ssm/inference_ukf_test.py index 6900a08d..bb8f9a11 100644 --- a/dynamax/nonlinear_gaussian_ssm/inference_ukf_test.py +++ b/dynamax/nonlinear_gaussian_ssm/inference_ukf_test.py @@ -1,5 +1,4 @@ import jax.numpy as jnp -import jax.random as jr from dynamax.nonlinear_gaussian_ssm.inference_ukf import unscented_kalman_smoother, UKFHyperParams from dynamax.nonlinear_gaussian_ssm.sarkka_lib import ukf, uks diff --git a/dynamax/nonlinear_gaussian_ssm/models.py b/dynamax/nonlinear_gaussian_ssm/models.py index 53d94a40..13769738 100644 --- a/dynamax/nonlinear_gaussian_ssm/models.py +++ b/dynamax/nonlinear_gaussian_ssm/models.py @@ -10,10 +10,10 @@ tfb = tfp.bijectors -FnStateToState = Callable[ [Float[Array, "state_dim"]], Float[Array, "state_dim"]] -FnStateAndInputToState = Callable[ [Float[Array, "state_dim"], Float[Array, "input_dim"]], Float[Array, "state_dim"]] -FnStateToEmission = Callable[ [Float[Array, "state_dim"]], Float[Array, "emission_dim"]] -FnStateAndInputToEmission = Callable[ [Float[Array, "state_dim"], Float[Array, "input_dim"] ], Float[Array, "emission_dim"]] +FnStateToState = Callable[ [Float[Array, " state_dim"]], Float[Array, " state_dim"]] +FnStateAndInputToState = Callable[ [Float[Array, " state_dim"], Float[Array, " input_dim"]], Float[Array, " state_dim"]] +FnStateToEmission = Callable[ [Float[Array, " state_dim"]], Float[Array, " emission_dim"]] +FnStateAndInputToEmission = Callable[ [Float[Array, " state_dim"], Float[Array, " input_dim"] ], Float[Array, " emission_dim"]] class ParamsNLGSSM(NamedTuple): @@ -34,7 +34,7 @@ class ParamsNLGSSM(NamedTuple): """ - initial_mean: Float[Array, "state_dim"] + initial_mean: Float[Array, " state_dim"] initial_covariance: Float[Array, "state_dim state_dim"] dynamics_function: Union[FnStateToState, FnStateAndInputToState] dynamics_covariance: Float[Array, "state_dim state_dim"] @@ -85,15 +85,15 @@ def inputs_shape(self): def initial_distribution( self, params: ParamsNLGSSM, - inputs: Optional[Float[Array, "input_dim"]] = None + inputs: Optional[Float[Array, " input_dim"]] = None ) -> tfd.Distribution: return MVN(params.initial_mean, params.initial_covariance) def transition_distribution( self, params: ParamsNLGSSM, - state: Float[Array, "state_dim"], - inputs: Optional[Float[Array, "input_dim"]] = None + state: Float[Array, " state_dim"], + inputs: Optional[Float[Array, " input_dim"]] = None ) -> tfd.Distribution: f = params.dynamics_function if inputs is None: @@ -105,8 +105,8 @@ def transition_distribution( def emission_distribution( self, params: ParamsNLGSSM, - state: Float[Array, "state_dim"], - inputs: Optional[Float[Array, "input_dim"]] = None + state: Float[Array, " state_dim"], + inputs: Optional[Float[Array, " input_dim"]] = None ) -> tfd.Distribution: h = params.emission_function if inputs is None: diff --git a/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py b/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py index ff65405d..875bf5f9 100644 --- a/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py +++ b/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py @@ -6,7 +6,6 @@ """ import jax.numpy as jnp -import jax.random as jr from jax import vmap from jax import lax from jax import jacfwd diff --git a/dynamax/parameters.py b/dynamax/parameters.py index 7e7a98f3..0182f689 100644 --- a/dynamax/parameters.py +++ b/dynamax/parameters.py @@ -2,18 +2,19 @@ from jax import lax from jax.tree_util import tree_reduce, tree_map, register_pytree_node_class import tensorflow_probability.substrates.jax.bijectors as tfb -from typing import Optional, Union +from typing import Optional, runtime_checkable from typing_extensions import Protocol -from jaxtyping import Array, Float -from dynamax.types import PRNGKey, Scalar +from dynamax.types import Scalar +@runtime_checkable class ParameterSet(Protocol): """A :class:`NamedTuple` with parameters stored as :class:`jax.DeviceArray` in the leaf nodes. """ pass +@runtime_checkable class PropertySet(Protocol): """A matching :class:`NamedTuple` with :class:`ParameterProperties` stored in the leaf nodes. diff --git a/dynamax/parameters_test.py b/dynamax/parameters_test.py index 4e75d53e..3b8fdc8d 100644 --- a/dynamax/parameters_test.py +++ b/dynamax/parameters_test.py @@ -10,7 +10,7 @@ class InitialParams(NamedTuple): - probs: Union[Float[Array, "state_dim"], ParameterProperties] + probs: Union[Float[Array, " state_dim"], ParameterProperties] class TransitionsParams(NamedTuple): transition_matrix: Union[Float[Array, "state_dim state_dim"], ParameterProperties] diff --git a/dynamax/slds/inference.py b/dynamax/slds/inference.py index 3e8ee48f..6f7ea39e 100644 --- a/dynamax/slds/inference.py +++ b/dynamax/slds/inference.py @@ -6,10 +6,10 @@ from jaxtyping import Array, Float, Int from typing import NamedTuple, Optional from dynamax.utils.utils import psd_solve -from dynamax.types import PRNGKey +from dynamax.types import PRNGKeyT class DiscreteParamsSLDS(NamedTuple): - initial_distribution: Float[Array, "num_states"] + initial_distribution: Float[Array, " num_states"] transition_matrix : Float[Array, "num_states num_states"] proposal_transition_matrix : Float[Array, "num_states num_states"] @@ -164,7 +164,7 @@ def rbpfilter( num_particles: int, params: ParamsSLDS, emissions: Float[Array, "ntime emission_dim"], - key: PRNGKey = jr.PRNGKey(0), + key: PRNGKeyT = jr.PRNGKey(0), inputs: Optional[Float[Array, "ntime input_dim"]] = None, ess_threshold: float = 0.5 ): @@ -253,7 +253,7 @@ def rbpfilter_optimal( num_particles: int, params: ParamsSLDS, emissions: Float[Array, "ntime emission_dim"], - key: PRNGKey = jr.PRNGKey(0), + key: PRNGKeyT = jr.PRNGKey(0), inputs: Optional[Float[Array, "ntime input_dim"]]=None ): ''' @@ -336,4 +336,4 @@ def _step(carry, t): _, out = lax.scan(_step, carry, jnp.arange(num_timesteps)) - return out \ No newline at end of file + return out diff --git a/dynamax/slds/inference_test.py b/dynamax/slds/inference_test.py index 6986492e..5e592fe2 100644 --- a/dynamax/slds/inference_test.py +++ b/dynamax/slds/inference_test.py @@ -2,7 +2,6 @@ import jax.random as jr from dynamax.slds import SLDS, DiscreteParamsSLDS, LGParamsSLDS, ParamsSLDS, rbpfilter, rbpfilter_optimal from functools import partial -import matplotlib.pyplot as plt import dynamax.slds.mixture_kalman_filter_demo as kflib from functools import partial from jax.scipy.special import logit diff --git a/dynamax/slds/mixture_kalman_filter_demo.py b/dynamax/slds/mixture_kalman_filter_demo.py index 9302b913..18c0d7aa 100644 --- a/dynamax/slds/mixture_kalman_filter_demo.py +++ b/dynamax/slds/mixture_kalman_filter_demo.py @@ -3,11 +3,12 @@ # Author: Gerardo Durán-Martín (@gerdm) +from dataclasses import dataclass import jax -import jax.numpy as jnp from jax import random +import jax.numpy as jnp from jax.scipy.special import logit -from dataclasses import dataclass +from jaxtyping import Array, Float @dataclass @@ -24,12 +25,12 @@ class RBPFParamsDiscrete: noise1_next ~ N(0, Q) noise2_next ~ N(0, R) """ - A: jnp.array - B: jnp.array - C: jnp.array - Q: jnp.array - R: jnp.array - transition_matrix: jnp.array + A: Float[Array, "dim_hidden dim_hidden"] + B: Float[Array, "dim_hidden dim_control"] + C: Float[Array, "dim_emission dim_hidden"] + Q: Float[Array, "dim_hidden dim_hidden"] + R: Float[Array, "dim_emission dim_emission"] + transition_matrix: Float[Array, "dim_control dim_control"] def draw_state(val, key, params): @@ -42,7 +43,7 @@ def draw_state(val, key, params): ---------- val: tuple (int, jnp.array) (latent value of system, state value of system). - params: PRBPFParamsDiscrete + params: RBPFParamsDiscrete key: PRNGKey """ latent_old, state_old = val @@ -158,4 +159,4 @@ def rbpf_optimal(current_config, xt, params, nparticles=100): weights_t = jnp.ones(nparticles) / nparticles - return (key_next, mu_t, Sigma_t, weights_t, st), (mu_t, Sigma_t, weights_t, st, proposal_samp) \ No newline at end of file + return (key_next, mu_t, Sigma_t, weights_t, st), (mu_t, Sigma_t, weights_t, st, proposal_samp) diff --git a/dynamax/slds/models.py b/dynamax/slds/models.py index b0418796..52a0a6cf 100644 --- a/dynamax/slds/models.py +++ b/dynamax/slds/models.py @@ -1,27 +1,15 @@ -from fastprogress.fastprogress import progress_bar -from functools import partial -from jax import jit, lax +from jax import lax import jax.numpy as jnp import jax.random as jr from jax.tree_util import tree_map -from jaxtyping import Array, Float, PyTree +from jaxtyping import Array, Float import tensorflow_probability.substrates.jax.distributions as tfd from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN -from typing import Any, Optional, Tuple, Union -from typing_extensions import Protocol +from typing import Optional, Tuple from dynamax.ssm import SSM -from dynamax.linear_gaussian_ssm.models import LinearGaussianSSM -from dynamax.linear_gaussian_ssm.inference import lgssm_filter, lgssm_smoother, lgssm_posterior_sample from dynamax.slds.inference import ParamsSLDS -from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered, PosteriorGSSMSmoothed -from dynamax.parameters import ParameterProperties, ParameterSet -from dynamax.types import PRNGKey, Scalar -from dynamax.utils.bijectors import RealToPSDBijector -from dynamax.utils.distributions import MatrixNormalInverseWishart as MNIW -from dynamax.utils.distributions import NormalInverseWishart as NIW -from dynamax.utils.distributions import mniw_posterior_update, niw_posterior_update -from dynamax.utils.utils import pytree_stack, psd_solve +from dynamax.types import PRNGKeyT class SLDS(SSM): @@ -58,7 +46,7 @@ def transition_distribution( self, params: ParamsSLDS, dstate: int, - cstate: Float[Array, "state_dim"], + cstate: Float[Array, " state_dim"], inputs: Optional[Float[Array, "ntime input_dim"]]=None ) -> tfd.Distribution: params = params.linear_gaussian @@ -71,7 +59,7 @@ def emission_distribution( self, params: ParamsSLDS, dstate: int, - cstate: Float[Array, "state_dim"], + cstate: Float[Array, " state_dim"], inputs: Optional[Float[Array, "ntime input_dim"]]=None ) -> tfd.Distribution: params = params.linear_gaussian @@ -83,7 +71,7 @@ def emission_distribution( def sample( self, params: ParamsSLDS, - key: PRNGKey, + key: PRNGKeyT, num_timesteps: int, inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None ) -> Tuple[Float[Array, "num_timesteps state_dim"], @@ -130,4 +118,4 @@ def _step(prev_states, args): dstates = tree_map(expand_and_cat, initial_dstate, next_dstates) cstates = tree_map(expand_and_cat, initial_cstate, next_cstates) emissions = tree_map(expand_and_cat, initial_emission, next_emissions) - return dstates, cstates, emissions \ No newline at end of file + return dstates, cstates, emissions diff --git a/dynamax/ssm.py b/dynamax/ssm.py index 9ded8439..7b4ae2f8 100644 --- a/dynamax/ssm.py +++ b/dynamax/ssm.py @@ -6,23 +6,25 @@ import jax.random as jr from jax import jit, lax, vmap from jax.tree_util import tree_map -from jaxtyping import Float, Array, PyTree +from jaxtyping import Array, Float, Real import optax from tensorflow_probability.substrates.jax import distributions as tfd -from typing import Optional, Union, Tuple, Any +from typing import Optional, Union, Tuple, Any, runtime_checkable from typing_extensions import Protocol from dynamax.parameters import to_unconstrained, from_unconstrained from dynamax.parameters import ParameterSet, PropertySet -from dynamax.types import PRNGKey, Scalar +from dynamax.types import PRNGKeyT, Scalar from dynamax.utils.optimize import run_sgd from dynamax.utils.utils import ensure_array_has_batch_dim +@runtime_checkable class Posterior(Protocol): """A :class:`NamedTuple` with parameters stored as :class:`jax.DeviceArray` in the leaf nodes.""" pass +@runtime_checkable class SuffStatsSSM(Protocol): """A :class:`NamedTuple` with sufficient statics stored as :class:`jax.DeviceArray` in the leaf nodes.""" pass @@ -85,7 +87,7 @@ class SSM(ABC): def initial_distribution( self, params: ParameterSet, - inputs: Optional[Float[Array, "input_dim"]] + inputs: Optional[Float[Array, " input_dim"]] ) -> tfd.Distribution: r"""Return an initial distribution over latent states. @@ -103,8 +105,8 @@ def initial_distribution( def transition_distribution( self, params: ParameterSet, - state: Float[Array, "state_dim"], - inputs: Optional[Float[Array, "input_dim"]] + state: Float[Array, " state_dim"], + inputs: Optional[Float[Array, " input_dim"]] ) -> tfd.Distribution: r"""Return a distribution over next latent state given current state. @@ -123,8 +125,8 @@ def transition_distribution( def emission_distribution( self, params: ParameterSet, - state: Float[Array, "state_dim"], - inputs: Optional[Float[Array, "input_dim"]]=None + state: Float[Array, " state_dim"], + inputs: Optional[Float[Array, " input_dim"]]=None ) -> tfd.Distribution: r"""Return a distribution over emissions given current state. @@ -171,7 +173,7 @@ def inputs_shape(self) -> Optional[Tuple[int]]: def sample( self, params: ParameterSet, - key: PRNGKey, + key: PRNGKeyT, num_timesteps: int, inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None ) -> Tuple[Float[Array, "num_timesteps state_dim"], @@ -349,13 +351,13 @@ def fit_em( self, params: ParameterSet, props: PropertySet, - emissions: Union[Float[Array, "num_timesteps emission_dim"], - Float[Array, "num_batches num_timesteps emission_dim"]], + emissions: Union[Real[Array, "num_timesteps emission_dim"], + Real[Array, "num_batches num_timesteps emission_dim"]], inputs: Optional[Union[Float[Array, "num_timesteps input_dim"], Float[Array, "num_batches num_timesteps input_dim"]]]=None, num_iters: int=50, verbose: bool=True - ) -> Tuple[ParameterSet, Float[Array, "num_iters"]]: + ) -> Tuple[ParameterSet, Float[Array, " num_iters"]]: r"""Compute parameter MLE/ MAP estimate using Expectation-Maximization (EM). EM aims to find parameters that maximize the marginal log probability, @@ -412,8 +414,8 @@ def fit_sgd( batch_size: int=1, num_epochs: int=50, shuffle: bool=False, - key: PRNGKey=jr.PRNGKey(0) - ) -> Tuple[ParameterSet, Float[Array, "niter"]]: + key: PRNGKeyT=jr.PRNGKey(0) + ) -> Tuple[ParameterSet, Float[Array, " niter"]]: r"""Compute parameter MLE/ MAP estimate using Stochastic Gradient Descent (SGD). SGD aims to find parameters that maximize the marginal log probability, diff --git a/dynamax/types.py b/dynamax/types.py index 3fff53f3..267cf5aa 100644 --- a/dynamax/types.py +++ b/dynamax/types.py @@ -1,9 +1,8 @@ -from typing import Optional, Union -from typing_extensions import Protocol -from jaxtyping import Array, Float -import jax._src.random as prng +from typing import Union +from jaxtyping import Array, Float, Int - -PRNGKey = prng.KeyArray +PRNGKeyT = Array Scalar = Union[float, Float[Array, ""]] # python float or scalar jax device array with dtype float + +IntScalar = Union[int, Int[Array, ""]] diff --git a/dynamax/utils/distributions_test.py b/dynamax/utils/distributions_test.py index 88355e19..fc2f5d37 100644 --- a/dynamax/utils/distributions_test.py +++ b/dynamax/utils/distributions_test.py @@ -1,4 +1,3 @@ -import pytest import jax.numpy as jnp import jax.random as jr from jax.tree_util import tree_map diff --git a/dynamax/utils/utils.py b/dynamax/utils/utils.py index e75d6f01..a663ac31 100644 --- a/dynamax/utils/utils.py +++ b/dynamax/utils/utils.py @@ -8,7 +8,6 @@ import jaxlib from jaxtyping import Array, Int from scipy.optimize import linear_sum_assignment -from typing import Optional from jax.scipy.linalg import cho_factor, cho_solve def has_tpu(): @@ -44,7 +43,7 @@ def pad(seq, len): return dataset -def monotonically_increasing(x, atol=0, rtol=0): +def monotonically_increasing(x, atol=0., rtol=0.): thresh = atol + rtol*jnp.abs(x[:-1]) return jnp.all(jnp.diff(x) >= -thresh) @@ -56,7 +55,7 @@ def pytree_len(pytree): return len(tree_leaves(pytree)[0]) -def pytree_sum(pytree, axis=None, keepdims=None, where=None): +def pytree_sum(pytree, axis=None, keepdims=False, where=None): return tree_map(partial(jnp.sum, axis=axis, keepdims=keepdims, where=where), pytree) @@ -148,8 +147,8 @@ def _expand_dim(x, shp): def compute_state_overlap( - z1: Int[Array, "num_timesteps"], - z2: Int[Array, "num_timesteps"] + z1: Int[Array, " num_timesteps"], + z2: Int[Array, " num_timesteps"] ): """ Compute a matrix describing the state-wise overlap between two state vectors @@ -167,7 +166,7 @@ def compute_state_overlap( assert z1.shape == z2.shape assert z1.min() >= 0 and z2.min() >= 0 - K = max(z1.max(), z2.max()) + 1 + K = max(max(z1), max(z2)) + 1 overlap = jnp.sum( (z1[:, None] == jnp.arange(K))[:, :, None] @@ -178,8 +177,8 @@ def compute_state_overlap( def find_permutation( - z1: Int[Array, "num_timesteps"], - z2: Int[Array, "num_timesteps"] + z1: Int[Array, " num_timesteps"], + z2: Int[Array, " num_timesteps"] ): """ Find the permutation of the state labels in sequence ``z1`` so that they @@ -208,4 +207,4 @@ def psd_solve(A, b, diagonal_boost=1e-9): def symmetrize(A): """Symmetrize one or more matrices.""" - return 0.5 * (A + jnp.swapaxes(A, -1, -2)) \ No newline at end of file + return 0.5 * (A + jnp.swapaxes(A, -1, -2)) diff --git a/pyproject.toml b/pyproject.toml index 87a3b2c6..8d5ae8be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,4 +2,7 @@ requires = ["setuptools >= 30.3.0", "wheel"] [tool.black] -line-length = 120 \ No newline at end of file +line-length = 120 + +[tool.ruff.lint] +ignore = ["F722"]