diff --git a/src/conformist/performance_report.py b/src/conformist/performance_report.py index 7405332..e751211 100644 --- a/src/conformist/performance_report.py +++ b/src/conformist/performance_report.py @@ -33,66 +33,65 @@ def pct_trio_plus_sets(prediction_sets): prediction_set in prediction_sets) / \ len(prediction_sets) - def visualize_mean_set_sizes_by_class(self, - mean_set_sizes_by_class): + def _class_report(self, + items_by_class, + output_file_prefix, + ylabel, + color): # Reset plt plt.figure() # Sort the dictionary by its values - mean_set_sizes = dict(sorted(mean_set_sizes_by_class.items(), - key=lambda item: item[1])) + mean_sizes = dict(sorted(items_by_class.items(), + key=lambda item: item[1])) # Convert dictionary to dataframe and transpose - df = pd.DataFrame(mean_set_sizes, index=[0]).T + df = pd.DataFrame(mean_sizes, index=[0]).T # Save as csv - df.to_csv(f'{self.output_dir}/mean_set_sizes_class.csv', + df.to_csv(f'{self.output_dir}/{output_file_prefix}.csv', index=True, header=False) # Visualize this dict as a bar chart sns.set_style('whitegrid') - palette = sns.color_palette("deep") fig, ax = plt.subplots(figsize=(10, 6)) - ax.bar(range(len(mean_set_sizes)), mean_set_sizes.values(), color=palette[1]) - ax.set_xticks(range(len(mean_set_sizes))) - ax.set_xticklabels(mean_set_sizes.keys(), rotation='vertical') - ax.set_ylabel('Mean set size') + ax.bar(range(len(mean_sizes)), mean_sizes.values(), color=color) + ax.set_xticks(range(len(mean_sizes))) + ax.set_xticklabels(mean_sizes.keys(), rotation='vertical') + ax.set_ylabel(ylabel) ax.set_xlabel('True class') plt.tight_layout() - plt.savefig(f'{self.output_dir}/mean_set_sizes_by_class.png') + plt.savefig(f'{self.output_dir}/{output_file_prefix}.png') + + def visualize_mean_set_sizes_by_class(self, + mean_set_sizes_by_class): + palette = sns.color_palette("deep") + self._class_report(mean_set_sizes_by_class, + 'mean_set_sizes_by_class', + 'Mean set size', + palette[1]) def visualize_mean_fnrs_by_class(self, mean_fnrs_by_class): - # Setup - plt.figure() - - # Sort the dictionary by its values - mean_fnrs = dict(sorted(mean_fnrs_by_class.items(), - key=lambda item: item[1])) - - # Visualize this dict as a bar chart - sns.set_style('whitegrid') palette = sns.color_palette("deep") - fig, ax = plt.subplots(figsize=(10, 6)) - ax.bar(range(len(mean_fnrs)), mean_fnrs.values(), color=palette[0]) - ax.set_xticks(range(len(mean_fnrs))) - ax.set_xticklabels(mean_fnrs.keys(), rotation='vertical') - ax.set_ylabel('Mean FNR') - ax.set_xlabel('True class') - plt.tight_layout() - - # Export as fig and text - plt.savefig(f'{self.output_dir}/mean_fnrs_by_class.png') + self._class_report(mean_fnrs_by_class, + 'mean_fnrs_by_class', + 'Mean FNR', + palette[0]) - # Convert dictionary to dataframe and transpose - df = pd.DataFrame(mean_fnrs, index=[0]).T - - # Save as csv - df.to_csv(f'{self.output_dir}/mean_fnrs_by_class.csv', - index=True, header=False) + def visualize_mean_model_fnrs_by_class(self, + mean_fnrs_by_class): + palette = sns.color_palette("deep") + self._class_report(mean_fnrs_by_class, + 'mean_model_fnrs_by_class', + 'Mean model FNR', + palette[2]) def report_class_statistics(self, mean_set_sizes_by_class, - mean_fnrs_by_class): + mean_fnrs_by_class, + mean_model_fnrs_by_class=None): self.visualize_mean_fnrs_by_class(mean_fnrs_by_class) self.visualize_mean_set_sizes_by_class(mean_set_sizes_by_class) + if mean_model_fnrs_by_class: + self.visualize_mean_model_fnrs_by_class(mean_model_fnrs_by_class) diff --git a/src/conformist/validation_run.py b/src/conformist/validation_run.py index e2935bc..8bb2d38 100644 --- a/src/conformist/validation_run.py +++ b/src/conformist/validation_run.py @@ -93,13 +93,13 @@ def mean_set_sizes_by_class(self, class_names): averages[key] = sum(sizes) / len(sizes) return averages - def mean_fnrs_by_class(self, class_names): + def mean_fnrs_by_class(self, sets, class_names): fnrs = {} - for i in range(len(self.prediction_sets)): + for i in range(len(sets)): labels = self.labels_idx[i] # Get corresponding values from class_names pset_class_names = [class_names[i] for i, label in enumerate(labels) if label == 1] - pset = np.array([int(value) for value in self.prediction_sets[i]]) + pset = np.array([int(value) for value in sets[i]]) if (pset[labels == 1] == 1).size > 0: fnr = 1 - np.mean((pset[labels == 1] == 1)) else: @@ -117,7 +117,8 @@ def mean_fnrs_by_class(self, class_names): def run_reports(self, base_output_dir): pr = PerformanceReport(base_output_dir) pr.report_class_statistics(self.mean_set_sizes_by_class(self.class_names), - self.mean_fnrs_by_class(self.class_names)) + self.mean_fnrs_by_class(self.prediction_sets, self.class_names), + self.mean_fnrs_by_class(self.model_predictions, self.class_names)) np.seterr(all='raise') self.create_output_dir(base_output_dir)