Skip to content

Commit

Permalink
Change type hints to jaxtyping in slds code
Browse files Browse the repository at this point in the history
  • Loading branch information
gileshd committed Oct 16, 2024
1 parent 0328712 commit ee3fdc8
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions dynamax/slds/mixture_kalman_filter_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

# Author: Gerardo Durán-Martín (@gerdm)

from dataclasses import dataclass
import jax
import jax.numpy as jnp
from jax import random
import jax.numpy as jnp
from jax.scipy.special import logit
from dataclasses import dataclass
from jaxtyping import Array, Float


@dataclass
Expand All @@ -24,12 +25,12 @@ class RBPFParamsDiscrete:
noise1_next ~ N(0, Q)
noise2_next ~ N(0, R)
"""
A: jnp.array
B: jnp.array
C: jnp.array
Q: jnp.array
R: jnp.array
transition_matrix: jnp.array
A: Float[Array, "dim_hidden dim_hidden"]
B: Float[Array, "dim_hidden dim_control"]
C: Float[Array, "dim_emission dim_hidden"]
Q: Float[Array, "dim_hidden dim_hidden"]
R: Float[Array, "dim_emission dim_emission"]
transition_matrix: Float[Array, "dim_control dim_control"]


def draw_state(val, key, params):
Expand All @@ -42,7 +43,7 @@ def draw_state(val, key, params):
----------
val: tuple (int, jnp.array)
(latent value of system, state value of system).
params: PRBPFParamsDiscrete
params: RBPFParamsDiscrete
key: PRNGKey
"""
latent_old, state_old = val
Expand Down Expand Up @@ -158,4 +159,4 @@ def rbpf_optimal(current_config, xt, params, nparticles=100):

weights_t = jnp.ones(nparticles) / nparticles

return (key_next, mu_t, Sigma_t, weights_t, st), (mu_t, Sigma_t, weights_t, st, proposal_samp)
return (key_next, mu_t, Sigma_t, weights_t, st), (mu_t, Sigma_t, weights_t, st, proposal_samp)

0 comments on commit ee3fdc8

Please sign in to comment.