Skip to content

Commit

Permalink
Report on uncalibrated model FNRs by class
Browse files Browse the repository at this point in the history
  • Loading branch information
mariya committed Nov 7, 2024
1 parent ed5687b commit 3d5b380
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 42 deletions.
75 changes: 37 additions & 38 deletions src/conformist/performance_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,66 +33,65 @@ def pct_trio_plus_sets(prediction_sets):
prediction_set in prediction_sets) / \
len(prediction_sets)

def visualize_mean_set_sizes_by_class(self,
mean_set_sizes_by_class):
def _class_report(self,
items_by_class,
output_file_prefix,
ylabel,
color):
# Reset plt
plt.figure()

# Sort the dictionary by its values
mean_set_sizes = dict(sorted(mean_set_sizes_by_class.items(),
key=lambda item: item[1]))
mean_sizes = dict(sorted(items_by_class.items(),
key=lambda item: item[1]))

# Convert dictionary to dataframe and transpose
df = pd.DataFrame(mean_set_sizes, index=[0]).T
df = pd.DataFrame(mean_sizes, index=[0]).T

# Save as csv
df.to_csv(f'{self.output_dir}/mean_set_sizes_class.csv',
df.to_csv(f'{self.output_dir}/{output_file_prefix}.csv',
index=True, header=False)

# Visualize this dict as a bar chart
sns.set_style('whitegrid')
palette = sns.color_palette("deep")
fig, ax = plt.subplots(figsize=(10, 6))
ax.bar(range(len(mean_set_sizes)), mean_set_sizes.values(), color=palette[1])
ax.set_xticks(range(len(mean_set_sizes)))
ax.set_xticklabels(mean_set_sizes.keys(), rotation='vertical')
ax.set_ylabel('Mean set size')
ax.bar(range(len(mean_sizes)), mean_sizes.values(), color=color)
ax.set_xticks(range(len(mean_sizes)))
ax.set_xticklabels(mean_sizes.keys(), rotation='vertical')
ax.set_ylabel(ylabel)
ax.set_xlabel('True class')
plt.tight_layout()
plt.savefig(f'{self.output_dir}/mean_set_sizes_by_class.png')
plt.savefig(f'{self.output_dir}/{output_file_prefix}.png')

def visualize_mean_set_sizes_by_class(self,
mean_set_sizes_by_class):
palette = sns.color_palette("deep")
self._class_report(mean_set_sizes_by_class,
'mean_set_sizes_by_class',
'Mean set size',
palette[1])

def visualize_mean_fnrs_by_class(self,
mean_fnrs_by_class):
# Setup
plt.figure()

# Sort the dictionary by its values
mean_fnrs = dict(sorted(mean_fnrs_by_class.items(),
key=lambda item: item[1]))

# Visualize this dict as a bar chart
sns.set_style('whitegrid')
palette = sns.color_palette("deep")
fig, ax = plt.subplots(figsize=(10, 6))
ax.bar(range(len(mean_fnrs)), mean_fnrs.values(), color=palette[0])
ax.set_xticks(range(len(mean_fnrs)))
ax.set_xticklabels(mean_fnrs.keys(), rotation='vertical')
ax.set_ylabel('Mean FNR')
ax.set_xlabel('True class')
plt.tight_layout()

# Export as fig and text
plt.savefig(f'{self.output_dir}/mean_fnrs_by_class.png')
self._class_report(mean_fnrs_by_class,
'mean_fnrs_by_class',
'Mean FNR',
palette[0])

# Convert dictionary to dataframe and transpose
df = pd.DataFrame(mean_fnrs, index=[0]).T

# Save as csv
df.to_csv(f'{self.output_dir}/mean_fnrs_by_class.csv',
index=True, header=False)
def visualize_mean_model_fnrs_by_class(self,
mean_fnrs_by_class):
palette = sns.color_palette("deep")
self._class_report(mean_fnrs_by_class,
'mean_model_fnrs_by_class',
'Mean model FNR',
palette[2])

def report_class_statistics(self,
mean_set_sizes_by_class,
mean_fnrs_by_class):
mean_fnrs_by_class,
mean_model_fnrs_by_class=None):
self.visualize_mean_fnrs_by_class(mean_fnrs_by_class)
self.visualize_mean_set_sizes_by_class(mean_set_sizes_by_class)
if mean_model_fnrs_by_class:
self.visualize_mean_model_fnrs_by_class(mean_model_fnrs_by_class)
9 changes: 5 additions & 4 deletions src/conformist/validation_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ def mean_set_sizes_by_class(self, class_names):
averages[key] = sum(sizes) / len(sizes)
return averages

def mean_fnrs_by_class(self, class_names):
def mean_fnrs_by_class(self, sets, class_names):
fnrs = {}
for i in range(len(self.prediction_sets)):
for i in range(len(sets)):
labels = self.labels_idx[i]
# Get corresponding values from class_names
pset_class_names = [class_names[i] for i, label in enumerate(labels) if label == 1]
pset = np.array([int(value) for value in self.prediction_sets[i]])
pset = np.array([int(value) for value in sets[i]])
if (pset[labels == 1] == 1).size > 0:
fnr = 1 - np.mean((pset[labels == 1] == 1))
else:
Expand All @@ -117,7 +117,8 @@ def mean_fnrs_by_class(self, class_names):
def run_reports(self, base_output_dir):
pr = PerformanceReport(base_output_dir)
pr.report_class_statistics(self.mean_set_sizes_by_class(self.class_names),
self.mean_fnrs_by_class(self.class_names))
self.mean_fnrs_by_class(self.prediction_sets, self.class_names),
self.mean_fnrs_by_class(self.model_predictions, self.class_names))

np.seterr(all='raise')
self.create_output_dir(base_output_dir)
Expand Down

0 comments on commit 3d5b380

Please sign in to comment.