Skip to content

Commit

Permalink
Break up reporting function
Browse files Browse the repository at this point in the history
  • Loading branch information
mariya committed Nov 6, 2024
1 parent bfdc061 commit 01a66e5
Showing 1 changed file with 34 additions and 29 deletions.
63 changes: 34 additions & 29 deletions src/conformist/performance_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class PerformanceReport(OutputDir):
def __init__(self, base_output_dir):
self.base_output_dir = base_output_dir
self.create_output_dir(self.base_output_dir)

def mean_set_size(prediction_sets):
return sum(sum(prediction_set) for
Expand All @@ -33,12 +33,36 @@ def pct_trio_plus_sets(prediction_sets):
prediction_set in prediction_sets) / \
len(prediction_sets)

def report_class_statistics(self,
mean_set_sizes_by_class,
mean_fnrs_by_class):
def visualize_mean_set_sizes_by_class(self,
mean_set_sizes_by_class):
# Reset plt
plt.figure()

# Sort the dictionary by its values
mean_set_sizes = dict(sorted(mean_set_sizes_by_class.items(),
key=lambda item: item[1]))

# Convert dictionary to dataframe and transpose
df = pd.DataFrame(mean_set_sizes, index=[0]).T

# Save as csv
df.to_csv(f'{self.output_dir}/mean_set_sizes_class.csv',
index=True, header=False)

# Visualize this dict as a bar chart
sns.set_style('whitegrid')
fig, ax = plt.subplots(figsize=(10, 6))
ax.bar(range(len(mean_set_sizes)), mean_set_sizes.values(), color=palette[1])
ax.set_xticks(range(len(mean_set_sizes)))
ax.set_xticklabels(mean_set_sizes.keys(), rotation='vertical')
ax.set_ylabel('Mean set size')
ax.set_xlabel('True class')
plt.tight_layout()
plt.savefig(f'{self.output_dir}/mean_set_sizes_by_class.png')

def visualize_mean_fnrs_by_class(self,
mean_fnrs_by_class):
# Setup
self.create_output_dir(self.base_output_dir)
plt.figure()

# Sort the dictionary by its values
Expand Down Expand Up @@ -66,27 +90,8 @@ def report_class_statistics(self,
df.to_csv(f'{self.output_dir}/mean_fnrs_by_class.csv',
index=True, header=False)

# Reset plt
plt.figure()

# Sort the dictionary by its values
mean_set_sizes = dict(sorted(mean_set_sizes_by_class.items(),
key=lambda item: item[1]))

# Convert dictionary to dataframe and transpose
df = pd.DataFrame(mean_set_sizes, index=[0]).T

# Save as csv
df.to_csv(f'{self.output_dir}/mean_set_sizes_class.csv',
index=True, header=False)

# Visualize this dict as a bar chart
sns.set_style('whitegrid')
fig, ax = plt.subplots(figsize=(10, 6))
ax.bar(range(len(mean_set_sizes)), mean_set_sizes.values(), color=palette[1])
ax.set_xticks(range(len(mean_set_sizes)))
ax.set_xticklabels(mean_set_sizes.keys(), rotation='vertical')
ax.set_ylabel('Mean set size')
ax.set_xlabel('True class')
plt.tight_layout()
plt.savefig(f'{self.output_dir}/mean_set_sizes_by_class.png')
def report_class_statistics(self,
mean_set_sizes_by_class,
mean_fnrs_by_class):
self.visualize_mean_fnrs_by_class(mean_fnrs_by_class)
self.visualize_mean_set_sizes_by_class(mean_set_sizes_by_class)

0 comments on commit 01a66e5

Please sign in to comment.