From 53ca7ff45b72e1f69f453937aff6735255e6474c Mon Sep 17 00:00:00 2001 From: Dianjing Liu Date: Wed, 18 Sep 2024 14:51:30 -0700 Subject: [PATCH] Add unit test verifying compatibility with huggingface models (#1352) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/1352 Our current unit tests for LLM Attribution use mocked models which are similar to huggingface transformer models (e.g. Llama, Llama2), but may have some unexpected differences such as [this](https://discuss.pytorch.org/t/trying-to-explain-zephyr-generative-llm/195262/3?fbclid=IwZXh0bgNhZW0CMTEAAR3REGbJsdhbNqG5LAyQ9_2J-82nPmNjt5avVyvNw-l8SMTWVXfI2DqIE8w_aem_GRP8EzELKtqDXDMZmox3Uw). To validate coverage and ensure compatibility with future changes to models, we would like to add tests using huggingface models directly and validate compatibility with LLM Attribution, which will help us quickly catch any breaking changes. So far we only test for model type `LlamaForCausalLM` Differential Revision: D62894898 --- scripts/install_via_conda.sh | 1 + scripts/install_via_pip.sh | 2 + tests/attr/test_llm_attr_hf_compatibility.py | 94 ++++++++++++++++++++ 3 files changed, 97 insertions(+) create mode 100644 tests/attr/test_llm_attr_hf_compatibility.py diff --git a/scripts/install_via_conda.sh b/scripts/install_via_conda.sh index a290fc3299..ad7a786ab8 100755 --- a/scripts/install_via_conda.sh +++ b/scripts/install_via_conda.sh @@ -37,6 +37,7 @@ fi # install other deps conda install -q -y pytest ipywidgets ipython scikit-learn parameterized werkzeug==2.2.2 conda install -q -y -c conda-forge matplotlib pytest-cov flask flask-compress +conda install -q -y transformers # install captum python setup.py develop diff --git a/scripts/install_via_pip.sh b/scripts/install_via_pip.sh index 4e89e63f26..613909e228 100755 --- a/scripts/install_via_pip.sh +++ b/scripts/install_via_pip.sh @@ -65,3 +65,5 @@ fi if [[ $DEPLOY == true ]]; then pip install beautifulsoup4 ipython nbconvert==5.6.1 --progress-bar off fi + +pip install transformers --progress-bar off diff --git a/tests/attr/test_llm_attr_hf_compatibility.py b/tests/attr/test_llm_attr_hf_compatibility.py new file mode 100644 index 0000000000..a42592aeaa --- /dev/null +++ b/tests/attr/test_llm_attr_hf_compatibility.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 + + +from typing import cast, Dict, Optional, Type + +import torch +from captum.attr._core.feature_ablation import FeatureAblation +from captum.attr._core.llm_attr import LLMAttribution +from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling +from captum.attr._utils.attribution import PerturbationAttribution +from captum.attr._utils.interpretable_input import TextTemplateInput +from parameterized import parameterized, parameterized_class +from tests.helpers import BaseTest +from torch import Tensor + +HAS_HF = True +try: + # pyre-fixme[21]: Could not find a module corresponding to import `transformers` + from transformers import AutoModelForCausalLM, AutoTokenizer +except ImportError: + HAS_HF = False + + +@parameterized_class( + ("device", "use_cached_outputs"), + ( + [("cpu", True), ("cpu", False), ("cuda", True), ("cuda", False)] + if torch.cuda.is_available() + else [("cpu", True), ("cpu", False)] + ), +) +class TestLLMAttrHFCompatibility(BaseTest): + # pyre-fixme[13]: Attribute `device` is never initialized. + device: str + # pyre-fixme[13]: Attribute `use_cached_outputs` is never initialized. + use_cached_outputs: bool + + def setUp(self) -> None: + if not HAS_HF: + self.skipTest("transformers package not found, skipping tests") + super().setUp() + + # pyre-fixme[56]: Pyre was not able to infer the type of argument `comprehension + @parameterized.expand( + [ + ( + AttrClass, + delta, + n_samples, + ) + for AttrClass, delta, n_samples in zip( + (FeatureAblation, ShapleyValueSampling, ShapleyValues), # AttrClass + (0.001, 0.001, 0.001), # delta + (None, 1000, None), # n_samples + ) + ] + ) + def test_llm_attr_hf_compatibility( + self, + AttrClass: Type[PerturbationAttribution], + delta: float, + n_samples: Optional[int], + ) -> None: + attr_kws: Dict[str, int] = {} + if n_samples is not None: + attr_kws["n_samples"] = n_samples + + tokenizer = AutoTokenizer.from_pretrained( + "hf-internal-testing/tiny-random-LlamaForCausalLM" + ) + llm = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-LlamaForCausalLM" + ) + + llm.to(self.device) + llm.eval() + llm_attr = LLMAttribution(AttrClass(llm), tokenizer) + + inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"]) + res = llm_attr.attribute( + inp, + "m n o p q", + use_cached_outputs=self.use_cached_outputs, + # pyre-fixme[6]: In call `LLMAttribution.attribute`, + # for 4th positional argument, expected + # `Optional[typing.Callable[..., typing.Any]]` but got `int`. + **attr_kws, # type: ignore + ) + self.assertEqual(res.seq_attr.shape, (4,)) + self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4)) + self.assertEqual(res.input_tokens, ["a", "c", "d", "f"]) + self.assertEqual(len(res.output_tokens), 5) + self.assertEqual(res.seq_attr.device.type, self.device) + self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)