diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 89f32747d..4163a5eca 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -1,6 +1,8 @@ # pyre-strict from copy import copy +from textwrap import shorten + from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union import matplotlib.pyplot as plt @@ -103,7 +105,10 @@ def plot_token_attr( cbar.ax.set_ylabel("Token Attribuiton", rotation=-90, va="bottom") # Show all ticks and label them with the respective list entries. - ax.set_xticks(np.arange(data.shape[1]), labels=self.input_tokens) + shortened_tokens = [ + shorten(t, width=50, placeholder="...") for t in self.input_tokens + ] + ax.set_xticks(np.arange(data.shape[1]), labels=shortened_tokens) ax.set_yticks(np.arange(data.shape[0]), labels=self.output_tokens) # Let the horizontal axes labeling appear on top. @@ -149,7 +154,10 @@ def plot_seq_attr( data = self.seq_attr.cpu().numpy() - ax.set_xticks(range(data.shape[0]), labels=self.input_tokens) + shortened_tokens = [ + shorten(t, width=50, placeholder="...") for t in self.input_tokens + ] + ax.set_xticks(range(data.shape[0]), labels=shortened_tokens) ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)