Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce complexity of logging try-catch + add typing #1380

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 133 additions & 90 deletions captum/_utils/models/linear_model/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# pyre-strict
import time
import warnings
from typing import Any, Callable, Dict, List, Optional
from functools import reduce
from types import ModuleType
from typing import Any, Callable, cast, Dict, List, Optional, Tuple

import torch
import torch.nn as nn
Expand All @@ -17,6 +19,82 @@ def l2_loss(x1, x2, weights=None) -> torch.Tensor:
return torch.sum((weights / weights.norm(p=1)) * ((x1 - x2) ** 2)) / 2.0


class ConvergenceTracker:
def __init__(self, patience: int, threshold: float) -> None:
self.min_avg_loss: Optional[torch.Tensor] = None
self.convergence_counter: int = 0
self.converged: bool = False

self.threshold = threshold
self.patience = patience

def update(self, average_loss: torch.Tensor) -> bool:
if self.min_avg_loss is not None:
# if we haven't improved by at least `threshold`
if average_loss > self.min_avg_loss or torch.isclose(
cast(torch.Tensor, self.min_avg_loss), average_loss, atol=self.threshold
):
self.convergence_counter += 1
if self.convergence_counter >= self.patience:
self.converged = True
return True
else:
self.convergence_counter = 0
if self.min_avg_loss is None or self.min_avg_loss >= average_loss:
self.min_avg_loss = average_loss.clone()
return False


class LossWindow:
def __init__(self, window_size: int) -> None:
self.loss_window: List[torch.Tensor] = []
self.window_size = window_size

def append(self, loss: torch.Tensor) -> None:
if len(self.loss_window) >= self.window_size:
self.loss_window = self.loss_window[-self.window_size :]
self.loss_window.append(loss)

def average(self) -> torch.Tensor:
return torch.mean(torch.stack(self.loss_window))


def _init_linear_model(model: LinearModel, init_scheme: Optional[str] = None) -> None:
assert model.linear is not None
if init_scheme is not None:
assert init_scheme in ["xavier", "zeros"]

with torch.no_grad():
if init_scheme == "xavier":
# pyre-fixme[16]: `Optional` has no attribute `weight`.
torch.nn.init.xavier_uniform_(model.linear.weight)
else:
model.linear.weight.zero_()

# pyre-fixme[16]: `Optional` has no attribute `bias`.
if model.linear.bias is not None:
model.linear.bias.zero_()


def _get_point(
datapoint: Tuple[torch.Tensor, ...],
device: Optional[str] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if len(datapoint) == 2:
x, y = datapoint
w = None
else:
x, y, w = datapoint

if device is not None:
x = x.to(device)
y = y.to(device)
if w is not None:
w = w.to(device)

return x, y, w


def sgd_train_linear_model(
model: LinearModel,
dataloader: DataLoader,
Expand Down Expand Up @@ -102,31 +180,16 @@ def sgd_train_linear_model(
This will return the final training loss (averaged with
`running_loss_window`)
"""
loss_window: List[torch.Tensor] = []
min_avg_loss = None
convergence_counter = 0
converged = False

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def get_point(datapoint):
if len(datapoint) == 2:
x, y = datapoint
w = None
else:
x, y, w = datapoint

if device is not None:
x = x.to(device)
y = y.to(device)
if w is not None:
w = w.to(device)

return x, y, w
converge_tracker = ConvergenceTracker(patience, threshold)

# get a point and construct the model
data_iter = iter(dataloader)
x, y, w = get_point(next(data_iter))
x, y, w = _get_point(next(data_iter), device)

if running_loss_window is None:
running_loss_window = x.shape[0] * len(dataloader)

loss_window = LossWindow(running_loss_window)

model._construct_model_params(
in_features=x.shape[1],
Expand All @@ -135,21 +198,8 @@ def get_point(datapoint):
)
model.train()

assert model.linear is not None

if init_scheme is not None:
assert init_scheme in ["xavier", "zeros"]

with torch.no_grad():
if init_scheme == "xavier":
# pyre-fixme[16]: `Optional` has no attribute `weight`.
torch.nn.init.xavier_uniform_(model.linear.weight)
else:
model.linear.weight.zero_()

# pyre-fixme[16]: `Optional` has no attribute `bias`.
if model.linear.bias is not None:
model.linear.bias.zero_()
# Initialize linear model weights if applicable
_init_linear_model(model, init_scheme)

with torch.enable_grad():
optim = torch.optim.SGD(model.parameters(), lr=initial_lr)
Expand All @@ -163,9 +213,6 @@ def get_point(datapoint):
i = 0
while epoch < max_epoch:
while True: # for x, y, w in dataloader
if running_loss_window is None:
running_loss_window = x.shape[0] * len(dataloader)

y = y.view(x.shape[0], -1)
if w is not None:
w = w.view(x.shape[0], -1)
Expand All @@ -176,33 +223,20 @@ def get_point(datapoint):

loss = loss_fn(y, out, w)
if reg_term is not None:
reg = torch.norm(model.linear.weight, p=reg_term)
# pyre-fixme[16]: `Optional` has no attribute `weight`.
reg = torch.norm(model.linear.weight, p=reg_term) # type: ignore
loss += reg.sum() * alpha

if len(loss_window) >= running_loss_window:
loss_window = loss_window[1:]
loss_window.append(loss.clone().detach())
assert len(loss_window) <= running_loss_window

average_loss = torch.mean(torch.stack(loss_window))
if min_avg_loss is not None:
# if we haven't improved by at least `threshold`
if average_loss > min_avg_loss or torch.isclose(
min_avg_loss, average_loss, atol=threshold
):
convergence_counter += 1
if convergence_counter >= patience:
converged = True
break
else:
convergence_counter = 0
if min_avg_loss is None or min_avg_loss >= average_loss:
min_avg_loss = average_loss.clone()
average_loss = loss_window.average()
if converge_tracker.update(average_loss):
break # converged

if debug:
print(
f"lr={optim.param_groups[0]['lr']}, Loss={loss},"
+ "Aloss={average_loss}, min_avg_loss={min_avg_loss}"
f"lr={optim.param_groups[0]['lr']}, Loss={loss}, "
f"Aloss={average_loss}, "
f"min_avg_loss={converge_tracker.min_avg_loss}"
)

loss.backward()
Expand All @@ -215,19 +249,19 @@ def get_point(datapoint):
temp = next(data_iter, None)
if temp is None:
break
x, y, w = get_point(temp)
x, y, w = _get_point(temp, device)

if converged:
if converge_tracker.converged:
break

epoch += 1
data_iter = iter(dataloader)
x, y, w = get_point(next(data_iter))
x, y, w = _get_point(next(data_iter), device)

t2 = time.time()
return {
"train_time": t2 - t1,
"train_loss": torch.mean(torch.stack(loss_window)).item(),
"train_loss": loss_window.average().item(),
"train_iter": i,
"train_epoch": epoch,
}
Expand All @@ -250,14 +284,38 @@ def forward(self, x):
return (x - self.mean) / (self.std + self.eps)


def _import_sklearn() -> ModuleType:
try:
import sklearn
import sklearn.linear_model
import sklearn.svm
except ImportError:
raise ValueError("sklearn is not available. Please install sklearn >= 0.23")

if not sklearn.__version__ >= "0.23.0":
warnings.warn(
"Must have sklearn version 0.23.0 or higher to use "
"sample_weight in Lasso regression.",
stacklevel=1,
)
return sklearn


def _import_numpy() -> ModuleType:
try:
import numpy
except ImportError:
raise ValueError("numpy is not available. Please install numpy.")
return numpy


def sklearn_train_linear_model(
model: LinearModel,
dataloader: DataLoader,
construct_kwargs: Dict[str, Any],
sklearn_trainer: str = "Lasso",
norm_input: bool = False,
# pyre-fixme[2]: Parameter must be annotated.
**fit_kwargs,
**fit_kwargs: Any,
) -> Dict[str, float]:
r"""
Alternative method to train with sklearn. This does introduce some slight
Expand Down Expand Up @@ -286,25 +344,9 @@ def sklearn_train_linear_model(
fit_kwargs
Other arguments to send to `sklearn_trainer`'s `.fit` method
"""
from functools import reduce

try:
import numpy as np
except ImportError:
raise ValueError("numpy is not available. Please install numpy.")

try:
import sklearn
import sklearn.linear_model
import sklearn.svm
except ImportError:
raise ValueError("sklearn is not available. Please install sklearn >= 0.23")

if not sklearn.__version__ >= "0.23.0":
warnings.warn(
"Must have sklearn version 0.23.0 or higher to use "
"sample_weight in Lasso regression."
)
# Lazy imports
np = _import_numpy()
sklearn = _import_sklearn()

num_batches = 0
xs, ys, ws = [], [], []
Expand Down Expand Up @@ -336,8 +378,8 @@ def sklearn_train_linear_model(

t1 = time.time()
# pyre-fixme[29]: `str` is not a function.
sklearn_model = reduce(
lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".")
sklearn_model = reduce( # type: ignore
lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".") # type: ignore # noqa: E501
)(**construct_kwargs)
try:
sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs)
Expand All @@ -346,7 +388,8 @@ def sklearn_train_linear_model(
warnings.warn(
"Sample weight is not supported for the provided linear model!"
" Trained model without weighting inputs. For Lasso, please"
" upgrade sklearn to a version >= 0.23.0."
" upgrade sklearn to a version >= 0.23.0.",
stacklevel=1,
)

t2 = time.time()
Expand Down
Loading
Loading