Skip to content

Commit

Permalink
Add unit test verifying compatibility with huggingface models (#1352)
Browse files Browse the repository at this point in the history
Summary:

Our current unit tests for LLM Attribution use mocked models which are similar to huggingface transformer models (e.g. Llama, Llama2), but may have some unexpected differences such as [this](https://discuss.pytorch.org/t/trying-to-explain-zephyr-generative-llm/195262/3?fbclid=IwZXh0bgNhZW0CMTEAAR3REGbJsdhbNqG5LAyQ9_2J-82nPmNjt5avVyvNw-l8SMTWVXfI2DqIE8w_aem_GRP8EzELKtqDXDMZmox3Uw). To validate coverage and ensure compatibility with future changes to models, we would like to add tests using huggingface models directly and validate compatibility with LLM Attribution, which will help us quickly catch any breaking changes.

So far we only test for model type `LlamaForCausalLM`

Differential Revision: D62894898
  • Loading branch information
DianjingLiu authored and facebook-github-bot committed Sep 18, 2024
1 parent 70619a6 commit a612efc
Showing 1 changed file with 86 additions and 0 deletions.
86 changes: 86 additions & 0 deletions tests/attr/test_llm_attr_hf_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#!/usr/bin/env python3

# pyre-strict

from typing import cast, Dict, Optional, Type

import torch
from captum.attr._core.feature_ablation import FeatureAblation
from captum.attr._core.llm_attr import LLMAttribution
from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling
from captum.attr._utils.attribution import PerturbationAttribution
from captum.attr._utils.interpretable_input import TextTemplateInput
from parameterized import parameterized, parameterized_class
from tests.helpers import BaseTest
from torch import Tensor

# pyre-fixme[21]: Could not find a module corresponding to import `transformers`
from transformers import AutoModelForCausalLM, AutoTokenizer


@parameterized_class(
("device", "use_cached_outputs"),
(
[("cpu", True), ("cpu", False), ("cuda", True), ("cuda", False)]
if torch.cuda.is_available()
else [("cpu", True), ("cpu", False)]
),
)
class TestLLMAttr(BaseTest):
# pyre-fixme[13]: Attribute `device` is never initialized.
device: str
# pyre-fixme[13]: Attribute `use_cached_outputs` is declared in class `TestLLMAttr` to have type `bool` but is never initialized.
use_cached_outputs: bool

# pyre-fixme[56]: Pyre was not able to infer the type of argument `comprehension
@parameterized.expand(
[
(
AttrClass,
delta,
n_samples,
)
for AttrClass, delta, n_samples in zip(
(FeatureAblation, ShapleyValueSampling, ShapleyValues), # AttrClass
(0.001, 0.001, 0.001), # delta
(None, 1000, None), # n_samples
)
]
)
def test_llm_attr(
self,
AttrClass: Type[PerturbationAttribution],
delta: float,
n_samples: Optional[int],
) -> None:
attr_kws: Dict[str, int] = {}
if n_samples is not None:
attr_kws["n_samples"] = n_samples

tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/tiny-random-LlamaForCausalLM"
)
llm = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-LlamaForCausalLM"
)

llm.to(self.device)
llm.eval()
llm_attr = LLMAttribution(AttrClass(llm), tokenizer)

inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"])
res = llm_attr.attribute(
inp,
"m n o p q",
use_cached_outputs=self.use_cached_outputs,
# pyre-fixme[6]: In call `LLMAttribution.attribute`,
# for 4th positional argument, expected
# `Optional[typing.Callable[..., typing.Any]]` but got `int`.
**attr_kws, # type: ignore
)
self.assertEqual(res.seq_attr.shape, (4,))
self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))
self.assertEqual(res.input_tokens, ["a", "c", "d", "f"])
self.assertEqual(len(res.output_tokens), 5)
self.assertEqual(res.seq_attr.device.type, self.device)
self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)

0 comments on commit a612efc

Please sign in to comment.