Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAX kmeans implementation #371

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions dynamax/hidden_markov_model/models/arhmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dynamax.parameters import ParameterProperties
from dynamax.types import Scalar
from dynamax.utils.bijectors import RealToPSDBijector
from dynamax.utils.cluster import kmeans_sklearn
from tensorflow_probability.substrates import jax as tfp
from typing import NamedTuple, Optional, Tuple, Union

Expand Down Expand Up @@ -42,12 +43,8 @@ def initialize(self,
emissions=None):
if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
from sklearn.cluster import KMeans
key, subkey = jr.split(key) # Create a random seed for SKLearn.
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
_emission_weights = jnp.zeros((self.num_states, self.emission_dim, self.emission_dim * self.num_lags))
_emission_biases = jnp.array(km.cluster_centers_)
_emission_biases, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key)
_emission_covs = jnp.tile(jnp.eye(self.emission_dim)[None, :, :], (self.num_states, 1, 1))

elif method.lower() == "prior":
Expand Down
9 changes: 3 additions & 6 deletions dynamax/hidden_markov_model/models/gamma_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState
from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions
from dynamax.types import Scalar
from dynamax.utils.cluster import kmeans_sklearn
import optax
from typing import NamedTuple, Optional, Tuple, Union

Expand Down Expand Up @@ -38,13 +39,9 @@ def initialize(self,

if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
from sklearn.cluster import KMeans
key, subkey = jr.split(key) # Create a random seed for SKLearn.
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, 1))

cluster_centers, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, 1), key)
_emission_concentrations = jnp.ones((self.num_states,))
_emission_rates = jnp.ravel(1.0 / km.cluster_centers_)
_emission_rates = jnp.ravel(1.0 / cluster_centers)

elif method.lower() == "prior":
_emission_concentrations = jnp.ones((self.num_states,))
Expand Down
32 changes: 6 additions & 26 deletions dynamax/hidden_markov_model/models/gaussian_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from dynamax.utils.distributions import niw_posterior_update
from dynamax.utils.bijectors import RealToPSDBijector
from dynamax.utils.utils import pytree_sum
from dynamax.utils.cluster import kmeans_sklearn
from typing import NamedTuple, Optional, Tuple, Union


Expand Down Expand Up @@ -70,12 +71,7 @@ def initialize(self, key=jr.PRNGKey(0),
emissions=None):
if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
from sklearn.cluster import KMeans
key, subkey = jr.split(key) # Create a random seed for SKLearn.
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))

_emission_means = jnp.array(km.cluster_centers_)
_emission_means, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key)
_emission_covs = jnp.tile(jnp.eye(self.emission_dim)[None, :, :], (self.num_states, 1, 1))

elif method.lower() == "prior":
Expand Down Expand Up @@ -168,11 +164,7 @@ def initialize(self, key=jr.PRNGKey(0),

if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
from sklearn.cluster import KMeans
key, subkey = jr.split(key) # Create a random seed for SKLearn.
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
_emission_means = jnp.array(km.cluster_centers_)
_emission_means, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key)
_emission_scale_diags = jnp.ones((self.num_states, self.emission_dim))

elif method.lower() == "prior":
Expand Down Expand Up @@ -289,11 +281,7 @@ def initialize(self, key=jr.PRNGKey(0),
"""
if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
from sklearn.cluster import KMeans
key, subkey = jr.split(key) # Create a random seed for SKLearn.
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
_emission_means = jnp.array(km.cluster_centers_)
_emission_means, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key)
_emission_scales = jnp.ones((self.num_states,))

elif method.lower() == "prior":
Expand Down Expand Up @@ -391,11 +379,7 @@ def initialize(self, key=jr.PRNGKey(0),
"""
if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
from sklearn.cluster import KMeans
key, subkey = jr.split(key) # Create a random seed for SKLearn.
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
_emission_means = jnp.array(km.cluster_centers_)
_emission_means, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key)
_emission_cov = jnp.eye(self.emission_dim)

elif method.lower() == "prior":
Expand Down Expand Up @@ -513,11 +497,7 @@ def initialize(self, key=jr.PRNGKey(0),
"""
if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
from sklearn.cluster import KMeans
key, subkey = jr.split(key) # Create a random seed for SKLearn.
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
_emission_means = jnp.array(km.cluster_centers_)
_emission_means, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key)
_emission_cov_diag_factors = jnp.ones((self.num_states, self.emission_dim))
_emission_cov_low_rank_factors = jnp.zeros((self.num_states, self.emission_dim, self.emission_rank))

Expand Down
15 changes: 5 additions & 10 deletions dynamax/hidden_markov_model/models/gmm_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions
from dynamax.utils.bijectors import RealToPSDBijector
from dynamax.utils.utils import pytree_sum
from dynamax.utils.cluster import kmeans_sklearn
from dynamax.types import Scalar
from typing import NamedTuple, Optional, Tuple, Union

Expand Down Expand Up @@ -77,12 +78,9 @@ def initialize(self, key=jr.PRNGKey(0),
emissions=None):
if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
from sklearn.cluster import KMeans
key, subkey = jr.split(key) # Create a random seed for SKLearn.
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
cluster_centers, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key)
_emission_weights = jnp.ones((self.num_states, self.num_components)) / self.num_components
_emission_means = jnp.tile(jnp.array(km.cluster_centers_)[:, None, :], (1, self.num_components, 1))
_emission_means = jnp.tile(jnp.array(cluster_centers)[:, None, :], (1, self.num_components, 1))
_emission_covs = jnp.tile(jnp.eye(self.emission_dim), (self.num_states, self.num_components, 1, 1))

elif method.lower() == "prior":
Expand Down Expand Up @@ -299,12 +297,9 @@ def initialize(self, key=jr.PRNGKey(0),
emissions=None):
if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
from sklearn.cluster import KMeans
key, subkey = jr.split(key) # Create a random seed for SKLearn.
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
cluster_centers, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key)
_emission_weights = jnp.ones((self.num_states, self.num_components)) / self.num_components
_emission_means = jnp.tile(jnp.array(km.cluster_centers_)[:, None, :], (1, self.num_components, 1))
_emission_means = jnp.tile(jnp.array(cluster_centers)[:, None, :], (1, self.num_components, 1))
_emission_scale_diags = jnp.ones((self.num_states, self.num_components, self.emission_dim))

elif method.lower() == "prior":
Expand Down
7 changes: 2 additions & 5 deletions dynamax/hidden_markov_model/models/linreg_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dynamax.types import Scalar
from dynamax.utils.utils import pytree_sum
from dynamax.utils.bijectors import RealToPSDBijector
from dynamax.utils.cluster import kmeans_sklearn
from tensorflow_probability.substrates import jax as tfp
from typing import NamedTuple, Optional, Tuple, Union

Expand Down Expand Up @@ -58,12 +59,8 @@ def initialize(self,
emissions=None):
if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
from sklearn.cluster import KMeans
key, subkey = jr.split(key) # Create a random seed for SKLearn.
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
_emission_weights = jnp.zeros((self.num_states, self.emission_dim, self.input_dim))
_emission_biases = jnp.array(km.cluster_centers_)
_emission_biases, _ = kmeans_sklearn(self.num_states, emissions.reshape(-1, self.emission_dim), key)
_emission_covs = jnp.tile(jnp.eye(self.emission_dim)[None, :, :], (self.num_states, 1, 1))

elif method.lower() == "prior":
Expand Down
16 changes: 10 additions & 6 deletions dynamax/hidden_markov_model/models/logreg_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState
from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions
from dynamax.types import Scalar
from dynamax.utils.cluster import kmeans_sklearn
import optax
from typing import NamedTuple, Optional, Tuple, Union

Expand Down Expand Up @@ -48,16 +49,19 @@ def initialize(self,
if method.lower() == "kmeans":
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
assert inputs is not None, "Need inputs to initialize the model with K-Means!"
from sklearn.cluster import KMeans

flat_emissions = emissions.reshape(-1,)
flat_inputs = inputs.reshape(-1, self.input_dim)
key, subkey = jr.split(key) # Create a random seed for SKLearn.
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(flat_inputs)

_, km_labels = kmeans_sklearn(self.num_states, flat_inputs, key)
_emission_weights = jnp.zeros((self.num_states, self.input_dim))
_emission_biases = jnp.array([tfb.Sigmoid().inverse(flat_emissions[km.labels_ == k].mean())
for k in range(self.num_states)])
cluster_emissions_means = jnp.array(
[jnp.mean(flat_emissions, where=km_labels == k) for k in range(self.num_states)]
)
cluster_emissions_means = jnp.where(
jnp.isnan(cluster_emissions_means), flat_emissions.mean(), cluster_emissions_means
)
_emission_biases = tfb.Sigmoid().inverse(cluster_emissions_means)
Comment on lines +58 to +64
Copy link
Collaborator Author

@gileshd gileshd Jul 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is required so that this chunk is friendly with JAX transformations - the old version has intermediate arrays with variable shape.


elif method.lower() == "prior":
# TODO: Use an MNIW prior
Expand Down
102 changes: 102 additions & 0 deletions dynamax/utils/cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from functools import partial
from jax import lax, jit
from jax import numpy as jnp
from jax import random as jr
from jaxtyping import Array, Int, Float
from typing import NamedTuple, Tuple


def kmeans_sklearn(
k: int, X: Float[Array, "num_samples state_dim"], key: Array
) -> Tuple[Float[Array, "num_states state_dim"], Float[Array, "num_samples"]]:
"""
Compute the cluster centers and assignments using the sklearn K-means algorithm.

Args:
k (int): The number of clusters.
X (Array(N, D)): The input data array. N samples of dimension D.
key (Array): The random seed array.

Returns:
Array(k, D), Array(N,): The cluster centers and labels
"""
from sklearn.cluster import KMeans

key, subkey = jr.split(key) # Create a random seed for SKLearn.
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
km = KMeans(k, random_state=int(sklearn_key)).fit(X)
return jnp.array(km.cluster_centers_), jnp.array(km.labels_)


class KMeansState(NamedTuple):
centroids: Float[Array, "num_states state_dim"]
assignments: Int[Array, "num_samples"]
prev_centroids: Float[Array, "num_states state_dim"]
itr: int


@partial(jit, static_argnums=(1, 3))
def kmeans_jax(
X: Float[Array, "num_samples state_dim"],
k: int,
key: Array = jr.PRNGKey(0),
max_iters: int = 1000,
) -> KMeansState:
"""
Perform k-means clustering using JAX.

K-means++ initialization is used to initialize the centroids.

Args:
X (Array): The input data array of shape (n_samples, n_features).
k (int): The number of clusters.
max_iters (int, optional): The maximum number of iterations. Defaults to 1000.
key (PRNGKey, optional): The random key for initialization. Defaults to jr.PRNGKey(0).

Returns:
KMeansState: A named tuple containing the final centroids array of shape (k, n_features),
the assignments array of shape (n_samples,) indicating the cluster index for each sample,
the previous centroids array of shape (k, n_features), and the number of iterations.
"""

def _update_centroids(X: Array, assignments: Array):
new_centroids = jnp.array([jnp.mean(X, axis=0, where=(assignments == i)[:, None]) for i in range(k)])
return new_centroids

def _update_assignments(X, centroids):
return jnp.argmin(jnp.linalg.norm(X[:, None] - centroids, axis=2), axis=1)

def body(carry: KMeansState):
centroids, assignments, *_ = carry
new_centroids = _update_centroids(X, assignments)
new_assignments = _update_assignments(X, new_centroids)
return KMeansState(new_centroids, new_assignments, centroids, carry.itr + 1)

def cond(carry: KMeansState):
return jnp.any(carry.centroids != carry.prev_centroids) & (carry.itr < max_iters)

def init(key):
"""kmeans++ initialization of centroids

Iteratively sample new centroids with probability proportional to the squared distance
from the closest centroid. This initialization method is more stable than random
initialization and leads to faster convergence.
Ref: Arthur, D., & Vassilvitskii, S. (2006).
"""
centroids = jnp.zeros((k, X.shape[1]))
centroids = centroids.at[0, :].set(jr.choice(key, X))
for i in range(1, k):
squared_diffs = jnp.sum((X[:, None, :] - centroids[None, :i, :]) ** 2, axis=2)
min_squared_dists = jnp.min(squared_diffs, axis=1)
probs = min_squared_dists / jnp.sum(min_squared_dists)
centroids = centroids.at[i, :].set(jr.choice(key, X, p=probs))
assignments = _update_assignments(X, centroids)
# Perform one iteration to update centroids
updated_centroids = _update_centroids(X, assignments)
updated_assignments = _update_assignments(X, updated_centroids)
return KMeansState(updated_centroids, updated_assignments, centroids, 1)

init_state = init(key)
state = lax.while_loop(cond, body, init_state)

return state
50 changes: 50 additions & 0 deletions dynamax/utils/cluster_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from jax import numpy as jnp
from jax import random as jr
from jax import vmap

from dynamax.utils.cluster import kmeans_jax


def test_kmeans_jax_toy():
"""Checks that kmeans works against toy example.

Ref: scikit-learn tests
"""

key = jr.PRNGKey(101)
x = jnp.array([[0, 0], [0.5, 0], [0.5, 1], [1, 1]])

centroids, assignments, *_ = kmeans_jax(x, 2, key)

# There are two possible solutions for the centroids and assignments
try:
expected_labels = jnp.array([0, 0, 1, 1])
expected_centers = jnp.array([[0.25, 0], [0.75, 1]])
assert jnp.all(assignments == expected_labels)
assert jnp.allclose(centroids, expected_centers)
except AssertionError:
expected_labels = jnp.array([1, 1, 0, 0])
expected_centers = jnp.array([[0.75, 1.0], [0.25, 0.0]])
assert jnp.all(assignments == expected_labels)
assert jnp.allclose(centroids, expected_centers)


def test_kmeans_jax_vmap():
"""Test that kmeans_jax works with vmap."""

def _gen_data(key):
"""Generate 3 clusters of 10 samples each."""
subkeys = jr.split(key, 3)
means = jnp.array([-2., 0., 2.])
_2D_normal = lambda key, mean: jr.normal(key, (10, 2))*0.2 + mean
return vmap(_2D_normal)(subkeys, means).reshape(-1, 2)

key = jr.PRNGKey(5)
key, *data_subkeys = jr.split(key,3)
# Generate 2 samples of the 3-cluster data
x = vmap(_gen_data)(jnp.array(data_subkeys))

alg_subkeys = jr.split(key, 2)
_, assignments, *_ = vmap(kmeans_jax, (0, None, 0))(x, 3, alg_subkeys)
# Check that the assignments are the same for both samples (clusters are very distinct)
assert jnp.all(assignments[0] == assignments[1])
Loading