Skip to content

Commit

Permalink
Reduce complexity of 'sklearn_train_linear_model' (#1375)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1375

Reduce complexity of 'sklearn_train_linear_model'

Reviewed By: jsawruk

Differential Revision: D64438317

fbshipit-source-id: aa99f2ec9d9a0b349a423fc0a37e9b21b2a0ff39
  • Loading branch information
craymichael authored and facebook-github-bot committed Oct 17, 2024
1 parent 9689ccd commit 28c58da
Showing 1 changed file with 33 additions and 24 deletions.
57 changes: 33 additions & 24 deletions captum/_utils/models/linear_model/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# pyre-strict
import time
import warnings
from functools import reduce
from types import ModuleType
from typing import Any, Callable, cast, Dict, List, Optional, Tuple

import torch
Expand Down Expand Up @@ -282,14 +284,38 @@ def forward(self, x):
return (x - self.mean) / (self.std + self.eps)


def _import_sklearn() -> ModuleType:
try:
import sklearn
import sklearn.linear_model
import sklearn.svm
except ImportError:
raise ValueError("sklearn is not available. Please install sklearn >= 0.23")

if not sklearn.__version__ >= "0.23.0":
warnings.warn(
"Must have sklearn version 0.23.0 or higher to use "
"sample_weight in Lasso regression.",
stacklevel=1,
)
return sklearn


def _import_numpy() -> ModuleType:
try:
import numpy
except ImportError:
raise ValueError("numpy is not available. Please install numpy.")
return numpy


def sklearn_train_linear_model(
model: LinearModel,
dataloader: DataLoader,
construct_kwargs: Dict[str, Any],
sklearn_trainer: str = "Lasso",
norm_input: bool = False,
# pyre-fixme[2]: Parameter must be annotated.
**fit_kwargs,
**fit_kwargs: Any,
) -> Dict[str, float]:
r"""
Alternative method to train with sklearn. This does introduce some slight
Expand Down Expand Up @@ -318,26 +344,9 @@ def sklearn_train_linear_model(
fit_kwargs
Other arguments to send to `sklearn_trainer`'s `.fit` method
"""
from functools import reduce

try:
import numpy as np
except ImportError:
raise ValueError("numpy is not available. Please install numpy.")

try:
import sklearn
import sklearn.linear_model
import sklearn.svm
except ImportError:
raise ValueError("sklearn is not available. Please install sklearn >= 0.23")

if not sklearn.__version__ >= "0.23.0":
warnings.warn(
"Must have sklearn version 0.23.0 or higher to use "
"sample_weight in Lasso regression.",
stacklevel=1,
)
# Lazy imports
np = _import_numpy()
sklearn = _import_sklearn()

num_batches = 0
xs, ys, ws = [], [], []
Expand Down Expand Up @@ -369,8 +378,8 @@ def sklearn_train_linear_model(

t1 = time.time()
# pyre-fixme[29]: `str` is not a function.
sklearn_model = reduce(
lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".")
sklearn_model = reduce( # type: ignore
lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".") # type: ignore # noqa: E501
)(**construct_kwargs)
try:
sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs)
Expand Down

0 comments on commit 28c58da

Please sign in to comment.