Skip to content

Commit

Permalink
modify import (#3293)
Browse files Browse the repository at this point in the history
  • Loading branch information
kzkadc authored Oct 14, 2024
1 parent a5d3464 commit 8782ca4
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
10 changes: 6 additions & 4 deletions ignite/metrics/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@


def precision_recall_curve_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor) -> Tuple[Any, Any, Any]:
try:
from sklearn.metrics import precision_recall_curve
except ImportError:
raise ModuleNotFoundError("This contrib module requires scikit-learn to be installed.")
from sklearn.metrics import precision_recall_curve

y_true = y_targets.cpu().numpy()
y_pred = y_preds.cpu().numpy()
Expand Down Expand Up @@ -83,6 +80,11 @@ def __init__(
device: Union[str, torch.device] = torch.device("cpu"),
skip_unrolling: bool = False,
) -> None:
try:
from sklearn.metrics import precision_recall_curve # noqa: F401
except ImportError:
raise ModuleNotFoundError("This module requires scikit-learn to be installed.")

super(PrecisionRecallCurve, self).__init__(
precision_recall_curve_compute_fn, # type: ignore[arg-type]
output_transform=output_transform,
Expand Down
2 changes: 1 addition & 1 deletion tests/ignite/metrics/test_precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def mock_no_sklearn():


def test_no_sklearn(mock_no_sklearn):
with pytest.raises(ModuleNotFoundError, match=r"This contrib module requires scikit-learn to be installed."):
with pytest.raises(ModuleNotFoundError, match=r"This module requires scikit-learn to be installed."):
y = torch.tensor([1, 1])
pr_curve = PrecisionRecallCurve()
pr_curve.update((y, y))
Expand Down

0 comments on commit 8782ca4

Please sign in to comment.