Skip to content

Commit

Permalink
No gridlines on token attr plots+typo fix
Browse files Browse the repository at this point in the history
Summary:
Gridlines can make negative signs and numbers harder to read. Removing them to ease readability.

Also relevant to aesthetics: D63039687

Differential Revision: D63062341
  • Loading branch information
craymichael authored and facebook-github-bot committed Sep 19, 2024
1 parent 49d8689 commit 3e6841b
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def plot_token_attr(

fig, ax = plt.subplots()

# Hide the grid
ax.grid(False)

# Plot the heatmap
data = token_attr.numpy()

Expand Down Expand Up @@ -119,7 +122,7 @@ def plot_token_attr(

# Create colorbar
cbar = fig.colorbar(im, ax=ax) # type: ignore
cbar.ax.set_ylabel("Token Attribuiton", rotation=-90, va="bottom")
cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom")

# Show all ticks and label them with the respective list entries.
shortened_tokens = [
Expand Down Expand Up @@ -204,7 +207,7 @@ def plot_seq_attr(
color="#d0365b",
)

ax.set_ylabel("Sequence Attribuiton", rotation=90, va="bottom")
ax.set_ylabel("Sequence Attribution", rotation=90, va="bottom")

if show:
plt.show()
Expand Down

0 comments on commit 3e6841b

Please sign in to comment.