From 01a66e5b6c3a003196ab7b88ba6d0a29cfa9c926 Mon Sep 17 00:00:00 2001 From: Mariya Lysenkova Wiklander Date: Wed, 6 Nov 2024 15:48:26 +0100 Subject: [PATCH] Break up reporting function --- src/conformist/performance_report.py | 63 +++++++++++++++------------- 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/src/conformist/performance_report.py b/src/conformist/performance_report.py index ee20ab3..2be04d5 100644 --- a/src/conformist/performance_report.py +++ b/src/conformist/performance_report.py @@ -6,7 +6,7 @@ class PerformanceReport(OutputDir): def __init__(self, base_output_dir): - self.base_output_dir = base_output_dir + self.create_output_dir(self.base_output_dir) def mean_set_size(prediction_sets): return sum(sum(prediction_set) for @@ -33,12 +33,36 @@ def pct_trio_plus_sets(prediction_sets): prediction_set in prediction_sets) / \ len(prediction_sets) - def report_class_statistics(self, - mean_set_sizes_by_class, - mean_fnrs_by_class): + def visualize_mean_set_sizes_by_class(self, + mean_set_sizes_by_class): + # Reset plt + plt.figure() + + # Sort the dictionary by its values + mean_set_sizes = dict(sorted(mean_set_sizes_by_class.items(), + key=lambda item: item[1])) + + # Convert dictionary to dataframe and transpose + df = pd.DataFrame(mean_set_sizes, index=[0]).T + + # Save as csv + df.to_csv(f'{self.output_dir}/mean_set_sizes_class.csv', + index=True, header=False) + + # Visualize this dict as a bar chart + sns.set_style('whitegrid') + fig, ax = plt.subplots(figsize=(10, 6)) + ax.bar(range(len(mean_set_sizes)), mean_set_sizes.values(), color=palette[1]) + ax.set_xticks(range(len(mean_set_sizes))) + ax.set_xticklabels(mean_set_sizes.keys(), rotation='vertical') + ax.set_ylabel('Mean set size') + ax.set_xlabel('True class') + plt.tight_layout() + plt.savefig(f'{self.output_dir}/mean_set_sizes_by_class.png') + def visualize_mean_fnrs_by_class(self, + mean_fnrs_by_class): # Setup - self.create_output_dir(self.base_output_dir) plt.figure() # Sort the dictionary by its values @@ -66,27 +90,8 @@ def report_class_statistics(self, df.to_csv(f'{self.output_dir}/mean_fnrs_by_class.csv', index=True, header=False) - # Reset plt - plt.figure() - - # Sort the dictionary by its values - mean_set_sizes = dict(sorted(mean_set_sizes_by_class.items(), - key=lambda item: item[1])) - - # Convert dictionary to dataframe and transpose - df = pd.DataFrame(mean_set_sizes, index=[0]).T - - # Save as csv - df.to_csv(f'{self.output_dir}/mean_set_sizes_class.csv', - index=True, header=False) - - # Visualize this dict as a bar chart - sns.set_style('whitegrid') - fig, ax = plt.subplots(figsize=(10, 6)) - ax.bar(range(len(mean_set_sizes)), mean_set_sizes.values(), color=palette[1]) - ax.set_xticks(range(len(mean_set_sizes))) - ax.set_xticklabels(mean_set_sizes.keys(), rotation='vertical') - ax.set_ylabel('Mean set size') - ax.set_xlabel('True class') - plt.tight_layout() - plt.savefig(f'{self.output_dir}/mean_set_sizes_by_class.png') + def report_class_statistics(self, + mean_set_sizes_by_class, + mean_fnrs_by_class): + self.visualize_mean_fnrs_by_class(mean_fnrs_by_class) + self.visualize_mean_set_sizes_by_class(mean_set_sizes_by_class)