Skip to content

Commit

Permalink
Reduce complexity of 'sgd_train_linear_model' (#1374)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1374

Reduce complexity of 'sgd_train_linear_model'

Reviewed By: jsawruk

Differential Revision: D64432524

fbshipit-source-id: f9c9a4e2f3d7b8a38f3de68b8cc729f3338ffe30
  • Loading branch information
craymichael authored and facebook-github-bot committed Oct 17, 2024
1 parent 8ac2898 commit 9689ccd
Showing 1 changed file with 102 additions and 68 deletions.
170 changes: 102 additions & 68 deletions captum/_utils/models/linear_model/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pyre-strict
import time
import warnings
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, cast, Dict, List, Optional, Tuple

import torch
import torch.nn as nn
Expand All @@ -17,6 +17,82 @@ def l2_loss(x1, x2, weights=None) -> torch.Tensor:
return torch.sum((weights / weights.norm(p=1)) * ((x1 - x2) ** 2)) / 2.0


class ConvergenceTracker:
def __init__(self, patience: int, threshold: float) -> None:
self.min_avg_loss: Optional[torch.Tensor] = None
self.convergence_counter: int = 0
self.converged: bool = False

self.threshold = threshold
self.patience = patience

def update(self, average_loss: torch.Tensor) -> bool:
if self.min_avg_loss is not None:
# if we haven't improved by at least `threshold`
if average_loss > self.min_avg_loss or torch.isclose(
cast(torch.Tensor, self.min_avg_loss), average_loss, atol=self.threshold
):
self.convergence_counter += 1
if self.convergence_counter >= self.patience:
self.converged = True
return True
else:
self.convergence_counter = 0
if self.min_avg_loss is None or self.min_avg_loss >= average_loss:
self.min_avg_loss = average_loss.clone()
return False


class LossWindow:
def __init__(self, window_size: int) -> None:
self.loss_window: List[torch.Tensor] = []
self.window_size = window_size

def append(self, loss: torch.Tensor) -> None:
if len(self.loss_window) >= self.window_size:
self.loss_window = self.loss_window[-self.window_size :]
self.loss_window.append(loss)

def average(self) -> torch.Tensor:
return torch.mean(torch.stack(self.loss_window))


def _init_linear_model(model: LinearModel, init_scheme: Optional[str] = None) -> None:
assert model.linear is not None
if init_scheme is not None:
assert init_scheme in ["xavier", "zeros"]

with torch.no_grad():
if init_scheme == "xavier":
# pyre-fixme[16]: `Optional` has no attribute `weight`.
torch.nn.init.xavier_uniform_(model.linear.weight)
else:
model.linear.weight.zero_()

# pyre-fixme[16]: `Optional` has no attribute `bias`.
if model.linear.bias is not None:
model.linear.bias.zero_()


def _get_point(
datapoint: Tuple[torch.Tensor, ...],
device: Optional[str] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if len(datapoint) == 2:
x, y = datapoint
w = None
else:
x, y, w = datapoint

if device is not None:
x = x.to(device)
y = y.to(device)
if w is not None:
w = w.to(device)

return x, y, w


def sgd_train_linear_model(
model: LinearModel,
dataloader: DataLoader,
Expand Down Expand Up @@ -102,31 +178,16 @@ def sgd_train_linear_model(
This will return the final training loss (averaged with
`running_loss_window`)
"""
loss_window: List[torch.Tensor] = []
min_avg_loss = None
convergence_counter = 0
converged = False

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def get_point(datapoint):
if len(datapoint) == 2:
x, y = datapoint
w = None
else:
x, y, w = datapoint

if device is not None:
x = x.to(device)
y = y.to(device)
if w is not None:
w = w.to(device)

return x, y, w
converge_tracker = ConvergenceTracker(patience, threshold)

# get a point and construct the model
data_iter = iter(dataloader)
x, y, w = get_point(next(data_iter))
x, y, w = _get_point(next(data_iter), device)

if running_loss_window is None:
running_loss_window = x.shape[0] * len(dataloader)

loss_window = LossWindow(running_loss_window)

model._construct_model_params(
in_features=x.shape[1],
Expand All @@ -135,21 +196,8 @@ def get_point(datapoint):
)
model.train()

assert model.linear is not None

if init_scheme is not None:
assert init_scheme in ["xavier", "zeros"]

with torch.no_grad():
if init_scheme == "xavier":
# pyre-fixme[16]: `Optional` has no attribute `weight`.
torch.nn.init.xavier_uniform_(model.linear.weight)
else:
model.linear.weight.zero_()

# pyre-fixme[16]: `Optional` has no attribute `bias`.
if model.linear.bias is not None:
model.linear.bias.zero_()
# Initialize linear model weights if applicable
_init_linear_model(model, init_scheme)

with torch.enable_grad():
optim = torch.optim.SGD(model.parameters(), lr=initial_lr)
Expand All @@ -163,9 +211,6 @@ def get_point(datapoint):
i = 0
while epoch < max_epoch:
while True: # for x, y, w in dataloader
if running_loss_window is None:
running_loss_window = x.shape[0] * len(dataloader)

y = y.view(x.shape[0], -1)
if w is not None:
w = w.view(x.shape[0], -1)
Expand All @@ -176,33 +221,20 @@ def get_point(datapoint):

loss = loss_fn(y, out, w)
if reg_term is not None:
reg = torch.norm(model.linear.weight, p=reg_term)
# pyre-fixme[16]: `Optional` has no attribute `weight`.
reg = torch.norm(model.linear.weight, p=reg_term) # type: ignore
loss += reg.sum() * alpha

if len(loss_window) >= running_loss_window:
loss_window = loss_window[1:]
loss_window.append(loss.clone().detach())
assert len(loss_window) <= running_loss_window

average_loss = torch.mean(torch.stack(loss_window))
if min_avg_loss is not None:
# if we haven't improved by at least `threshold`
if average_loss > min_avg_loss or torch.isclose(
min_avg_loss, average_loss, atol=threshold
):
convergence_counter += 1
if convergence_counter >= patience:
converged = True
break
else:
convergence_counter = 0
if min_avg_loss is None or min_avg_loss >= average_loss:
min_avg_loss = average_loss.clone()
average_loss = loss_window.average()
if converge_tracker.update(average_loss):
break # converged

if debug:
print(
f"lr={optim.param_groups[0]['lr']}, Loss={loss},"
+ "Aloss={average_loss}, min_avg_loss={min_avg_loss}"
f"lr={optim.param_groups[0]['lr']}, Loss={loss}, "
f"Aloss={average_loss}, "
f"min_avg_loss={converge_tracker.min_avg_loss}"
)

loss.backward()
Expand All @@ -215,19 +247,19 @@ def get_point(datapoint):
temp = next(data_iter, None)
if temp is None:
break
x, y, w = get_point(temp)
x, y, w = _get_point(temp, device)

if converged:
if converge_tracker.converged:
break

epoch += 1
data_iter = iter(dataloader)
x, y, w = get_point(next(data_iter))
x, y, w = _get_point(next(data_iter), device)

t2 = time.time()
return {
"train_time": t2 - t1,
"train_loss": torch.mean(torch.stack(loss_window)).item(),
"train_loss": loss_window.average().item(),
"train_iter": i,
"train_epoch": epoch,
}
Expand Down Expand Up @@ -303,7 +335,8 @@ def sklearn_train_linear_model(
if not sklearn.__version__ >= "0.23.0":
warnings.warn(
"Must have sklearn version 0.23.0 or higher to use "
"sample_weight in Lasso regression."
"sample_weight in Lasso regression.",
stacklevel=1,
)

num_batches = 0
Expand Down Expand Up @@ -346,7 +379,8 @@ def sklearn_train_linear_model(
warnings.warn(
"Sample weight is not supported for the provided linear model!"
" Trained model without weighting inputs. For Lasso, please"
" upgrade sklearn to a version >= 0.23.0."
" upgrade sklearn to a version >= 0.23.0.",
stacklevel=1,
)

t2 = time.time()
Expand Down

0 comments on commit 9689ccd

Please sign in to comment.