Skip to content

Meeting minutes

Rémi Louf edited this page Oct 11, 2021 · 1 revision

Friday 07/10/2021

Present:

  • Junpeng Lao
  • Adrien Corenflos
  • Rémi Louf

Meta-algorithm & coupling

Pausing for now.

Inverse mass matrix

Users who will build their own matrix:

  • Want to learn from the library or teach with the library; Pass their own mass matrix to show how the matrix impacts sampling.
  • Researchers working on new algorithms (mostly adaptation); will do by definition. Need a way to provide a matrix at a high level.
  • Applied researchers who already have quadratic approximation of their potential.

Next Actions

  • Adrien: Design document

Contributing guide

Junpeng started working in it here. We need to add a description of the high-level code architecture.

New kernel API

We should push through parameters at each step

`new_state, info = nuts.one_step(rng_key, state, step_size, inverse_mass_matrix)`

Q: What can we dowith this design that we could not? A: Modify the kernel between two HMC steps; adaptation is currently slower than Numpyro’s which is not cool. It’s also a lot of partial applications in the code which makes it look less elegant. And now we can `vmap` accros array of step sizes & mass matrix -> ChEes

Q: Concerns? A: With more parameters it will get more (too?) complicated.

Q: Can we somehow hide that to the users at a higher level? Code structure?

hmc(logprob_fn, *, step_size, inverse_mass_matrix)  # public API

if not step_size and not inverse_mass_matrix:
    ft.partial(one_step, step_size=step_size, inverse_mass_matrix=inverse_mass_matrix)

Q: Can we use NamedTuple for parameters?

class HMCParameters(NamedTuple):
    num_integration_steps: Optional[int]
    step_size: Optional[float]
    inverse_mass_matrix: Optional[jnp.ndarray]

params = HMCParameters(10)

### IS THAT OK?
kernel = blackjax.hmc(logprob_fn)  ## function unpacks
kernel.update(rng_key, init_state, params)  ## value of making it more complicated?
#  !!! We cannot jit it from the outside here, you cannot jit
#  !!! For `vmap` we would have to duplicate the elements we're not mapping over

Actually we can add that on top of the `one_step` function that only takes `step_size`, `inverse_mass_matrix`.

NamedTuples to hold functions?

import blackjax

class SamplingKernel(NamedTuple):
    init: Callable
    step: Callable

# Now
init, step = bjx.hmc(logprob_fn)
state = init(position)
new_state, info = step(rng_key, state)

# With Optax design
hmc = blackjax.hmc(logprob_fn)
state = hmc.init(position)
new_state, info = hmc.step(rng_key, state)

Q: What should we name the second function? Q: Are there other places in the internals where we can do that?

Decision: Ok, proceed.

Users pass their own gradient

Not only would help harmonize the API with the introduction of SgLD, but also something that people actually use (see #136)

hmc = blackjax.hmc(logprob_fn: Callable, logprob_grad_fn: Optional[Callable])

if `logprob_grad_fn` is `None` then `jax.grad` is used to compute the gradient.

Decision: Not for now.

Scope of the library

Adrien: State-space models shouldn’t be in blackjax. Can be quickly complicated to handle.

Remi:

  • Variational inference is fine is someone wants to implement it.
  • Bayesian updating is missing in most PPLs. It is currently possible to do with SMC, and we should add an example.

Junpeng:

  • HMC-focused, then SMC, more than inference focused

Should we move notebooks to markdown?

Or scripts (docstrings) ----> Sphinx (see Python Optimal Transport) However users should be able to easily convert them to notebooks; very often people start tinkering with a library from example notebooks.

Q: Can we export those as notebooks *easily*? A: look at what matplotlib does

Decision: Use scripts if it can be easily converted to notebooks for users to tinker with; otherwise jupytext.

Doc

  • API doc (compiler)
  • Rappeler le scope
  • Rappeler le public
  • Design principle, code map
  • Contributing guide
  • If there’s a question about how to do stuff it should probably be raised as an issue.

Decision: We should just do it!

Formal structure

Decision: Rémi investigates.

Next actions

Adrien:

  • Design document inverse mass matrix

Junpeng:

  • PR on `chex` for testing

Rémi

  • SgLD design and implementation
  • Adaptation