Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit test verifying compatibility with huggingface models #1352

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions scripts/install_via_conda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ fi
# install other deps
conda install -q -y pytest ipywidgets ipython scikit-learn parameterized werkzeug==2.2.2
conda install -q -y -c conda-forge matplotlib pytest-cov flask flask-compress
conda install -q -y transformers

# install captum
python setup.py develop
2 changes: 2 additions & 0 deletions scripts/install_via_pip.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,5 @@ fi
if [[ $DEPLOY == true ]]; then
pip install beautifulsoup4 ipython nbconvert==5.6.1 --progress-bar off
fi

pip install transformers --progress-bar off
89 changes: 89 additions & 0 deletions tests/attr/test_llm_attr_hf_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#!/usr/bin/env python3


from typing import cast, Dict, Optional, Type

import torch
from captum.attr._core.feature_ablation import FeatureAblation
from captum.attr._core.llm_attr import LLMAttribution
from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling
from captum.attr._utils.attribution import PerturbationAttribution
from captum.attr._utils.interpretable_input import TextTemplateInput
from parameterized import parameterized, parameterized_class
from tests.helpers import BaseTest
from torch import Tensor

HAS_HF = True
try:
# pyre-fixme[21]: Could not find a module corresponding to import `transformers`
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
HAS_HF = False


@parameterized_class(
("device", "use_cached_outputs"),
(
[("cpu", True), ("cpu", False), ("cuda", True), ("cuda", False)]
if torch.cuda.is_available()
else [("cpu", True), ("cpu", False)]
),
)
class TestLLMAttrHFCompatibility(BaseTest):
# pyre-fixme[13]: Attribute `device` is never initialized.
device: str
# pyre-fixme[13]: Attribute `use_cached_outputs` is never initialized.
use_cached_outputs: bool

def setUp(self) -> None:
if not HAS_HF:
self.skipTest("transformers package not found, skipping tests")
super().setUp()

# pyre-fixme[56]: Pyre was not able to infer the type of argument `comprehension
@parameterized.expand(
[
(
AttrClass,
n_samples,
)
for AttrClass, n_samples in zip(
(FeatureAblation, ShapleyValueSampling, ShapleyValues), # AttrClass
(None, 1000, None), # n_samples
)
]
)
def test_llm_attr_hf_compatibility(
self,
AttrClass: Type[PerturbationAttribution],
n_samples: Optional[int],
) -> None:
attr_kws: Dict[str, int] = {}
if n_samples is not None:
attr_kws["n_samples"] = n_samples

tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/tiny-random-LlamaForCausalLM"
)
llm = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-LlamaForCausalLM"
)

llm.to(self.device)
llm.eval()
llm_attr = LLMAttribution(AttrClass(llm), tokenizer)

inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
res = llm_attr.attribute(
inp,
"m n o p q",
use_cached_outputs=self.use_cached_outputs,
# pyre-fixme[6]: In call `LLMAttribution.attribute`,
# for 4th positional argument, expected
# `Optional[typing.Callable[..., typing.Any]]` but got `int`.
**attr_kws, # type: ignore
)
self.assertEqual(res.seq_attr.shape, (4,))
self.assertEqual(res.input_tokens, ["a", "c", "d", "f"])
self.assertEqual(res.seq_attr.device.type, self.device)
self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)
Loading