diff --git a/src/conformist/prediction_dataset.py b/src/conformist/prediction_dataset.py index b55471d..a552aa5 100644 --- a/src/conformist/prediction_dataset.py +++ b/src/conformist/prediction_dataset.py @@ -195,6 +195,15 @@ def visualize_class_counts_by_dataset(self): # create a bar chart ccs = self.class_counts_by_dataset() + # Get all unique classes + all_classes = ccs.index.get_level_values(1).unique() + + # Define a colormap + colormap = plt.cm.get_cmap('tab20') + + # Create a dictionary to map each class to a color + class_to_color = {cls: colormap(i) for i, cls in enumerate(all_classes)} + # Count how many datasets and create a grid of plots num_datasets = len(ccs.index.get_level_values(0).unique()) fig, axs = plt.subplots(num_datasets, 1, figsize=(10, 8 * num_datasets)) @@ -207,7 +216,6 @@ def visualize_class_counts_by_dataset(self): # Order datasets by number of items ordered_datasets = grouped_ccs.sort_values(ascending=False).index - print(type(ordered_datasets)) # For each dataset, create a bar chart for i, dataset in enumerate(ordered_datasets): @@ -219,8 +227,14 @@ def visualize_class_counts_by_dataset(self): sorted_series = dataset_series.sort_values(ascending=False) + # Get colors for the bars + bar_colors = [class_to_color[cls] for cls in sorted_series.index] + # Plot bar chart with fixed width - axs[i].bar(sorted_series.index, sorted_series.values, width=0.5) + bars = axs[i].bar(sorted_series.index, + sorted_series.values, + width=0.5, + color=bar_colors) axs[i].set_title(dataset) # Print count above each bar @@ -228,6 +242,13 @@ def visualize_class_counts_by_dataset(self): if np.isfinite(v): axs[i].text(j, v, str(v), ha='center', va='bottom') + # Remove x-axis labels + axs[i].set_xticklabels([]) + + # Add legend + if i == 0: + axs[i].legend(bars, sorted_series.index, title="Classes") + # show the plot plt.savefig(f'{self.output_dir}/class_counts_by_dataset.png', bbox_inches='tight')