Skip to content

Commit

Permalink
Add legend to bar chart
Browse files Browse the repository at this point in the history
  • Loading branch information
mariya committed Nov 12, 2024
1 parent 57b0a5a commit b1fd715
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions src/conformist/prediction_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,15 @@ def visualize_class_counts_by_dataset(self):
# create a bar chart
ccs = self.class_counts_by_dataset()

# Get all unique classes
all_classes = ccs.index.get_level_values(1).unique()

# Define a colormap
colormap = plt.cm.get_cmap('tab20')

# Create a dictionary to map each class to a color
class_to_color = {cls: colormap(i) for i, cls in enumerate(all_classes)}

# Count how many datasets and create a grid of plots
num_datasets = len(ccs.index.get_level_values(0).unique())
fig, axs = plt.subplots(num_datasets, 1, figsize=(10, 8 * num_datasets))
Expand All @@ -207,7 +216,6 @@ def visualize_class_counts_by_dataset(self):

# Order datasets by number of items
ordered_datasets = grouped_ccs.sort_values(ascending=False).index
print(type(ordered_datasets))

# For each dataset, create a bar chart
for i, dataset in enumerate(ordered_datasets):
Expand All @@ -219,15 +227,28 @@ def visualize_class_counts_by_dataset(self):

sorted_series = dataset_series.sort_values(ascending=False)

# Get colors for the bars
bar_colors = [class_to_color[cls] for cls in sorted_series.index]

# Plot bar chart with fixed width
axs[i].bar(sorted_series.index, sorted_series.values, width=0.5)
bars = axs[i].bar(sorted_series.index,
sorted_series.values,
width=0.5,
color=bar_colors)
axs[i].set_title(dataset)

# Print count above each bar
for j, v in enumerate(sorted_series):
if np.isfinite(v):
axs[i].text(j, v, str(v), ha='center', va='bottom')

# Remove x-axis labels
axs[i].set_xticklabels([])

# Add legend
if i == 0:
axs[i].legend(bars, sorted_series.index, title="Classes")

# show the plot
plt.savefig(f'{self.output_dir}/class_counts_by_dataset.png',
bbox_inches='tight')
Expand Down

0 comments on commit b1fd715

Please sign in to comment.