diff --git a/art/estimators/classification/__init__.py b/art/estimators/classification/__init__.py index 476cce1fc9..d1582fde92 100644 --- a/art/estimators/classification/__init__.py +++ b/art/estimators/classification/__init__.py @@ -16,7 +16,7 @@ from art.estimators.classification.keras import KerasClassifier from art.estimators.classification.lightgbm import LightGBMClassifier from art.estimators.classification.mxnet import MXClassifier -from art.estimators.classification.pytorch import PyTorchClassifier +from art.estimators.classification.pytorch import PyTorchClassifier, PyTorchClassifierViT from art.estimators.classification.query_efficient_bb import QueryEfficientGradientEstimationClassifier from art.estimators.classification.scikitlearn import SklearnClassifier from art.estimators.classification.tensorflow import ( diff --git a/art/estimators/classification/pytorch.py b/art/estimators/classification/pytorch.py index fea9171419..e90f4d2996 100644 --- a/art/estimators/classification/pytorch.py +++ b/art/estimators/classification/pytorch.py @@ -1208,3 +1208,114 @@ def get_layers(self) -> List[str]: except ImportError: # pragma: no cover raise ImportError("Could not find PyTorch (`torch`) installation.") from ImportError + + +class PyTorchClassifierViT(PyTorchClassifier): + """ + This class implements a ViT classifier with the PyTorch framework. + """ + + def __init__( + self, + model: "torch.nn.Module", + loss: "torch.nn.modules.loss._Loss", + input_shape: Tuple[int, ...], + nb_classes: int, + optimizer: Optional["torch.optim.Optimizer"] = None, # type: ignore + use_amp: bool = False, + opt_level: str = "O1", + loss_scale: Optional[Union[float, str]] = "dynamic", + channels_first: bool = True, + clip_values: Optional["CLIP_VALUES_TYPE"] = None, + preprocessing_defences: Union["Preprocessor", List["Preprocessor"], None] = None, + postprocessing_defences: Union["Postprocessor", List["Postprocessor"], None] = None, + preprocessing: "PREPROCESSING_TYPE" = (0.0, 1.0), + device_type: str = "gpu", + ) -> None: + """ + Initialization specifically for the PyTorch-based implementation. + + :param model: PyTorch model. The output of the model can be logits, probabilities or anything else. Logits + output should be preferred where possible to ensure attack efficiency. + :param loss: The loss function for which to compute gradients for training. The target label must be raw + categorical, i.e. not converted to one-hot encoding. + :param input_shape: The shape of one input instance. + :param optimizer: The optimizer used to train the classifier. + :param use_amp: Whether to use the automatic mixed precision tool to enable mixed precision training or + gradient computation, e.g. with loss gradient computation. When set to True, this option is + only triggered if there are GPUs available. + :param opt_level: Specify a pure or mixed precision optimization level. Used when use_amp is True. Accepted + values are `O0`, `O1`, `O2`, and `O3`. + :param loss_scale: Loss scaling. Used when use_amp is True. If passed as a string, must be a string + representing a number, e.g., “1.0”, or the string “dynamic”. + :param nb_classes: The number of classes of the model. + :param optimizer: The optimizer used to train the classifier. + :param channels_first: Set channels first or last. + :param clip_values: Tuple of the form `(min, max)` of floats or `np.ndarray` representing the minimum and + maximum values allowed for features. If floats are provided, these will be used as the range of all + features. If arrays are provided, each value will be considered the bound for a feature, thus + the shape of clip values needs to match the total number of features. + :param preprocessing_defences: Preprocessing defence(s) to be applied by the classifier. + :param postprocessing_defences: Postprocessing defence(s) to be applied by the classifier. + :param preprocessing: Tuple of the form `(subtrahend, divisor)` of floats or `np.ndarray` of values to be + used for data preprocessing. The first value will be subtracted from the input. The input will then + be divided by the second one. + :param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`. + """ + import torch + + super().__init__( + model=model, + clip_values=clip_values, + channels_first=channels_first, + preprocessing_defences=preprocessing_defences, + postprocessing_defences=postprocessing_defences, + preprocessing=preprocessing, + device_type=device_type, + nb_classes=nb_classes, + input_shape=input_shape, + loss=loss, + optimizer=optimizer, + use_amp=use_amp, + opt_level=opt_level, + loss_scale=loss_scale, + ) + + @property + def patch_size(self): + return self.model.patch_size + + def get_attention_weights(self, x: Union[np.ndarray, "torch.Tensor"], batch_size: int = 128): + import torch + from torch import fx + from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor + + graph: fx.Graph = fx.Tracer().trace(self.model) + # 'need_weights' is set to False in the implementation + for node in graph.nodes: + new_kwargs = {"need_weights": True} + if node.op == "call_module": + if node.target.endswith("self_attention"): + node.kwargs = new_kwargs + graph.lint() + new_model = fx.GraphModule(self.model, graph) + + return_nodes = [ + "encoder.layers.encoder_layer_0.self_attention", + "encoder.layers.encoder_layer_1.self_attention", + "encoder.layers.encoder_layer_2.self_attention", + "encoder.layers.encoder_layer_3.self_attention", + "encoder.layers.encoder_layer_4.self_attention", + "encoder.layers.encoder_layer_5.self_attention", + "encoder.layers.encoder_layer_6.self_attention", + "encoder.layers.encoder_layer_7.self_attention", + "encoder.layers.encoder_layer_8.self_attention", + "encoder.layers.encoder_layer_9.self_attention", + "encoder.layers.encoder_layer_10.self_attention", + "encoder.layers.encoder_layer_11.self_attention", + ] + + feature_extractor = create_feature_extractor(new_model, return_nodes=return_nodes) + out = feature_extractor(x) + att_weights = [v[1] for v in out.values()] + return torch.stack(att_weights, dim=1)