Skip to content

Commit

Permalink
renaming to data mask
Browse files Browse the repository at this point in the history
  • Loading branch information
ciguaran committed Oct 4, 2024
1 parent a5922ad commit 26d271e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
22 changes: 11 additions & 11 deletions blackjax/smc/partial_posteriors_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ class PartialPosteriorsSMCState(NamedTuple):
The particles' positions.
weights:
Weights of the particles, so that they represent a probability distribution
selector:
Datapoints used to calculate the posterior the particles represent, a 1D boolean
array to indicate which datapoints to include in the computation of the observed likelihood.
data_mask:
A 1D boolean array to indicate which datapoints to include
in the computation of the observed likelihood.
"""

particles: ArrayTree
weights: Array
selector: Array
data_mask: Array


def init(particles: ArrayLikeTree, num_datapoints: int) -> PartialPosteriorsSMCState:
"""num_datapoints are the number of observations that could potentially be
used in a partial posterior. Since the initial selector is all 0s, it
used in a partial posterior. Since the initial data_mask is all 0s, it
means that no likelihood term will be added (only prior).
"""
num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0]
Expand Down Expand Up @@ -73,11 +73,11 @@ def build_kernel(
delegate = smc_from_mcmc(mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy)

def step(
key, state: PartialPosteriorsSMCState, selector: Array
key, state: PartialPosteriorsSMCState, data_mask: Array
) -> Tuple[PartialPosteriorsSMCState, smc.base.SMCInfo]:
logposterior_fn = partial_logposterior_factory(selector)
logposterior_fn = partial_logposterior_factory(data_mask)

previous_logposterior_fn = partial_logposterior_factory(state.selector)
previous_logposterior_fn = partial_logposterior_factory(state.data_mask)

def log_weights_fn(x):
return logposterior_fn(x) - previous_logposterior_fn(x)
Expand All @@ -86,7 +86,7 @@ def log_weights_fn(x):
key, state, num_mcmc_steps, mcmc_parameters, logposterior_fn, log_weights_fn
)

return PartialPosteriorsSMCState(state.particles, state.weights, selector), info
return PartialPosteriorsSMCState(state.particles, state.weights, data_mask), info

return step

Expand Down Expand Up @@ -118,7 +118,7 @@ def init_fn(position: ArrayLikeTree, num_observations, rng_key=None):
del rng_key
return init(position, num_observations)

def step(key: PRNGKey, state: PartialPosteriorsSMCState, selector: Array):
return kernel(key, state, selector)
def step(key: PRNGKey, state: PartialPosteriorsSMCState, data_mask: Array):
return kernel(key, state, data_mask)

return SamplingAlgorithm(init_fn, step) # type: ignore[arg-type]
16 changes: 8 additions & 8 deletions tests/smc/test_partial_posteriors_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ def test_partial_posteriors(self):

dataset_size = 1000

def partial_logposterior_factory(selector):
def partial_logposterior_factory(data_mask):
def partial_logposterior(x):
lp = logprior_fn(x)
return lp + jnp.sum(
self.logdensity_by_observation(**x, **observations)
* selector.reshape(-1, 1)
* data_mask.reshape(-1, 1)
)

return jax.jit(partial_logposterior)
Expand All @@ -60,20 +60,20 @@ def partial_logposterior(x):
init_state = init(init_particles, 1000)
smc_kernel = self.variant(kernel)

selectors = jnp.array(
data_masks = jnp.array(
[
jnp.concat([jnp.ones(selector), jnp.zeros(dataset_size - selector)])
for selector in np.arange(100, 1001, 50)
jnp.concat([jnp.ones(datapoints_chosen), jnp.zeros(dataset_size - datapoints_chosen)])
for datapoints_chosen in np.arange(100, 1001, 50)
]
)

def body_fn(carry, selector):
def body_fn(carry, data_mask):
i, state = carry
subkey = jax.random.fold_in(self.key, i)
new_state, info = smc_kernel(subkey, state, selector)
new_state, info = smc_kernel(subkey, state, data_mask)
return (i + 1, new_state), (new_state, info)

(steps, result), it = jax.lax.scan(body_fn, (0, init_state), selectors)
(steps, result), it = jax.lax.scan(body_fn, (0, init_state), data_masks)
assert steps == 19

self.assert_linear_regression_test_case(result)
Expand Down

0 comments on commit 26d271e

Please sign in to comment.