Skip to content

Commit

Permalink
Fix other instances of Quantity.item() use
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Oct 6, 2023
1 parent 16d831c commit 30c292f
Showing 1 changed file with 29 additions and 27 deletions.
56 changes: 29 additions & 27 deletions brian2modelfitting/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,8 +1101,8 @@ def pairplot(self, samples=None, points=None, limits=None, subset=None,
if len(lim_vals) != 2:
raise ValueError('Invalid limits for parameter: '
f'{param_name}')
limits = [[limits[param_name][0].item(),
limits[param_name][1].item()]
limits = [[float(limits[param_name][0]),
float(limits[param_name][1])]
for param_name in self.param_names]
if subset:
for param_name in subset:
Expand All @@ -1117,8 +1117,8 @@ def pairplot(self, samples=None, points=None, limits=None, subset=None,
if len(lim_vals) != 2:
raise ValueError('Invalid limits for parameter: '
f'{param_name}')
ticks = [[ticks[param_name][0].item(),
ticks[param_name][1].item()]
ticks = [[float(ticks[param_name][0]),
float(ticks[param_name][1])]
for param_name in self.param_names]
else:
ticks = []
Expand Down Expand Up @@ -1207,12 +1207,12 @@ def conditional_pairplot(self, condition, density=None, points=None,
if len(lim_vals) != 2:
raise ValueError('Invalid limits for parameter: '
f'{param_name}')
limits = [[limits[param_name][0].item(),
limits[param_name][1].item()]
limits = [[float(limits[param_name][0]),
float(limits[param_name][1])]
for param_name in self.param_names]
else:
limits = [[self.params[param_name][0].item(),
self.params[param_name][1].item()]
limits = [[float(self.params[param_name][0]),
float(self.params[param_name][1])]
for param_name in self.param_names]
if subset:
for param_name in subset:
Expand All @@ -1227,8 +1227,8 @@ def conditional_pairplot(self, condition, density=None, points=None,
if len(lim_vals) != 2:
raise ValueError('Invalid limits for parameter: '
f'{param_name}')
ticks = [[ticks[param_name][0].item(),
ticks[param_name][1].item()]
ticks = [[float(ticks[param_name][0]),
float(ticks[param_name][1])]
for param_name in self.param_names]
else:
ticks = []
Expand All @@ -1237,14 +1237,15 @@ def conditional_pairplot(self, condition, density=None, points=None,
if param_name not in self.param_names:
raise AttributeError(f'Invalid parameter: {param_name}')
labels = [labels[param_name] for param_name in self.param_names]
fig, axes = sbi.analysis.conditional_pairplot(density=d,
condition=condition,
limits=limits,
points=points,
subset=subset,
labels=labels,
ticks=ticks,
**kwargs)
from sbi.analysis import conditional_pairplot
fig, axes = conditional_pairplot(density=d,
condition=condition,
limits=limits,
points=points,
subset=subset,
labels=labels,
ticks=ticks,
**kwargs)
return fig, axes

def conditional_corrcoeff(self, condition, density=None, limits=None,
Expand Down Expand Up @@ -1305,12 +1306,12 @@ def conditional_corrcoeff(self, condition, density=None, limits=None,
if len(lim_vals) != 2:
raise ValueError('Invalid limits for parameter: '
f'{param_name}')
limits = [[limits[param_name][0].item(),
limits[param_name][1].item()]
limits = [[float(limits[param_name][0]),
float(limits[param_name][1])]
for param_name in self.param_names]
else:
limits = [[self.params[param_name][0].item(),
self.params[param_name][1].item()]
limits = [[float(self.params[param_name][0]),
float(self.params[param_name][1])]
for param_name in self.param_names]
limits = torch.tensor(limits)
if subset:
Expand All @@ -1319,11 +1320,12 @@ def conditional_corrcoeff(self, condition, density=None, limits=None,
raise AttributeError(f'Invalid parameter: {param_name}')
subset = [self.param_names.index(param_name)
for param_name in subset]
cond_coeff = sbi.analysis.conditional_corrcoeff(density=d,
limits=limits,
condition=condition,
subset=subset,
**kwargs)
from sbi.analysis import conditional_corrcoeff
cond_coeff = conditional_corrcoeff(density=d,
limits=limits,
condition=condition,
subset=subset,
**kwargs)
return cond_coeff.numpy()

def generate_traces(self, n_samples=1, posterior=None, output_var=None,
Expand Down

0 comments on commit 30c292f

Please sign in to comment.