Skip to content

Commit

Permalink
Merge pull request #12 from simon-hirsch:fix_gram_not_initialised
Browse files Browse the repository at this point in the history
Fix_gram_not_initialised
  • Loading branch information
BerriJ authored Jul 25, 2024
2 parents 14e5ee2 + 7648abb commit decf44b
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 18 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "rolch"
version = "0.1.6"
version = "0.1.7"
authors = [
{name="Simon Hirsch", email="[email protected]"},
{name="Jonathan Berrisch", email="[email protected]"},
Expand Down
46 changes: 29 additions & 17 deletions src/rolch/online_gamlss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,16 @@
online_coordinate_descent,
online_coordinate_descent_path,
)
from rolch.gram import init_gram, init_inverted_gram, init_y_gram, update_inverted_gram
from rolch.gram import (
init_gram,
init_inverted_gram,
init_y_gram,
update_inverted_gram,
init_forget_vector,
)
from rolch.information_criteria import select_best_model_by_information_criterion
from rolch.scaler import OnlineScaler
from rolch.utils import calculate_effective_training_length
from rolch.utils import calculate_effective_training_length, online_mean_update


class OnlineGamlss:
Expand Down Expand Up @@ -135,6 +141,9 @@ def fit_beta(
iteration_inner,
param,
):

f = init_forget_vector(self.forget, self.n_obs)

if self.method == "ols":
lambda_max = None
lambda_path = None
Expand All @@ -143,10 +152,7 @@ def fit_beta(
beta = (x_gram @ y_gram).flatten()
residuals = y - X @ beta.T

if self.method == "ols" or self.intercept_only[param]:
rss = np.sum(residuals**2 * w) / np.mean(w)
else:
rss = np.sum(residuals**2 * w[:, None], axis=0) / np.mean(w)
rss = np.sum(residuals**2 * w * f) / np.mean(w * f)

elif (self.method == "lasso") & self.intercept_only[param]:
lambda_max = None
Expand All @@ -167,10 +173,7 @@ def fit_beta(
)[0]
residuals = y - X @ beta.T

if self.method == "ols" or self.intercept_only[param]:
rss = np.sum(residuals**2 * w, axis=0) / np.mean(w)
elif self.method == "lasso":
rss = np.sum(residuals**2 * w[:, None], axis=0) / np.mean(w)
rss = np.sum(residuals**2 * w * f, axis=0) / np.mean(w * f)

elif self.method == "lasso":
intercept = (
Expand All @@ -197,10 +200,9 @@ def fit_beta(

residuals = y[:, None] - X @ beta_path.T

if self.method == "ols" or self.intercept_only[param]:
rss = np.sum(residuals**2 * w, axis=0) / np.mean(w)
elif self.method == "lasso":
rss = np.sum(residuals**2 * w[:, None], axis=0) / np.mean(w)
rss = np.sum(residuals**2 * w[:, None] * f[:, None], axis=0) / np.mean(
w * f
)

model_params_n = np.sum(~np.isclose(beta_path, 0), axis=1)
best_ic = select_best_model_by_information_criterion(
Expand Down Expand Up @@ -228,6 +230,11 @@ def update_beta(
iteration_inner,
param,
):

denom = online_mean_update(
self.mean_of_weights[param], w, self.forget, self.n_obs
)

if self.method == "ols":
# Not relevant for OLS
lambda_max = None
Expand All @@ -240,7 +247,7 @@ def update_beta(
rss = (
(residuals**2).flatten() * w
+ (1 - self.forget) * (self.rss[param] * self.mean_of_weights[param])
) / (self.mean_of_weights[param] * (1 - self.forget) + w)
) / denom

elif (self.method == "lasso") & self.intercept_only[param]:
lambda_max = None
Expand All @@ -264,7 +271,7 @@ def update_beta(
rss = (
(residuals**2).flatten() * w
+ (1 - self.forget) * (self.rss[param] * self.mean_of_weights[param])
) / (self.mean_of_weights[param] * (1 - self.forget) + w)
) / denom

elif self.method == "lasso":
intercept = (
Expand Down Expand Up @@ -293,7 +300,7 @@ def update_beta(
rss = (
(residuals**2).flatten() * w
+ (1 - self.forget) * (self.rss[param] * self.mean_of_weights[param])
) / (self.mean_of_weights[param] * (1 - self.forget) + w)
) / denom

model_params_n = np.sum(np.isclose(beta_path, 0), axis=1)
best_ic = select_best_model_by_information_criterion(
Expand Down Expand Up @@ -823,6 +830,11 @@ def _inner_update(
lambda_max = lambda_max[param]
lambda_path = lambda_path[param]

## TODO REFACTOR: Will be returned if we converge in the first iteration
## Will be overwritten if don't converge in the first iteration
x_gram_it = self.x_gram_inner[param]
y_gram_it = self.y_gram_inner[param]

di = -2 * np.log(self.distribution.pdf(y, fv))
dv = (1 - self.forget) * self.global_dev + np.sum(di * w)
olddv = dv + 1
Expand Down
19 changes: 19 additions & 0 deletions src/rolch/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import numpy as np


def calculate_asymptotic_training_length(forget: float):
Expand All @@ -14,3 +15,21 @@ def calculate_effective_training_length(forget: float, n_obs: int):
return n_obs
else:
return (1 - (1 - forget) ** n_obs) / forget


def online_mean_update(avg: float, value: float, forget: float, n_seen: int):

n_asymmptotic = calculate_asymptotic_training_length(forget)
n_eff = calculate_effective_training_length(forget, n_seen)

forget_scaled = forget * np.maximum(n_asymmptotic / n_eff, 1.0)

diff = value - avg
incr = forget_scaled * diff

if forget_scaled > 0:
new_avg = avg + incr
else:
new_avg = avg + diff / n_seen

return new_avg

0 comments on commit decf44b

Please sign in to comment.