From fc910e5e0289ffd856d40503d5504d73e8b28b95 Mon Sep 17 00:00:00 2001 From: Dianjing Liu Date: Thu, 19 Sep 2024 13:57:54 -0700 Subject: [PATCH] fix LLMAttribution for old pytorch/python versions (#1353) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/1353 When setting `use_cached_outputs=False`, the `LLMAttribution` failed to run on some old versions of pytorch/python. ## Error message ``` ====================================================================== ERROR: test_llm_attr_hf_compatibility_0 (tests.attr.test_llm_attr_hf_compatibility.TestLLMAttrHFCompatibility_1_cpu) ---------------------------------------------------------------------- Traceback (most recent call last): File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/parameterized/parameterized.py", line 620, in standalone_func return func(*(a + p.args), **p.kwargs, **kw) File "/data/users/liudj/captum/tests/attr/test_llm_attr_hf_compatibility.py", line 80, in test_llm_attr_hf_compatibility res = llm_attr.attribute( File "/data/users/liudj/captum/captum/attr/_core/llm_attr.py", line 461, in attribute cur_attr = self.attr_method.attribute( File "/data/users/liudj/captum/captum/log/__init__.py", line 52, in wrapper return func(*args, **kwargs) File "/data/users/liudj/captum/captum/attr/_core/feature_ablation.py", line 292, in attribute initial_eval: Union[Tensor, Future[Tensor]] = _run_forward( File "/data/users/liudj/captum/captum/_utils/common.py", line 599, in _run_forward output = forward_func( File "/data/users/liudj/captum/captum/attr/_core/llm_attr.py", line 335, in _forward_func outputs = self.model.forward(model_inp, **model_kwargs) File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward outputs = self.model( File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1001, in forward layer_outputs = decoder_layer( File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 734, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 428, in forward attn_weights = attn_weights + causal_mask RuntimeError: The size of tensor a (8) must match the size of tensor b (7) at non-singleton dimension 3 ``` ## Root cause The `attention_mask` was not updated to adapt to the growth of input size. Error message see test plan. ## Impacted versions - Python 3.8-3.10, PyTorch 1.10-1.12, transformers 4.44.2 - Python 3.8-3.11, PyTorch 1.13-2.1.0, transformers 4.44.2 {F1876426564} Reviewed By: vivekmig Differential Revision: D63016032 fbshipit-source-id: cb4c75a486ffb4b7c5d8d7d16c88a15fa2ce3745 --- captum/attr/_core/llm_attr.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 73aca0d44..80e84cd8b 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -332,6 +332,11 @@ def _forward_func( ) outputs = self.model.forward(**model_inputs) else: + # Update attention mask to adapt to input size change + attention_mask = torch.ones( + [1, model_inp.shape[1]], dtype=torch.long, device=model_inp.device + ) + model_kwargs["attention_mask"] = attention_mask outputs = self.model.forward(model_inp, **model_kwargs) new_token_logits = outputs.logits[:, -1] log_probs = torch.nn.functional.log_softmax(new_token_logits, dim=1)