Skip to content

Commit

Permalink
Order class barchart by num of examples per class
Browse files Browse the repository at this point in the history
  • Loading branch information
mariya committed Nov 12, 2024
1 parent df0a36b commit 57b0a5a
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions src/conformist/prediction_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,18 +202,31 @@ def visualize_class_counts_by_dataset(self):
if num_datasets == 1:
axs = [axs]

# Group by the first level of the index (dataset) and sum the values
grouped_ccs = ccs.groupby(level=0).sum()

# Order datasets by number of items
ordered_datasets = grouped_ccs.sort_values(ascending=False).index
print(type(ordered_datasets))

# For each dataset, create a bar chart
for i, dataset in enumerate(ccs.index.get_level_values(0).unique()):
sorted = ccs.loc[dataset].sort_values(ascending=False)
for i, dataset in enumerate(ordered_datasets):
dataset_series = ccs.loc[dataset]

# Ensure dataset_series is a Series
if not isinstance(dataset_series, pd.Series):
raise ValueError(f"Expected ccs.loc[{dataset}] to be a Series")

sorted_series = dataset_series.sort_values(ascending=False)

# Set fixed width for bar
# axs[i].bar(sorted.index, sorted.values, width=0.5)
sorted.plot.bar(ax=axs[i])
# Plot bar chart with fixed width
axs[i].bar(sorted_series.index, sorted_series.values, width=0.5)
axs[i].set_title(dataset)

# Print count above each bar
for j, v in enumerate(sorted):
axs[i].text(j, v, str(v), ha='center', va='bottom')
for j, v in enumerate(sorted_series):
if np.isfinite(v):
axs[i].text(j, v, str(v), ha='center', va='bottom')

# show the plot
plt.savefig(f'{self.output_dir}/class_counts_by_dataset.png',
Expand Down

0 comments on commit 57b0a5a

Please sign in to comment.