From 28c58da16bd94ce06a342686f9bfff7081139e7d Mon Sep 17 00:00:00 2001 From: Zach Carmichael Date: Thu, 17 Oct 2024 11:02:04 -0700 Subject: [PATCH] Reduce complexity of 'sklearn_train_linear_model' (#1375) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/1375 Reduce complexity of 'sklearn_train_linear_model' Reviewed By: jsawruk Differential Revision: D64438317 fbshipit-source-id: aa99f2ec9d9a0b349a423fc0a37e9b21b2a0ff39 --- captum/_utils/models/linear_model/train.py | 57 +++++++++++++--------- 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/captum/_utils/models/linear_model/train.py b/captum/_utils/models/linear_model/train.py index 2ae6db73d..37f1507c9 100644 --- a/captum/_utils/models/linear_model/train.py +++ b/captum/_utils/models/linear_model/train.py @@ -1,6 +1,8 @@ # pyre-strict import time import warnings +from functools import reduce +from types import ModuleType from typing import Any, Callable, cast, Dict, List, Optional, Tuple import torch @@ -282,14 +284,38 @@ def forward(self, x): return (x - self.mean) / (self.std + self.eps) +def _import_sklearn() -> ModuleType: + try: + import sklearn + import sklearn.linear_model + import sklearn.svm + except ImportError: + raise ValueError("sklearn is not available. Please install sklearn >= 0.23") + + if not sklearn.__version__ >= "0.23.0": + warnings.warn( + "Must have sklearn version 0.23.0 or higher to use " + "sample_weight in Lasso regression.", + stacklevel=1, + ) + return sklearn + + +def _import_numpy() -> ModuleType: + try: + import numpy + except ImportError: + raise ValueError("numpy is not available. Please install numpy.") + return numpy + + def sklearn_train_linear_model( model: LinearModel, dataloader: DataLoader, construct_kwargs: Dict[str, Any], sklearn_trainer: str = "Lasso", norm_input: bool = False, - # pyre-fixme[2]: Parameter must be annotated. - **fit_kwargs, + **fit_kwargs: Any, ) -> Dict[str, float]: r""" Alternative method to train with sklearn. This does introduce some slight @@ -318,26 +344,9 @@ def sklearn_train_linear_model( fit_kwargs Other arguments to send to `sklearn_trainer`'s `.fit` method """ - from functools import reduce - - try: - import numpy as np - except ImportError: - raise ValueError("numpy is not available. Please install numpy.") - - try: - import sklearn - import sklearn.linear_model - import sklearn.svm - except ImportError: - raise ValueError("sklearn is not available. Please install sklearn >= 0.23") - - if not sklearn.__version__ >= "0.23.0": - warnings.warn( - "Must have sklearn version 0.23.0 or higher to use " - "sample_weight in Lasso regression.", - stacklevel=1, - ) + # Lazy imports + np = _import_numpy() + sklearn = _import_sklearn() num_batches = 0 xs, ys, ws = [], [], [] @@ -369,8 +378,8 @@ def sklearn_train_linear_model( t1 = time.time() # pyre-fixme[29]: `str` is not a function. - sklearn_model = reduce( - lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".") + sklearn_model = reduce( # type: ignore + lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".") # type: ignore # noqa: E501 )(**construct_kwargs) try: sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs)