From 16d831cf3067433860209d93572b4054cb8e848d Mon Sep 17 00:00:00 2001 From: Marcel Stimberg Date: Fri, 6 Oct 2023 12:05:33 +0200 Subject: [PATCH] Fix sbi imports --- brian2modelfitting/inferencer.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/brian2modelfitting/inferencer.py b/brian2modelfitting/inferencer.py index 4b93462..60f17a2 100644 --- a/brian2modelfitting/inferencer.py +++ b/brian2modelfitting/inferencer.py @@ -147,11 +147,12 @@ def calc_prior(param_names, **params): Return ------ - sbi.utils.torchutils.BoxUniform + sbi.utils.BoxUniform ``sbi``-compatible object that contains a uniform prior distribution over a given set of parameters. """ _check_sbi() + from sbi.utils import BoxUniform for param_name in param_names: if param_name not in params: raise TypeError(f'Bounds must be set for parameter {param_name}') @@ -160,8 +161,8 @@ def calc_prior(param_names, **params): for param_name in param_names: prior_min.append(float(min(params[param_name]))) prior_max.append(float(max(params[param_name]))) - prior = sbi.utils.torchutils.BoxUniform(low=torch.as_tensor(prior_min), - high=torch.as_tensor(prior_max)) + prior = BoxUniform(low=torch.as_tensor(prior_min), + high=torch.as_tensor(prior_max)) return prior @@ -1126,13 +1127,14 @@ def pairplot(self, samples=None, points=None, limits=None, subset=None, if param_name not in self.param_names: raise AttributeError(f'Invalid parameter: {param_name}') labels = [labels[param_name] for param_name in self.param_names] - fig, axes = sbi.analysis.pairplot(samples=s, - points=points, - limits=limits, - subset=subset, - labels=labels, - ticks=ticks, - **kwargs) + from sbi.analysis import pairplot + fig, axes = pairplot(samples=s, + points=points, + limits=limits, + subset=subset, + labels=labels, + ticks=ticks, + **kwargs) return fig, axes def conditional_pairplot(self, condition, density=None, points=None,