Skip to content

Commit

Permalink
fix LLMAttribution for old pytorch/python versions (#1353)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1353

When setting `use_cached_outputs=False`, the `LLMAttribution` failed to run on some old versions of pytorch/python.
## Error message
```
======================================================================
ERROR: test_llm_attr_hf_compatibility_0 (tests.attr.test_llm_attr_hf_compatibility.TestLLMAttrHFCompatibility_1_cpu)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/parameterized/parameterized.py", line 620, in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
  File "/data/users/liudj/captum/tests/attr/test_llm_attr_hf_compatibility.py", line 80, in test_llm_attr_hf_compatibility
    res = llm_attr.attribute(
  File "/data/users/liudj/captum/captum/attr/_core/llm_attr.py", line 461, in attribute
    cur_attr = self.attr_method.attribute(
  File "/data/users/liudj/captum/captum/log/__init__.py", line 52, in wrapper
    return func(*args, **kwargs)
  File "/data/users/liudj/captum/captum/attr/_core/feature_ablation.py", line 292, in attribute
    initial_eval: Union[Tensor, Future[Tensor]] = _run_forward(
  File "/data/users/liudj/captum/captum/_utils/common.py", line 599, in _run_forward
    output = forward_func(
  File "/data/users/liudj/captum/captum/attr/_core/llm_attr.py", line 335, in _forward_func
    outputs = self.model.forward(model_inp, **model_kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
    outputs = self.model(
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1001, in forward
    layer_outputs = decoder_layer(
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 734, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/liudj/local/anaconda3/envs/captum_py38/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 428, in forward
    attn_weights = attn_weights + causal_mask
RuntimeError: The size of tensor a (8) must match the size of tensor b (7) at non-singleton dimension 3
```

## Root cause
The `attention_mask` was not updated to adapt to the growth of input size. Error message see test plan.

## Impacted versions
- Python 3.8-3.10, PyTorch 1.10-1.12, transformers 4.44.2
- Python 3.8-3.11, PyTorch 1.13-2.1.0, transformers 4.44.2

{F1876426564}

Reviewed By: vivekmig

Differential Revision: D63016032

fbshipit-source-id: cb4c75a486ffb4b7c5d8d7d16c88a15fa2ce3745
  • Loading branch information
DianjingLiu authored and facebook-github-bot committed Sep 19, 2024
1 parent 7b80c5b commit fc910e5
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,11 @@ def _forward_func(
)
outputs = self.model.forward(**model_inputs)
else:
# Update attention mask to adapt to input size change
attention_mask = torch.ones(
[1, model_inp.shape[1]], dtype=torch.long, device=model_inp.device
)
model_kwargs["attention_mask"] = attention_mask
outputs = self.model.forward(model_inp, **model_kwargs)
new_token_logits = outputs.logits[:, -1]
log_probs = torch.nn.functional.log_softmax(new_token_logits, dim=1)
Expand Down

0 comments on commit fc910e5

Please sign in to comment.