diff --git a/tests/vi/test_schrodinger_follmer.py b/tests/vi/test_schrodinger_follmer.py index fd58fed0a..79e7afedd 100644 --- a/tests/vi/test_schrodinger_follmer.py +++ b/tests/vi/test_schrodinger_follmer.py @@ -51,7 +51,7 @@ def logp_unnormalized_posterior(x, observed, prior_mu, prior_prec, true_cov): # Simulate the data observed = jax.random.multivariate_normal( - rng_key_observed, true_mu, true_cov, shape=(10_000,) + rng_key_observed, true_mu, true_cov, shape=(25,) ) logp_model = functools.partial(