diff --git a/captum/_utils/common.py b/captum/_utils/common.py index 58968d747..f1b5fd9a7 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -73,17 +73,17 @@ def safe_div( @typing.overload # pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`) # is incompatible with the return type of the implementation (`bool`). -# pyre-fixme[31]: Expression `Literal[False]` is not a valid type. +# pyre-fixme[31]: Expression `Literal[True]` is not a valid type. # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. -def _is_tuple(inputs: Tensor) -> Literal[False]: ... +def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ... @typing.overload # pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`) # is incompatible with the return type of the implementation (`bool`). -# pyre-fixme[31]: Expression `Literal[True]` is not a valid type. +# pyre-fixme[31]: Expression `Literal[False]` is not a valid type. # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. -def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ... +def _is_tuple(inputs: Tensor) -> Literal[False]: ... def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool: @@ -277,7 +277,7 @@ def _format_additional_forward_args( @overload -def _format_additional_forward_args( +def _format_additional_forward_args( # type: ignore # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any, # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. @@ -780,10 +780,10 @@ def _reduce_list( """ assert len(val_list) > 0, "Cannot reduce empty list!" if isinstance(val_list[0], torch.Tensor): - # pyre-fixme[16]: `bool` has no attribute `device`. - first_device = val_list[0].device - # pyre-fixme[16]: `bool` has no attribute `to`. - return red_func([elem.to(first_device) for elem in val_list]) + first_device = cast(Tensor, val_list[0]).device + return red_func( + [elem.to(first_device) for elem in cast(List[Tensor], val_list)] + ) elif isinstance(val_list[0], bool): # pyre-fixme[7]: Expected `TupleOrTensorOrBoolGeneric` but got `bool`. return any(val_list) diff --git a/captum/_utils/gradient.py b/captum/_utils/gradient.py index d24df1fff..cc74ef92c 100644 --- a/captum/_utils/gradient.py +++ b/captum/_utils/gradient.py @@ -159,33 +159,34 @@ def _neuron_gradients( @typing.overload # pyre-fixme[43]: The implementation of `_forward_layer_eval` does not accept all -# possible arguments of overload defined on line `158`. +# possible arguments of overload defined on line `170`. def _forward_layer_eval( # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_fn: Callable, inputs: Union[Tensor, Tuple[Tensor, ...]], - layer: Module, + layer: List[Module], # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, device_ids: Union[None, List[int]] = None, attribute_to_layer_input: bool = False, grad_enabled: bool = False, -) -> Tuple[Tensor, ...]: ... +) -> List[Tuple[Tensor, ...]]: ... @typing.overload # pyre-fixme[43]: The implementation of `_forward_layer_eval` does not accept all -# possible arguments of overload defined on line `170`. +# possible arguments of overload defined on line `158`. def _forward_layer_eval( # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_fn: Callable, inputs: Union[Tensor, Tuple[Tensor, ...]], - layer: List[Module], + layer: Module, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, device_ids: Union[None, List[int]] = None, attribute_to_layer_input: bool = False, grad_enabled: bool = False, -) -> List[Tuple[Tensor, ...]]: ... +) -> Tuple[Tensor, ...]: ... def _forward_layer_eval( @@ -434,34 +435,34 @@ def _forward_layer_eval_with_neuron_grads( @typing.overload # pyre-fixme[43]: The implementation of `_forward_layer_eval_with_neuron_grads` does -# not accept all possible arguments of overload defined on line `392`. +# not accept all possible arguments of overload defined on line `405`. def _forward_layer_eval_with_neuron_grads( # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_fn: Callable, inputs: Union[Tensor, Tuple[Tensor, ...]], - layer: Module, + layer: List[Module], additional_forward_args: Any = None, gradient_neuron_selector: None = None, grad_enabled: bool = False, device_ids: Union[None, List[int]] = None, attribute_to_layer_input: bool = False, -) -> Tuple[Tensor, ...]: ... +) -> List[Tuple[Tensor, ...]]: ... @typing.overload # pyre-fixme[43]: The implementation of `_forward_layer_eval_with_neuron_grads` does -# not accept all possible arguments of overload defined on line `405`. +# not accept all possible arguments of overload defined on line `392`. def _forward_layer_eval_with_neuron_grads( # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. forward_fn: Callable, inputs: Union[Tensor, Tuple[Tensor, ...]], - layer: List[Module], + layer: Module, additional_forward_args: Any = None, gradient_neuron_selector: None = None, grad_enabled: bool = False, device_ids: Union[None, List[int]] = None, attribute_to_layer_input: bool = False, -) -> List[Tuple[Tensor, ...]]: ... +) -> Tuple[Tensor, ...]: ... def _forward_layer_eval_with_neuron_grads( diff --git a/captum/attr/_core/deep_lift.py b/captum/attr/_core/deep_lift.py index e03437fb3..eda69f177 100644 --- a/captum/attr/_core/deep_lift.py +++ b/captum/attr/_core/deep_lift.py @@ -118,36 +118,36 @@ def __init__( @typing.overload # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `120`. + # arguments of overload defined on line `131`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + *, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. - return_convergence_delta: Literal[False] = False, + return_convergence_delta: Literal[True], custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, - ) -> TensorOrTupleOfTensorsGeneric: ... + ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... @typing.overload # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `131`. + # arguments of overload defined on line `120`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, - *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. - return_convergence_delta: Literal[True], + return_convergence_delta: Literal[False] = False, custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, - ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... + ) -> TensorOrTupleOfTensorsGeneric: ... @log_usage() def attribute( # type: ignore @@ -636,7 +636,7 @@ def __init__(self, model: Module, multiply_by_inputs: bool = True) -> None: # DeepLiftShap.attribute, so we ignore typing here @typing.overload # type: ignore # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `584`. + # arguments of overload defined on line `597`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, @@ -646,16 +646,16 @@ def attribute( target: TargetType = None, # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + *, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. - return_convergence_delta: Literal[False] = False, + return_convergence_delta: Literal[True], custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, - ) -> TensorOrTupleOfTensorsGeneric: ... + ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... @typing.overload # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `597`. + # arguments of overload defined on line `584`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, @@ -663,13 +663,14 @@ def attribute( TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric] ], target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, - *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. - return_convergence_delta: Literal[True], + return_convergence_delta: Literal[False] = False, custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, - ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... + ) -> TensorOrTupleOfTensorsGeneric: ... @log_usage() def attribute( # type: ignore diff --git a/captum/attr/_core/integrated_gradients.py b/captum/attr/_core/integrated_gradients.py index 730cfd48b..e80326293 100644 --- a/captum/attr/_core/integrated_gradients.py +++ b/captum/attr/_core/integrated_gradients.py @@ -81,7 +81,7 @@ def __init__( # a tuple with both attributions and deltas. @typing.overload # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `82`. + # arguments of overload defined on line `95`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, @@ -92,29 +92,30 @@ def attribute( n_steps: int = 50, method: str = "gausslegendre", internal_batch_size: Union[None, int] = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + *, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. - return_convergence_delta: Literal[False] = False, - ) -> TensorOrTupleOfTensorsGeneric: ... + return_convergence_delta: Literal[True], + ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... @typing.overload # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `95`. + # arguments of overload defined on line `82`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, n_steps: int = 50, method: str = "gausslegendre", internal_batch_size: Union[None, int] = None, - *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. - return_convergence_delta: Literal[True], - ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... + return_convergence_delta: Literal[False] = False, + ) -> TensorOrTupleOfTensorsGeneric: ... @log_usage() def attribute( # type: ignore diff --git a/captum/attr/_core/layer/layer_deep_lift.py b/captum/attr/_core/layer/layer_deep_lift.py index 05ae49e56..2c4c10bbf 100644 --- a/captum/attr/_core/layer/layer_deep_lift.py +++ b/captum/attr/_core/layer/layer_deep_lift.py @@ -102,7 +102,7 @@ def __init__( # Ignoring mypy error for inconsistent signature with DeepLift @typing.overload # type: ignore # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `104`. + # arguments of overload defined on line `117`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -110,32 +110,33 @@ def attribute( target: TargetType = None, # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + *, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. - return_convergence_delta: Literal[False] = False, + return_convergence_delta: Literal[True], attribute_to_layer_input: bool = False, custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, grad_kwargs: Optional[Dict[str, Any]] = None, - ) -> Union[Tensor, Tuple[Tensor, ...]]: ... + ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ... @typing.overload # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `117`. + # arguments of overload defined on line `104`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: BaselineType = None, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, - *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. - return_convergence_delta: Literal[True], + return_convergence_delta: Literal[False] = False, attribute_to_layer_input: bool = False, custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, grad_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ... + ) -> Union[Tensor, Tuple[Tensor, ...]]: ... @log_usage() # pyre-fixme[43]: This definition does not have the same decorators as the @@ -452,7 +453,7 @@ def __init__( # Ignoring mypy error for inconsistent signature with DeepLiftShap @typing.overload # type: ignore # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `439`. + # arguments of overload defined on line `453`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -462,17 +463,17 @@ def attribute( target: TargetType = None, # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + *, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. - return_convergence_delta: Literal[False] = False, + return_convergence_delta: Literal[True], attribute_to_layer_input: bool = False, custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, - ) -> Union[Tensor, Tuple[Tensor, ...]]: ... + ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ... @typing.overload # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `453`. + # arguments of overload defined on line `439`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -480,14 +481,15 @@ def attribute( Tensor, Tuple[Tensor, ...], Callable[..., Union[Tensor, Tuple[Tensor, ...]]] ], target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, - *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. - return_convergence_delta: Literal[True], + return_convergence_delta: Literal[False] = False, attribute_to_layer_input: bool = False, custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None, - ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ... + ) -> Union[Tensor, Tuple[Tensor, ...]]: ... @log_usage() # pyre-fixme[43]: This definition does not have the same decorators as the diff --git a/captum/attr/_core/layer/layer_gradient_shap.py b/captum/attr/_core/layer/layer_gradient_shap.py index c2d659aee..8c94c13b5 100644 --- a/captum/attr/_core/layer/layer_gradient_shap.py +++ b/captum/attr/_core/layer/layer_gradient_shap.py @@ -393,7 +393,7 @@ def __init__( @typing.overload # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `373`. + # arguments of overload defined on line `385`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], @@ -401,30 +401,31 @@ def attribute( target: TargetType = None, # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + *, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. - return_convergence_delta: Literal[False] = False, + return_convergence_delta: Literal[True], attribute_to_layer_input: bool = False, grad_kwargs: Optional[Dict[str, Any]] = None, - ) -> Union[Tensor, Tuple[Tensor, ...]]: ... + ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ... @typing.overload # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `385`. + # arguments of overload defined on line `373`. def attribute( self, inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: Union[Tensor, Tuple[Tensor, ...]], target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, - *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. - return_convergence_delta: Literal[True], + return_convergence_delta: Literal[False] = False, attribute_to_layer_input: bool = False, grad_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ... + ) -> Union[Tensor, Tuple[Tensor, ...]]: ... @log_usage() def attribute( # type: ignore @@ -481,7 +482,11 @@ def attribute( # type: ignore if self.multiplies_by_inputs: input_baseline_diffs = tuple( - input - baseline for input, baseline in zip(attr_inputs, attr_baselines) + # pyre-fixme[58]: `-` is not supported for operand types + # `typing.Tuple[torch._tensor.Tensor, ...]` and + # `typing.Tuple[torch._tensor.Tensor, ...]`. + input - baseline + for input, baseline in zip(attr_inputs, attr_baselines) ) attributions = tuple( input_baseline_diff * grad diff --git a/captum/attr/_core/layer/layer_integrated_gradients.py b/captum/attr/_core/layer/layer_integrated_gradients.py index 1b10404f2..146c5c552 100644 --- a/captum/attr/_core/layer/layer_integrated_gradients.py +++ b/captum/attr/_core/layer/layer_integrated_gradients.py @@ -131,11 +131,12 @@ def attribute( @overload # pyre-fixme[43]: The implementation of `attribute` does not accept all possible # arguments of overload defined on line `126`. - def attribute( + def attribute( # type: ignore self, inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: BaselineType, target: TargetType, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any, n_steps: int, method: str, diff --git a/captum/attr/_core/layer/layer_lrp.py b/captum/attr/_core/layer/layer_lrp.py index 7bd272132..705cb2a91 100644 --- a/captum/attr/_core/layer/layer_lrp.py +++ b/captum/attr/_core/layer/layer_lrp.py @@ -65,39 +65,40 @@ def __init__(self, model: Module, layer: ModuleOrModuleList) -> None: @typing.overload # type: ignore # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `66`. + # arguments of overload defined on line `77`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + *, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. - return_convergence_delta: Literal[False] = False, + return_convergence_delta: Literal[True], attribute_to_layer_input: bool = False, verbose: bool = False, - ) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: ... + ) -> Tuple[ + Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]], + Union[Tensor, List[Tensor]], + ]: ... @typing.overload # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `77`. + # arguments of overload defined on line `66`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, - *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. - return_convergence_delta: Literal[True], + return_convergence_delta: Literal[False] = False, attribute_to_layer_input: bool = False, verbose: bool = False, - ) -> Tuple[ - Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]], - Union[Tensor, List[Tensor]], - ]: ... + ) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: ... def attribute( self, diff --git a/captum/attr/_core/lime.py b/captum/attr/_core/lime.py index 7911cce73..cf432007d 100644 --- a/captum/attr/_core/lime.py +++ b/captum/attr/_core/lime.py @@ -1275,7 +1275,7 @@ def _convert_output_shape( @typing.overload # pyre-fixme[43]: The implementation of `_convert_output_shape` does not accept # all possible arguments of overload defined on line `1211`. - def _convert_output_shape( + def _convert_output_shape( # type: ignore self, formatted_inp: Tuple[Tensor, ...], feature_mask: Tuple[Tensor, ...], diff --git a/captum/attr/_core/lrp.py b/captum/attr/_core/lrp.py index 03772d7aa..0c66c94c7 100644 --- a/captum/attr/_core/lrp.py +++ b/captum/attr/_core/lrp.py @@ -63,34 +63,34 @@ def multiplies_by_inputs(self) -> bool: @typing.overload # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `65`. + # arguments of overload defined on line `75`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, - # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + *, + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. - return_convergence_delta: Literal[False] = False, + return_convergence_delta: Literal[True], verbose: bool = False, - ) -> TensorOrTupleOfTensorsGeneric: ... + ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... @typing.overload # pyre-fixme[43]: The implementation of `attribute` does not accept all possible - # arguments of overload defined on line `75`. + # arguments of overload defined on line `65`. def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, target: TargetType = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any = None, - *, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. - return_convergence_delta: Literal[True], + return_convergence_delta: Literal[False] = False, verbose: bool = False, - ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ... + ) -> TensorOrTupleOfTensorsGeneric: ... @log_usage() # pyre-fixme[43]: This definition does not have the same decorators as the diff --git a/captum/attr/_utils/common.py b/captum/attr/_utils/common.py index 068749748..92c1ccafb 100644 --- a/captum/attr/_utils/common.py +++ b/captum/attr/_utils/common.py @@ -77,7 +77,7 @@ def _format_input_baseline( @typing.overload -def _format_input_baseline( +def _format_input_baseline( # type: ignore inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: BaselineType ) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: ... @@ -201,7 +201,7 @@ def _format_and_verify_sliding_window_shapes( @typing.overload # pyre-fixme[43]: The implementation of `_compute_conv_delta_and_format_attrs` does -# not accept all possible arguments of overload defined on line `199`. +# not accept all possible arguments of overload defined on line `212`. def _compute_conv_delta_and_format_attrs( attr_algo: "GradientAttribution", return_convergence_delta: bool, @@ -211,28 +211,29 @@ def _compute_conv_delta_and_format_attrs( # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any, target: TargetType, - # pyre-fixme[9]: is_inputs_tuple has type `Literal[]`; used as `bool`. - # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. + # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. - is_inputs_tuple: Literal[False] = False, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: ... + is_inputs_tuple: Literal[True], +) -> Union[Tuple[Tensor, ...], Tuple[Tuple[Tensor, ...], Tensor]]: ... @typing.overload # pyre-fixme[43]: The implementation of `_compute_conv_delta_and_format_attrs` does -# not accept all possible arguments of overload defined on line `212`. +# not accept all possible arguments of overload defined on line `199`. def _compute_conv_delta_and_format_attrs( attr_algo: "GradientAttribution", return_convergence_delta: bool, attributions: Tuple[Tensor, ...], start_point: Union[int, float, Tensor, Tuple[Union[int, float, Tensor], ...]], end_point: Union[Tensor, Tuple[Tensor, ...]], + # pyre-fixme[2]: Parameter annotation cannot be `Any`. additional_forward_args: Any, target: TargetType, - # pyre-fixme[31]: Expression `Literal[True]` is not a valid type. + # pyre-fixme[9]: is_inputs_tuple has type `Literal[]`; used as `bool`. + # pyre-fixme[31]: Expression `Literal[False]` is not a valid type. # pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters. - is_inputs_tuple: Literal[True], -) -> Union[Tuple[Tensor, ...], Tuple[Tuple[Tensor, ...], Tensor]]: ... + is_inputs_tuple: Literal[False] = False, +) -> Union[Tensor, Tuple[Tensor, Tensor]]: ... # FIXME: GradientAttribution is provided as a string due to a circular import. diff --git a/tests/metrics/test_infidelity.py b/tests/metrics/test_infidelity.py index 82031fe2c..6516ace12 100644 --- a/tests/metrics/test_infidelity.py +++ b/tests/metrics/test_infidelity.py @@ -38,19 +38,19 @@ def _local_perturb_func_default( @typing.overload -# pyre-fixme[43]: The implementation of `_local_perturb_func` does not accept all -# possible arguments of overload defined on line `35`. -def _local_perturb_func(inputs: Tensor) -> Tuple[Tensor, Tensor]: ... - - -@typing.overload -# pyre-fixme[43]: The implementation of `_local_perturb_func` does not accept all -# possible arguments of overload defined on line `39`. +# pyre-ignore[43]: The implementation of `_local_perturb_func` does not accept all +# possible arguments of overload defined on line `43`. def _local_perturb_func( inputs: Tuple[Tensor, ...] ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: ... +@typing.overload +# pyre-ignore[43]: The implementation of `_local_perturb_func` does not accept all +# possible arguments of overload defined on line `51`. +def _local_perturb_func(inputs: Tensor) -> Tuple[Tensor, Tensor]: ... + + def _local_perturb_func( inputs: TensorOrTupleOfTensorsGeneric, ) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Union[Tensor, Tuple[Tensor, ...]]]: @@ -79,12 +79,6 @@ def _global_perturb_func1_default( return _global_perturb_func1(inputs)[1] -@typing.overload -# pyre-fixme[43]: The implementation of `_global_perturb_func1` does not accept all -# possible arguments of overload defined on line `70`. -def _global_perturb_func1(inputs: Tensor) -> Tuple[Tensor, Tensor]: ... - - @typing.overload # pyre-fixme[43]: The implementation of `_global_perturb_func1` does not accept all # possible arguments of overload defined on line `74`. @@ -93,6 +87,12 @@ def _global_perturb_func1( ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: ... +@typing.overload +# pyre-fixme[43]: The implementation of `_global_perturb_func1` does not accept all +# possible arguments of overload defined on line `70`. +def _global_perturb_func1(inputs: Tensor) -> Tuple[Tensor, Tensor]: ... + + # sensitivity-N, N = #input features def _global_perturb_func1( inputs: TensorOrTupleOfTensorsGeneric, diff --git a/tests/metrics/test_sensitivity.py b/tests/metrics/test_sensitivity.py index a71805f9f..9fafed5e7 100644 --- a/tests/metrics/test_sensitivity.py +++ b/tests/metrics/test_sensitivity.py @@ -29,14 +29,14 @@ @typing.overload # pyre-fixme[43]: The implementation of `_perturb_func` does not accept all possible -# arguments of overload defined on line `28`. -def _perturb_func(inputs: Tensor) -> Tensor: ... +# arguments of overload defined on line `32`. +def _perturb_func(inputs: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]: ... @typing.overload # pyre-fixme[43]: The implementation of `_perturb_func` does not accept all possible -# arguments of overload defined on line `32`. -def _perturb_func(inputs: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]: ... +# arguments of overload defined on line `28`. +def _perturb_func(inputs: Tensor) -> Tensor: ... def _perturb_func(