From 5d67795e1943359393fbf66fd571b95a3962e87f Mon Sep 17 00:00:00 2001 From: Anuj Date: Thu, 15 Feb 2024 11:04:04 +0100 Subject: [PATCH 1/2] added k-means initialisation for Poisson HMMs --- dynamax/hidden_markov_model/models/poisson_hmm.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/dynamax/hidden_markov_model/models/poisson_hmm.py b/dynamax/hidden_markov_model/models/poisson_hmm.py index 198bf128..2e1a27b3 100644 --- a/dynamax/hidden_markov_model/models/poisson_hmm.py +++ b/dynamax/hidden_markov_model/models/poisson_hmm.py @@ -45,14 +45,21 @@ def emission_shape(self): def initialize(self, key=jr.PRNGKey(0), method="prior", - emission_rates=None): + emission_rates=None, + emissions=None): # Initialize the emission probabilities if emission_rates is None: if method.lower() == "prior": prior = tfd.Gamma(self.emission_prior_concentration, self.emission_prior_rate) emission_rates = prior.sample(seed=key, sample_shape=(self.num_states, self.emission_dim)) elif method.lower() == "kmeans": - raise NotImplementedError("kmeans initialization is not yet implemented!") + assert emissions is not None, "Need emissions to initialize the model with K-Means!" + from sklearn.cluster import KMeans + key, subkey = jr.split(key) # Create a random seed for SKLearn. + sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value. + km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim)) + ## Cluster centers, also forms the Poisson emission rate + emission_rates = jnp.array(km.cluster_centers_) else: raise Exception("invalid initialization method: {}".format(method)) else: From 359d86e523371635c2bd2b2db004f5eb16fc907a Mon Sep 17 00:00:00 2001 From: Anuj Date: Thu, 15 Feb 2024 11:20:01 +0100 Subject: [PATCH 2/2] fixed k-means initialization in PoissonHMM class --- dynamax/hidden_markov_model/models/poisson_hmm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dynamax/hidden_markov_model/models/poisson_hmm.py b/dynamax/hidden_markov_model/models/poisson_hmm.py index 2e1a27b3..e54e6397 100644 --- a/dynamax/hidden_markov_model/models/poisson_hmm.py +++ b/dynamax/hidden_markov_model/models/poisson_hmm.py @@ -143,7 +143,8 @@ def initialize(self, key=jr.PRNGKey(0), method="prior", initial_probs: Optional[Float[Array, "num_states"]]=None, transition_matrix: Optional[Float[Array, "num_states num_states"]]=None, - emission_rates: Optional[Float[Array, "num_states emission_dim"]]=None + emission_rates: Optional[Float[Array, "num_states emission_dim"]]=None, + emissions: Optional[Float[Array, "num_timesteps emission_dim"]]=None ) -> Tuple[ParameterSet, PropertySet]: """Initialize the model parameters and their corresponding properties. @@ -167,5 +168,5 @@ def initialize(self, key=jr.PRNGKey(0), params, props = dict(), dict() params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs) params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix) - params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_rates=emission_rates) + params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_rates=emission_rates, emissions = emissions) return ParamsPoissonHMM(**params), ParamsPoissonHMM(**props)