diff --git a/dynamax/hidden_markov_model/inference.py b/dynamax/hidden_markov_model/inference.py index 7576a707..e2df10ea 100644 --- a/dynamax/hidden_markov_model/inference.py +++ b/dynamax/hidden_markov_model/inference.py @@ -101,7 +101,8 @@ def hmm_filter( transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"], Float[Array, "num_states num_states"]], log_likelihoods: Float[Array, "num_timesteps num_states"], - transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None + transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None, + num_timesteps: Optional[Int] = None, ) -> HMMPosteriorFiltered: r"""Forwards filtering @@ -115,12 +116,14 @@ def hmm_filter( transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$ log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$. transition_fn: function that takes in an integer time index and returns a $K \times K$ transition matrix. + num_timesteps: number of "valid" timesteps, to support vmapping with padded arrays. Returns: filtered posterior distribution """ - num_timesteps, num_states = log_likelihoods.shape + max_num_timesteps = log_likelihoods.shape[0] + num_timesteps = num_timesteps if num_timesteps is not None else max_num_timesteps def _step(carry, t): log_normalizer, predicted_probs = carry @@ -128,6 +131,9 @@ def _step(carry, t): A = get_trans_mat(transition_matrix, transition_fn, t) ll = log_likelihoods[t] + # Ignore observations after specified number of timesteps + ll = jnp.where(t < num_timesteps, ll, 0.0) + filtered_probs, log_norm = _condition_on(predicted_probs, ll) log_normalizer += log_norm predicted_probs_next = _predict(filtered_probs, A) @@ -135,7 +141,7 @@ def _step(carry, t): return (log_normalizer, predicted_probs_next), (filtered_probs, predicted_probs) carry = (0.0, initial_distribution) - (log_normalizer, _), (filtered_probs, predicted_probs) = lax.scan(_step, carry, jnp.arange(num_timesteps)) + (log_normalizer, _), (filtered_probs, predicted_probs) = lax.scan(_step, carry, jnp.arange(max_num_timesteps)) post = HMMPosteriorFiltered(marginal_loglik=log_normalizer, filtered_probs=filtered_probs, @@ -149,7 +155,8 @@ def hmm_backward_filter( transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"], Float[Array, "num_states num_states"]], log_likelihoods: Float[Array, "num_timesteps num_states"], - transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None + transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None, + num_timesteps: Optional[Int] = None ) -> Tuple[Float, Float[Array, "num_timesteps num_states"]]: r"""Run the filter backwards in time. This is the second step of the forward-backward algorithm. @@ -163,12 +170,14 @@ def hmm_backward_filter( transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$ log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$. transition_fn: function that takes in an integer time index and returns a $K \times K$ transition matrix. + num_timesteps: number of "valid" timesteps, to support vmapping with padded arrays. Returns: marginal log likelihood and backward messages. """ - num_timesteps, num_states = log_likelihoods.shape + max_num_timesteps, num_states = log_likelihoods.shape + num_timesteps = num_timesteps if num_timesteps is not None else max_num_timesteps def _step(carry, t): log_normalizer, backward_pred_probs = carry @@ -176,17 +185,18 @@ def _step(carry, t): A = get_trans_mat(transition_matrix, transition_fn, t) ll = log_likelihoods[t] + # Ignore observations after specified number of timesteps + ll = jnp.where(t < num_timesteps, ll, 0.0) + # Condition on emission at time t, being careful not to overflow. backward_filt_probs, log_norm = _condition_on(backward_pred_probs, ll) - # Update the log normalizer. - log_normalizer += log_norm + # Predict the next state (going backward in time). next_backward_pred_probs = _predict(backward_filt_probs, A.T) return (log_normalizer, next_backward_pred_probs), backward_pred_probs carry = (0.0, jnp.ones(num_states)) - (log_normalizer, _), rev_backward_pred_probs = lax.scan(_step, carry, jnp.arange(num_timesteps)[::-1]) - backward_pred_probs = rev_backward_pred_probs[::-1] + (log_normalizer, _), backward_pred_probs = lax.scan(_step, carry, jnp.arange(max_num_timesteps), reverse=True) return log_normalizer, backward_pred_probs @@ -197,6 +207,7 @@ def hmm_two_filter_smoother( Float[Array, "num_states num_states"]], log_likelihoods: Float[Array, "num_timesteps num_states"], transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None, + num_timesteps: Optional[Int] = None, compute_trans_probs: bool = True ) -> HMMPosterior: r"""Computed the smoothed state probabilities using the two-filter @@ -212,16 +223,19 @@ def hmm_two_filter_smoother( transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$ log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$. transition_fn: function that takes in an integer time index and returns a $K \times K$ transition matrix. + num_timesteps: number of "valid" timesteps, to support vmapping with padded arrays. Returns: posterior distribution """ - post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods, transition_fn) + # Forward + post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods, transition_fn, num_timesteps) ll = post.marginal_loglik filtered_probs, predicted_probs = post.filtered_probs, post.predicted_probs - _, backward_pred_probs = hmm_backward_filter(transition_matrix, log_likelihoods, transition_fn) + # Backward + _, backward_pred_probs = hmm_backward_filter(transition_matrix, log_likelihoods, transition_fn, num_timesteps) # Compute smoothed probabilities smoothed_probs = filtered_probs * backward_pred_probs @@ -251,6 +265,7 @@ def hmm_smoother( Float[Array, "num_states num_states"]], log_likelihoods: Float[Array, "num_timesteps num_states"], transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None, + num_timesteps: Optional[Int]=None, compute_trans_probs: bool = True ) -> HMMPosterior: r"""Computed the smoothed state probabilities using a general @@ -268,15 +283,17 @@ def hmm_smoother( transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$ log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$. transition_fn: function that takes in an integer time index and returns a $K \times K$ transition matrix. + num_timesteps: number of "valid" timesteps, to support vmapping with padded arrays. Returns: posterior distribution """ - num_timesteps, num_states = log_likelihoods.shape + max_num_timesteps, _ = log_likelihoods.shape + num_timesteps = num_timesteps if num_timesteps is not None else max_num_timesteps # Run the HMM filter - post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods, transition_fn) + post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods, transition_fn, num_timesteps) ll = post.marginal_loglik filtered_probs, predicted_probs = post.filtered_probs, post.predicted_probs @@ -294,16 +311,15 @@ def _step(carry, args): smoothed_probs_next / predicted_probs_next) smoothed_probs = filtered_probs * (A @ relative_probs_next) smoothed_probs /= smoothed_probs.sum() - return smoothed_probs, smoothed_probs # Run the HMM smoother carry = filtered_probs[-1] - args = (jnp.arange(num_timesteps - 2, -1, -1), filtered_probs[:-1][::-1], predicted_probs[1:][::-1]) - _, rev_smoothed_probs = lax.scan(_step, carry, args) + args = (jnp.arange(max_num_timesteps - 1), filtered_probs[:-1], predicted_probs[1:]) + _, smoothed_probs = lax.scan(_step, carry, args, reverse=True) # Reverse the arrays and return - smoothed_probs = jnp.vstack([rev_smoothed_probs[::-1], filtered_probs[-1]]) + smoothed_probs = jnp.vstack([smoothed_probs, filtered_probs[-1]]) # Package into a posterior posterior = HMMPosterior( @@ -352,6 +368,7 @@ def hmm_fixed_lag_smoother( posterior distribution """ + # TODO: Update to allow variable length time series num_timesteps, num_states = log_likelihoods.shape def _step(carry, t): @@ -441,7 +458,8 @@ def hmm_posterior_mode( transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"], Float[Array, "num_states num_states"]], log_likelihoods: Float[Array, "num_timesteps num_states"], - transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None + transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None, + num_timesteps: Optional[Int]=None, ) -> Int[Array, "num_timesteps"]: r"""Compute the most likely state sequence. This is called the Viterbi algorithm. @@ -450,12 +468,14 @@ def hmm_posterior_mode( transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$ log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$. transition_fn: function that takes in an integer time index and returns a $K \times K$ transition matrix. + num_timesteps: number of "valid" timesteps, to support vmapping with padded arrays. Returns: most likely state sequence """ - num_timesteps, num_states = log_likelihoods.shape + max_num_timesteps, _ = log_likelihoods.shape + num_timesteps = num_timesteps if num_timesteps is not None else max_num_timesteps # Run the backward pass def _backward_pass(best_next_score, t): @@ -464,14 +484,19 @@ def _backward_pass(best_next_score, t): scores = jnp.log(A) + best_next_score + log_likelihoods[t + 1] best_next_state = jnp.argmax(scores, axis=1) best_next_score = jnp.max(scores, axis=1) + + # Only update if log_likelihoods[t+1] is valid + best_next_score = jnp.where(t + 1 < num_timesteps, best_next_score, jnp.zeros(num_states)) + best_next_state = jnp.where(t + 1 < num_timesteps, best_next_state, jnp.zeros(num_states, dtype=int)) + return best_next_score, best_next_state num_states = log_likelihoods.shape[1] - best_second_score, rev_best_next_states = lax.scan( - _backward_pass, jnp.zeros(num_states), jnp.arange(num_timesteps - 2, -1, -1) + best_second_score, best_next_states = lax.scan( + _backward_pass, jnp.zeros(num_states), jnp.arange(max_num_timesteps - 1), + reverse=True ) - best_next_states = rev_best_next_states[::-1] - + # Run the forward pass def _forward_pass(state, best_next_state): next_state = best_next_state[state] @@ -490,7 +515,8 @@ def hmm_posterior_sample( transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"], Float[Array, "num_states num_states"]], log_likelihoods: Float[Array, "num_timesteps num_states"], - transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None + transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None, + num_timesteps: Optional[Int] = None, ) -> Int[Array, "num_timesteps"]: r"""Sample a latent sequence from the posterior. @@ -505,10 +531,11 @@ def hmm_posterior_sample( :sample of the latent states, $z_{1:T}$ """ - num_timesteps, num_states = log_likelihoods.shape + max_num_timesteps, num_states = log_likelihoods.shape + num_timesteps = num_timesteps if num_timesteps is not None else max_num_timesteps # Run the HMM filter - post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods, transition_fn) + post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods, transition_fn, num_timesteps) log_normalizer, filtered_probs = post.marginal_loglik, post.filtered_probs # Run the sampler backward in time @@ -528,13 +555,13 @@ def _step(carry, args): return state, state # Run the HMM smoother - rngs = jr.split(rng, num_timesteps) + rngs = jr.split(rng, max_num_timesteps) last_state = jr.choice(rngs[-1], a=num_states, p=filtered_probs[-1]) - args = (jnp.arange(num_timesteps - 1, 0, -1), rngs[:-1][::-1], filtered_probs[:-1][::-1]) - _, rev_states = lax.scan(_step, last_state, args) + args = (jnp.arange(max_num_timesteps - 1), rngs[:-1], filtered_probs[:-1]) + _, states = lax.scan(_step, last_state, args, reverse=True) # Reverse the arrays and return - states = jnp.concatenate([rev_states[::-1], jnp.array([last_state])]) + states = jnp.concatenate([states, jnp.array([last_state])]) return log_normalizer, states def _compute_sum_transition_probs( diff --git a/dynamax/hidden_markov_model/inference_test.py b/dynamax/hidden_markov_model/inference_test.py index f72e8a58..ec22f323 100644 --- a/dynamax/hidden_markov_model/inference_test.py +++ b/dynamax/hidden_markov_model/inference_test.py @@ -5,6 +5,7 @@ import dynamax.hidden_markov_model.inference as core import dynamax.hidden_markov_model.parallel_inference as parallel +from jax import vmap from jax.scipy.special import logsumexp def big_log_joint(initial_probs, transition_matrix, log_likelihoods): @@ -259,6 +260,43 @@ def trans_mat_callable(t): assert jnp.allclose(sample, sample2) +def test_hmm_padding(key=0, num_timesteps=10, num_states=5, padding=3): + if isinstance(key, int): + key = jr.PRNGKey(key) + + initial_probs, transition_matrix, log_lkhds = random_hmm_args(key, num_timesteps + padding, num_states) + + # Run the HMM filter with a 3d list of transition matrices and a callable + post = core.hmm_filter(initial_probs, transition_matrix, log_lkhds[:num_timesteps]) + post2 = core.hmm_filter(initial_probs, transition_matrix, log_lkhds, num_timesteps=num_timesteps) + assert jnp.allclose(post.marginal_loglik, post2.marginal_loglik, atol=1e-4) + assert jnp.allclose(post.filtered_probs, post2.filtered_probs[:num_timesteps], atol=1e-4) + + # Run the HMM smoother with a 3d list of transition matrices and a callable + post = core.hmm_smoother(initial_probs, transition_matrix, log_lkhds[:num_timesteps]) + post2 = core.hmm_smoother(initial_probs, transition_matrix, log_lkhds, num_timesteps=num_timesteps) + assert jnp.allclose(post.smoothed_probs, post2.smoothed_probs[:num_timesteps], atol=1e-4) + + # Run Viterbi + mode = core.hmm_posterior_mode(initial_probs, transition_matrix, log_lkhds[:num_timesteps]) + mode2 = core.hmm_posterior_mode(initial_probs, transition_matrix, log_lkhds, num_timesteps=num_timesteps) + assert jnp.allclose(mode, mode2[:num_timesteps]) + + +# Test vmap +def test_hmm_variable_length_vmap(key=0, max_num_timesteps=10, num_states=5, num_seqs=10): + if isinstance(key, int): + key = jr.PRNGKey(key) + + all_args = vmap(random_hmm_args, in_axes=(0, None, None))( + jr.split(key, num_seqs), max_num_timesteps, num_states) + + all_num_timesteps = jr.randint(key, (num_seqs,), 1, max_num_timesteps) + + # Just make sure vmap runs without throwing a concretization error + posteriors = vmap(core.hmm_filter)(*all_args, num_timesteps=all_num_timesteps) + + def test_parallel_filter(key=0, num_timesteps=100, num_states=3): if isinstance(key, int): key = jr.PRNGKey(key) diff --git a/dynamax/nonlinear_gaussian_ssm/inference_ekf.py b/dynamax/nonlinear_gaussian_ssm/inference_ekf.py index c375d7b5..5e019909 100644 --- a/dynamax/nonlinear_gaussian_ssm/inference_ekf.py +++ b/dynamax/nonlinear_gaussian_ssm/inference_ekf.py @@ -341,5 +341,6 @@ def _step(carry, _): smoothed_posterior = extended_kalman_smoother(params, emissions, smoothed_prior, inputs) return smoothed_posterior, None + # TODO: Does this even work with None as initial carry? smoothed_posterior, _ = lax.scan(_step, None, jnp.arange(num_iter)) return smoothed_posterior