-
Notifications
You must be signed in to change notification settings - Fork 105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Fullrank vi #479
Conversation
__all__ = ["FullrankVIState", "FullrankVIInfo", "sample", "generate_fullrank_logdensity", "step"] | ||
|
||
|
||
def _real_vector_to_cholesky(X): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this jitable? I think we might need to make m
and n
as kwarg, and create a closure below when we are setting up the parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if this is jitable but jax seems to be happy with it. I didn't see any warning.
**optimizer_kwargs | ||
) -> FullrankVIState: | ||
"""Initialize the fullrank VI state.""" | ||
mu = jax.tree_map(jnp.zeros_like, position) # Is this a good init strategy? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also allow random initialization for both mu and L, maybe allowing user to pass a callable that takes random_key and shape as input (for zeros and ones we can just ignore the random_key)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, shall we also make the changes to MFVI?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to make a choice and let users initialise manually if they want something different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, the init function here only serves as a "default initialization option", if the users want to specify their own strategy, they could replace the call to "fullrank_vi.init()" https://github.com/blackjax-devs/blackjax/pull/479/files#diff-8923a0a4ea42b4d3c2e1756e182a25f482809bc5b1c23a601423f5908f0f03e2R36 with their own function
@rlouf @junpenglao |
Sorry for my late reaction. I think that the issue that you're facing with the covariance matrix is part of a more general discussion we're having at the library level (we have similar issues with mass matrix for the HMC algorithms). I suggest we leave it as is for now to get the ball rolling on this PR and get it in a mergeable state. |
Is there still an interest in full rank VI? |
Yes - i think we are nearly there but maybe it is better to start a new PR afresh |
Close in favor of #720 |
Fullrank VI
TODO 1: Better test cases that verify posterior covariance recovery
TODO 2: Currently the user don't have access to the _real_to_vector function therefore cannot convert the unnormalized real space vector to the covariance matrix. This is not sensible and should be fixed.