Skip to content

Commit

Permalink
Reduce complexity of FeatureAblation.attribute_future (#1368)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1368

Reduce complexity of FeatureAblation.attribute_future by refactoring inner function

Reviewed By: cyrjano

Differential Revision: D64361191

fbshipit-source-id: 642df64df6b3dce9d32e7c0f9c26392a963d45e2
  • Loading branch information
craymichael authored and facebook-github-bot committed Oct 15, 2024
1 parent fd758e0 commit 4cb2808
Showing 1 changed file with 65 additions and 65 deletions.
130 changes: 65 additions & 65 deletions captum/attr/_core/feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,73 +555,9 @@ def attribute_future(
]
)

def eval_fut_to_ablated_out_fut(
# pyre-ignore Invalid type parameters [24]
eval_futs: Future[List[Future[List[object]]]],
current_inputs: Tuple[Tensor, ...],
current_mask: Tensor,
i: int,
perturbations_per_eval: int,
num_examples: int,
formatted_inputs: Tuple[Tensor, ...],
) -> Tuple[List[Tensor], List[Tensor]]:
try:
modified_eval = cast(Tensor, eval_futs.value()[1].value())
initial_eval_tuple = cast(
Tuple[
List[Tensor],
List[Tensor],
Tensor,
Tensor,
int,
dtype,
],
eval_futs.value()[0].value(),
)
if len(initial_eval_tuple) != 6:
raise AssertionError(
"eval_fut_to_ablated_out_fut: "
"initial_eval_tuple should have 6 elements: "
"total_attrib, weights, initial_eval, "
"flattened_initial_eval, n_outputs, attrib_type "
)
if not isinstance(modified_eval, Tensor):
raise AssertionError(
"eval_fut_to_ablated_out_fut: "
"modified eval should be a Tensor"
)
(
total_attrib,
weights,
initial_eval,
flattened_initial_eval,
n_outputs,
attrib_type,
) = initial_eval_tuple
result = self._process_ablated_out( # type: ignore # noqa: E501 line too long
modified_eval=modified_eval,
current_inputs=current_inputs,
current_mask=current_mask,
perturbations_per_eval=perturbations_per_eval,
num_examples=num_examples,
initial_eval=initial_eval,
flattened_initial_eval=flattened_initial_eval,
inputs=formatted_inputs,
n_outputs=n_outputs,
total_attrib=total_attrib,
weights=weights,
i=i,
attrib_type=attrib_type,
)
except FeatureAblationFutureError as e:
raise FeatureAblationFutureError(
"eval_fut_to_ablated_out_fut func failed)"
) from e
return result

ablated_out_fut: Future[Tuple[List[Tensor], List[Tensor]]] = (
eval_futs.then(
lambda eval_futs, current_inputs=current_inputs, current_mask=current_mask, i=i: eval_fut_to_ablated_out_fut( # type: ignore # noqa: E501 line too long
lambda eval_futs, current_inputs=current_inputs, current_mask=current_mask, i=i: self._eval_fut_to_ablated_out_fut( # type: ignore # noqa: E501 line too long
eval_futs=eval_futs,
current_inputs=current_inputs,
current_mask=current_mask,
Expand Down Expand Up @@ -660,6 +596,70 @@ def _attribute_progress_setup(
)
return attr_progress

def _eval_fut_to_ablated_out_fut(
self,
# pyre-ignore Invalid type parameters [24]
eval_futs: Future[List[Future[List[object]]]],
current_inputs: Tuple[Tensor, ...],
current_mask: Tensor,
i: int,
perturbations_per_eval: int,
num_examples: int,
formatted_inputs: Tuple[Tensor, ...],
) -> Tuple[List[Tensor], List[Tensor]]:
try:
modified_eval = cast(Tensor, eval_futs.value()[1].value())
initial_eval_tuple = cast(
Tuple[
List[Tensor],
List[Tensor],
Tensor,
Tensor,
int,
dtype,
],
eval_futs.value()[0].value(),
)
if len(initial_eval_tuple) != 6:
raise AssertionError(
"eval_fut_to_ablated_out_fut: "
"initial_eval_tuple should have 6 elements: "
"total_attrib, weights, initial_eval, "
"flattened_initial_eval, n_outputs, attrib_type "
)
if not isinstance(modified_eval, Tensor):
raise AssertionError(
"eval_fut_to_ablated_out_fut: " "modified eval should be a Tensor"
)
(
total_attrib,
weights,
initial_eval,
flattened_initial_eval,
n_outputs,
attrib_type,
) = initial_eval_tuple
result = self._process_ablated_out( # type: ignore # noqa: E501 line too long
modified_eval=modified_eval,
current_inputs=current_inputs,
current_mask=current_mask,
perturbations_per_eval=perturbations_per_eval,
num_examples=num_examples,
initial_eval=initial_eval,
flattened_initial_eval=flattened_initial_eval,
inputs=formatted_inputs,
n_outputs=n_outputs,
total_attrib=total_attrib,
weights=weights,
i=i,
attrib_type=attrib_type,
)
except FeatureAblationFutureError as e:
raise FeatureAblationFutureError(
"eval_fut_to_ablated_out_fut func failed)"
) from e
return result

# pyre-fixme[3]: Return type must be specified as type that does not contain `Any`
def _ith_input_ablation_generator(
self,
Expand Down

0 comments on commit 4cb2808

Please sign in to comment.