Skip to content

Commit

Permalink
Pretty-print tokens in llm_attr methods (#1348)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1348

Convert ids to tokens without ugly unicode characters (e.g., Ġ). See:
 huggingface/transformers#4786 and
https://discuss.huggingface.co/t/bpe-tokenizers-and-spaces-before-words/475/2

This is the preferred function over tokenizer.convert_ids_to_tokens() for user-facing data.

Quote from links:
    > Spaces are converted in a special character (the Ġ) in the tokenizer prior to
    > BPE splitting mostly to avoid digesting spaces since the standard BPE algorithm
    > used spaces in its process

Reviewed By: csauper

Differential Revision: D62672912

fbshipit-source-id: 1bafdd7231131cf5b12864a73aaf81e40e6d206c
  • Loading branch information
craymichael authored and facebook-github-bot committed Sep 16, 2024
1 parent b7ca840 commit 6636f4d
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 43 deletions.
38 changes: 35 additions & 3 deletions captum/_utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,24 @@

# pyre-strict

from typing import List, Optional, Protocol, Tuple, TYPE_CHECKING, TypeVar, Union
from typing import (
List,
Optional,
overload,
Protocol,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
)

from torch import Tensor
from torch.nn import Module

if TYPE_CHECKING:
from typing import Literal
else:
Literal = {True: bool, False: bool, (True, False): bool}
Literal = {True: bool, False: bool, (True, False): bool, "pt": str}

TensorOrTupleOfTensorsGeneric = TypeVar(
"TensorOrTupleOfTensorsGeneric", Tensor, Tuple[Tensor, ...]
Expand Down Expand Up @@ -39,8 +48,31 @@ class TokenizerLike(Protocol):
"""A protocol for tokenizer-like objects that can be used with Captum
LLM attribution methods."""

@overload
def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
@overload
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...

def encode(
self, text: str, return_tensors: Optional[str] = None
) -> Union[List[int], Tensor]: ...

def decode(self, token_ids: Tensor) -> str: ...
def convert_ids_to_tokens(self, token_ids: Tensor) -> List[str]: ...

@overload
def convert_ids_to_tokens(self, token_ids: List[int]) -> List[str]: ...
@overload
def convert_ids_to_tokens(self, token_ids: int) -> str: ...

def convert_ids_to_tokens(
self, token_ids: Union[List[int], int]
) -> Union[List[str], str]: ...

@overload
def convert_tokens_to_ids(self, tokens: str) -> int: ...
@overload
def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: ...

def convert_tokens_to_ids(
self, tokens: Union[List[str], str]
) -> Union[List[int], int]: ...
21 changes: 19 additions & 2 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,23 @@ def plot_seq_attr(
return fig, ax


def _convert_ids_to_pretty_tokens(ids: Tensor, tokenizer: TokenizerLike) -> List[str]:
"""
Convert ids to tokens without ugly unicode characters (e.g., Ġ). See:
https://github.com/huggingface/transformers/issues/4786 and
https://discuss.huggingface.co/t/bpe-tokenizers-and-spaces-before-words/475/2
This is the preferred function over tokenizer.convert_ids_to_tokens() for
user-facing data.
Quote from links:
> Spaces are converted in a special character (the Ġ) in the tokenizer prior to
> BPE splitting mostly to avoid digesting spaces since the standard BPE algorithm
> used spaces in its process
"""
return [tokenizer.decode(id_) for id_ in ids]


class LLMAttribution(Attribution):
"""
Attribution class for large language models. It wraps a perturbation-based
Expand Down Expand Up @@ -461,7 +478,7 @@ def attribute(
attr[1:] if self.include_per_token_attr else None
), # shape(n_output_token, n_input_features)
inp.values,
self.tokenizer.convert_ids_to_tokens(target_tokens),
_convert_ids_to_pretty_tokens(target_tokens, self.tokenizer),
)

# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
Expand Down Expand Up @@ -641,7 +658,7 @@ def attribute(
seq_attr,
attr, # shape(n_output_token, n_input_features)
inp.values,
self.tokenizer.convert_ids_to_tokens(target_tokens),
_convert_ids_to_pretty_tokens(target_tokens, self.tokenizer),
)

# pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
Expand Down
46 changes: 20 additions & 26 deletions captum/attr/_utils/interpretable_input.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# pyre-strict
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union

import torch

from captum._utils.typing import TokenizerLike
from torch import Tensor


Expand Down Expand Up @@ -215,9 +217,8 @@ def __init__(
), f"the values must be either a list or a dict, received: {type(values)}"
dict_keys = []

self.values = values
# pyre-fixme[4]: Attribute must be annotated.
self.dict_keys = dict_keys
self.values: List[str] = values
self.dict_keys: List[str] = dict_keys

n_features = len(values)

Expand Down Expand Up @@ -393,25 +394,22 @@ class TextTokenInput(InterpretableInput):
def __init__(
self,
text: str,
# pyre-fixme[2]: Parameter must be annotated.
tokenizer,
tokenizer: TokenizerLike,
baselines: Union[int, str] = 0, # usually UNK
skip_tokens: Union[List[int], List[str], None] = None,
) -> None:
inp_tensor = tokenizer.encode(text, return_tensors="pt")

# input tensor into the model of token ids
# pyre-fixme[4]: Attribute must be annotated.
self.inp_tensor = inp_tensor
self.inp_tensor: Tensor = inp_tensor
# tensor of interpretable token ids
# pyre-fixme[4]: Attribute must be annotated.
self.itp_tensor = inp_tensor
self.itp_tensor: Tensor = inp_tensor
# interpretable mask
# pyre-fixme[4]: Attribute must be annotated.
self.itp_mask = None
self.itp_mask: Optional[Tensor] = None

if skip_tokens:
if isinstance(skip_tokens[0], str):
skip_tokens = cast(List[str], skip_tokens)
skip_tokens = tokenizer.convert_tokens_to_ids(skip_tokens)
assert isinstance(skip_tokens, list)

Expand All @@ -428,18 +426,16 @@ def __init__(
self.skip_tokens = skip_tokens

# features values, the tokens
# pyre-fixme[4]: Attribute must be annotated.
self.values = tokenizer.convert_ids_to_tokens(self.itp_tensor[0].tolist())
# pyre-fixme[4]: Attribute must be annotated.
self.tokenizer = tokenizer
# pyre-fixme[4]: Attribute must be annotated.
self.n_itp_features = len(self.values)
self.values: List[str] = tokenizer.convert_ids_to_tokens(
self.itp_tensor[0].tolist()
)
self.tokenizer: TokenizerLike = tokenizer
self.n_itp_features: int = len(self.values)

# pyre-fixme[4]: Attribute must be annotated.
self.baselines = (
self.baselines: int = (
baselines
if type(baselines) is int
else tokenizer.convert_tokens_to_ids([baselines])[0]
else tokenizer.convert_tokens_to_ids([baselines])[0] # type: ignore
)

def to_tensor(self) -> torch.Tensor:
Expand All @@ -448,8 +444,7 @@ def to_tensor(self) -> torch.Tensor:

# pyre-fixme[14]: `to_model_input` overrides method defined in
# `InterpretableInput` inconsistently.
# pyre-fixme[2]: Parameter must be annotated.
def to_model_input(self, perturbed_tensor=None) -> torch.Tensor:
def to_model_input(self, perturbed_tensor: Optional[Tensor] = None) -> Tensor:
if perturbed_tensor is None:
return self.inp_tensor

Expand All @@ -467,13 +462,12 @@ def to_model_input(self, perturbed_tensor=None) -> torch.Tensor:
if self.itp_mask is None:
return perturb_itp_tensor

perturb_inp_tensor = self.inp_tensor.expand(*expand_shape).clone().to(device)
itp_mask = self.itp_mask.expand(*expand_shape).to(device)
perturb_inp_tensor = self.inp_tensor.expand(*expand_shape).clone().to(device)

perturb_inp_tensor[itp_mask] = perturb_itp_tensor.view(-1)

return perturb_inp_tensor

# pyre-fixme[3]: Return type must be annotated.
def format_attr(self, itp_attr: torch.Tensor):
def format_attr(self, itp_attr: Tensor) -> Tensor:
return itp_attr
50 changes: 46 additions & 4 deletions tests/attr/test_interpretable_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

# pyre-unsafe

from typing import List, Optional, overload, Union

import torch
from captum._utils.typing import Literal
from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput
from parameterized import parameterized
from tests.helpers import BaseTest
Expand All @@ -16,20 +19,59 @@ def __init__(self, vocab_list) -> None:
self.id_to_token = vocab_list
self.unk_idx = len(vocab_list) + 1

def encode(self, text, **kwargs) -> Tensor:
@overload
def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
@overload
# pyre-fixme[43]: Incompatible overload. The implementation of
# `DummyTokenizer.encode` does not accept all possible arguments of overload.
# pyre-ignore[11]: Annotation `pt` is not defined as a type
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...

def encode(
self, text: str, return_tensors: Optional[str] = "pt"
) -> Union[List[int], Tensor]:
assert return_tensors == "pt"
return torch.tensor([self.convert_tokens_to_ids(text.split(" "))])

def convert_ids_to_tokens(self, ids):
@overload
def convert_ids_to_tokens(self, token_ids: List[int]) -> List[str]: ...
@overload
def convert_ids_to_tokens(self, token_ids: int) -> str: ...

def convert_ids_to_tokens(
self, token_ids: Union[List[int], int]
) -> Union[List[str], str]:
if isinstance(token_ids, int):
return (
self.id_to_token[token_ids]
if token_ids < len(self.id_to_token)
else "[UNK]"
)
return [
(self.id_to_token[i] if i < len(self.id_to_token) else "[UNK]") for i in ids
(self.id_to_token[i] if i < len(self.id_to_token) else "[UNK]")
for i in token_ids
]

def convert_tokens_to_ids(self, tokens):
@overload
def convert_tokens_to_ids(self, tokens: str) -> int: ...
@overload
def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: ...

def convert_tokens_to_ids(
self, tokens: Union[List[str], str]
) -> Union[List[int], int]:
if isinstance(tokens, str):
return (
self.token_to_id[tokens] if tokens in self.token_to_id else self.unk_idx
)
return [
(self.token_to_id[t] if t in self.token_to_id else self.unk_idx)
for t in tokens
]

def decode(self, token_ids: Tensor) -> str:
raise NotImplementedError


class TestTextTemplateInput(BaseTest):
@parameterized.expand(
Expand Down
49 changes: 46 additions & 3 deletions tests/attr/test_llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,22 @@
# pyre-strict

import copy
from typing import Any, cast, Dict, List, NamedTuple, Optional, Tuple, Type, Union
from typing import (
Any,
cast,
Dict,
List,
NamedTuple,
Optional,
overload,
Tuple,
Type,
Union,
)

import torch
from captum._utils.models.linear_model import SkLearnLasso
from captum._utils.typing import Literal
from captum.attr._core.feature_ablation import FeatureAblation
from captum.attr._core.kernel_shap import KernelShap
from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap
Expand All @@ -29,6 +41,14 @@ class DummyTokenizer:
unk: int = 1
special_tokens: Dict[int, str] = {sos: "<sos>", unk: "<unk>"}

@overload
def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
@overload
# pyre-fixme[43]: Incompatible overload. The implementation of
# `DummyTokenizer.encode` does not accept all possible arguments of overload.
# pyre-ignore[11]: Annotation `pt` is not defined as a type
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...

def encode(
self, text: str, return_tensors: Optional[str] = None
) -> Union[List[int], Tensor]:
Expand All @@ -45,14 +65,37 @@ def encode(
return torch.tensor([tokens_ids])
return tokens_ids

def convert_ids_to_tokens(self, token_ids: Tensor) -> List[str]:
@overload
def convert_ids_to_tokens(self, token_ids: List[int]) -> List[str]: ...
@overload
def convert_ids_to_tokens(self, token_ids: int) -> str: ...

def convert_ids_to_tokens(
self, token_ids: Union[List[int], int]
) -> Union[List[str], str]:
if isinstance(token_ids, int):
return (
self.special_tokens[token_ids]
if token_ids in self.special_tokens
else chr(token_ids)
)
return [
(self.special_tokens[tid] if tid in self.special_tokens else chr(tid))
for tid in token_ids
]

@overload
def convert_tokens_to_ids(self, tokens: str) -> int: ...
@overload
def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: ...

def convert_tokens_to_ids(
self, tokens: Union[List[str], str]
) -> Union[List[int], int]:
raise NotImplementedError

def decode(self, token_ids: Tensor) -> str:
return " ".join(self.convert_ids_to_tokens(token_ids))
return " ".join(self.convert_ids_to_tokens(token_ids.tolist()))


class Result(NamedTuple):
Expand Down
Loading

0 comments on commit 6636f4d

Please sign in to comment.