Skip to content

Commit

Permalink
Address remaining mypy errors (#1383)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1383

Misc fixes to remaining mypy errors

Reviewed By: vivekmig

Differential Revision: D64518879
  • Loading branch information
craymichael authored and facebook-github-bot committed Oct 17, 2024
1 parent 6afbf2c commit 982f27b
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 6 deletions.
2 changes: 2 additions & 0 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ def attribute(
skip_tokens = self.tokenizer.convert_tokens_to_ids(skip_tokens)
else:
skip_tokens = []
skip_tokens = cast(List[int], skip_tokens)

if isinstance(target, str):
encoded = self.tokenizer.encode(target)
Expand Down Expand Up @@ -700,6 +701,7 @@ def attribute(
skip_tokens = self.tokenizer.convert_tokens_to_ids(skip_tokens)
else:
skip_tokens = []
skip_tokens = cast(List[int], skip_tokens)

if isinstance(target, str):
encoded = self.tokenizer.encode(target)
Expand Down
2 changes: 1 addition & 1 deletion captum/attr/_core/occlusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def _occlusion_mask(
def _get_feature_range_and_mask(
self, input: Tensor, input_mask: Optional[Tensor], **kwargs: Any
) -> Tuple[int, int, Union[None, Tensor, Tuple[Tensor, ...]]]:
feature_max = np.prod(kwargs["shift_counts"])
feature_max = int(np.prod(kwargs["shift_counts"]))
return 0, feature_max, None

def _get_feature_counts(
Expand Down
4 changes: 3 additions & 1 deletion captum/attr/_utils/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,9 @@ def multiplies_by_inputs(self) -> bool:
return True


class InternalAttribution(Attribution, Generic[ModuleOrModuleList]):
# mypy false positive "Free type variable expected in Generic[...]" but
# ModuleOrModuleList is a TypeVar
class InternalAttribution(Attribution, Generic[ModuleOrModuleList]): # type: ignore
r"""
Shared base class for LayerAttrubution and NeuronAttribution,
attribution types that require a model and a particular layer.
Expand Down
5 changes: 2 additions & 3 deletions captum/influence/_core/tracincp_fast_rand_proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def __init__(
self.vectorize = vectorize

# TODO: restore prior state
self.final_fc_layer = final_fc_layer # type: ignore
self.final_fc_layer = cast(Module, final_fc_layer)
for param in self.final_fc_layer.parameters():
param.requires_grad = True

Expand All @@ -212,8 +212,7 @@ def final_fc_layer(self) -> Module:
return self._final_fc_layer

@final_fc_layer.setter
# pyre-fixme[3]: Return type must be annotated.
def final_fc_layer(self, layer: Union[Module, str]):
def final_fc_layer(self, layer: Union[Module, str]) -> None:
if isinstance(layer, str):
try:
self._final_fc_layer = _get_module_from_name(self.model, layer)
Expand Down
2 changes: 1 addition & 1 deletion tests/attr/test_interpretable_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def encode(self, text: str, return_tensors: None = None) -> List[int]: ...
# pyre-fixme[43]: Incompatible overload. The implementation of
# `DummyTokenizer.encode` does not accept all possible arguments of overload.
# pyre-ignore[11]: Annotation `pt` is not defined as a type
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ...
def encode(self, text: str, return_tensors: Literal["pt"]) -> Tensor: ... # type: ignore # noqa: E501 line too long

def encode(
self, text: str, return_tensors: Optional[str] = "pt"
Expand Down

0 comments on commit 982f27b

Please sign in to comment.