Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added skip_bos Argument to LLM Attribution To Enable Wider Model Support for Attributing Against a Single Token #1322

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
17 changes: 12 additions & 5 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,12 @@ def _forward_func(
# 1st element is the total prob, rest are the target tokens
# add a leading dim for batch even we only support single instance for now
if self.include_per_token_attr:
target_log_probs = torch.stack(
[total_log_prob, *log_prob_list], dim=0
).unsqueeze(0)
try:
target_log_probs = torch.stack(
[total_log_prob, *log_prob_list], dim=0
).unsqueeze(0)
except TypeError:
raise TypeError("It seems like you got an empty list of target tokens. If you are attributing only one target token (a single character or word) try using the skip_bos argument in the attribute function.")
else:
target_log_probs = total_log_prob
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[int,
Expand Down Expand Up @@ -327,6 +330,7 @@ def attribute(
inp: InterpretableInput,
target: Union[str, torch.Tensor, None] = None,
num_trials: int = 1,
skip_bos: bool = True,
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting
# errors.
Expand Down Expand Up @@ -382,8 +386,11 @@ def attribute(
assert gen_args is None, "gen_args must be None when target is given"

if type(target) is str:
# exclude sos
target_tokens = self.tokenizer.encode(target)[1:]
# exclude sos / bos
if skip_bos:
target_tokens = self.tokenizer.encode(target)[1:]
else:
target_tokens = self.tokenizer.encode(target)
target_tokens = torch.tensor(target_tokens)
elif type(target) is torch.Tensor:
target_tokens = target
Expand Down
Loading