Skip to content

Commit

Permalink
Fix sbi imports
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Oct 6, 2023
1 parent 60311d4 commit 16d831c
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions brian2modelfitting/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,12 @@ def calc_prior(param_names, **params):
Return
------
sbi.utils.torchutils.BoxUniform
sbi.utils.BoxUniform
``sbi``-compatible object that contains a uniform prior
distribution over a given set of parameters.
"""
_check_sbi()
from sbi.utils import BoxUniform
for param_name in param_names:
if param_name not in params:
raise TypeError(f'Bounds must be set for parameter {param_name}')
Expand All @@ -160,8 +161,8 @@ def calc_prior(param_names, **params):
for param_name in param_names:
prior_min.append(float(min(params[param_name])))
prior_max.append(float(max(params[param_name])))
prior = sbi.utils.torchutils.BoxUniform(low=torch.as_tensor(prior_min),
high=torch.as_tensor(prior_max))
prior = BoxUniform(low=torch.as_tensor(prior_min),
high=torch.as_tensor(prior_max))
return prior


Expand Down Expand Up @@ -1126,13 +1127,14 @@ def pairplot(self, samples=None, points=None, limits=None, subset=None,
if param_name not in self.param_names:
raise AttributeError(f'Invalid parameter: {param_name}')
labels = [labels[param_name] for param_name in self.param_names]
fig, axes = sbi.analysis.pairplot(samples=s,
points=points,
limits=limits,
subset=subset,
labels=labels,
ticks=ticks,
**kwargs)
from sbi.analysis import pairplot
fig, axes = pairplot(samples=s,
points=points,
limits=limits,
subset=subset,
labels=labels,
ticks=ticks,
**kwargs)
return fig, axes

def conditional_pairplot(self, condition, density=None, points=None,
Expand Down

0 comments on commit 16d831c

Please sign in to comment.