Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added skip_bos Argument to LLM Attribution To Enable Wider Model Support for Attributing Against a Single Token #1322

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
20 changes: 15 additions & 5 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,12 @@ def _forward_func(
# 1st element is the total prob, rest are the target tokens
# add a leading dim for batch even we only support single instance for now
if self.include_per_token_attr:
target_log_probs = torch.stack(
[total_log_prob, *log_prob_list], dim=0 # type: ignore
).unsqueeze(0)
try:
target_log_probs = torch.stack(
[total_log_prob, *log_prob_list], dim=0 # type: ignore
).unsqueeze(0)
except TypeError:
raise TypeError("Try using the skip_bos argument.")
else:
target_log_probs = total_log_prob # type: ignore
target_probs = torch.exp(target_log_probs)
Expand Down Expand Up @@ -325,6 +328,10 @@ def attribute(
inp: InterpretableInput,
target: Union[str, torch.Tensor, None] = None,
num_trials: int = 1,
skip_bos: bool = True,
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting
# errors.
gen_args: Optional[Dict[str, Any]] = None,
use_cached_outputs: bool = True,
# internal callback hook can be used for logging
Expand Down Expand Up @@ -375,8 +382,11 @@ def attribute(
assert gen_args is None, "gen_args must be None when target is given"

if type(target) is str:
# exclude sos
target_tokens = self.tokenizer.encode(target)[1:]
# exclude sos / bos
if skip_bos:
target_tokens = self.tokenizer.encode(target)[1:]
else:
target_tokens = self.tokenizer.encode(target)
target_tokens = torch.tensor(target_tokens)
elif type(target) is torch.Tensor:
target_tokens = target
Expand Down