diff --git a/brian2modelfitting/inferencer.py b/brian2modelfitting/inferencer.py index 60f17a2..fadc6ae 100644 --- a/brian2modelfitting/inferencer.py +++ b/brian2modelfitting/inferencer.py @@ -1101,8 +1101,8 @@ def pairplot(self, samples=None, points=None, limits=None, subset=None, if len(lim_vals) != 2: raise ValueError('Invalid limits for parameter: ' f'{param_name}') - limits = [[limits[param_name][0].item(), - limits[param_name][1].item()] + limits = [[float(limits[param_name][0]), + float(limits[param_name][1])] for param_name in self.param_names] if subset: for param_name in subset: @@ -1117,8 +1117,8 @@ def pairplot(self, samples=None, points=None, limits=None, subset=None, if len(lim_vals) != 2: raise ValueError('Invalid limits for parameter: ' f'{param_name}') - ticks = [[ticks[param_name][0].item(), - ticks[param_name][1].item()] + ticks = [[float(ticks[param_name][0]), + float(ticks[param_name][1])] for param_name in self.param_names] else: ticks = [] @@ -1207,12 +1207,12 @@ def conditional_pairplot(self, condition, density=None, points=None, if len(lim_vals) != 2: raise ValueError('Invalid limits for parameter: ' f'{param_name}') - limits = [[limits[param_name][0].item(), - limits[param_name][1].item()] + limits = [[float(limits[param_name][0]), + float(limits[param_name][1])] for param_name in self.param_names] else: - limits = [[self.params[param_name][0].item(), - self.params[param_name][1].item()] + limits = [[float(self.params[param_name][0]), + float(self.params[param_name][1])] for param_name in self.param_names] if subset: for param_name in subset: @@ -1227,8 +1227,8 @@ def conditional_pairplot(self, condition, density=None, points=None, if len(lim_vals) != 2: raise ValueError('Invalid limits for parameter: ' f'{param_name}') - ticks = [[ticks[param_name][0].item(), - ticks[param_name][1].item()] + ticks = [[float(ticks[param_name][0]), + float(ticks[param_name][1])] for param_name in self.param_names] else: ticks = [] @@ -1237,14 +1237,15 @@ def conditional_pairplot(self, condition, density=None, points=None, if param_name not in self.param_names: raise AttributeError(f'Invalid parameter: {param_name}') labels = [labels[param_name] for param_name in self.param_names] - fig, axes = sbi.analysis.conditional_pairplot(density=d, - condition=condition, - limits=limits, - points=points, - subset=subset, - labels=labels, - ticks=ticks, - **kwargs) + from sbi.analysis import conditional_pairplot + fig, axes = conditional_pairplot(density=d, + condition=condition, + limits=limits, + points=points, + subset=subset, + labels=labels, + ticks=ticks, + **kwargs) return fig, axes def conditional_corrcoeff(self, condition, density=None, limits=None, @@ -1305,12 +1306,12 @@ def conditional_corrcoeff(self, condition, density=None, limits=None, if len(lim_vals) != 2: raise ValueError('Invalid limits for parameter: ' f'{param_name}') - limits = [[limits[param_name][0].item(), - limits[param_name][1].item()] + limits = [[float(limits[param_name][0]), + float(limits[param_name][1])] for param_name in self.param_names] else: - limits = [[self.params[param_name][0].item(), - self.params[param_name][1].item()] + limits = [[float(self.params[param_name][0]), + float(self.params[param_name][1])] for param_name in self.param_names] limits = torch.tensor(limits) if subset: @@ -1319,11 +1320,12 @@ def conditional_corrcoeff(self, condition, density=None, limits=None, raise AttributeError(f'Invalid parameter: {param_name}') subset = [self.param_names.index(param_name) for param_name in subset] - cond_coeff = sbi.analysis.conditional_corrcoeff(density=d, - limits=limits, - condition=condition, - subset=subset, - **kwargs) + from sbi.analysis import conditional_corrcoeff + cond_coeff = conditional_corrcoeff(density=d, + limits=limits, + condition=condition, + subset=subset, + **kwargs) return cond_coeff.numpy() def generate_traces(self, n_samples=1, posterior=None, output_var=None,