From 57b0a5ac519e8e088e911bc4c8884ef690ab870f Mon Sep 17 00:00:00 2001 From: Mariya Lysenkova Wiklander Date: Tue, 12 Nov 2024 15:44:03 +0100 Subject: [PATCH] Order class barchart by num of examples per class --- src/conformist/prediction_dataset.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/src/conformist/prediction_dataset.py b/src/conformist/prediction_dataset.py index ae34281..b55471d 100644 --- a/src/conformist/prediction_dataset.py +++ b/src/conformist/prediction_dataset.py @@ -202,18 +202,31 @@ def visualize_class_counts_by_dataset(self): if num_datasets == 1: axs = [axs] + # Group by the first level of the index (dataset) and sum the values + grouped_ccs = ccs.groupby(level=0).sum() + + # Order datasets by number of items + ordered_datasets = grouped_ccs.sort_values(ascending=False).index + print(type(ordered_datasets)) + # For each dataset, create a bar chart - for i, dataset in enumerate(ccs.index.get_level_values(0).unique()): - sorted = ccs.loc[dataset].sort_values(ascending=False) + for i, dataset in enumerate(ordered_datasets): + dataset_series = ccs.loc[dataset] + + # Ensure dataset_series is a Series + if not isinstance(dataset_series, pd.Series): + raise ValueError(f"Expected ccs.loc[{dataset}] to be a Series") + + sorted_series = dataset_series.sort_values(ascending=False) - # Set fixed width for bar - # axs[i].bar(sorted.index, sorted.values, width=0.5) - sorted.plot.bar(ax=axs[i]) + # Plot bar chart with fixed width + axs[i].bar(sorted_series.index, sorted_series.values, width=0.5) axs[i].set_title(dataset) # Print count above each bar - for j, v in enumerate(sorted): - axs[i].text(j, v, str(v), ha='center', va='bottom') + for j, v in enumerate(sorted_series): + if np.isfinite(v): + axs[i].text(j, v, str(v), ha='center', va='bottom') # show the plot plt.savefig(f'{self.output_dir}/class_counts_by_dataset.png',