Skip to content

Commit

Permalink
Support Cache Class for Even Newer Versions of Transformers Library (#…
Browse files Browse the repository at this point in the history
…1343)

Summary:
Pull Request resolved: #1343

Supports multiple and newer versions of the transformers library. Adds the `packaging` dependency as well to more robustly check package versions.

Reviewed By: vivekmig

Differential Revision: D62468332

fbshipit-source-id: 6cbe984adc867771242dec1bc98ae1ab1962bd93
  • Loading branch information
craymichael authored and facebook-github-bot committed Sep 13, 2024
1 parent 7b22550 commit 5839d52
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 11 deletions.
101 changes: 94 additions & 7 deletions captum/_utils/transformers_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

# pyre-strict

from typing import Optional, Protocol, Tuple, Type
from typing import Any, Dict, Optional, Protocol, Tuple, Type

import torch

from packaging.version import Version
from torch import nn


class CacheLike(Protocol):
"""Protocol for cache-like objects."""
Expand All @@ -21,12 +24,96 @@ def from_legacy_cache(
) -> "DynamicCacheLike": ...


transformers_installed: bool
Cache: Optional[Type[CacheLike]]
DynamicCache: Optional[Type[DynamicCacheLike]]

try:
# pyre-ignore[21]: Could not find a module corresponding to import
# `transformers.cache_utils`
from transformers.cache_utils import Cache as _Cache, DynamicCache as _DynamicCache
# pyre-ignore[21]: Could not find a module corresponding to import `transformers`.
import transformers # noqa: F401

transformers_installed = True
except ImportError:
_Cache = _DynamicCache = None
transformers_installed = False

if transformers_installed:
try:
# pyre-ignore[21]: Could not find a module corresponding to import
# `transformers.cache_utils`.
from transformers.cache_utils import ( # noqa: F401
Cache as _Cache,
DynamicCache as _DynamicCache,
)

Cache = _Cache
# pyre-ignore[9]: Incompatible variable type: DynamicCache is declared to have
# type `Optional[Type[DynamicCacheLike]]` but is used as type
# `Type[_DynamicCache]`
DynamicCache = _DynamicCache
except ImportError:
Cache = DynamicCache = None
else:
Cache = DynamicCache = None

# GenerationMixin._update_model_kwargs_for_generation
# "cache_position" at v4.39.0 (only needed for models that support cache class)
# "use_cache" at v4.41.0 (optional, default is True)
# "cache_position" is mandatory at v4.43.0 ("use_cache" is still optional, default True)
_transformers_version: Optional[Version]
if transformers_installed:
_transformers_version = Version(transformers.__version__)
else:
_transformers_version = None

_mandated_cache_version = Version("4.43.0")
_use_cache_version = Version("4.41.0")
_cache_position_version = Version("4.39.0")


def update_model_kwargs(
model_kwargs: Dict[str, Any],
model: nn.Module,
input_ids: torch.Tensor,
caching: bool,
) -> None:
if not supports_caching(model):
return
assert _transformers_version is not None
if caching:
# Enable caching
if _transformers_version >= _cache_position_version:
cache_position = torch.arange(
input_ids.shape[1], dtype=torch.int64, device=input_ids.device
)
model_kwargs["cache_position"] = cache_position
# pyre-ignore[58]: Unsupported operand `>=` is not supported for operand types
# `Optional[Version]` and `Version`.
if _transformers_version >= _use_cache_version:
model_kwargs["use_cache"] = True
else:
# Disable caching
if _transformers_version >= _use_cache_version:
model_kwargs["use_cache"] = False


Cache: Optional[Type[CacheLike]] = _Cache
DynamicCache: Optional[Type[DynamicCacheLike]] = _DynamicCache
def supports_caching(model: nn.Module) -> bool:
if not transformers_installed:
# Not a transformers model
return False
# Cache may be optional or unsupported depending on model/version
try:
# pyre-ignore[21]: Could not find a module corresponding to import
# `transformers.generation.utils`.
from transformers.generation.utils import GenerationMixin
except ImportError:
return False
if not isinstance(model, GenerationMixin):
# Model isn't a GenerationMixin, we don't support additional caching logic
# for it
return False
assert _transformers_version is not None
if _transformers_version >= _mandated_cache_version:
# Cache is mandatory
return True
# Fallback on _supports_cache_class attribute
return getattr(model, "_supports_cache_class", False)
21 changes: 19 additions & 2 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np

import torch
from captum._utils.transformers_typing import Cache, DynamicCache
from captum._utils.typing import TokenizerLike
from captum.attr._core.feature_ablation import FeatureAblation
from captum.attr._core.kernel_shap import KernelShap
Expand Down Expand Up @@ -259,6 +258,15 @@ def _forward_func(
use_cached_outputs: bool = False,
_inspect_forward: Optional[Callable[[str, str, List[float]], None]] = None,
) -> Tensor:
# Lazily import transformers_typing to avoid importing transformers package if
# it isn't needed
from captum._utils.transformers_typing import (
Cache,
DynamicCache,
supports_caching,
update_model_kwargs,
)

perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor))
init_model_inp = perturbed_input

Expand All @@ -267,16 +275,25 @@ def _forward_func(
[1, model_inp.shape[1]], dtype=torch.long, device=model_inp.device
)
model_kwargs = {"attention_mask": attention_mask}
# If applicable, update model kwargs for transformers models
update_model_kwargs(
model_kwargs=model_kwargs,
model=self.model,
input_ids=model_inp,
caching=use_cached_outputs,
)

log_prob_list = []
outputs = None
for target_token in target_tokens:
if use_cached_outputs:
if outputs is not None:
# If applicable, convert past_key_values to DynamicCache for
# transformers models
if (
Cache is not None
and DynamicCache is not None
and getattr(self.model, "_supports_cache_class", False)
and supports_caching(self.model)
and not isinstance(outputs.past_key_values, Cache)
):
outputs.past_key_values = DynamicCache.from_legacy_cache(
Expand Down
12 changes: 10 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def report(*args):

# get version string from module
with open(os.path.join(os.path.dirname(__file__), "captum/__init__.py"), "r") as f:
version = re.search(r"__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M).group(1)
version_match = re.search(r"__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M)
assert version_match is not None, "Unable to find version string."
version = version_match.group(1)
report("-- Building version " + version)

# read in README.md as the long description
Expand Down Expand Up @@ -147,7 +149,13 @@ def get_package_files(root, subdirs):
long_description=long_description,
long_description_content_type="text/markdown",
python_requires=">=3.8",
install_requires=["matplotlib", "numpy<2.0", "torch>=1.10", "tqdm"],
install_requires=[
"matplotlib",
"numpy<2.0",
"packaging",
"torch>=1.10",
"tqdm",
],
packages=find_packages(exclude=("tests", "tests.*")),
extras_require={
"dev": DEV_REQUIRES,
Expand Down

0 comments on commit 5839d52

Please sign in to comment.