diff --git a/dynamax/linear_gaussian_ssm/inference.py b/dynamax/linear_gaussian_ssm/inference.py index 18840bab..80a86ff2 100644 --- a/dynamax/linear_gaussian_ssm/inference.py +++ b/dynamax/linear_gaussian_ssm/inference.py @@ -7,7 +7,8 @@ from tensorflow_probability.substrates.jax.distributions import ( MultivariateNormalDiagPlusLowRankCovariance as MVNLowRank, - MultivariateNormalFullCovariance as MVN) + MultivariateNormalFullCovariance as MVN, +) from jax.tree_util import tree_map from jaxtyping import Array, Float @@ -16,15 +17,16 @@ from dynamax.parameters import ParameterProperties from dynamax.types import PRNGKey, Scalar + class ParamsLGSSMInitial(NamedTuple): r"""Parameters of the initial distribution - $$p(z_1) = \mathcal{N}(z_1 \mid \mu_1, Q_1)$$ + $$p(z_0) = \mathcal{N}(z_0 \mid \mu_0, Q_0)$$ The tuple doubles as a container for the ParameterProperties. - :param mean: $\mu_1$ - :param cov: $Q_1$ + :param mean: $\mu_0$ + :param cov: $Q_0$ """ mean: Union[Float[Array, "state_dim"], ParameterProperties] @@ -35,7 +37,7 @@ class ParamsLGSSMInitial(NamedTuple): class ParamsLGSSMDynamics(NamedTuple): r"""Parameters of the emission distribution - $$p(z_{t+1} \mid z_t, u_t) = \mathcal{N}(z_{t+1} \mid F z_t + B u_t + b, Q)$$ + $$p(z_{t+1} \mid z_t, u_{t+1}) = \mathcal{N}(z_{t+1} \mid F_{t+1} z_t + B_{t+1} u_{t+1} + b_{t+1}, Q_{t+1})$$ The tuple doubles as a container for the ParameterProperties. @@ -45,28 +47,36 @@ class ParamsLGSSMDynamics(NamedTuple): :param cov: dynamics covariance $Q$ """ - weights: Union[ParameterProperties, - Float[Array, "state_dim state_dim"], - Float[Array, "ntime state_dim state_dim"]] - - bias: Union[ParameterProperties, - Float[Array, "state_dim"], - Float[Array, "ntime state_dim"]] - - input_weights: Union[ParameterProperties, - Float[Array, "state_dim input_dim"], - Float[Array, "ntime state_dim input_dim"]] - - cov: Union[ParameterProperties, - Float[Array, "state_dim state_dim"], - Float[Array, "ntime state_dim state_dim"], - Float[Array, "state_dim_triu"]] + weights: Union[ + ParameterProperties, + Float[Array, "state_dim state_dim"], + Float[Array, "ntime state_dim state_dim"], + ] + + bias: Union[ + ParameterProperties, + Float[Array, "state_dim"], + Float[Array, "ntime state_dim"], + ] + + input_weights: Union[ + ParameterProperties, + Float[Array, "state_dim input_dim"], + Float[Array, "ntime state_dim input_dim"], + ] + + cov: Union[ + ParameterProperties, + Float[Array, "state_dim state_dim"], + Float[Array, "ntime state_dim state_dim"], + Float[Array, "state_dim_triu"], + ] class ParamsLGSSMEmissions(NamedTuple): r"""Parameters of the emission distribution - $$p(y_t \mid z_t, u_t) = \mathcal{N}(y_t \mid H z_t + D u_t + d, R)$$ + $$p(y_t \mid z_t, u_t) = \mathcal{N}(y_t \mid H_t z_t + D_t u_t + d_t, R_t)$$ The tuple doubles as a container for the ParameterProperties. @@ -76,24 +86,32 @@ class ParamsLGSSMEmissions(NamedTuple): :param cov: emission covariance $R$ """ - weights: Union[ParameterProperties, - Float[Array, "emission_dim state_dim"], - Float[Array, "ntime emission_dim state_dim"]] - - bias: Union[ParameterProperties, - Float[Array, "emission_dim"], - Float[Array, "ntime emission_dim"]] - - input_weights: Union[ParameterProperties, - Float[Array, "emission_dim input_dim"], - Float[Array, "ntime emission_dim input_dim"]] - - cov: Union[ParameterProperties, - Float[Array, "emission_dim emission_dim"], - Float[Array, "ntime emission_dim emission_dim"], - Float[Array, "emission_dim"], - Float[Array, "ntime emission_dim"], - Float[Array, "emission_dim_triu"]] + weights: Union[ + ParameterProperties, + Float[Array, "emission_dim state_dim"], + Float[Array, "ntime emission_dim state_dim"], + ] + + bias: Union[ + ParameterProperties, + Float[Array, "emission_dim"], + Float[Array, "ntime emission_dim"], + ] + + input_weights: Union[ + ParameterProperties, + Float[Array, "emission_dim input_dim"], + Float[Array, "ntime emission_dim input_dim"], + ] + + cov: Union[ + ParameterProperties, + Float[Array, "emission_dim emission_dim"], + Float[Array, "ntime emission_dim emission_dim"], + Float[Array, "emission_dim"], + Float[Array, "ntime emission_dim"], + Float[Array, "emission_dim_triu"], + ] class ParamsLGSSM(NamedTuple): @@ -145,6 +163,7 @@ class PosteriorGSSMSmoothed(NamedTuple): # Helper functions + def _get_one_param(x, dim, t): """Helper function to get one parameter at time t.""" if callable(x): @@ -154,6 +173,7 @@ def _get_one_param(x, dim, t): else: return x + def _get_params(params, num_timesteps, t): """Helper function to get parameters at time t.""" assert not callable(params.emissions.cov), "Emission covariance cannot be a callable." @@ -166,9 +186,9 @@ def _get_params(params, num_timesteps, t): D = _get_one_param(params.emissions.input_weights, 2, t) d = _get_one_param(params.emissions.bias, 1, t) - if len(params.emissions.cov.shape) == 1: + if len(params.emissions.cov.shape) == 1: R = _get_one_param(params.emissions.cov, 1, t) - elif len(params.emissions.cov.shape) > 2: + elif len(params.emissions.cov.shape) > 2: R = _get_one_param(params.emissions.cov, 2, t) elif params.emissions.cov.shape[0] != num_timesteps: R = _get_one_param(params.emissions.cov, 2, t) @@ -179,7 +199,8 @@ def _get_params(params, num_timesteps, t): warnings.warn( "Emission covariance has shape (N,N) where N is the number of timesteps. " "The covariance will be interpreted as static and non-diagonal. To " - "specify a dynamic and diagonal covariance, pass it as a 3D array.") + "specify a dynamic and diagonal covariance, pass it as a 3D array." + ) return F, B, b, Q, H, D, d, R @@ -187,39 +208,40 @@ def _get_params(params, num_timesteps, t): _zeros_if_none = lambda x, shape: x if x is not None else jnp.zeros(shape) -def make_lgssm_params(initial_mean, - initial_cov, - dynamics_weights, - dynamics_cov, - emissions_weights, - emissions_cov, - dynamics_bias=None, - dynamics_input_weights=None, - emissions_bias=None, - emissions_input_weights=None): +def make_lgssm_params( + initial_mean, + initial_cov, + dynamics_weights, + dynamics_cov, + emissions_weights, + emissions_cov, + dynamics_bias=None, + dynamics_input_weights=None, + emissions_bias=None, + emissions_input_weights=None, +): """Helper function to construct a ParamsLGSSM object from arguments.""" state_dim = len(initial_mean) emission_dim = emissions_cov.shape[-1] - input_dim = max(dynamics_input_weights.shape[-1] if dynamics_input_weights is not None else 0, - emissions_input_weights.shape[-1] if emissions_input_weights is not None else 0) + input_dim = max( + dynamics_input_weights.shape[-1] if dynamics_input_weights is not None else 0, + emissions_input_weights.shape[-1] if emissions_input_weights is not None else 0, + ) params = ParamsLGSSM( - initial=ParamsLGSSMInitial( - mean=initial_mean, - cov=initial_cov - ), + initial=ParamsLGSSMInitial(mean=initial_mean, cov=initial_cov), dynamics=ParamsLGSSMDynamics( weights=dynamics_weights, - bias=_zeros_if_none(dynamics_bias,state_dim), + bias=_zeros_if_none(dynamics_bias, state_dim), input_weights=_zeros_if_none(dynamics_input_weights, (state_dim, input_dim)), - cov=dynamics_cov + cov=dynamics_cov, ), emissions=ParamsLGSSMEmissions( weights=emissions_weights, bias=_zeros_if_none(emissions_bias, emission_dim), input_weights=_zeros_if_none(emissions_input_weights, (emission_dim, input_dim)), - cov=emissions_cov - ) + cov=emissions_cov, + ), ) return params @@ -227,8 +249,8 @@ def make_lgssm_params(initial_mean, def _predict(m, S, F, B, b, Q, u): r"""Predict next mean and covariance under a linear Gaussian model. - p(z_{t+1}) = int N(z_t \mid m, S) N(z_{t+1} \mid Fz_t + Bu + b, Q) - = N(z_{t+1} \mid Fm + Bu, F S F^T + Q) + p(z_{t+1}) = \int N(z_t \mid m_t, S_t) N(z_{t+1} \mid F_{t+1} z_t + B_{t+1} u_{t+1} + b_{t+1}, Q_{t+1}) d z_t + = N(z_{t+1} \mid F_{t+1} m_t + B_{t+1} u_{t+1} + b_{t+1}, F_{t+1} S_t F_{t+1}^T + Q_{t+1}) Args: m (D_hid,): prior mean. @@ -252,7 +274,7 @@ def _condition_on(m, P, H, D, d, R, u, y): r"""Condition a Gaussian potential on a new linear Gaussian observation p(z_t \mid y_t, u_t, y_{1:t-1}, u_{1:t-1}) propto p(z_t \mid y_{1:t-1}, u_{1:t-1}) p(y_t \mid z_t, u_t) - = N(z_t \mid m, P) N(y_t \mid H_t z_t + D_t u_t + d_t, R_t) + = N(z_t \mid m_t, P_t) N(y_t \mid H_t z_t + D_t u_t + d_t, R_t) = N(z_t \mid mm, PP) where mm = m + K*(y - yhat) = mu_cond @@ -278,20 +300,20 @@ def _condition_on(m, P, H, D, d, R, u, y): if R.ndim == 2: S = R + H @ P @ H.T K = psd_solve(S, H @ P).T - else: + else: # Optimization using Woodbury identity with A=R, U=H@chol(P), V=U.T, C=I # (see https://en.wikipedia.org/wiki/Woodbury_matrix_identity) I = jnp.eye(P.shape[0]) U = H @ jnp.linalg.cholesky(P) X = U / R[:, None] - S_inv = jnp.diag(1.0 / R) - X @ psd_solve(I + U.T @ X, X.T) + S_inv = jnp.diag(1.0 / R) - X @ psd_solve(I + U.T @ X, X.T) """ # Could alternatively use U=H and C=P R_inv = jnp.diag(1.0 / R) P_inv = psd_solve(P, jnp.eye(P.shape[0])) S_inv = R_inv - R_inv @ H @ psd_solve(P_inv + H.T @ R_inv @ H, H.T @ R_inv) """ - K = P @ H.T @ S_inv + K = P @ H.T @ S_inv S = jnp.diag(R) + H @ P @ H.T Sigma_cond = P - K @ S @ K.T @@ -324,20 +346,20 @@ def preprocess_params_and_inputs(params, num_timesteps, inputs): emissions_bias = _zeros_if_none(params.emissions.bias, (emission_dim,)) full_params = ParamsLGSSM( - initial=ParamsLGSSMInitial( - mean=params.initial.mean, - cov=params.initial.cov), + initial=ParamsLGSSMInitial(mean=params.initial.mean, cov=params.initial.cov), dynamics=ParamsLGSSMDynamics( weights=params.dynamics.weights, bias=dynamics_bias, input_weights=dynamics_input_weights, - cov=params.dynamics.cov), + cov=params.dynamics.cov, + ), emissions=ParamsLGSSMEmissions( weights=params.emissions.weights, bias=emissions_bias, input_weights=emissions_input_weights, - cov=params.emissions.cov) - ) + cov=params.emissions.cov, + ), + ) return full_params, inputs @@ -350,28 +372,26 @@ def wrapper(*args, **kwargs): # Extract the arguments by name bound_args = sig.bind(*args, **kwargs) bound_args.apply_defaults() - params = bound_args.arguments['params'] - emissions = bound_args.arguments['emissions'] - inputs = bound_args.arguments['inputs'] + params = bound_args.arguments["params"] + emissions = bound_args.arguments["emissions"] + inputs = bound_args.arguments["inputs"] num_timesteps = len(emissions) full_params, inputs = preprocess_params_and_inputs(params, num_timesteps, inputs) return f(full_params, emissions, inputs=inputs) - return wrapper - + return wrapper def lgssm_joint_sample( params: ParamsLGSSM, key: PRNGKey, num_timesteps: int, - inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None -)-> Tuple[Float[Array, "num_timesteps state_dim"], - Float[Array, "num_timesteps emission_dim"]]: + inputs: Optional[Float[Array, "num_timesteps input_dim"]] = None, +) -> Tuple[Float[Array, "num_timesteps state_dim"], Float[Array, "num_timesteps emission_dim"]]: r"""Sample from the joint distribution to produce state and emission trajectories. - + Args: params: model parameters inputs: optional array of inputs. @@ -382,15 +402,15 @@ def lgssm_joint_sample( """ params, inputs = preprocess_params_and_inputs(params, num_timesteps, inputs) - def _sample_transition(key, F, B, b, Q, x_tm1, u): - mean = F @ x_tm1 + B @ u + b + def _sample_transition(key, F, B, b, Q, x_prev, u): + mean = F @ x_prev + B @ u + b return MVN(mean, Q).sample(seed=key) def _sample_emission(key, H, D, d, R, x, u): mean = H @ x + D @ u + d - R = jnp.diag(R) if R.ndim==1 else R + R = jnp.diag(R) if R.ndim == 1 else R return MVN(mean, R).sample(seed=key) - + def _sample_initial(key, params, inputs): key1, key2 = jr.split(key) @@ -402,7 +422,7 @@ def _sample_initial(key, params, inputs): initial_emission = _sample_emission(key2, H0, D0, d0, R0, initial_state, u0) return initial_state, initial_emission - def _step(prev_state, args): + def _step(state_prev, args): key, t, inpt = args key1, key2 = jr.split(key, 2) @@ -410,14 +430,14 @@ def _step(prev_state, args): F, B, b, Q, H, D, d, R = _get_params(params, num_timesteps, t) # Sample from transition and emission distributions - state = _sample_transition(key1, F, B, b, Q, prev_state, inpt) + state = _sample_transition(key1, F, B, b, Q, state_prev, inpt) emission = _sample_emission(key2, H, D, d, R, state, inpt) return state, (state, emission) # Sample the initial state key1, key2 = jr.split(key) - + initial_state, initial_emission = _sample_initial(key1, params, inputs) # Sample the remaining emissions and states @@ -434,11 +454,21 @@ def _step(prev_state, args): return states, emissions +def _log_likelihood(pred_mean, pred_cov, H, D, d, R, u, y): + m = H @ pred_mean + D @ u + d + if R.ndim == 2: + S = R + H @ pred_cov @ H.T + return MVN(m, S).log_prob(y) + else: + L = H @ jnp.linalg.cholesky(pred_cov) + return MVNLowRank(m, R, L).log_prob(y) + + @preprocess_args def lgssm_filter( params: ParamsLGSSM, - emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + emissions: Float[Array, "ntime emission_dim"], + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> PosteriorGSSMFiltered: r"""Run a Kalman filter to produce the marginal likelihood and filtered state estimates. @@ -454,21 +484,11 @@ def lgssm_filter( num_timesteps = len(emissions) inputs = jnp.zeros((num_timesteps, 0)) if inputs is None else inputs - def _log_likelihood(pred_mean, pred_cov, H, D, d, R, u, y): - m = H @ pred_mean + D @ u + d - if R.ndim==2: - S = R + H @ pred_cov @ H.T - return MVN(m, S).log_prob(y) - else: - L = H @ jnp.linalg.cholesky(pred_cov) - return MVNLowRank(m, R, L).log_prob(y) - - def _step(carry, t): ll, pred_mean, pred_cov = carry # Shorthand: get parameters and inputs for time index t - F, B, b, Q, H, D, d, R = _get_params(params, num_timesteps, t) + _, _, _, _, H, D, d, R = _get_params(params, num_timesteps, t) u = inputs[t] y = emissions[t] @@ -479,13 +499,21 @@ def _step(carry, t): filtered_mean, filtered_cov = _condition_on(pred_mean, pred_cov, H, D, d, R, u, y) # Predict the next state - pred_mean, pred_cov = _predict(filtered_mean, filtered_cov, F, B, b, Q, u) + F_next, B_next, b_next, Q_next, _, _, _, _ = _get_params( + params, + num_timesteps, + (t + 1) % num_timesteps, # No update required (or possible) in the last time step so any params are fine. + ) + u_next = inputs[t + 1] + + pred_mean, pred_cov = _predict(filtered_mean, filtered_cov, F_next, B_next, b_next, Q_next, u_next) return (ll, pred_mean, pred_cov), (filtered_mean, filtered_cov) # Run the Kalman filter carry = (0.0, params.initial.mean, params.initial.cov) (ll, _, _), (filtered_means, filtered_covs) = lax.scan(_step, carry, jnp.arange(num_timesteps)) + return PosteriorGSSMFiltered(marginal_loglik=ll, filtered_means=filtered_means, filtered_covariances=filtered_covs) @@ -493,7 +521,7 @@ def _step(carry, t): def lgssm_smoother( params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> PosteriorGSSMSmoothed: r"""Run forward-filtering, backward-smoother to compute expectations under the posterior distribution on latent states. Technically, this @@ -521,17 +549,17 @@ def _step(carry, args): smoothed_mean_next, smoothed_cov_next = carry t, filtered_mean, filtered_cov = args - # Get parameters and inputs for time index t - F, B, b, Q = _get_params(params, num_timesteps, t)[:4] - u = inputs[t] + # Get parameters and inputs for time index t + 1 + F_next, B_next, b_next, Q_next = _get_params(params, num_timesteps, t + 1)[:4] + u_next = inputs[t + 1] # This is like the Kalman gain but in reverse # See Eq 8.11 of Saarka's "Bayesian Filtering and Smoothing" - G = psd_solve(Q + F @ filtered_cov @ F.T, F @ filtered_cov).T + G = psd_solve(Q_next + F_next @ filtered_cov @ F_next.T, F_next @ filtered_cov).T # Compute the smoothed mean and covariance - smoothed_mean = filtered_mean + G @ (smoothed_mean_next - F @ filtered_mean - B @ u - b) - smoothed_cov = filtered_cov + G @ (smoothed_cov_next - F @ filtered_cov @ F.T - Q) @ G.T + smoothed_mean = filtered_mean + G @ (smoothed_mean_next - F_next @ filtered_mean - B_next @ u_next - b_next) + smoothed_cov = filtered_cov + G @ (smoothed_cov_next - F_next @ filtered_cov @ F_next.T - Q_next) @ G.T # Compute the smoothed expectation of z_t z_{t+1}^T smoothed_cross = G @ smoothed_cov_next + jnp.outer(smoothed_mean, smoothed_mean_next) @@ -560,10 +588,9 @@ def _step(carry, args): def lgssm_posterior_sample( key: PRNGKey, params: ParamsLGSSM, - emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None, - jitter: Optional[Scalar]=0 - + emissions: Float[Array, "ntime emission_dim"], + inputs: Optional[Float[Array, "ntime input_dim"]] = None, + jitter: Optional[Scalar] = 0, ) -> Float[Array, "ntime state_dim"]: r"""Run forward-filtering, backward-sampling to draw samples from $p(z_{1:T} \mid y_{1:T}, u_{1:T})$. @@ -590,11 +617,13 @@ def _step(carry, args): key, filtered_mean, filtered_cov, t = args # Shorthand: get parameters and inputs for time index t - F, B, b, Q = _get_params(params, num_timesteps, t)[:4] - u = inputs[t] + F_next, B_next, b_next, Q_next = _get_params(params, num_timesteps, t + 1)[:4] + u_next = inputs[t + 1] # Condition on next state - smoothed_mean, smoothed_cov = _condition_on(filtered_mean, filtered_cov, F, B, b, Q, u, next_state) + smoothed_mean, smoothed_cov = _condition_on( + filtered_mean, filtered_cov, F_next, B_next, b_next, Q_next, u_next, next_state + ) smoothed_cov = smoothed_cov + jnp.eye(smoothed_cov.shape[-1]) * jitter state = MVN(smoothed_mean, smoothed_cov).sample(seed=key) return state, state diff --git a/dynamax/linear_gaussian_ssm/models.py b/dynamax/linear_gaussian_ssm/models.py index 8e88c6bf..ed73e431 100644 --- a/dynamax/linear_gaussian_ssm/models.py +++ b/dynamax/linear_gaussian_ssm/models.py @@ -11,8 +11,18 @@ from typing_extensions import Protocol from dynamax.ssm import SSM -from dynamax.linear_gaussian_ssm.inference import lgssm_filter, lgssm_smoother, lgssm_posterior_sample -from dynamax.linear_gaussian_ssm.inference import ParamsLGSSM, ParamsLGSSMInitial, ParamsLGSSMDynamics, ParamsLGSSMEmissions +from dynamax.linear_gaussian_ssm.inference import lgssm_filter as serial_lgssm_filter +from dynamax.linear_gaussian_ssm.inference import lgssm_smoother as serial_lgssm_smoother +from dynamax.linear_gaussian_ssm.inference import lgssm_posterior_sample as serial_lgssm_posterior_sample +from dynamax.linear_gaussian_ssm.parallel_inference import lgssm_filter as parallel_lgssm_filter +from dynamax.linear_gaussian_ssm.parallel_inference import lgssm_smoother as parallel_lgssm_smoother +from dynamax.linear_gaussian_ssm.parallel_inference import lgssm_posterior_sample as parallel_lgssm_posterior_sample +from dynamax.linear_gaussian_ssm.inference import ( + ParamsLGSSM, + ParamsLGSSMInitial, + ParamsLGSSMDynamics, + ParamsLGSSMEmissions, +) from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered, PosteriorGSSMSmoothed from dynamax.parameters import ParameterProperties, ParameterSet from dynamax.types import PRNGKey, Scalar @@ -22,8 +32,10 @@ from dynamax.utils.distributions import mniw_posterior_update, niw_posterior_update from dynamax.utils.utils import pytree_stack, psd_solve + class SuffStatsLGSSM(Protocol): """A :class:`NamedTuple` with sufficient statistics for LGSSM parameter estimation.""" + pass @@ -33,7 +45,7 @@ class LinearGaussianSSM(SSM): The model is defined as follows - $$p(z_1) = \mathcal{N}(z_1 \mid m, S)$$ + $$p(z_0) = \mathcal{N}(z_0 \mid m, S)$$ $$p(z_t \mid z_{t-1}, u_t) = \mathcal{N}(z_t \mid F_t z_{t-1} + B_t u_t + b_t, Q_t)$$ $$p(y_t \mid z_t) = \mathcal{N}(y_t \mid H_t z_t + D_t u_t + d_t, R_t)$$ @@ -56,26 +68,37 @@ class LinearGaussianSSM(SSM): The parameters of the model are stored in a :class:`ParamsLGSSM`. You can create the parameters manually, or by calling :meth:`initialize`. + Please note that we adopt the convention of Murphy, K. P. (2022), "Probabilistic machine learning: Advanced topics", + rather than Särkkä, S. (2013), "Bayesian Filtering and Smoothing" for indexing parameters of LGSSM, where we start + initial index at 0 instead of 1, which is not exactly in line with the former book. This tends to be a source of + confusion sometimes. As such, $F_0$, $B_0$, $b_0$, $Q_0$ are always ignored and the prior specified by $m$ and $S$ + is used as the distribution of the initial state. + :param state_dim: Dimensionality of latent state. :param emission_dim: Dimensionality of observation vector. :param input_dim: Dimensionality of input vector. Defaults to 0. :param has_dynamics_bias: Whether model contains an offset term $b$. Defaults to True. :param has_emissions_bias: Whether model contains an offset term $d$. Defaults to True. + :param use_parallel_inference: Whether parallel algorithms are used in filtering, smoothing + and sampling instead of sequential ones. Defaults to False. """ + def __init__( self, state_dim: int, emission_dim: int, - input_dim: int=0, - has_dynamics_bias: bool=True, - has_emissions_bias: bool=True + input_dim: int = 0, + has_dynamics_bias: bool = True, + has_emissions_bias: bool = True, + use_parallel_inference: bool = False, ): self.state_dim = state_dim self.emission_dim = emission_dim self.input_dim = input_dim self.has_dynamics_bias = has_dynamics_bias self.has_emissions_bias = has_emissions_bias + self.use_parallel_inference = use_parallel_inference @property def emission_shape(self): @@ -87,8 +110,8 @@ def inputs_shape(self): def initialize( self, - key: PRNGKey =jr.PRNGKey(0), - initial_mean: Optional[Float[Array, "state_dim"]]=None, + key: PRNGKey = jr.PRNGKey(0), + initial_mean: Optional[Float[Array, "state_dim"]] = None, initial_covariance=None, dynamics_weights=None, dynamics_bias=None, @@ -97,7 +120,7 @@ def initialize( emission_weights=None, emission_bias=None, emission_input_weights=None, - emission_covariance=None + emission_covariance=None, ) -> Tuple[ParamsLGSSM, ParamsLGSSM]: r"""Initialize model parameters that are set to None, and their corresponding properties. @@ -137,41 +160,47 @@ def initialize( params = ParamsLGSSM( initial=ParamsLGSSMInitial( mean=default(initial_mean, _initial_mean), - cov=default(initial_covariance, _initial_covariance)), + cov=default(initial_covariance, _initial_covariance), + ), dynamics=ParamsLGSSMDynamics( weights=default(dynamics_weights, _dynamics_weights), bias=default(dynamics_bias, _dynamics_bias), input_weights=default(dynamics_input_weights, _dynamics_input_weights), - cov=default(dynamics_covariance, _dynamics_covariance)), + cov=default(dynamics_covariance, _dynamics_covariance), + ), emissions=ParamsLGSSMEmissions( weights=default(emission_weights, _emission_weights), bias=default(emission_bias, _emission_bias), input_weights=default(emission_input_weights, _emission_input_weights), - cov=default(emission_covariance, _emission_covariance)) - ) + cov=default(emission_covariance, _emission_covariance), + ), + ) # The keys of param_props must match those of params! props = ParamsLGSSM( initial=ParamsLGSSMInitial( mean=ParameterProperties(), - cov=ParameterProperties(constrainer=RealToPSDBijector())), + cov=ParameterProperties(constrainer=RealToPSDBijector()), + ), dynamics=ParamsLGSSMDynamics( weights=ParameterProperties(), bias=ParameterProperties(), input_weights=ParameterProperties(), - cov=ParameterProperties(constrainer=RealToPSDBijector())), + cov=ParameterProperties(constrainer=RealToPSDBijector()), + ), emissions=ParamsLGSSMEmissions( weights=ParameterProperties(), bias=ParameterProperties(), input_weights=ParameterProperties(), - cov=ParameterProperties(constrainer=RealToPSDBijector())) - ) + cov=ParameterProperties(constrainer=RealToPSDBijector()), + ), + ) return params, props def initial_distribution( self, params: ParamsLGSSM, - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> tfd.Distribution: return MVN(params.initial.mean, params.initial.cov) @@ -179,7 +208,7 @@ def transition_distribution( self, params: ParamsLGSSM, state: Float[Array, "state_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> tfd.Distribution: inputs = inputs if inputs is not None else jnp.zeros(self.input_dim) mean = params.dynamics.weights @ state + params.dynamics.input_weights @ inputs @@ -191,7 +220,7 @@ def emission_distribution( self, params: ParamsLGSSM, state: Float[Array, "state_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> tfd.Distribution: inputs = inputs if inputs is not None else jnp.zeros(self.input_dim) mean = params.emissions.weights @ state + params.emissions.input_weights @ inputs @@ -203,41 +232,50 @@ def marginal_log_prob( self, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]] = None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> Scalar: - filtered_posterior = lgssm_filter(params, emissions, inputs) + filtered_posterior = self.filter(params, emissions, inputs) return filtered_posterior.marginal_loglik def filter( self, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]] = None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> PosteriorGSSMFiltered: - return lgssm_filter(params, emissions, inputs) + if self.use_parallel_inference: + return parallel_lgssm_filter(params, emissions, inputs) + else: + return serial_lgssm_filter(params, emissions, inputs) def smoother( self, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]] = None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> PosteriorGSSMSmoothed: - return lgssm_smoother(params, emissions, inputs) + if self.use_parallel_inference: + return parallel_lgssm_smoother(params, emissions, inputs) + else: + return serial_lgssm_smoother(params, emissions, inputs) def posterior_sample( self, key: PRNGKey, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> Float[Array, "ntime state_dim"]: - return lgssm_posterior_sample(key, params, emissions, inputs) + if self.use_parallel_inference: + return parallel_lgssm_posterior_sample(key, params, emissions, inputs) + else: + return serial_lgssm_posterior_sample(key, params, emissions, inputs) def posterior_predictive( self, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> Tuple[Float[Array, "ntime emission_dim"], Float[Array, "ntime emission_dim"]]: r"""Compute marginal posterior predictive smoothing distribution for each observation. @@ -250,32 +288,37 @@ def posterior_predictive( :posterior predictive means $\mathbb{E}[y_{t,d} \mid y_{1:T}]$ and standard deviations $\mathrm{std}[y_{t,d} \mid y_{1:T}]$ """ - posterior = lgssm_smoother(params, emissions, inputs) + posterior = self.smoother(params, emissions, inputs) H = params.emissions.weights b = params.emissions.bias R = params.emissions.cov emission_dim = R.shape[0] smoothed_emissions = posterior.smoothed_means @ H.T + b smoothed_emissions_cov = H @ posterior.smoothed_covariances @ H.T + R - smoothed_emissions_std = jnp.sqrt( - jnp.array([smoothed_emissions_cov[:, i, i] for i in range(emission_dim)])) + smoothed_emissions_std = jnp.sqrt(jnp.array([smoothed_emissions_cov[:, i, i] for i in range(emission_dim)])) return smoothed_emissions, smoothed_emissions_std # Expectation-maximization (EM) code def e_step( self, params: ParamsLGSSM, - emissions: Union[Float[Array, "num_timesteps emission_dim"], - Float[Array, "num_batches num_timesteps emission_dim"]], - inputs: Optional[Union[Float[Array, "num_timesteps input_dim"], - Float[Array, "num_batches num_timesteps input_dim"]]]=None, + emissions: Union[ + Float[Array, "num_timesteps emission_dim"], + Float[Array, "num_batches num_timesteps emission_dim"], + ], + inputs: Optional[ + Union[ + Float[Array, "num_timesteps input_dim"], + Float[Array, "num_batches num_timesteps input_dim"], + ] + ] = None, ) -> Tuple[SuffStatsLGSSM, Scalar]: num_timesteps = emissions.shape[0] if inputs is None: inputs = jnp.zeros((num_timesteps, 0)) # Run the smoother to get posterior expectations - posterior = lgssm_smoother(params, emissions, inputs) + posterior = self.smoother(params, emissions, inputs) # shorthand Ex = posterior.smoothed_means @@ -301,18 +344,17 @@ def e_step( # let zp[t] = [x[t], u[t]] for t = 0...T-2 # let xn[t] = x[t+1] for t = 0...T-2 sum_zpzpT = jnp.block([[Exp.T @ Exp, Exp.T @ up], [up.T @ Exp, up.T @ up]]) - sum_zpzpT = sum_zpzpT.at[:self.state_dim, :self.state_dim].add(Vxp.sum(0)) + sum_zpzpT = sum_zpzpT.at[: self.state_dim, : self.state_dim].add(Vxp.sum(0)) sum_zpxnT = jnp.block([[Expxn.sum(0)], [up.T @ Exn]]) sum_xnxnT = Vxn.sum(0) + Exn.T @ Exn dynamics_stats = (sum_zpzpT, sum_zpxnT, sum_xnxnT, num_timesteps - 1) if not self.has_dynamics_bias: - dynamics_stats = (sum_zpzpT[:-1, :-1], sum_zpxnT[:-1, :], sum_xnxnT, - num_timesteps - 1) + dynamics_stats = (sum_zpzpT[:-1, :-1], sum_zpxnT[:-1, :], sum_xnxnT, num_timesteps - 1) # more expected sufficient statistics for the emissions # let z[t] = [x[t], u[t]] for t = 0...T-1 sum_zzT = jnp.block([[Ex.T @ Ex, Ex.T @ u], [u.T @ Ex, u.T @ u]]) - sum_zzT = sum_zzT.at[:self.state_dim, :self.state_dim].add(Vx.sum(0)) + sum_zzT = sum_zzT.at[: self.state_dim, : self.state_dim].add(Vx.sum(0)) sum_zyT = jnp.block([[Ex.T @ y], [u.T @ y]]) sum_yyT = emissions.T @ emissions emission_stats = (sum_zzT, sum_zyT, sum_yyT, num_timesteps) @@ -321,22 +363,12 @@ def e_step( return (init_stats, dynamics_stats, emission_stats), posterior.marginal_loglik - - def initialize_m_step_state( - self, - params: ParamsLGSSM, - props: ParamsLGSSM - ) -> Any: + def initialize_m_step_state(self, params: ParamsLGSSM, props: ParamsLGSSM) -> Any: return None def m_step( - self, - params: ParamsLGSSM, - props: ParamsLGSSM, - batch_stats: SuffStatsLGSSM, - m_step_state: Any + self, params: ParamsLGSSM, props: ParamsLGSSM, batch_stats: SuffStatsLGSSM, m_step_state: Any ) -> Tuple[ParamsLGSSM, Any]: - def fit_linear_regression(ExxT, ExyT, EyyT, N): # Solve a linear regression given sufficient statistics W = psd_solve(ExxT, ExyT).T @@ -353,19 +385,17 @@ def fit_linear_regression(ExxT, ExyT, EyyT, N): m = sum_x0 / N FB, Q = fit_linear_regression(*dynamics_stats) - F = FB[:, :self.state_dim] - B, b = (FB[:, self.state_dim:-1], FB[:, -1]) if self.has_dynamics_bias \ - else (FB[:, self.state_dim:], None) + F = FB[:, : self.state_dim] + B, b = (FB[:, self.state_dim : -1], FB[:, -1]) if self.has_dynamics_bias else (FB[:, self.state_dim :], None) HD, R = fit_linear_regression(*emission_stats) - H = HD[:, :self.state_dim] - D, d = (HD[:, self.state_dim:-1], HD[:, -1]) if self.has_emissions_bias \ - else (HD[:, self.state_dim:], None) + H = HD[:, : self.state_dim] + D, d = (HD[:, self.state_dim : -1], HD[:, -1]) if self.has_emissions_bias else (HD[:, self.state_dim :], None) params = ParamsLGSSM( initial=ParamsLGSSMInitial(mean=m, cov=S), dynamics=ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q), - emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R) + emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R), ) return params, m_step_state @@ -387,40 +417,59 @@ class LinearGaussianConjugateSSM(LinearGaussianSSM): :param has_emissions_bias: Whether model contains an offset term d. Defaults to True. """ - def __init__(self, - state_dim, - emission_dim, - input_dim=0, - has_dynamics_bias=True, - has_emissions_bias=True, - **kw_priors): - super().__init__(state_dim=state_dim, emission_dim=emission_dim, input_dim=input_dim, - has_dynamics_bias=has_dynamics_bias, has_emissions_bias=has_emissions_bias) + + def __init__( + self, + state_dim, + emission_dim, + input_dim=0, + has_dynamics_bias=True, + has_emissions_bias=True, + use_parallel_inference=False, + **kw_priors, + ): + super().__init__( + state_dim=state_dim, + emission_dim=emission_dim, + input_dim=input_dim, + has_dynamics_bias=has_dynamics_bias, + has_emissions_bias=has_emissions_bias, + use_parallel_inference=use_parallel_inference, + ) # Initialize prior distributions def default_prior(arg, default): return kw_priors[arg] if arg in kw_priors else default self.initial_prior = default_prior( - 'initial_prior', - NIW(loc=jnp.zeros(self.state_dim), - mean_concentration=1., + "initial_prior", + NIW( + loc=jnp.zeros(self.state_dim), + mean_concentration=1.0, df=self.state_dim + 0.1, - scale=jnp.eye(self.state_dim))) + scale=jnp.eye(self.state_dim), + ), + ) self.dynamics_prior = default_prior( - 'dynamics_prior', - MNIW(loc=jnp.zeros((self.state_dim, self.state_dim + self.input_dim + self.has_dynamics_bias)), - col_precision=jnp.eye(self.state_dim + self.input_dim + self.has_dynamics_bias), - df=self.state_dim + 0.1, - scale=jnp.eye(self.state_dim))) + "dynamics_prior", + MNIW( + loc=jnp.zeros((self.state_dim, self.state_dim + self.input_dim + self.has_dynamics_bias)), + col_precision=jnp.eye(self.state_dim + self.input_dim + self.has_dynamics_bias), + df=self.state_dim + 0.1, + scale=jnp.eye(self.state_dim), + ), + ) self.emission_prior = default_prior( - 'emission_prior', - MNIW(loc=jnp.zeros((self.emission_dim, self.state_dim + self.input_dim + self.has_emissions_bias)), - col_precision=jnp.eye(self.state_dim + self.input_dim + self.has_emissions_bias), - df=self.emission_dim + 0.1, - scale=jnp.eye(self.emission_dim))) + "emission_prior", + MNIW( + loc=jnp.zeros((self.emission_dim, self.state_dim + self.input_dim + self.has_emissions_bias)), + col_precision=jnp.eye(self.state_dim + self.input_dim + self.has_emissions_bias), + df=self.emission_dim + 0.1, + scale=jnp.eye(self.emission_dim), + ), + ) @property def emission_shape(self): @@ -430,39 +479,23 @@ def emission_shape(self): def covariates_shape(self): return dict(inputs=(self.input_dim,)) if self.input_dim > 0 else dict() - def log_prior( - self, - params: ParamsLGSSM - ) -> Scalar: + def log_prior(self, params: ParamsLGSSM) -> Scalar: lp = self.initial_prior.log_prob((params.initial.cov, params.initial.mean)) # dynamics dynamics_bias = params.dynamics.bias if self.has_dynamics_bias else jnp.zeros((self.state_dim, 0)) - dynamics_matrix = jnp.column_stack((params.dynamics.weights, - params.dynamics.input_weights, - dynamics_bias)) + dynamics_matrix = jnp.column_stack((params.dynamics.weights, params.dynamics.input_weights, dynamics_bias)) lp += self.dynamics_prior.log_prob((params.dynamics.cov, dynamics_matrix)) emission_bias = params.emissions.bias if self.has_emissions_bias else jnp.zeros((self.emission_dim, 0)) - emission_matrix = jnp.column_stack((params.emissions.weights, - params.emissions.input_weights, - emission_bias)) + emission_matrix = jnp.column_stack((params.emissions.weights, params.emissions.input_weights, emission_bias)) lp += self.emission_prior.log_prob((params.emissions.cov, emission_matrix)) return lp - def initialize_m_step_state( - self, - params: ParamsLGSSM, - props: ParamsLGSSM - ) -> Any: + def initialize_m_step_state(self, params: ParamsLGSSM, props: ParamsLGSSM) -> Any: return None - def m_step( - self, - params: ParamsLGSSM, - props: ParamsLGSSM, - batch_stats: SuffStatsLGSSM, - m_step_state: Any): + def m_step(self, params: ParamsLGSSM, props: ParamsLGSSM, batch_stats: SuffStatsLGSSM, m_step_state: Any): # Sum the statistics across all batches stats = tree_map(partial(jnp.sum, axis=0), batch_stats) init_stats, dynamics_stats, emission_stats = stats @@ -473,20 +506,26 @@ def m_step( dynamics_posterior = mniw_posterior_update(self.dynamics_prior, dynamics_stats) Q, FB = dynamics_posterior.mode() - F = FB[:, :self.state_dim] - B, b = (FB[:, self.state_dim:-1], FB[:, -1]) if self.has_dynamics_bias \ - else (FB[:, self.state_dim:], jnp.zeros(self.state_dim)) + F = FB[:, : self.state_dim] + B, b = ( + (FB[:, self.state_dim : -1], FB[:, -1]) + if self.has_dynamics_bias + else (FB[:, self.state_dim :], jnp.zeros(self.state_dim)) + ) emission_posterior = mniw_posterior_update(self.emission_prior, emission_stats) R, HD = emission_posterior.mode() - H = HD[:, :self.state_dim] - D, d = (HD[:, self.state_dim:-1], HD[:, -1]) if self.has_emissions_bias \ - else (HD[:, self.state_dim:], jnp.zeros(self.emission_dim)) + H = HD[:, : self.state_dim] + D, d = ( + (HD[:, self.state_dim : -1], HD[:, -1]) + if self.has_emissions_bias + else (HD[:, self.state_dim :], jnp.zeros(self.emission_dim)) + ) params = ParamsLGSSM( initial=ParamsLGSSMInitial(mean=m, cov=S), dynamics=ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q), - emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R) + emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R), ) return params, m_step_state @@ -496,7 +535,7 @@ def fit_blocked_gibbs( initial_params: ParamsLGSSM, sample_size: int, emissions: Float[Array, "nbatch ntime emission_dim"], - inputs: Optional[Float[Array, "nbatch ntime input_dim"]]=None + inputs: Optional[Float[Array, "nbatch ntime input_dim"]] = None, ) -> ParamsLGSSM: r"""Estimate parameter posterior using block-Gibbs sampler. @@ -532,8 +571,7 @@ def sufficient_stats_from_sample(states): sum_xnxnT = xn.T @ xn dynamics_stats = (sum_zpzpT, sum_zpxnT, sum_xnxnT, num_timesteps - 1) if not self.has_dynamics_bias: - dynamics_stats = (sum_zpzpT[:-1, :-1], sum_zpxnT[:-1, :], sum_xnxnT, - num_timesteps - 1) + dynamics_stats = (sum_zpzpT[:-1, :-1], sum_zpxnT[:-1, :], sum_xnxnT, num_timesteps - 1) # Quantities for the emissions # Let z[t] = [x[t], u[t]] for t = 0...T-1 @@ -558,21 +596,27 @@ def lgssm_params_sample(rng, stats): # Sample the dynamics params dynamics_posterior = mniw_posterior_update(self.dynamics_prior, dynamics_stats) Q, FB = dynamics_posterior.sample(seed=next(rngs)) - F = FB[:, :self.state_dim] - B, b = (FB[:, self.state_dim:-1], FB[:, -1]) if self.has_dynamics_bias \ - else (FB[:, self.state_dim:], jnp.zeros(self.state_dim)) + F = FB[:, : self.state_dim] + B, b = ( + (FB[:, self.state_dim : -1], FB[:, -1]) + if self.has_dynamics_bias + else (FB[:, self.state_dim :], jnp.zeros(self.state_dim)) + ) # Sample the emission params emission_posterior = mniw_posterior_update(self.emission_prior, emission_stats) R, HD = emission_posterior.sample(seed=next(rngs)) - H = HD[:, :self.state_dim] - D, d = (HD[:, self.state_dim:-1], HD[:, -1]) if self.has_emissions_bias \ - else (HD[:, self.state_dim:], jnp.zeros(self.emission_dim)) + H = HD[:, : self.state_dim] + D, d = ( + (HD[:, self.state_dim : -1], HD[:, -1]) + if self.has_emissions_bias + else (HD[:, self.state_dim :], jnp.zeros(self.emission_dim)) + ) params = ParamsLGSSM( initial=ParamsLGSSMInitial(mean=m, cov=S), dynamics=ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q), - emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R) + emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R), ) return params @@ -580,12 +624,11 @@ def lgssm_params_sample(rng, stats): def one_sample(_params, rng): rngs = jr.split(rng, 2) # Sample latent states - states = lgssm_posterior_sample(rngs[0], _params, emissions, inputs) + states = self.posterior_sample(rngs[0], _params, emissions, inputs) # Sample parameters _stats = sufficient_stats_from_sample(states) return lgssm_params_sample(rngs[1], _stats) - sample_of_params = [] keys = iter(jr.split(key, sample_size)) current_params = initial_params diff --git a/dynamax/linear_gaussian_ssm/models_test.py b/dynamax/linear_gaussian_ssm/models_test.py index c4394858..1cee7358 100644 --- a/dynamax/linear_gaussian_ssm/models_test.py +++ b/dynamax/linear_gaussian_ssm/models_test.py @@ -8,14 +8,17 @@ NUM_TIMESTEPS = 100 CONFIGS = [ - (LinearGaussianSSM, dict(state_dim=2, emission_dim=10), None), - (LinearGaussianConjugateSSM, dict(state_dim=2, emission_dim=10), None), + (LinearGaussianSSM, dict(state_dim=2, emission_dim=10, use_parallel_inference=False), None), + (LinearGaussianSSM, dict(state_dim=2, emission_dim=10, use_parallel_inference=True), None), + (LinearGaussianConjugateSSM, dict(state_dim=2, emission_dim=10, use_parallel_inference=False), None), + (LinearGaussianConjugateSSM, dict(state_dim=2, emission_dim=10, use_parallel_inference=True), None), ] + @pytest.mark.parametrize(["cls", "kwargs", "inputs"], CONFIGS) def test_sample_and_fit(cls, kwargs, inputs): model = cls(**kwargs) - #key1, key2 = jr.split(jr.PRNGKey(int(datetime.now().timestamp()))) + # key1, key2 = jr.split(jr.PRNGKey(int(datetime.now().timestamp()))) key1, key2 = jr.split(jr.PRNGKey(0)) params, param_props = model.initialize(key1) states, emissions = model.sample(params, key2, num_timesteps=NUM_TIMESTEPS, inputs=inputs) diff --git a/dynamax/linear_gaussian_ssm/parallel_inference.py b/dynamax/linear_gaussian_ssm/parallel_inference.py index 2e50e8d6..a84aa6c6 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference.py @@ -1,4 +1,4 @@ -''' +""" Parallel filtering and smoothing for a lgssm. This implementation is adapted from the work of Adrien Correnflos: @@ -21,75 +21,41 @@ Dynamax - F₀,Q₀ F₁,Q₁ F₂,Q₂ + F₁,Q₁ F₂,Q₂ F₃,Q₃ Z₀ ─────────── Z₁ ─────────── Z₂ ─────────── Z₃ ─────... | | | | | H₀,R₀ | H₁,R₁ | H₂,R₂ | H₃,R₃ | | | | Y₀ Y₁ Y₂ Y₃ -''' +""" import jax.numpy as jnp from jax import vmap, lax from jaxtyping import Array, Float -from typing import NamedTuple +from typing import NamedTuple, Optional from dynamax.types import PRNGKey from functools import partial import warnings from tensorflow_probability.substrates.jax.distributions import ( MultivariateNormalDiagPlusLowRankCovariance as MVNLowRank, - MultivariateNormalFullCovariance as MVN) + MultivariateNormalFullCovariance as MVN, +) from jax.scipy.linalg import cho_solve, cho_factor from dynamax.utils.utils import symmetrize, psd_solve from dynamax.linear_gaussian_ssm import PosteriorGSSMFiltered, PosteriorGSSMSmoothed, ParamsLGSSM +from dynamax.linear_gaussian_ssm.inference import preprocess_args, _get_one_param, _get_params, _log_likelihood -def _get_one_param(x, dim, t): - """Helper function to get one parameter at time t.""" - if callable(x): - return x(t) - elif x.ndim == dim + 1: - return x[t] - else: - return x - -def _get_params(params, num_timesteps, t): - """Helper function to get parameters at time t.""" - assert not callable(params.emissions.cov), "Emission covariance cannot be a callable." - - F = _get_one_param(params.dynamics.weights, 2, t) - b = _get_one_param(params.dynamics.bias, 1, t) - Q = _get_one_param(params.dynamics.cov, 2, t) - H = _get_one_param(params.emissions.weights, 2, t+1) - d = _get_one_param(params.emissions.bias, 1, t+1) - - if len(params.emissions.cov.shape) == 1: - R = _get_one_param(params.emissions.cov, 1, t+1) - elif len(params.emissions.cov.shape) > 2: - R = _get_one_param(params.emissions.cov, 2, t+1) - elif params.emissions.cov.shape[0] != num_timesteps: - R = _get_one_param(params.emissions.cov, 2, t+1) - elif params.emissions.cov.shape[1] != num_timesteps: - R = _get_one_param(params.emissions.cov, 1, t+1) - else: - R = _get_one_param(params.emissions.cov, 2, t+1) - warnings.warn( - "Emission covariance has shape (N,N) where N is the number of timesteps. " - "The covariance will be interpreted as static and non-diagonal. To " - "specify a dynamic and diagonal covariance, pass it as a 3D array.") - - return F, b, Q, H, d, R - - -#---------------------------------------------------------------------------# +# --------------------------------------------------------------------------# # Filtering # -#---------------------------------------------------------------------------# +# --------------------------------------------------------------------------# + def _emissions_scale(Q, H, R): - """Compute the scale matrix for the emissions given the state covariance. + """Compute the scale matrix for the emissions given the state covariance S. S_inv = inv(H @ Q @ H.T + R) @@ -110,27 +76,10 @@ def _emissions_scale(Q, H, R): I = jnp.eye(Q.shape[0]) U = H @ jnp.linalg.cholesky(Q) X = U / R[:, None] - S_inv = jnp.diag(1.0 / R) - X @ psd_solve(I + U.T @ X, X.T) + S_inv = jnp.diag(1.0 / R) - X @ psd_solve(I + U.T @ X, X.T) return S_inv -def _marginal_loglik_elem(Q, H, R, y): - """Compute marginal log-likelihood elements. - - Args: - Q (state_dim, state_dim): State covariance. - H (emission_dim, state_dim): Emission matrix. - R (emission_dim, emission_dim) or (emission_dim,): Emission covariance. - y (emission_dim,): Emission. - """ - if R.ndim == 2: - S = H @ Q @ H.T + R - return -MVN(jnp.zeros_like(y), S).log_prob(y) - else: - L = H @ jnp.linalg.cholesky(Q) - return -MVNLowRank(jnp.zeros_like(y), R, L).log_prob(y) - - class FilterMessage(NamedTuple): """ Filtering associative scan elements. @@ -143,57 +92,59 @@ class FilterMessage(NamedTuple): eta: P(z_{i-1} | y_{i:j}) mean. logZ: log P(y_{i:j}) marginal log-likelihood. """ - A: Float[Array, "ntime state_dim state_dim"] - b: Float[Array, "ntime state_dim"] - C: Float[Array, "ntime state_dim state_dim"] - J: Float[Array, "ntime state_dim state_dim"] - eta: Float[Array, "ntime state_dim"] + + A: Float[Array, "ntime state_dim state_dim"] + b: Float[Array, "ntime state_dim"] + C: Float[Array, "ntime state_dim state_dim"] + J: Float[Array, "ntime state_dim state_dim"] + eta: Float[Array, "ntime state_dim"] logZ: Float[Array, "ntime"] -def _initialize_filtering_messages(params, emissions): +def _initialize_filtering_messages(params, emissions, inputs): """Preprocess observations to construct input for filtering assocative scan.""" - num_timesteps = emissions.shape[0] - - def _first_message(params, y): - H, d, R = _get_params(params, num_timesteps, -1)[3:] + + def _first_message(params, y, u): + H, D, d, R = _get_params(params, num_timesteps, 0)[4:] m = params.initial.mean P = params.initial.cov - S = H @ P @ H.T + (R if R.ndim==2 else jnp.diag(R)) + S = H @ P @ H.T + (R if R.ndim == 2 else jnp.diag(R)) S_inv = _emissions_scale(P, H, R) K = P @ H.T @ S_inv - A = jnp.zeros_like(P) - b = m + K @ (y - H @ m - d) + b = m + K @ (y - H @ m - D @ u - d) C = symmetrize(P - K @ S @ K.T) eta = jnp.zeros_like(b) J = jnp.eye(len(b)) - - logZ = _marginal_loglik_elem(P, H, R, y) + logZ = -_log_likelihood(m, P, H, D, d, R, u, y) return A, b, C, J, eta, logZ - @partial(vmap, in_axes=(None, 0, 0)) def _generic_message(params, y, t): - F, b, Q, H, d, R = _get_params(params, num_timesteps, t) + F, B, b, Q, H, D, d, R = _get_params(params, num_timesteps, t) + u = inputs[t] + + # Adjust the bias terms accoding to the input + b = b + B @ u + m = b S_inv = _emissions_scale(Q, H, R) K = Q @ H.T @ S_inv - - eta = F.T @ H.T @ S_inv @ (y - H @ b - d) + + eta = F.T @ H.T @ S_inv @ (y - H @ b - D @ u - d) J = symmetrize(F.T @ H.T @ S_inv @ H @ F) A = F - K @ H @ F - b = b + K @ (y - H @ b - d) + b = b + K @ (y - H @ b - D @ u - d) C = symmetrize(Q - K @ H @ Q) - logZ = _marginal_loglik_elem(Q, H, R, y) + logZ = -_log_likelihood(m, Q, H, D, d, R, u, y) return A, b, C, J, eta, logZ - A0, b0, C0, J0, eta0, logZ0 = _first_message(params, emissions[0]) - At, bt, Ct, Jt, etat, logZt = _generic_message(params, emissions[1:], jnp.arange(len(emissions)-1)) + A0, b0, C0, J0, eta0, logZ0 = _first_message(params, emissions[0], inputs[0]) + At, bt, Ct, Jt, etat, logZt = _generic_message(params, emissions[1:], jnp.arange(1, len(emissions))) return FilterMessage( A=jnp.concatenate([A0[None], At]), @@ -201,21 +152,21 @@ def _generic_message(params, y, t): C=jnp.concatenate([C0[None], Ct]), J=jnp.concatenate([J0[None], Jt]), eta=jnp.concatenate([eta0[None], etat]), - logZ=jnp.concatenate([logZ0[None], logZt]) + logZ=jnp.concatenate([logZ0[None], logZt]), ) - +@preprocess_args def lgssm_filter( params: ParamsLGSSM, - emissions: Float[Array, "ntime emission_dim"] + emissions: Float[Array, "ntime emission_dim"], + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> PosteriorGSSMFiltered: """A parallel version of the lgssm filtering algorithm. See S. Särkkä and Á. F. García-Fernández (2021) - https://arxiv.org/abs/1905.13002. - - Note: This function does not yet handle `inputs` to the system. """ + @vmap def _operator(elem1, elem2): A1, b1, C1, J1, eta1, logZ1 = elem1 @@ -234,22 +185,24 @@ def _operator(elem1, elem2): J = symmetrize(temp @ J2 @ A1 + J1) mu = jnp.linalg.solve(C1, b1) - t1 = (b1 @ mu - (eta2 + mu) @ jnp.linalg.solve(I_C1J2, C1 @ eta2 + b1)) - logZ = (logZ1 + logZ2 + 0.5 * jnp.linalg.slogdet(I_C1J2)[1] + 0.5 * t1) + t1 = b1 @ mu - (eta2 + mu) @ jnp.linalg.solve(I_C1J2, C1 @ eta2 + b1) + logZ = logZ1 + logZ2 + 0.5 * jnp.linalg.slogdet(I_C1J2)[1] + 0.5 * t1 return FilterMessage(A, b, C, J, eta, logZ) - initial_messages = _initialize_filtering_messages(params, emissions) + initial_messages = _initialize_filtering_messages(params, emissions, inputs) final_messages = lax.associative_scan(_operator, initial_messages) return PosteriorGSSMFiltered( + marginal_loglik=-final_messages.logZ[-1], filtered_means=final_messages.b, filtered_covariances=final_messages.C, - marginal_loglik=-final_messages.logZ[-1]) + ) -#---------------------------------------------------------------------------# +# --------------------------------------------------------------------------# # Smoothing # -#---------------------------------------------------------------------------# +# --------------------------------------------------------------------------# + class SmoothMessage(NamedTuple): """ @@ -260,12 +213,13 @@ class SmoothMessage(NamedTuple): g: P(z_i | y_{1:j}, z_{j+1}) bias. L: P(z_i | y_{1:j}, z_{j+1}) covariance. """ + E: Float[Array, "ntime state_dim state_dim"] g: Float[Array, "ntime state_dim"] L: Float[Array, "ntime state_dim state_dim"] -def _initialize_smoothing_messages(params, filtered_means, filtered_covariances): +def _initialize_smoothing_messages(params, inputs, filtered_means, filtered_covariances): """Preprocess filtering output to construct input for smoothing assocative scan.""" def _last_message(m, P): @@ -275,37 +229,44 @@ def _last_message(m, P): @partial(vmap, in_axes=(None, 0, 0, 0)) def _generic_message(params, m, P, t): - F, b, Q = _get_params(params, num_timesteps, t)[:3] - CF, low = cho_factor(F @ P @ F.T + Q) - E = cho_solve((CF, low), F @ P).T - g = m - E @ (F @ m + b) - L = symmetrize(P - E @ F @ P) + F_next, B_next, b_next, Q_next = _get_params(params, num_timesteps, t + 1)[:4] + + # Adjust the bias terms accoding to the input + u_next = inputs[t + 1] + b_next = b_next + B_next @ u_next + + CF, low = cho_factor(F_next @ P @ F_next.T + Q_next) + E = cho_solve((CF, low), F_next @ P).T + g = m - E @ (F_next @ m + b_next) + L = symmetrize(P - E @ F_next @ P) return E, g, L - + En, gn, Ln = _last_message(filtered_means[-1], filtered_covariances[-1]) - Et, gt, Lt = _generic_message(params, filtered_means[:-1], filtered_covariances[:-1], jnp.arange(len(filtered_means)-1)) - + Et, gt, Lt = _generic_message( + params, filtered_means[:-1], filtered_covariances[:-1], jnp.arange(len(filtered_means) - 1) + ) + return SmoothMessage( E=jnp.concatenate([Et, En[None]]), g=jnp.concatenate([gt, gn[None]]), - L=jnp.concatenate([Lt, Ln[None]]) + L=jnp.concatenate([Lt, Ln[None]]), ) +@preprocess_args def lgssm_smoother( params: ParamsLGSSM, - emissions: Float[Array, "ntime emission_dim"] + emissions: Float[Array, "ntime emission_dim"], + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> PosteriorGSSMSmoothed: """A parallel version of the lgssm smoothing algorithm. See S. Särkkä and Á. F. García-Fernández (2021) - https://arxiv.org/abs/1905.13002. - - Note: This function does not yet handle `inputs` to the system. """ - filtered_posterior = lgssm_filter(params, emissions) + filtered_posterior = lgssm_filter(params, emissions, inputs) filtered_means = filtered_posterior.filtered_means filtered_covs = filtered_posterior.filtered_covariances - + @vmap def _operator(elem1, elem2): E1, g1, L1 = elem1 @@ -315,21 +276,41 @@ def _operator(elem1, elem2): L = symmetrize(E2 @ L1 @ E2.T + L2) return E, g, L - initial_messages = _initialize_smoothing_messages(params, filtered_means, filtered_covs) + initial_messages = _initialize_smoothing_messages(params, inputs, filtered_means, filtered_covs) final_messages = lax.associative_scan(_operator, initial_messages, reverse=True) - + G = initial_messages.E[:-1] + smoothed_means = final_messages.g + smoothed_covariances = final_messages.L + smoothed_cross_covariances = compute_smoothed_cross_covariances( + G, smoothed_means[:-1], smoothed_means[1:], smoothed_covariances[1:] + ) return PosteriorGSSMSmoothed( marginal_loglik=filtered_posterior.marginal_loglik, filtered_means=filtered_means, filtered_covariances=filtered_covs, - smoothed_means=final_messages.g, - smoothed_covariances=final_messages.L + smoothed_means=smoothed_means, + smoothed_covariances=smoothed_covariances, + smoothed_cross_covariances=smoothed_cross_covariances, ) -#---------------------------------------------------------------------------# +@vmap +def compute_smoothed_cross_covariances( + G: Float[Array, "state_dim state_dim"], + smoothed_mean: Float[Array, "state_dim"], + smoothed_mean_next: Float[Array, "state_dim"], + smoothed_cov_next: Float[Array, "state_dim state_dim"], +) -> Float[Array, "state_dim state_dim"]: + # Compute the smoothed expectation of z_t z_{t+1}^T + # This is precomputed + # G = psd_solve(Q + F @ filtered_cov @ F.T, F @ filtered_cov).T + return G @ smoothed_cov_next + jnp.outer(smoothed_mean, smoothed_mean_next) + + +# --------------------------------------------------------------------------# # Sampling # -#---------------------------------------------------------------------------# +# --------------------------------------------------------------------------# + class SampleMessage(NamedTuple): """ @@ -339,32 +320,35 @@ class SampleMessage(NamedTuple): E: z_i ~ z_{j+1} weights. h: z_i ~ z_{j+1} bias. """ + E: Float[Array, "ntime state_dim state_dim"] h: Float[Array, "ntime state_dim"] -def _initialize_sampling_messages(key, params, filtered_means, filtered_covariances): +def _initialize_sampling_messages(key, params, inputs, filtered_means, filtered_covariances): """A parallel version of the lgssm sampling algorithm. - - Given parallel smoothing messages `z_i ~ N(E_i z_{i+1} + g_i, L_i)`, + + Given parallel smoothing messages `z_i ~ N(E_i z_{i+1} + g_i, L_i)`, the parallel sampling messages are `(E_i,h_i)` where `h_i ~ N(g_i, L_i)`. """ - E, g, L = _initialize_smoothing_messages(params, filtered_means, filtered_covariances) + E, g, L = _initialize_smoothing_messages(params, inputs, filtered_means, filtered_covariances) return SampleMessage(E=E, h=MVN(g, L).sample(seed=key)) def lgssm_posterior_sample( key: PRNGKey, params: ParamsLGSSM, - emissions: Float[Array, "ntime emission_dim"] + emissions: Float[Array, "ntime emission_dim"], + inputs: Optional[Float[Array, "ntime input_dim"]] = None, ) -> Float[Array, "ntime state_dim"]: """A parallel version of the lgssm sampling algorithm. See S. Särkkä and Á. F. García-Fernández (2021) - https://arxiv.org/abs/1905.13002. - - Note: This function does not yet handle `inputs` to the system. """ - filtered_posterior = lgssm_filter(params, emissions) + num_timesteps = len(emissions) + inputs = jnp.zeros((num_timesteps, 0)) if inputs is None else inputs + + filtered_posterior = lgssm_filter(params, emissions, inputs) filtered_means = filtered_posterior.filtered_means filtered_covs = filtered_posterior.filtered_covariances @@ -377,6 +361,6 @@ def _operator(elem1, elem2): h = E2 @ h1 + h2 return E, h - initial_messages = _initialize_sampling_messages(key, params, filtered_means, filtered_covs) + initial_messages = _initialize_sampling_messages(key, params, inputs, filtered_means, filtered_covs) _, samples = lax.associative_scan(_operator, initial_messages, reverse=True) - return samples \ No newline at end of file + return samples diff --git a/dynamax/linear_gaussian_ssm/parallel_inference_test.py b/dynamax/linear_gaussian_ssm/parallel_inference_test.py index cd6376b3..1fa67eea 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference_test.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference_test.py @@ -12,86 +12,111 @@ from dynamax.linear_gaussian_ssm.inference_test import flatten_diagonal_emission_cov -def allclose(x,y, atol=1e-2): - m = jnp.abs(jnp.max(x-y)) - if m > atol: - print(m) - return False - else: - return True - +from jax.config import config + +config.update("jax_enable_x64", True) + + +allclose = partial(jnp.allclose, atol=1e-2, rtol=1e-2) + def make_static_lgssm_params(): + latent_dim = 4 + observation_dim = 2 + input_dim = 3 + + keys = jr.split(jr.PRNGKey(0), 3) + dt = 0.1 F = jnp.eye(4) + dt * jnp.eye(4, k=2) - Q = 1. * jnp.kron(jnp.array([[dt**3/3, dt**2/2], - [dt**2/2, dt]]), - jnp.eye(2)) - + B = 0.2 * jr.normal(keys[0], (4, 3)) + b = 0.2 * jnp.arange(4) + Q = 1.0 * jnp.kron(jnp.array([[dt**3 / 3, dt**2 / 2], [dt**2 / 2, dt]]), jnp.eye(2)) + H = jnp.eye(2, 4) - R = 0.5 ** 2 * jnp.eye(2) - μ0 = jnp.array([0.,0.,1.,-1.]) + D = 0.2 * jr.normal(keys[1], (observation_dim, input_dim)) + d = 0.2 * jnp.ones(2) + R = 0.5**2 * jnp.eye(2) + + μ0 = jnp.array([0.0, 1.0, 1.0, -1.0]) Σ0 = jnp.eye(4) + lgssm = LinearGaussianSSM(latent_dim, observation_dim, input_dim) + params, _ = lgssm.initialize( + keys[2], + initial_mean=μ0, + initial_covariance=Σ0, + dynamics_weights=F, + dynamics_input_weights=B, + dynamics_bias=b, + dynamics_covariance=Q, + emission_weights=H, + emission_input_weights=D, + emission_bias=d, + emission_covariance=R, + ) + return params, lgssm + + +def make_dynamic_lgssm_params(num_timesteps): latent_dim = 4 observation_dim = 2 + input_dim = 3 - lgssm = LinearGaussianSSM(latent_dim, observation_dim) - params, _ = lgssm.initialize(jr.PRNGKey(0), - initial_mean=μ0, - initial_covariance= Σ0, - dynamics_weights=F, - dynamics_covariance=Q, - emission_weights=H, - emission_covariance=R) - return params, lgssm - - -def make_dynamic_lgssm_params(num_timesteps, latent_dim=4, observation_dim=2, seed=0): - key = jr.PRNGKey(seed) - key, key_f, key_r, key_init = jr.split(key, 4) + keys = jr.split(jr.PRNGKey(1), 9) dt = 0.1 - f_scale = jr.normal(key_f, (num_timesteps,)) * 0.5 - F = f_scale[:,None,None] * jnp.tile(jnp.eye(latent_dim), (num_timesteps, 1, 1)) - F += dt * jnp.eye(latent_dim, k=2) - - Q = 1. * jnp.kron(jnp.array([[dt**3/3, dt**2/2], - [dt**2/2, dt]]), - jnp.eye(latent_dim // 2)) - assert Q.shape[-1] == latent_dim - H = jnp.eye(observation_dim, latent_dim) - - r_scale = jr.normal(key_r, (num_timesteps,)) * 0.1 - R = (r_scale**2)[:,None,None] * jnp.tile(jnp.eye(observation_dim), (num_timesteps, 1, 1)) - - μ0 = jnp.array([0.,0.,1.,-1.]) + + F = ( + jnp.eye(4)[None] + + dt * jnp.eye(4, k=2)[None] + + 0.1 * jr.normal(keys[0], (num_timesteps, latent_dim, latent_dim)) + ) + B = 0.2 * jr.normal(keys[4], (num_timesteps, latent_dim, input_dim)) + b = 0.2 * jr.normal(keys[6], (num_timesteps, latent_dim)) + q_scale = jr.normal(keys[1], (num_timesteps, 1, 1)) ** 2 + Q = q_scale * jnp.kron(jnp.array([[dt**3 / 3, dt**2 / 2], [dt**2 / 2, dt]]), jnp.eye(2))[None] + + H = jnp.eye(2, 4)[None] * 0.1 * jr.normal(keys[2], (num_timesteps, observation_dim, latent_dim)) + D = 0.2 * jr.normal(keys[5], (num_timesteps, observation_dim, input_dim)) + d = 0.2 * jr.normal(keys[7], (num_timesteps, observation_dim)) + r_scale = jr.normal(keys[3], (num_timesteps, 1, 1)) ** 2 + R = r_scale * jnp.eye(2)[None] + + μ0 = jnp.array([1.0, -2.0, 1.0, -1.0]) Σ0 = jnp.eye(latent_dim) - lgssm = LinearGaussianSSM(latent_dim, observation_dim) - params, _ = lgssm.initialize(key_init, - initial_mean=μ0, - initial_covariance=Σ0, - dynamics_weights=F, - dynamics_covariance=Q, - emission_weights=H, - emission_covariance=R) + lgssm = LinearGaussianSSM(latent_dim, observation_dim, input_dim) + params, _ = lgssm.initialize( + keys[8], + initial_mean=μ0, + initial_covariance=Σ0, + dynamics_weights=F, + dynamics_input_weights=B, + dynamics_bias=b, + dynamics_covariance=Q, + emission_weights=H, + emission_input_weights=D, + emission_bias=d, + emission_covariance=R, + ) return params, lgssm class TestParallelLGSSMSmoother: - """ Compare parallel and serial lgssm smoothing implementations.""" - + """Compare parallel and serial lgssm smoothing implementations.""" + num_timesteps = 50 - key = jr.PRNGKey(1) + keys = jr.split(jr.PRNGKey(1), 2) - params, lgssm = make_static_lgssm_params() + params, lgssm = make_static_lgssm_params() params_diag = flatten_diagonal_emission_cov(params) - _, emissions = lgssm_joint_sample(params, key, num_timesteps) + inputs = jr.normal(keys[0], (num_timesteps, params.dynamics.input_weights.shape[-1])) + _, emissions = lgssm_joint_sample(params, keys[1], num_timesteps, inputs) - serial_posterior = serial_lgssm_smoother(params, emissions) - parallel_posterior = parallel_lgssm_smoother(params, emissions) - parallel_posterior_diag = parallel_lgssm_smoother(params_diag, emissions) + serial_posterior = serial_lgssm_smoother(params, emissions, inputs) + parallel_posterior = parallel_lgssm_smoother(params, emissions, inputs) + parallel_posterior_diag = parallel_lgssm_smoother(params_diag, emissions, inputs) def test_filtered_means(self): assert allclose(self.serial_posterior.filtered_means, self.parallel_posterior.filtered_means) @@ -109,28 +134,39 @@ def test_smoothed_covariances(self): assert allclose(self.serial_posterior.smoothed_covariances, self.parallel_posterior.smoothed_covariances) assert allclose(self.serial_posterior.smoothed_covariances, self.parallel_posterior_diag.smoothed_covariances) + def test_smoothed_cross_covariances(self): + x = self.serial_posterior.smoothed_cross_covariances + y = self.parallel_posterior.smoothed_cross_covariances + z = self.parallel_posterior_diag.smoothed_cross_covariances + matrix_norm_rel_diff = jnp.linalg.norm(x - y, axis=(1, 2)) / jnp.linalg.norm(x, axis=(1, 2)) + assert allclose(matrix_norm_rel_diff, 0) + matrix_norm_rel_diff = jnp.linalg.norm(x - z, axis=(1, 2)) / jnp.linalg.norm(x, axis=(1, 2)) + assert allclose(matrix_norm_rel_diff, 0) + def test_marginal_loglik(self): assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior.marginal_loglik, atol=2e-1) - assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior_diag.marginal_loglik, atol=2e-1) - - + assert jnp.allclose( + self.serial_posterior.marginal_loglik, self.parallel_posterior_diag.marginal_loglik, atol=2e-1 + ) class TestTimeVaryingParallelLGSSMSmoother: """Compare parallel and serial time-varying lgssm smoothing implementations. - + Vary dynamics weights and observation covariances with time. """ + num_timesteps = 50 - key = jr.PRNGKey(1) + keys = jr.split(jr.PRNGKey(1), 2) - params, lgssm = make_dynamic_lgssm_params(num_timesteps) + params, lgssm = make_dynamic_lgssm_params(num_timesteps) params_diag = flatten_diagonal_emission_cov(params) - _, emissions = lgssm_joint_sample(params, key, num_timesteps) + inputs = jr.normal(keys[0], (num_timesteps, params.emissions.input_weights.shape[-1])) + _, emissions = lgssm_joint_sample(params, keys[1], num_timesteps, inputs) - serial_posterior = serial_lgssm_smoother(params, emissions) - parallel_posterior = parallel_lgssm_smoother(params, emissions) - parallel_posterior_diag = parallel_lgssm_smoother(params_diag, emissions) + serial_posterior = serial_lgssm_smoother(params, emissions, inputs) + parallel_posterior = parallel_lgssm_smoother(params, emissions, inputs) + parallel_posterior_diag = parallel_lgssm_smoother(params_diag, emissions, inputs) def test_filtered_means(self): assert allclose(self.serial_posterior.filtered_means, self.parallel_posterior.filtered_means) @@ -148,41 +184,53 @@ def test_smoothed_covariances(self): assert allclose(self.serial_posterior.smoothed_covariances, self.parallel_posterior.smoothed_covariances) assert allclose(self.serial_posterior.smoothed_covariances, self.parallel_posterior_diag.smoothed_covariances) - def test_marginal_loglik(self): - assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior.marginal_loglik, atol=2e-1) - assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior_diag.marginal_loglik, atol=2e-1) + def test_smoothed_cross_covariances(self): + x = self.serial_posterior.smoothed_cross_covariances + y = self.parallel_posterior.smoothed_cross_covariances + z = self.parallel_posterior_diag.smoothed_cross_covariances + matrix_norm_rel_diff = jnp.linalg.norm(x - y, axis=(1, 2)) / jnp.linalg.norm(x, axis=(1, 2)) + assert allclose(matrix_norm_rel_diff, 0) + matrix_norm_rel_diff = jnp.linalg.norm(x - z, axis=(1, 2)) / jnp.linalg.norm(x, axis=(1, 2)) + assert allclose(matrix_norm_rel_diff, 0) + def test_marginal_loglik(self): + assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior.marginal_loglik, rtol=2e-2) + assert jnp.allclose( + self.serial_posterior.marginal_loglik, self.parallel_posterior_diag.marginal_loglik, rtol=2e-2 + ) -class TestTimeVaryingParallelLGSSMSampler(): +class TestTimeVaryingParallelLGSSMSampler: """Compare parallel and serial lgssm posterior sampling implementations in expectation.""" - + num_timesteps = 50 - key = jr.PRNGKey(1) + keys = jr.split(jr.PRNGKey(1), 2) - params, lgssm = make_dynamic_lgssm_params(num_timesteps) + params, lgssm = make_dynamic_lgssm_params(num_timesteps) params_diag = flatten_diagonal_emission_cov(params) - _, emissions = lgssm_joint_sample(params_diag, key, num_timesteps) + inputs = jr.normal(keys[0], (num_timesteps, params.emissions.input_weights.shape[-1])) + _, emissions = lgssm_joint_sample(params, keys[1], num_timesteps, inputs) num_samples = 1000 serial_keys = jr.split(jr.PRNGKey(2), num_samples) parallel_keys = jr.split(jr.PRNGKey(3), num_samples) - serial_samples = vmap(serial_lgssm_posterior_sample, in_axes=(0,None,None))( - serial_keys, params, emissions) - - parallel_samples = vmap(parallel_lgssm_posterior_sample, in_axes=(0, None, None))( - parallel_keys, params, emissions) - - parallel_samples_diag = vmap(parallel_lgssm_posterior_sample, in_axes=(0, None, None))( - parallel_keys, params, emissions) + serial_samples = vmap(serial_lgssm_posterior_sample, in_axes=(0, None, None, None))( + serial_keys, params, emissions, inputs + ) + parallel_samples = vmap(parallel_lgssm_posterior_sample, in_axes=(0, None, None, None))( + parallel_keys, params, emissions, inputs + ) + parallel_samples_diag = vmap(parallel_lgssm_posterior_sample, in_axes=(0, None, None, None))( + parallel_keys, params_diag, emissions, inputs + ) def test_sampled_means(self): serial_mean = self.serial_samples.mean(axis=0) parallel_mean = self.parallel_samples.mean(axis=0) parallel_mean_diag = self.parallel_samples.mean(axis=0) - assert allclose(serial_mean, parallel_mean, atol=1e-1) - assert allclose(serial_mean, parallel_mean_diag, atol=1e-1) + assert allclose(serial_mean, parallel_mean, atol=1e-1, rtol=1e-1) + assert allclose(serial_mean, parallel_mean_diag, atol=1e-1, rtol=1e-1) def test_sampled_covariances(self): # samples have shape (N, T, D): vmap over the T axis, calculate cov over N axis @@ -190,4 +238,4 @@ def test_sampled_covariances(self): parallel_cov = vmap(partial(jnp.cov, rowvar=False), in_axes=1)(self.parallel_samples) parallel_cov_diag = vmap(partial(jnp.cov, rowvar=False), in_axes=1)(self.parallel_samples) assert allclose(serial_cov, parallel_cov, atol=1e-1) - assert allclose(serial_cov, parallel_cov_diag, atol=1e-1) \ No newline at end of file + assert allclose(serial_cov, parallel_cov_diag, atol=1e-1)