Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAX kmeans implementation #371

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft

Add JAX kmeans implementation #371

wants to merge 3 commits into from

Conversation

gileshd
Copy link
Collaborator

@gileshd gileshd commented Jul 21, 2024

Summary

This PR add an implementation of the k-means algorithm written in JAX.

By replacing the current calls to the sklearn implementation with this new implementation, we can now vmap and jit code which uses kmeans for initialization. We can also remove scikit-learn as a core dependency (it is still used in some demos).

Details

The changes are currently structured into different commits as follows:

  1. Add utility function for sklearn kmeans
    • This just adds as wrapper around the sklearn kmeans implementation to dynamax.utils.cluster.
  2. Update SSMs to use kmeans utility function
    • Refactor the calls to sklearn kmeans to use the new utility function.
  3. Add jax implementation of kmeans
    • Add a new implementation of kmeans to dynamax.utils.cluster which is compatible with JAX transformations.
    • Add some basic tests for this implementation.

Further Testing

It would be nice to be able to test if the new implementation does roughly as good a job as the sklearn implementation (it is considerably less complex). From playing about with it thus far once I added k-means++ initialisation it seemed to work pretty well.

I'm currently working on some test code where I patch in the new implementation and check that we can get about the same goodness of fit.

Final changes before merging

Before merging I will replace the current calls to kmeans_sklearn with kmeans_jax (and perhaps rename it) and remove the kmeans_sklearn function from dynamax.utils.cluster.

Questions:

  1. Should the kmeans implementation be marked as reserved for internal use only? - i.e. not part of the public API?
    • This way we have a bit more leeway to change the interface if necessary at a later date without breaking the API
    • Is this in fact already (implicitly?) the case with the contents of dynamax.utils?
  2. Any bright ideas on testing this?

Related issues:

Closes #315.

@gileshd gileshd requested a review from slinderman July 21, 2024 16:18
Comment on lines +58 to +64
cluster_emissions_means = jnp.array(
[jnp.mean(flat_emissions, where=km_labels == k) for k in range(self.num_states)]
)
cluster_emissions_means = jnp.where(
jnp.isnan(cluster_emissions_means), flat_emissions.mean(), cluster_emissions_means
)
_emission_biases = tfb.Sigmoid().inverse(cluster_emissions_means)
Copy link
Collaborator Author

@gileshd gileshd Jul 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is required so that this chunk is friendly with JAX transformations - the old version has intermediate arrays with variable shape.

@gileshd gileshd self-assigned this Jul 21, 2024
@murphyk
Copy link
Member

murphyk commented Jul 22, 2024

"Should the kmeans implementation be marked as reserved for internal use only?" No!
K-means in jax could be of independent interest, so it is worth making it a first class public citizen.
But see also
https://ott-jax.readthedocs.io/en/latest/_autosummary/ott.tools.k_means.k_means.html

Possibly also relevant
https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.approx_max_k.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for kmeans initialization with vmap
2 participants