Skip to content

Commit

Permalink
Add ViT classifier
Browse files Browse the repository at this point in the history
Signed-off-by: Teodora Sechkova <[email protected]>
  • Loading branch information
sechkova committed May 24, 2023
1 parent 0b85b79 commit a63e9d0
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 1 deletion.
2 changes: 1 addition & 1 deletion art/estimators/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from art.estimators.classification.keras import KerasClassifier
from art.estimators.classification.lightgbm import LightGBMClassifier
from art.estimators.classification.mxnet import MXClassifier
from art.estimators.classification.pytorch import PyTorchClassifier
from art.estimators.classification.pytorch import PyTorchClassifier, PyTorchClassifierViT
from art.estimators.classification.query_efficient_bb import QueryEfficientGradientEstimationClassifier
from art.estimators.classification.scikitlearn import SklearnClassifier
from art.estimators.classification.tensorflow import (
Expand Down
111 changes: 111 additions & 0 deletions art/estimators/classification/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,3 +1208,114 @@ def get_layers(self) -> List[str]:

except ImportError: # pragma: no cover
raise ImportError("Could not find PyTorch (`torch`) installation.") from ImportError


class PyTorchClassifierViT(PyTorchClassifier):
"""
This class implements a ViT classifier with the PyTorch framework.
"""

def __init__(
self,
model: "torch.nn.Module",
loss: "torch.nn.modules.loss._Loss",
input_shape: Tuple[int, ...],
nb_classes: int,
optimizer: Optional["torch.optim.Optimizer"] = None, # type: ignore
use_amp: bool = False,
opt_level: str = "O1",
loss_scale: Optional[Union[float, str]] = "dynamic",
channels_first: bool = True,
clip_values: Optional["CLIP_VALUES_TYPE"] = None,
preprocessing_defences: Union["Preprocessor", List["Preprocessor"], None] = None,
postprocessing_defences: Union["Postprocessor", List["Postprocessor"], None] = None,
preprocessing: "PREPROCESSING_TYPE" = (0.0, 1.0),
device_type: str = "gpu",
) -> None:
"""
Initialization specifically for the PyTorch-based implementation.
:param model: PyTorch model. The output of the model can be logits, probabilities or anything else. Logits
output should be preferred where possible to ensure attack efficiency.
:param loss: The loss function for which to compute gradients for training. The target label must be raw
categorical, i.e. not converted to one-hot encoding.
:param input_shape: The shape of one input instance.
:param optimizer: The optimizer used to train the classifier.
:param use_amp: Whether to use the automatic mixed precision tool to enable mixed precision training or
gradient computation, e.g. with loss gradient computation. When set to True, this option is
only triggered if there are GPUs available.
:param opt_level: Specify a pure or mixed precision optimization level. Used when use_amp is True. Accepted
values are `O0`, `O1`, `O2`, and `O3`.
:param loss_scale: Loss scaling. Used when use_amp is True. If passed as a string, must be a string
representing a number, e.g., “1.0”, or the string “dynamic”.
:param nb_classes: The number of classes of the model.
:param optimizer: The optimizer used to train the classifier.
:param channels_first: Set channels first or last.
:param clip_values: Tuple of the form `(min, max)` of floats or `np.ndarray` representing the minimum and
maximum values allowed for features. If floats are provided, these will be used as the range of all
features. If arrays are provided, each value will be considered the bound for a feature, thus
the shape of clip values needs to match the total number of features.
:param preprocessing_defences: Preprocessing defence(s) to be applied by the classifier.
:param postprocessing_defences: Postprocessing defence(s) to be applied by the classifier.
:param preprocessing: Tuple of the form `(subtrahend, divisor)` of floats or `np.ndarray` of values to be
used for data preprocessing. The first value will be subtracted from the input. The input will then
be divided by the second one.
:param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`.
"""
import torch

super().__init__(
model=model,
clip_values=clip_values,
channels_first=channels_first,
preprocessing_defences=preprocessing_defences,
postprocessing_defences=postprocessing_defences,
preprocessing=preprocessing,
device_type=device_type,
nb_classes=nb_classes,
input_shape=input_shape,
loss=loss,
optimizer=optimizer,
use_amp=use_amp,
opt_level=opt_level,
loss_scale=loss_scale,
)

@property
def patch_size(self):
return self.model.patch_size

def get_attention_weights(self, x: Union[np.ndarray, "torch.Tensor"], batch_size: int = 128):
import torch
from torch import fx
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor

graph: fx.Graph = fx.Tracer().trace(self.model)
# 'need_weights' is set to False in the implementation
for node in graph.nodes:
new_kwargs = {"need_weights": True}
if node.op == "call_module":
if node.target.endswith("self_attention"):
node.kwargs = new_kwargs
graph.lint()
new_model = fx.GraphModule(self.model, graph)

return_nodes = [
"encoder.layers.encoder_layer_0.self_attention",
"encoder.layers.encoder_layer_1.self_attention",
"encoder.layers.encoder_layer_2.self_attention",
"encoder.layers.encoder_layer_3.self_attention",
"encoder.layers.encoder_layer_4.self_attention",
"encoder.layers.encoder_layer_5.self_attention",
"encoder.layers.encoder_layer_6.self_attention",
"encoder.layers.encoder_layer_7.self_attention",
"encoder.layers.encoder_layer_8.self_attention",
"encoder.layers.encoder_layer_9.self_attention",
"encoder.layers.encoder_layer_10.self_attention",
"encoder.layers.encoder_layer_11.self_attention",
]

feature_extractor = create_feature_extractor(new_model, return_nodes=return_nodes)
out = feature_extractor(x)
att_weights = [v[1] for v in out.values()]
return torch.stack(att_weights, dim=1)

0 comments on commit a63e9d0

Please sign in to comment.