Skip to content

Commit

Permalink
Address instances of "Overloaded function signature x will never be m…
Browse files Browse the repository at this point in the history
…atched" + minor typing fixes

Summary: Many overloads produced false positives or required changing order due to mypy breaking ties by picking the first matching variant (https://mypy.readthedocs.io/en/stable/more_types.html). This fixes or suppresses these errors. Created T204932142 to address Literal-related issues.

Differential Revision: D64517613
  • Loading branch information
Zach Carmichael authored and facebook-github-bot committed Oct 17, 2024
1 parent a510bf6 commit e5d6d6b
Show file tree
Hide file tree
Showing 13 changed files with 150 additions and 137 deletions.
18 changes: 9 additions & 9 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,17 @@ def safe_div(
@typing.overload
# pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`)
# is incompatible with the return type of the implementation (`bool`).
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
def _is_tuple(inputs: Tensor) -> Literal[False]: ...
def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ...


@typing.overload
# pyre-fixme[43]: The return type of overloaded function `_is_tuple` (`Literal[]`)
# is incompatible with the return type of the implementation (`bool`).
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: ...
def _is_tuple(inputs: Tensor) -> Literal[False]: ...


def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool:
Expand Down Expand Up @@ -277,7 +277,7 @@ def _format_additional_forward_args(


@overload
def _format_additional_forward_args(
def _format_additional_forward_args( # type: ignore
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any,
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
Expand Down Expand Up @@ -780,10 +780,10 @@ def _reduce_list(
"""
assert len(val_list) > 0, "Cannot reduce empty list!"
if isinstance(val_list[0], torch.Tensor):
# pyre-fixme[16]: `bool` has no attribute `device`.
first_device = val_list[0].device
# pyre-fixme[16]: `bool` has no attribute `to`.
return red_func([elem.to(first_device) for elem in val_list])
first_device = cast(Tensor, val_list[0]).device
return red_func(
[elem.to(first_device) for elem in cast(List[Tensor], val_list)]
)
elif isinstance(val_list[0], bool):
# pyre-fixme[7]: Expected `TupleOrTensorOrBoolGeneric` but got `bool`.
return any(val_list)
Expand Down
25 changes: 13 additions & 12 deletions captum/_utils/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,33 +159,34 @@ def _neuron_gradients(

@typing.overload
# pyre-fixme[43]: The implementation of `_forward_layer_eval` does not accept all
# possible arguments of overload defined on line `158`.
# possible arguments of overload defined on line `170`.
def _forward_layer_eval(
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_fn: Callable,
inputs: Union[Tensor, Tuple[Tensor, ...]],
layer: Module,
layer: List[Module],
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
grad_enabled: bool = False,
) -> Tuple[Tensor, ...]: ...
) -> List[Tuple[Tensor, ...]]: ...


@typing.overload
# pyre-fixme[43]: The implementation of `_forward_layer_eval` does not accept all
# possible arguments of overload defined on line `170`.
# possible arguments of overload defined on line `158`.
def _forward_layer_eval(
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_fn: Callable,
inputs: Union[Tensor, Tuple[Tensor, ...]],
layer: List[Module],
layer: Module,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
grad_enabled: bool = False,
) -> List[Tuple[Tensor, ...]]: ...
) -> Tuple[Tensor, ...]: ...


def _forward_layer_eval(
Expand Down Expand Up @@ -434,34 +435,34 @@ def _forward_layer_eval_with_neuron_grads(

@typing.overload
# pyre-fixme[43]: The implementation of `_forward_layer_eval_with_neuron_grads` does
# not accept all possible arguments of overload defined on line `392`.
# not accept all possible arguments of overload defined on line `405`.
def _forward_layer_eval_with_neuron_grads(
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_fn: Callable,
inputs: Union[Tensor, Tuple[Tensor, ...]],
layer: Module,
layer: List[Module],
additional_forward_args: Any = None,
gradient_neuron_selector: None = None,
grad_enabled: bool = False,
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
) -> Tuple[Tensor, ...]: ...
) -> List[Tuple[Tensor, ...]]: ...


@typing.overload
# pyre-fixme[43]: The implementation of `_forward_layer_eval_with_neuron_grads` does
# not accept all possible arguments of overload defined on line `405`.
# not accept all possible arguments of overload defined on line `392`.
def _forward_layer_eval_with_neuron_grads(
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_fn: Callable,
inputs: Union[Tensor, Tuple[Tensor, ...]],
layer: List[Module],
layer: Module,
additional_forward_args: Any = None,
gradient_neuron_selector: None = None,
grad_enabled: bool = False,
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
) -> List[Tuple[Tensor, ...]]: ...
) -> Tuple[Tensor, ...]: ...


def _forward_layer_eval_with_neuron_grads(
Expand Down
43 changes: 22 additions & 21 deletions captum/attr/_core/deep_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,36 +118,36 @@ def __init__(

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `120`.
# arguments of overload defined on line `131`.
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[False] = False,
return_convergence_delta: Literal[True],
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
) -> TensorOrTupleOfTensorsGeneric: ...
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `131`.
# arguments of overload defined on line `120`.
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[True],
return_convergence_delta: Literal[False] = False,
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...
) -> TensorOrTupleOfTensorsGeneric: ...

@log_usage()
def attribute( # type: ignore
Expand Down Expand Up @@ -636,7 +636,7 @@ def __init__(self, model: Module, multiply_by_inputs: bool = True) -> None:
# DeepLiftShap.attribute, so we ignore typing here
@typing.overload # type: ignore
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `584`.
# arguments of overload defined on line `597`.
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
Expand All @@ -646,30 +646,31 @@ def attribute(
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[False] = False,
return_convergence_delta: Literal[True],
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
) -> TensorOrTupleOfTensorsGeneric: ...
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `597`.
# arguments of overload defined on line `584`.
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
baselines: Union[
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
],
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[True],
return_convergence_delta: Literal[False] = False,
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...
) -> TensorOrTupleOfTensorsGeneric: ...

@log_usage()
def attribute( # type: ignore
Expand Down
21 changes: 11 additions & 10 deletions captum/attr/_core/integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
# a tuple with both attributions and deltas.
@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `82`.
# arguments of overload defined on line `95`.
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
Expand All @@ -92,29 +92,30 @@ def attribute(
n_steps: int = 50,
method: str = "gausslegendre",
internal_batch_size: Union[None, int] = None,
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[False] = False,
) -> TensorOrTupleOfTensorsGeneric: ...
return_convergence_delta: Literal[True],
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `95`.
# arguments of overload defined on line `82`.
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
n_steps: int = 50,
method: str = "gausslegendre",
internal_batch_size: Union[None, int] = None,
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[True],
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: ...
return_convergence_delta: Literal[False] = False,
) -> TensorOrTupleOfTensorsGeneric: ...

@log_usage()
def attribute( # type: ignore
Expand Down
42 changes: 22 additions & 20 deletions captum/attr/_core/layer/layer_deep_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,40 +102,41 @@ def __init__(
# Ignoring mypy error for inconsistent signature with DeepLift
@typing.overload # type: ignore
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `104`.
# arguments of overload defined on line `117`.
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[False] = False,
return_convergence_delta: Literal[True],
attribute_to_layer_input: bool = False,
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
grad_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `117`.
# arguments of overload defined on line `104`.
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
baselines: BaselineType = None,
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[True],
return_convergence_delta: Literal[False] = False,
attribute_to_layer_input: bool = False,
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
grad_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...
) -> Union[Tensor, Tuple[Tensor, ...]]: ...

@log_usage()
# pyre-fixme[43]: This definition does not have the same decorators as the
Expand Down Expand Up @@ -452,7 +453,7 @@ def __init__(
# Ignoring mypy error for inconsistent signature with DeepLiftShap
@typing.overload # type: ignore
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `439`.
# arguments of overload defined on line `453`.
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
Expand All @@ -462,32 +463,33 @@ def attribute(
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[False] = False,
return_convergence_delta: Literal[True],
attribute_to_layer_input: bool = False,
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
) -> Union[Tensor, Tuple[Tensor, ...]]: ...
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...

@typing.overload
# pyre-fixme[43]: The implementation of `attribute` does not accept all possible
# arguments of overload defined on line `453`.
# arguments of overload defined on line `439`.
def attribute(
self,
inputs: Union[Tensor, Tuple[Tensor, ...]],
baselines: Union[
Tensor, Tuple[Tensor, ...], Callable[..., Union[Tensor, Tuple[Tensor, ...]]]
],
target: TargetType = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
additional_forward_args: Any = None,
*,
# pyre-fixme[31]: Expression `Literal[True]` is not a valid type.
# pyre-fixme[9]: return_convergence_delta has type `Literal[]`; used as `bool`.
# pyre-fixme[31]: Expression `Literal[False]` is not a valid type.
# pyre-fixme[24]: Non-generic type `typing.Literal` cannot take parameters.
return_convergence_delta: Literal[True],
return_convergence_delta: Literal[False] = False,
attribute_to_layer_input: bool = False,
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
) -> Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]: ...
) -> Union[Tensor, Tuple[Tensor, ...]]: ...

@log_usage()
# pyre-fixme[43]: This definition does not have the same decorators as the
Expand Down
Loading

0 comments on commit e5d6d6b

Please sign in to comment.