Skip to content

Commit

Permalink
Format strip plot
Browse files Browse the repository at this point in the history
  • Loading branch information
mariya committed Nov 13, 2024
1 parent 00e12ce commit 77b45f5
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions src/conformist/prediction_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
from matplotlib.patches import Patch

from .output_dir import OutputDir

Expand Down Expand Up @@ -165,6 +166,7 @@ def run_reports(self, base_output_dir):
self.visualize_class_counts()
self.visualize_class_counts_by_dataset()
self.visualize_prediction_heatmap()
self.visualize_prediction_stripplot()
print(f'Reports saved to {self.output_dir}')

def _class_colors(self):
Expand Down Expand Up @@ -349,11 +351,18 @@ def visualize_prediction_stripplot(self):
x='Softmax score',
y='True class',
hue='Predicted class',
jitter=0.2,
alpha=0.5,
jitter=0.5,
alpha=0.75,
dodge=True,
palette=self._class_colors(),
size=5)
size=4)

# Create custom legend handles
class_to_color = self._class_colors()
legend_handles = [Patch(color=class_to_color[cls], label=cls) for cls in new_df['Predicted class'].unique()]

# Position the legend to the right of the plot with bars instead of dots
plt.legend(handles=legend_handles, title="Predicted Classes", bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)

# Save the plot to a file
plt.tight_layout()
Expand Down

0 comments on commit 77b45f5

Please sign in to comment.