diff --git a/dask_glm/algorithms.py b/dask_glm/algorithms.py index 4a583b4..b31e4ea 100644 --- a/dask_glm/algorithms.py +++ b/dask_glm/algorithms.py @@ -2,6 +2,7 @@ """ from __future__ import absolute_import, division, print_function +from warnings import warn from dask import delayed, persist, compute, set_options import functools @@ -142,53 +143,68 @@ def gradient_descent(X, y, max_iter=100, tol=1e-14, family=Logistic, **kwargs): @normalize def sgd(X, y, epochs=100, tol=1e-3, family=Logistic, batch_size=64, - initial_step=1e-4, callback=None, average=True): + initial_step=1e-4, callback=None, average=True, maxiter=np.inf, **kwargs): """Stochastic Gradient Descent. Parameters ---------- - X : array-like, shape (n_samples, n_features) - y : array-like, shape (n_samples,) - epochs : int, float + X: array - like, shape(n_samples, n_features) + y: array - like, shape(n_samples,) + epochs: int, float maximum number of passes through the dataset - tol : float + tol: float Maximum allowed change from prior iteration required to declare convergence - batch_size : int + batch_size: int The batch size used to approximate the gradient. Larger batch sizes will approximate the gradient better. - initial_step : float - The initial step size. The step size is decays like 1/k. - callback : callable + initial_step: float + The initial step size. The step size is decays like 1 / k. + callback: callable A callback to call every iteration that accepts keyword arguments `X`, `y`, `beta`, `grad`, `nit` (number of iterations) and `family` - average : bool - To average the parameters found or not. See [1]_. - family : Family + average: bool + To average the parameters found or not. See[1]_. + family: Family Returns ------- beta : array-like, shape (n_features,) + Notes + ----- + The dataset is not shuffled beforehand. It is assumed that when + approximating the gradient each row accessed from the dataset (which is + accessed sequentially) is an indepedent and identically distributed (iid) + approximations to the gradient. Shuffling the array beforehand should + ensure this. + .. _1: https://en.wikipedia.org/wiki/Stochastic_gradient_descent#Averaging """ gradient = family.gradient n, p = X.shape if np.isnan(n): - raise ValueError('SGD needs shape information to allow indexing. ' - 'Possible by passing a computed array in (`X.compute()` ' - 'or `X.values.compute()`), then doing using ' - '`dask.array.from_array ') + n = kwargs.get('n', np.nan) + if np.isnan(n): + warn('Computing X to find the number of examples. Pass the number ' + 'examples in as a keyword argument `n` to avoid this. i.e., ' + '`sgd(..., n=num_examples)` or `sgd(..., n=len(X.compute()))`.') + n = len(X.compute()) beta = np.zeros(p) if average: beta_sum = np.zeros(p) nit = 0 - for epoch in range(epochs): - j = np.random.permutation(n) - X = X[j] - y = y[j] + + # step_size = O(1/sqrt(k)) from "Non-asymptotic analysis of + # stochastic approximation algorithms for machine learning" by + # Moulines, Eric and Bach, Francis Rsgd + # but, this may require many iterations. Using + # step_size = lambda init, nit, decay: init * decay**(nit//n) + # is used in practice but not testing now + step_size = lambda init, nit: init / np.sqrt(nit + 1) + while True: for k in range(n // batch_size): beta_old = beta.copy() nit += 1 @@ -197,11 +213,7 @@ def sgd(X, y, epochs=100, tol=1e-3, family=Logistic, batch_size=64, Xbeta = dot(X[i], beta) grad = gradient(Xbeta, X[i], y[i]).compute() - # step_size = O(1/sqrt(k)) from "Non-asymptotic analysis of - # stochastic approximation algorithms for machine learning" by - # Moulines, Eric and Bach, Francis Rsgd - step_size = initial_step / np.sqrt(nit + 1) - beta -= step_size * (n / batch_size) * grad + beta -= step_size(initial_step, nit) * (n / batch_size) * grad if average: beta_sum += beta if callback: @@ -209,7 +221,7 @@ def sgd(X, y, epochs=100, tol=1e-3, family=Logistic, batch_size=64, beta=beta if not average else beta_sum / nit) rel_error = LA.norm(beta_old - beta) / LA.norm(beta) - converged = (rel_error < tol) or (nit / n > epochs) + converged = (rel_error < tol) or (nit / n > epochs) or (nit > maxiter) if converged: break if average: