diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 10ae3c7866..9441368d73 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -85,6 +85,7 @@ jobs: tests/step_methods/hmc/test_quadpotential.py - | + tests/backends/test_mcbackend.py tests/distributions/test_truncated.py tests/logprob/test_abstract.py tests/logprob/test_censoring.py diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index 44c5d73ecc..757bead693 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -41,3 +41,4 @@ dependencies: - pip: - git+https://github.com/pymc-devs/pymc-sphinx-theme - numdifftools>=0.9.40 + - mcbackend>=0.4.0 diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index b70ad4cee3..a915b0c7a9 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -31,3 +31,4 @@ dependencies: - types-cachetools - pip: - numdifftools>=0.9.40 + - mcbackend>=0.4.0 diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index 5b6e5749e5..f0ea51254a 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -38,3 +38,4 @@ dependencies: - pip: - git+https://github.com/pymc-devs/pymc-sphinx-theme - numdifftools>=0.9.40 + - mcbackend>=0.4.0 diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index fd88e054ce..ae184cdf9d 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -31,3 +31,4 @@ dependencies: - types-cachetools - pip: - numdifftools>=0.9.40 + - mcbackend>=0.4.0 diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index fbfe8914a9..a89f64d833 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -61,16 +61,32 @@ """ from copy import copy -from typing import Dict, List, Optional, Sequence, Union +from typing import Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np +from typing_extensions import TypeAlias + from pymc.backends.arviz import predictions_to_inference_data, to_inference_data from pymc.backends.base import BaseTrace, IBaseTrace from pymc.backends.ndarray import NDArray from pymc.model import Model from pymc.step_methods.compound import BlockedStep, CompoundStep +HAS_MCB = False +try: + from mcbackend import Backend, NumPyBackend, Run + + from pymc.backends.mcbackend import init_chain_adapters + + TraceOrBackend = Union[BaseTrace, Backend] + RunType: TypeAlias = Run + HAS_MCB = True +except ImportError: + TraceOrBackend = BaseTrace # type: ignore + RunType = type(None) # type: ignore + + __all__ = ["to_inference_data", "predictions_to_inference_data"] @@ -99,16 +115,27 @@ def _init_trace( def init_traces( *, - backend: Optional[BaseTrace], + backend: Optional[TraceOrBackend], chains: int, expected_length: int, step: Union[BlockedStep, CompoundStep], - var_dtypes: Dict[str, np.dtype], - var_shapes: Dict[str, Sequence[int]], + initial_point: Mapping[str, np.ndarray], model: Model, -) -> Sequence[IBaseTrace]: +) -> Tuple[Optional[RunType], Sequence[IBaseTrace]]: """Initializes a trace recorder for each chain.""" - return [ + if HAS_MCB and backend is None: + backend = NumPyBackend(preallocate=expected_length) + if HAS_MCB and isinstance(backend, Backend): + return init_chain_adapters( + backend=backend, + chains=chains, + initial_point=initial_point, + step=step, + model=model, + ) + + assert backend is None or isinstance(backend, BaseTrace) + traces = [ _init_trace( expected_length=expected_length, stats_dtypes=step.stats_dtypes, @@ -118,3 +145,4 @@ def init_traces( ) for chain_number in range(chains) ] + return None, traces diff --git a/pymc/backends/mcbackend.py b/pymc/backends/mcbackend.py new file mode 100644 index 0000000000..3e732101b1 --- /dev/null +++ b/pymc/backends/mcbackend.py @@ -0,0 +1,286 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import logging +import pickle + +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast + +import hagelkorn +import mcbackend as mcb +import numpy as np + +from mcbackend.npproto.utils import ndarray_from_numpy +from pytensor.compile.sharedvalue import SharedVariable +from pytensor.graph.basic import Constant + +from pymc.backends.base import IBaseTrace +from pymc.model import Model +from pymc.pytensorf import PointFunc +from pymc.step_methods.compound import ( + BlockedStep, + CompoundStep, + StatsBijection, + flat_statname, + flatten_steps, +) + +_log = logging.getLogger("pymc") + + +def find_data(pmodel: Model) -> List[mcb.DataVariable]: + """Extracts data variables from a model.""" + observed_rvs = {pmodel.rvs_to_values[rv] for rv in pmodel.observed_RVs} + dvars = [] + # All data containers are named vars! + for name, var in pmodel.named_vars.items(): + dv = mcb.DataVariable(name) + if isinstance(var, Constant): + dv.value = ndarray_from_numpy(var.data) + elif isinstance(var, SharedVariable): + dv.value = ndarray_from_numpy(var.get_value()) + else: + continue + dv.dims = list(pmodel.named_vars_to_dims.get(name, [])) + dv.is_observed = var in observed_rvs + dvars.append(dv) + return dvars + + +def get_variables_and_point_fn( + model: Model, initial_point: Mapping[str, np.ndarray] +) -> Tuple[List[mcb.Variable], PointFunc]: + """Get metadata on free, value and deterministic model variables.""" + # The samplers act only on the inputs needed for the log-likelihood, + # but the user is interested in transformed variables and deterministics. + vvars = model.value_vars + vars = model.unobserved_value_vars + # Below we compilt the "point function" that transforms a draw to the set + # of untransformed, transformed and deterministic variables that will be traced. + point_fn = model.compile_fn(vars, inputs=vvars, on_unused_input="ignore", point_fn=True) + point_fn = cast(PointFunc, point_fn) + point = point_fn(initial_point) + + names = [v.name for v in vars] + dtypes = [v.dtype for v in vars] + shapes = [v.shape for v in point] + deterministics = {d.name for d in model.deterministics} + variables = [ + mcb.Variable( + name=name, + dtype=str(dtype), + shape=list(shape), + dims=list(model.named_vars_to_dims.get(name, [])), + is_deterministic=name in deterministics, + ) + for name, dtype, shape in zip(names, dtypes, shapes) + ] + return variables, point_fn + + +class ChainRecordAdapter(IBaseTrace): + """Wraps an McBackend ``Chain`` as an ``IBaseTrace``.""" + + def __init__( + self, chain: mcb.Chain, point_fn: PointFunc, stats_bijection: StatsBijection + ) -> None: + # Assign attributes required by IBaseTrace + self.chain = chain.cmeta.chain_number + self.varnames = [v.name for v in chain.rmeta.variables] + stats_dtypes = {s.name: np.dtype(s.dtype) for s in chain.rmeta.sample_stats} + self.sampler_vars = [ + {sname: stats_dtypes[fname] for fname, sname, is_obj in sstats} + for sstats in stats_bijection._stat_groups + ] + + self._chain = chain + self._point_fn = point_fn + self._statsbj = stats_bijection + super().__init__() + + def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]): + values = self._point_fn(draw) + value_dict = {n: v for n, v in zip(self.varnames, values)} + stats_dict = self._statsbj.map(stats) + # Apply pickling to objects stats + for fname in self._statsbj.object_stats.keys(): + val_bytes = pickle.dumps(stats_dict[fname]) + val = base64.encodebytes(val_bytes).decode("ascii") + stats_dict[fname] = np.array(val, dtype=str) + return self._chain.append(value_dict, stats_dict) + + def __len__(self): + return len(self._chain) + + def get_values(self, varname: str, burn=0, thin=1) -> np.ndarray: + return self._chain.get_draws(varname, slice(burn, None, thin)) + + def _get_stats(self, fname: str, slc: slice) -> np.ndarray: + """Wraps `self._chain.get_stats` but unpickles automatically.""" + values = self._chain.get_stats(fname, slc) + # Unpickle object stats + if fname in self._statsbj.object_stats: + objs = [] + for v in values: + enc = str(v).encode("ascii") + str_ = base64.decodebytes(enc) + obj = pickle.loads(str_) + objs.append(obj) + return np.array(objs, dtype=object) + return values + + def get_sampler_stats( + self, stat_name: str, sampler_idx: Optional[int] = None, burn=0, thin=1 + ) -> np.ndarray: + slc = slice(burn, None, thin) + # When there's just one sampler, default to remove the sampler dimension + if sampler_idx is None and self._statsbj.n_samplers == 1: + sampler_idx = 0 + # Fetching for a specific sampler is easy + if sampler_idx is not None: + return self._get_stats(flat_statname(sampler_idx, stat_name), slc) + # To fetch for all samplers, we must collect the arrays one by one. + stats_dict = { + stat.name: self._get_stats(stat.name, slc) + for stat in self._chain.rmeta.sample_stats + if stat_name in stat.name + } + if not stats_dict: + raise KeyError(f"No stat '{stat_name}' was recorded.") + stats_list = self._statsbj.rmap(stats_dict) + stats_arrays = [] + is_ragged = False + for sd in stats_list: + if not sd: + is_ragged = True + continue + else: + stats_arrays.append(tuple(sd.values())[0]) + + if is_ragged: + _log.debug("Stat '%s' was not recorded by all samplers.", stat_name) + if len(stats_arrays) == 1: + return stats_arrays[0] + return np.array(stats_arrays).T + + def _slice(self, idx: slice) -> "IBaseTrace": + # Get the integer indices + start, stop, step = idx.indices(len(self)) + indices = np.arange(start, stop, step) + + # Create a NumPyChain for the sliced data + nchain = mcb.backends.numpy.NumPyChain( + self._chain.cmeta, self._chain.rmeta, preallocate=len(indices) + ) + + # Copy at selected indices and append them to the new chain. + # This may be slow, but NumPyChain currently don't have a batch-insert or slice API. + vnames = [v.name for v in nchain.variables.values()] + snames = [s.name for s in nchain.sample_stats.values()] + for i in indices: + draw = self._chain.get_draws_at(i, var_names=vnames) + stats = self._chain.get_stats_at(i, stat_names=snames) + nchain.append(draw, stats) + return ChainRecordAdapter(nchain, self._point_fn, self._statsbj) + + def point(self, idx: int) -> Dict[str, np.ndarray]: + return self._chain.get_draws_at(idx, [v.name for v in self._chain.variables.values()]) + + +def make_runmeta_and_point_fn( + *, + initial_point: Mapping[str, np.ndarray], + step: Union[CompoundStep, BlockedStep], + model: Model, +) -> Tuple[mcb.RunMeta, PointFunc]: + variables, point_fn = get_variables_and_point_fn(model, initial_point) + + sample_stats = [ + mcb.Variable("tune", "bool"), + ] + + # In PyMC the sampler stats are grouped by the sampler. + steps = flatten_steps(step) + for s, sm in enumerate(steps): + for statname, (dtype, shape) in sm.stats_dtypes_shapes.items(): + sname = flat_statname(s, statname) + sshape = [ + # PyMC uses None to indicate dynamic dims, MCB uses -1 + (-1 if s is None else s) + for s in (shape or []) + ] + svar = mcb.Variable( + name=sname, + dtype=np.dtype(dtype).name, + shape=sshape, + undefined_ndim=shape is None, + ) + sample_stats.append(svar) + + coordinates = [ + mcb.Coordinate(dname, mcb.npproto.utils.ndarray_from_numpy(np.array(cvals))) + for dname, cvals in model.coords.items() + if cvals is not None + ] + meta = mcb.RunMeta( + rid=hagelkorn.random(), + variables=variables, + coordinates=coordinates, + sample_stats=sample_stats, + data=find_data(model), + ) + return meta, point_fn + + +def init_chain_adapters( + *, + backend: mcb.Backend, + chains: int, + initial_point: Mapping[str, np.ndarray], + step: Union[CompoundStep, BlockedStep], + model: Model, +) -> Tuple[mcb.Run, List[ChainRecordAdapter]]: + """Create an McBackend metadata description for the MCMC run. + + Parameters + ---------- + backend + An McBackend `Backend` instance. + chains + Number of chains to initialize. + initial_point + Dictionary mapping value variable names to initial values. + step : CompoundStep or BlockedStep + The step method that iterates the MCMC. + model : pm.Model + The current PyMC model. + + Returns + ------- + adapters + Chain recording adapters that wrap McBackend Chains in the PyMC IBaseTrace interface. + """ + meta, point_fn = make_runmeta_and_point_fn(initial_point=initial_point, step=step, model=model) + run = backend.init_run(meta) + statsbj = StatsBijection(step.stats_dtypes) + adapters = [ + ChainRecordAdapter( + chain=run.init_chain(chain_number=chain_number), + point_fn=point_fn, + stats_bijection=statsbj, + ) + for chain_number in range(chains) + ] + return run, adapters diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 243b03eb5e..0449b029f8 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -32,8 +32,8 @@ import pymc as pm -from pymc.backends import init_traces -from pymc.backends.base import BaseTrace, IBaseTrace, MultiTrace, _choose_chains +from pymc.backends import RunType, TraceOrBackend, init_traces +from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain @@ -328,7 +328,7 @@ def sample( init: str = "auto", jitter_max_retries: int = 10, n_init: int = 200_000, - trace: Optional[BaseTrace] = None, + trace: Optional[TraceOrBackend] = None, discard_tuned_samples: bool = True, compute_convergence_checks: bool = True, keep_warning_stat: bool = False, @@ -609,13 +609,12 @@ def sample( _check_start_shape(model, ip) # Create trace backends for each chain - traces = init_traces( + run, traces = init_traces( backend=trace, chains=chains, expected_length=draws + tune, step=step, - var_dtypes={vn: v.dtype for vn, v in ip.items()}, - var_shapes={vn: v.shape for vn, v in ip.items()}, + initial_point=ip, model=model, ) @@ -690,6 +689,7 @@ def sample( # Packaging, validating and returning the result was extracted # into a function to make it easier to test and refactor. return _sample_return( + run=run, traces=traces, tune=tune, t_sampling=t_sampling, @@ -704,6 +704,7 @@ def sample( def _sample_return( *, + run: Optional[RunType], traces: Sequence[IBaseTrace], tune: int, t_sampling: float, diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index 4b9147f526..28a1efb718 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -191,6 +191,11 @@ def stop_tuning(self): self.tune = False +def flat_statname(sampler_idx: int, sname: str) -> str: + """Get the flat-stats name for a samplers stat.""" + return f"sampler_{sampler_idx}__{sname}" + + def get_stats_dtypes_shapes_from_steps( steps: Iterable[BlockedStep], ) -> Dict[str, Tuple[StatDtype, StatShape]]: @@ -201,7 +206,7 @@ def get_stats_dtypes_shapes_from_steps( result = {} for s, step in enumerate(steps): for sname, (dtype, shape) in step.stats_dtypes_shapes.items(): - result[f"sampler_{s}__{sname}"] = (dtype, shape) + result[flat_statname(s, sname)] = (dtype, shape) return result @@ -262,10 +267,21 @@ class StatsBijection: def __init__(self, sampler_stats_dtypes: Sequence[Mapping[str, type]]) -> None: # Keep a list of flat vs. original stat names - self._stat_groups: List[List[Tuple[str, str]]] = [ - [(f"sampler_{s}__{statname}", statname) for statname, _ in names_dtypes.items()] - for s, names_dtypes in enumerate(sampler_stats_dtypes) - ] + stat_groups = [] + for s, names_dtypes in enumerate(sampler_stats_dtypes): + group = [] + for statname, dtype in names_dtypes.items(): + flatname = flat_statname(s, statname) + is_obj = np.dtype(dtype) == np.dtype(object) + group.append((flatname, statname, is_obj)) + stat_groups.append(group) + self._stat_groups: List[List[Tuple[str, str, bool]]] = stat_groups + self.object_stats = { + fname: (s, sname) + for s, group in enumerate(self._stat_groups) + for fname, sname, is_obj in group + if is_obj + } @property def n_samplers(self) -> int: @@ -275,9 +291,10 @@ def map(self, stats_list: Sequence[Mapping[str, Any]]) -> StatsDict: """Combine stats dicts of multiple samplers into one dict.""" stats_dict = {} for s, sts in enumerate(stats_list): - for statname, sval in sts.items(): - sname = f"sampler_{s}__{statname}" - stats_dict[sname] = sval + for fname, sname, is_obj in self._stat_groups[s]: + if sname not in sts: + continue + stats_dict[fname] = sts[sname] return stats_dict def rmap(self, stats_dict: Mapping[str, Any]) -> StatsType: @@ -286,7 +303,11 @@ def rmap(self, stats_dict: Mapping[str, Any]) -> StatsType: The ``stats_dict`` can be a subset of all sampler stats. """ stats_list = [] - for namemap in self._stat_groups: - d = {statname: stats_dict[sname] for sname, statname in namemap if sname in stats_dict} + for group in self._stat_groups: + d = {} + for fname, sname, is_obj in group: + if fname not in stats_dict: + continue + d[sname] = stats_dict[fname] stats_list.append(d) return stats_list diff --git a/requirements-dev.txt b/requirements-dev.txt index 4d3ba57f9b..8d56365965 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -9,6 +9,7 @@ git+https://github.com/pymc-devs/pymc-sphinx-theme h5py>=2.7 ipython>=7.16 jupyter-sphinx +mcbackend>=0.4.0 mypy==0.990 myst-nb numdifftools>=0.9.40 diff --git a/tests/backends/test_mcbackend.py b/tests/backends/test_mcbackend.py new file mode 100644 index 0000000000..2e3693c785 --- /dev/null +++ b/tests/backends/test_mcbackend.py @@ -0,0 +1,305 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import arviz +import numpy as np +import pytest + +import pymc as pm + +from pymc.backends import init_traces +from pymc.step_methods.arraystep import ArrayStepShared + +try: + import mcbackend as mcb + + from mcbackend.npproto.utils import ndarray_to_numpy +except ImportError: + pytest.skip("Requires McBackend to be installed.") + +from pymc.backends.mcbackend import ( + ChainRecordAdapter, + find_data, + get_variables_and_point_fn, + make_runmeta_and_point_fn, +) + + +@pytest.fixture +def simple_model(): + seconds = np.linspace(0, 5) + observations = np.random.normal(0.5 + np.random.uniform(size=3)[:, None] * seconds[None, :]) + with pm.Model( + coords={ + "condition": ["A", "B", "C"], + } + ) as pmodel: + x = pm.ConstantData("seconds", seconds, dims="time") + a = pm.Normal("scalar") + b = pm.Uniform("vector", dims="condition") + pm.Deterministic("matrix", a + b[:, None] * x[None, :], dims=("condition", "time")) + pm.Bernoulli("integer", p=0.5) + obs = pm.MutableData("obs", observations, dims=("condition", "time")) + pm.Normal("L", pmodel["matrix"], observed=obs, dims=("condition", "time")) + return pmodel + + +def test_find_data(simple_model): + dvars = find_data(simple_model) + dvardict = {d.name: d for d in dvars} + assert set(dvardict) == {"seconds", "obs"} + + secs = dvardict["seconds"] + assert isinstance(secs, mcb.DataVariable) + assert secs.dims == ["time"] + assert not secs.is_observed + np.testing.assert_array_equal(ndarray_to_numpy(secs.value), simple_model["seconds"].data) + + obs = dvardict["obs"] + assert isinstance(obs, mcb.DataVariable) + assert obs.dims == ["condition", "time"] + assert obs.is_observed + np.testing.assert_array_equal(ndarray_to_numpy(obs.value), simple_model["obs"].get_value()) + + +def test_find_data_skips_deterministics(): + data = np.array([0, 1], dtype="float32") + with pm.Model() as pmodel: + a = pm.ConstantData("a", data, dims="item") + b = pm.Normal("b") + pm.Deterministic("c", a + b, dims="item") + assert "c" in pmodel.named_vars + dvars = find_data(pmodel) + assert len(dvars) == 1 + assert dvars[0].name == "a" + assert dvars[0].dims == ["item"] + np.testing.assert_array_equal(ndarray_to_numpy(dvars[0].value), data) + assert not dvars[0].is_observed + + +def test_get_variables_and_point_fn(simple_model): + ip = simple_model.initial_point() + variables, point_fn = get_variables_and_point_fn(simple_model, ip) + assert isinstance(variables, list) + assert callable(point_fn) + vdict = {v.name: v for v in variables} + assert set(vdict) == {"integer", "scalar", "vector", "vector_interval__", "matrix"} + point = point_fn(ip) + assert len(point) == len(variables) + for v, p in zip(variables, point): + assert str(p.dtype) == v.dtype + + +def test_make_runmeta_and_point_fn(simple_model): + with simple_model: + step = pm.DEMetropolisZ() + rmeta, point_fn = make_runmeta_and_point_fn( + initial_point=simple_model.initial_point(), + step=step, + model=simple_model, + ) + assert isinstance(rmeta, mcb.RunMeta) + assert callable(point_fn) + vars = {v.name: v for v in rmeta.variables} + assert set(vars.keys()) == {"scalar", "vector", "vector_interval__", "matrix", "integer"} + # NOTE: Technically the "vector" is deterministic, but from the user perspective it is not. + # This is merely a matter of which version of transformed variables should be traced. + assert not vars["vector"].is_deterministic + assert not vars["vector_interval__"].is_deterministic + assert vars["matrix"].is_deterministic + assert len(rmeta.sample_stats) == 1 + len(step.stats_dtypes[0]) + pass + + +def test_init_traces(simple_model): + with simple_model: + step = pm.DEMetropolisZ() + run, traces = init_traces( + backend=mcb.NumPyBackend(), + chains=2, + expected_length=70, + step=step, + initial_point=simple_model.initial_point(), + model=simple_model, + ) + assert isinstance(run, mcb.backends.numpy.NumPyRun) + assert isinstance(traces, list) + assert len(traces) == 2 + assert isinstance(traces[0], ChainRecordAdapter) + assert isinstance(traces[0]._chain, mcb.backends.numpy.NumPyChain) + pass + + +class ToyStepper(ArrayStepShared): + stats_dtypes_shapes = { + "accepted": (bool, []), + "tune": (bool, []), + "s1": (np.float64, []), + } + + def astep(self, *args, **kwargs): + raise NotImplementedError() + + +class ToyStepperWithOtherStats(ToyStepper): + stats_dtypes_shapes = { + "accepted": (bool, []), + "tune": (bool, []), + "s2": (np.float64, []), + } + + +class TestChainRecordAdapter: + def test_get_sampler_stats(self): + # Initialize a very simply toy model + N = 45 + with pm.Model() as pmodel: + a = pm.Normal("a") + b = pm.Uniform("b") + c = pm.Deterministic("c", a + b) + ip = pmodel.initial_point() + shared = pm.make_shared_replacements(ip, [a, b], pmodel) + run, traces = init_traces( + backend=mcb.NumPyBackend(), + chains=1, + expected_length=N, + step=ToyStepper([a, b], shared), + initial_point=pmodel.initial_point(), + model=pmodel, + ) + cra = traces[0] + assert isinstance(run, mcb.backends.numpy.NumPyRun) + assert isinstance(cra, ChainRecordAdapter) + + # Simulate recording of draws and stats + rng = np.random.RandomState(2023) + for i in range(N): + draw = {"a": rng.normal(), "b_interval__": rng.normal()} + stats = [dict(tune=(i <= 5), s1=i, accepted=bool(rng.randint(0, 2)))] + cra.record(draw, stats) + + # Check final state of the chain + assert len(cra) == N + # Variables b and c were calculated by the point function + draws_a = cra.get_values("a") + draws_b = cra.get_values("b") + draws_c = cra.get_values("c") + np.testing.assert_array_equal(draws_a + draws_b, draws_c) + i = np.random.randint(0, N) + point = cra.point(idx=i) + assert point["a"] == draws_a[i] + assert point["b"] == draws_b[i] + assert point["c"] == draws_c[i] + + # Stats come in different shapes depending on the query + s1 = cra.get_sampler_stats("s1", sampler_idx=None, burn=3, thin=2) + assert s1.shape == (21,) + assert s1.dtype == np.dtype("float64") + np.testing.assert_array_equal(s1, np.arange(N)[3:None:2]) + + def test_get_sampler_stats_compound(self, caplog): + # Initialize a very simply toy model + N = 45 + with pm.Model() as pmodel: + a = pm.Normal("a") + b = pm.Uniform("b") + c = pm.Deterministic("c", a + b) + ip = pmodel.initial_point() + shared_a = pm.make_shared_replacements(ip, [a], pmodel) + shared_b = pm.make_shared_replacements(ip, [b], pmodel) + stepA = ToyStepper([a], shared_a) + stepB = ToyStepperWithOtherStats([b], shared_b) + run, traces = init_traces( + backend=mcb.NumPyBackend(), + chains=1, + expected_length=N, + step=pm.CompoundStep([stepA, stepB]), + initial_point=pmodel.initial_point(), + model=pmodel, + ) + cra = traces[0] + assert isinstance(cra, ChainRecordAdapter) + + # Simulate recording of draws and stats + rng = np.random.RandomState(2023) + for i in range(N): + tune = i <= 5 + draw = {"a": rng.normal(), "b_interval__": rng.normal()} + stats = [ + dict(tune=tune, s1=i, accepted=bool(rng.randint(0, 2))), + dict(tune=tune, s2=i, accepted=bool(rng.randint(0, 2))), + ] + cra.record(draw, stats) + + # The 'accepted' stat was emitted by both samplers + assert cra.get_sampler_stats("accepted", sampler_idx=None).shape == (N, 2) + acpt_1 = cra.get_sampler_stats("accepted", sampler_idx=0, burn=3, thin=2) + acpt_2 = cra.get_sampler_stats("accepted", sampler_idx=1, burn=3, thin=2) + assert acpt_1.shape == (21,) # (N-3)/2 + assert not np.array_equal(acpt_1, acpt_2) + + # s1 and s2 were sampler specific + # they are squeezed into vectors, but warnings are logged at DEBUG level + with caplog.at_level(logging.DEBUG, logger="pymc"): + s1 = cra.get_sampler_stats("s1", burn=10) + assert s1.shape == (35,) + assert s1.dtype == np.dtype("float64") + s2 = cra.get_sampler_stats("s2", thin=5) + assert s2.shape == (9,) # N/5 + assert s2.dtype == np.dtype("float64") + assert any("'s1' was not recorded by all samplers" in r.message for r in caplog.records) + + with pytest.raises(KeyError, match="No stat"): + cra.get_sampler_stats("notastat") + + +class TestMcBackendSampling: + @pytest.mark.parametrize("discard_warmup", [False, True]) + def test_return_multitrace(self, simple_model, discard_warmup): + with simple_model: + mtrace = pm.sample( + trace=mcb.NumPyBackend(), + tune=5, + draws=7, + cores=1, + chains=3, + step=pm.Metropolis(), + discard_tuned_samples=discard_warmup, + return_inferencedata=False, + ) + assert isinstance(mtrace, pm.backends.base.MultiTrace) + tune = mtrace._straces[0].get_sampler_stats("tune") + assert isinstance(tune, np.ndarray) + if discard_warmup: + assert tune.shape == (7, 3) + else: + assert tune.shape == (12, 3) + pass + + @pytest.mark.parametrize("cores", [1, 3]) + def test_return_inferencedata(self, simple_model, cores): + with simple_model: + idata = pm.sample( + trace=mcb.NumPyBackend(), + tune=5, + draws=7, + cores=cores, + chains=3, + discard_tuned_samples=False, + ) + assert isinstance(idata, arviz.InferenceData) + assert idata.warmup_posterior.sizes["draw"] == 5 + assert idata.posterior.sizes["draw"] == 7 + pass diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index c87eb6d456..8f7e06fb5c 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -362,6 +362,7 @@ def test_sample_return_lengths(self): # MultiTrace without warmup mtrace_pst = pm.sampling.mcmc._sample_return( + run=None, traces=traces, tune=50, t_sampling=123.4, @@ -380,6 +381,7 @@ def test_sample_return_lengths(self): # InferenceData with warmup idata_w = pm.sampling.mcmc._sample_return( + run=None, traces=traces, tune=50, t_sampling=123.4, @@ -398,6 +400,7 @@ def test_sample_return_lengths(self): # InferenceData without warmup idata = pm.sampling.mcmc._sample_return( + run=None, traces=traces, tune=50, t_sampling=123.4, @@ -463,6 +466,10 @@ def test_keep_warning_stat_setting(self, keep_warning_stat): # This tests flattens so we don't have to be exact in accessing (non-)squeezed items. # Also see https://github.com/pymc-devs/pymc/issues/6207. warn_objs = list(idata.sample_stats.warning.sel(chain=0).values.flatten()) + assert warn_objs + if isinstance(warn_objs[0], np.ndarray): + # Squeeze warning stats. See https://github.com/pymc-devs/pymc/issues/6207 + warn_objs = [a.tolist() for a in warn_objs] assert any(isinstance(w, SamplerWarning) for w in warn_objs) assert any("Asteroid" in w.message for w in warn_objs) else: diff --git a/tests/step_methods/test_compound.py b/tests/step_methods/test_compound.py index 1e20181ed5..ba9d90634d 100644 --- a/tests/step_methods/test_compound.py +++ b/tests/step_methods/test_compound.py @@ -164,20 +164,22 @@ def test_flatten_steps(self): def test_stats_bijection(self): step_stats_dtypes = [ {"a": float, "b": int}, - {"a": float, "c": int}, + {"a": float, "c": Warning}, ] bij = StatsBijection(step_stats_dtypes) + assert bij.object_stats == {"sampler_1__c": (1, "c")} assert bij.n_samplers == 2 + w = Warning("hmm") stats_l = [ dict(a=1.5, b=3), - dict(a=2.5, c=4), + dict(a=2.5, c=w), ] stats_d = bij.map(stats_l) assert isinstance(stats_d, dict) assert stats_d["sampler_0__a"] == 1.5 assert stats_d["sampler_0__b"] == 3 assert stats_d["sampler_1__a"] == 2.5 - assert stats_d["sampler_1__c"] == 4 + assert stats_d["sampler_1__c"] == w rev = bij.rmap(stats_d) assert isinstance(rev, list) assert len(rev) == len(stats_l)