diff --git a/captum/attr/_core/noise_tunnel.py b/captum/attr/_core/noise_tunnel.py index eb34eda85..7247ccc00 100644 --- a/captum/attr/_core/noise_tunnel.py +++ b/captum/attr/_core/noise_tunnel.py @@ -2,7 +2,7 @@ # pyre-strict from enum import Enum -from typing import Any, Callable, cast, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union import torch from captum._utils.common import ( @@ -180,179 +180,7 @@ def attribute( >>> attribution = nt.attribute(input, nt_type='smoothgrad', >>> nt_samples=10, target=3) """ - - def add_noise_to_inputs(nt_samples_partition: int) -> Tuple[Tensor, ...]: - if isinstance(stdevs, tuple): - assert len(stdevs) == len(inputs), ( - "The number of input tensors " - "in {} must be equal to the number of stdevs values {}".format( - len(inputs), len(stdevs) - ) - ) - else: - assert isinstance( - stdevs, float - ), "stdevs must be type float. " "Given: {}".format(type(stdevs)) - stdevs_ = (stdevs,) * len(inputs) - return tuple( - ( - add_noise_to_input( - input, stdev, nt_samples_partition - ).requires_grad_() - if self.is_gradient_method - else add_noise_to_input(input, stdev, nt_samples_partition) - ) - # pyre-fixme[61]: `stdevs_` is undefined, or not always defined. - for (input, stdev) in zip(inputs, stdevs_) - ) - - def add_noise_to_input( - input: Tensor, stdev: float, nt_samples_partition: int - ) -> Tensor: - # batch size - bsz = input.shape[0] - - # expand input size by the number of drawn samples - # pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]` - # and `Size`. - input_expanded_size = (bsz * nt_samples_partition,) + input.shape[1:] - - # expand stdev for the shape of the input and number of drawn samples - stdev_expanded = torch.tensor(stdev, device=input.device).repeat( - input_expanded_size - ) - - # draws `np.prod(input_expanded_size)` samples from normal distribution - # with given input parametrization - # FIXME it look like it is very difficult to make torch.normal - # deterministic this needs an investigation - noise = torch.normal(0, stdev_expanded) - return input.repeat_interleave(nt_samples_partition, dim=0) + noise - - def update_sum_attribution_and_sq( - sum_attribution: List[Tensor], - sum_attribution_sq: List[Tensor], - attribution: Tensor, - i: int, - nt_samples_batch_size_inter: int, - ) -> None: - bsz = attribution.shape[0] // nt_samples_batch_size_inter - attribution_shape = cast( - Tuple[int, ...], (bsz, nt_samples_batch_size_inter) - ) - if len(attribution.shape) > 1: - # pyre-fixme[22]: The cast is redundant. - attribution_shape += cast(Tuple[int, ...], tuple(attribution.shape[1:])) - - attribution = attribution.view(attribution_shape) - current_attribution_sum = attribution.sum(dim=1, keepdim=False) - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and - # `int`. - current_attribution_sq = torch.sum(attribution**2, dim=1, keepdim=False) - - sum_attribution[i] = ( - current_attribution_sum - if not isinstance(sum_attribution[i], torch.Tensor) - else sum_attribution[i] + current_attribution_sum - ) - sum_attribution_sq[i] = ( - current_attribution_sq - if not isinstance(sum_attribution_sq[i], torch.Tensor) - else sum_attribution_sq[i] + current_attribution_sq - ) - - def compute_partial_attribution( - inputs_with_noise_partition: Tuple[Tensor, ...], - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - kwargs_partition: Any, - ) -> Tuple[Tuple[Tensor, ...], bool, Union[None, Tensor]]: - # smoothgrad_Attr(x) = 1 / n * sum(Attr(x + N(0, sigma^2)) - # NOTE: using __wrapped__ such that it does not log the inner logs - - attributions = attr_func.__wrapped__( # type: ignore - self.attribution_method, # self - ( - inputs_with_noise_partition - if is_inputs_tuple - else inputs_with_noise_partition[0] - ), - **kwargs_partition, - ) - delta = None - - if self.is_delta_supported and return_convergence_delta: - attributions, delta = attributions - - is_attrib_tuple = _is_tuple(attributions) - attributions = _format_tensor_into_tuples(attributions) - - return ( - cast(Tuple[Tensor, ...], attributions), - cast(bool, is_attrib_tuple), - delta, - ) - - # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use - # `typing.Dict[, ]` to avoid runtime subscripting - # errors. - def expand_partial(nt_samples_partition: int, kwargs_partial: dict) -> None: - # if the algorithm supports targets, baselines and/or - # additional_forward_args they will be expanded based - # on the nt_samples_partition and corresponding kwargs - # variables will be updated accordingly - _expand_and_update_additional_forward_args( - nt_samples_partition, kwargs_partial - ) - _expand_and_update_target(nt_samples_partition, kwargs_partial) - _expand_and_update_baselines( - cast(Tuple[Tensor, ...], inputs), - nt_samples_partition, - kwargs_partial, - draw_baseline_from_distrib=draw_baseline_from_distrib, - ) - _expand_and_update_feature_mask(nt_samples_partition, kwargs_partial) - - def compute_smoothing( - expected_attributions: Tuple[Union[Tensor], ...], - expected_attributions_sq: Tuple[Union[Tensor], ...], - ) -> Tuple[Tensor, ...]: - if NoiseTunnelType[nt_type] == NoiseTunnelType.smoothgrad: - return expected_attributions - - if NoiseTunnelType[nt_type] == NoiseTunnelType.smoothgrad_sq: - return expected_attributions_sq - - vargrad = tuple( - expected_attribution_sq - expected_attribution * expected_attribution - for expected_attribution, expected_attribution_sq in zip( - expected_attributions, expected_attributions_sq - ) - ) - - # pyre-fixme[22]: The cast is redundant. - return cast(Tuple[Tensor, ...], vargrad) - - def update_partial_attribution_and_delta( - attributions_partial: Tuple[Tensor, ...], - delta_partial: Tensor, - sum_attributions: List[Tensor], - sum_attributions_sq: List[Tensor], - delta_partial_list: List[Tensor], - nt_samples_partial: int, - ) -> None: - for i, attribution_partial in enumerate(attributions_partial): - update_sum_attribution_and_sq( - sum_attributions, - sum_attributions_sq, - attribution_partial, - i, - nt_samples_partial, - ) - if self.is_delta_supported and return_convergence_delta: - delta_partial_list.append(delta_partial) - - return_convergence_delta: bool - return_convergence_delta = ( + return_convergence_delta: bool = ( "return_convergence_delta" in kwargs and kwargs["return_convergence_delta"] ) with torch.no_grad(): @@ -373,21 +201,28 @@ def update_partial_attribution_and_delta( _validate_noise_tunnel_type(nt_type, SUPPORTED_NOISE_TUNNEL_TYPES) kwargs_copy = kwargs.copy() - expand_partial(nt_samples_batch_size, kwargs_copy) - - attr_func = self.attribution_method.attribute + self._expand_partial( + nt_samples_batch_size, kwargs_copy, inputs, draw_baseline_from_distrib + ) sum_attributions: List[Union[None, Tensor]] = [] sum_attributions_sq: List[Union[None, Tensor]] = [] delta_partial_list: List[Tensor] = [] for _ in range(nt_samples_partition): - inputs_with_noise = add_noise_to_inputs(nt_samples_batch_size) + inputs_with_noise = self._add_noise_to_inputs( + nt_samples_batch_size, inputs, stdevs + ) ( attributions_partial, is_attrib_tuple, delta_partial, - ) = compute_partial_attribution(inputs_with_noise, kwargs_copy) + ) = self._compute_partial_attribution( + inputs_with_noise, + kwargs_copy, + is_inputs_tuple, + return_convergence_delta, + ) if len(sum_attributions) == 0: # pyre-fixme[9]: sum_attributions has type @@ -397,36 +232,45 @@ def update_partial_attribution_and_delta( # `List[Optional[Tensor]]`; used as `List[None]`. sum_attributions_sq = [None] * len(attributions_partial) - update_partial_attribution_and_delta( - # pyre-fixme[22]: The cast is redundant. - cast(Tuple[Tensor, ...], attributions_partial), + self._update_partial_attribution_and_delta( + attributions_partial, cast(Tensor, delta_partial), cast(List[Tensor], sum_attributions), cast(List[Tensor], sum_attributions_sq), delta_partial_list, nt_samples_batch_size, + return_convergence_delta, ) nt_samples_remaining = ( nt_samples - nt_samples_partition * nt_samples_batch_size ) if nt_samples_remaining > 0: - inputs_with_noise = add_noise_to_inputs(nt_samples_remaining) - expand_partial(nt_samples_remaining, kwargs) + inputs_with_noise = self._add_noise_to_inputs( + nt_samples_remaining, inputs, stdevs + ) + self._expand_partial( + nt_samples_remaining, kwargs, inputs, draw_baseline_from_distrib + ) ( attributions_partial, is_attrib_tuple, delta_partial, - ) = compute_partial_attribution(inputs_with_noise, kwargs) + ) = self._compute_partial_attribution( + inputs_with_noise, + kwargs, + is_inputs_tuple, + return_convergence_delta, + ) - update_partial_attribution_and_delta( - # pyre-fixme[22]: The cast is redundant. - cast(Tuple[Tensor, ...], attributions_partial), + self._update_partial_attribution_and_delta( + attributions_partial, cast(Tensor, delta_partial), cast(List[Tensor], sum_attributions), cast(List[Tensor], sum_attributions_sq), delta_partial_list, nt_samples_remaining, + return_convergence_delta, ) expected_attributions = tuple( @@ -441,11 +285,10 @@ def update_partial_attribution_and_delta( for sum_attribution_sq in sum_attributions_sq ] ) - attributions = compute_smoothing( - # pyre-fixme[22]: The cast is redundant. - cast(Tuple[Tensor, ...], expected_attributions), - # pyre-fixme[22]: The cast is redundant. - cast(Tuple[Tensor, ...], expected_attributions_sq), + attributions = self._compute_smoothing( + expected_attributions, + expected_attributions_sq, + nt_type, ) delta = None @@ -467,6 +310,189 @@ def attribute_future(self) -> Callable: """ raise NotImplementedError("attribute_future is not implemented for NoiseTunnel") + def _add_noise_to_inputs( + self, + nt_samples_partition: int, + inputs: Tuple[Tensor, ...], + stdevs: Union[float, Tuple[float, ...]], + ) -> Tuple[Tensor, ...]: + if isinstance(stdevs, tuple): + assert len(stdevs) == len(inputs), ( + "The number of input tensors " + "in {} must be equal to the number of stdevs values {}".format( + len(inputs), len(stdevs) + ) + ) + stdevs_ = stdevs + else: + assert isinstance( + stdevs, float + ), "stdevs must be type float. " "Given: {}".format(type(stdevs)) + stdevs_ = (stdevs,) * len(inputs) + return tuple( + ( + self._add_noise_to_input( + input, stdev, nt_samples_partition + ).requires_grad_() + if self.is_gradient_method + else self._add_noise_to_input(input, stdev, nt_samples_partition) + ) + for (input, stdev) in zip(inputs, stdevs_) + ) + + @staticmethod + def _add_noise_to_input( + input: Tensor, stdev: float, nt_samples_partition: int + ) -> Tensor: + # batch size + bsz = input.shape[0] + + # expand input size by the number of drawn samples + # pyre-fixme[58]: `+` is not supported for operand types `Tuple[int]` + # and `Size`. + input_expanded_size = (bsz * nt_samples_partition,) + input.shape[1:] + + # expand stdev for the shape of the input and number of drawn samples + stdev_expanded = torch.tensor(stdev, device=input.device).repeat( + input_expanded_size + ) + + # draws `np.prod(input_expanded_size)` samples from normal distribution + # with given input parametrization + # FIXME it look like it is very difficult to make torch.normal + # deterministic this needs an investigation + noise = torch.normal(0, stdev_expanded) + return input.repeat_interleave(nt_samples_partition, dim=0) + noise + + @staticmethod + def _update_sum_attribution_and_sq( + sum_attribution: List[Tensor], + sum_attribution_sq: List[Tensor], + attribution: Tensor, + i: int, + nt_samples_batch_size_inter: int, + ) -> None: + bsz = attribution.shape[0] // nt_samples_batch_size_inter + attribution_shape = cast(Tuple[int, ...], (bsz, nt_samples_batch_size_inter)) + if len(attribution.shape) > 1: + # pyre-fixme[22]: The cast is redundant. + attribution_shape += cast(Tuple[int, ...], tuple(attribution.shape[1:])) + + attribution = attribution.view(attribution_shape) + current_attribution_sum = attribution.sum(dim=1, keepdim=False) + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + current_attribution_sq = torch.sum(attribution**2, dim=1, keepdim=False) + + sum_attribution[i] = ( + current_attribution_sum + if not isinstance(sum_attribution[i], torch.Tensor) + else sum_attribution[i] + current_attribution_sum + ) + sum_attribution_sq[i] = ( + current_attribution_sq + if not isinstance(sum_attribution_sq[i], torch.Tensor) + else sum_attribution_sq[i] + current_attribution_sq + ) + + def _compute_partial_attribution( + self, + inputs_with_noise_partition: Tuple[Tensor, ...], + # pyre-fixme[2]: Parameter annotation cannot be `Any`. + kwargs_partition: Any, + is_inputs_tuple: bool, + return_convergence_delta: bool, + ) -> Tuple[Tuple[Tensor, ...], bool, Union[None, Tensor]]: + attr_func = self.attribution_method.attribute + # smoothgrad_Attr(x) = 1 / n * sum(Attr(x + N(0, sigma^2)) + # NOTE: using __wrapped__ such that it does not log the inner logs + + attributions = attr_func.__wrapped__( # type: ignore + self.attribution_method, # self + ( + inputs_with_noise_partition + if is_inputs_tuple + else inputs_with_noise_partition[0] + ), + **kwargs_partition, + ) + delta = None + + if self.is_delta_supported and return_convergence_delta: + attributions, delta = attributions + + is_attrib_tuple = _is_tuple(attributions) + attributions = _format_tensor_into_tuples(attributions) + + return ( + cast(Tuple[Tensor, ...], attributions), + cast(bool, is_attrib_tuple), + delta, + ) + + @staticmethod + def _expand_partial( + nt_samples_partition: int, + kwargs_partial: Dict[str, Any], + inputs: Tuple[Tensor, ...], + draw_baseline_from_distrib: bool, + ) -> None: + # if the algorithm supports targets, baselines and/or + # additional_forward_args they will be expanded based + # on the nt_samples_partition and corresponding kwargs + # variables will be updated accordingly + _expand_and_update_additional_forward_args(nt_samples_partition, kwargs_partial) + _expand_and_update_target(nt_samples_partition, kwargs_partial) + _expand_and_update_baselines( + inputs, + nt_samples_partition, + kwargs_partial, + draw_baseline_from_distrib=draw_baseline_from_distrib, + ) + _expand_and_update_feature_mask(nt_samples_partition, kwargs_partial) + + @staticmethod + def _compute_smoothing( + expected_attributions: Tuple[Union[Tensor], ...], + expected_attributions_sq: Tuple[Union[Tensor], ...], + nt_type: str, + ) -> Tuple[Tensor, ...]: + if NoiseTunnelType[nt_type] == NoiseTunnelType.smoothgrad: + return expected_attributions + + if NoiseTunnelType[nt_type] == NoiseTunnelType.smoothgrad_sq: + return expected_attributions_sq + + vargrad = tuple( + expected_attribution_sq - expected_attribution * expected_attribution + for expected_attribution, expected_attribution_sq in zip( + expected_attributions, expected_attributions_sq + ) + ) + + return vargrad + + def _update_partial_attribution_and_delta( + self, + attributions_partial: Tuple[Tensor, ...], + delta_partial: Tensor, + sum_attributions: List[Tensor], + sum_attributions_sq: List[Tensor], + delta_partial_list: List[Tensor], + nt_samples_partial: int, + return_convergence_delta: bool, + ) -> None: + for i, attribution_partial in enumerate(attributions_partial): + self._update_sum_attribution_and_sq( + sum_attributions, + sum_attributions_sq, + attribution_partial, + i, + nt_samples_partial, + ) + if self.is_delta_supported and return_convergence_delta: + delta_partial_list.append(delta_partial) + def _apply_checks_and_return_attributions( self, attributions: Tuple[Tensor, ...],