Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate estimators #66

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions dask_glm/estimators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
Models following scikit-learn's estimator API.
"""
import warnings

from sklearn.base import BaseEstimator

from . import algorithms
Expand All @@ -10,6 +12,12 @@
poisson_deviance
)

msg = ("The 'dask_glm.estimators' module is deprecated in favor of "
"'dask_ml.linear_models'. Please install 'dask-ml' and update "
"your imports.")

warnings.warn(msg, FutureWarning)


class _GLM(BaseEstimator):
""" Base estimator for Generalized Linear Models
Expand Down
106 changes: 5 additions & 101 deletions dask_glm/tests/test_estimators.py
Original file line number Diff line number Diff line change
@@ -1,106 +1,10 @@
import pytest
import dask

from dask_glm.estimators import LogisticRegression, LinearRegression, PoissonRegression
from dask_glm.datasets import make_classification, make_regression, make_poisson
from dask_glm.regularizers import Regularizer

def test_warns():
with pytest.warns(FutureWarning) as w:
import dask_glm.estimators # noqa

@pytest.fixture(params=[r() for r in Regularizer.__subclasses__()])
def solver(request):
"""Parametrized fixture for all the solver names"""
return request.param


@pytest.fixture(params=[r() for r in Regularizer.__subclasses__()])
def regularizer(request):
"""Parametrized fixture for all the regularizer names"""
return request.param


class DoNothingTransformer(object):
def fit(self, X, y=None):
return self

def transform(self, X, y=None):
return X

def fit_transform(self, X, y=None):
return X

def get_params(self, deep=True):
return {}


X, y = make_classification()


def test_lr_init(solver):
LogisticRegression(solver=solver)


def test_pr_init(solver):
PoissonRegression(solver=solver)


@pytest.mark.parametrize('fit_intercept', [True, False])
def test_fit(fit_intercept):
X, y = make_classification(n_samples=100, n_features=5, chunksize=10)
lr = LogisticRegression(fit_intercept=fit_intercept)
lr.fit(X, y)
lr.predict(X)
lr.predict_proba(X)


@pytest.mark.parametrize('fit_intercept', [True, False])
def test_lm(fit_intercept):
X, y = make_regression(n_samples=100, n_features=5, chunksize=10)
lr = LinearRegression(fit_intercept=fit_intercept)
lr.fit(X, y)
lr.predict(X)
if fit_intercept:
assert lr.intercept_ is not None


@pytest.mark.parametrize('fit_intercept', [True, False])
def test_big(fit_intercept):
with dask.config.set(scheduler='synchronous'):
X, y = make_classification()
lr = LogisticRegression(fit_intercept=fit_intercept)
lr.fit(X, y)
lr.predict(X)
lr.predict_proba(X)
if fit_intercept:
assert lr.intercept_ is not None


@pytest.mark.parametrize('fit_intercept', [True, False])
def test_poisson_fit(fit_intercept):
with dask.config.set(scheduler='synchronous'):
X, y = make_poisson()
pr = PoissonRegression(fit_intercept=fit_intercept)
pr.fit(X, y)
pr.predict(X)
pr.get_deviance(X, y)
if fit_intercept:
assert pr.intercept_ is not None


def test_in_pipeline():
from sklearn.pipeline import make_pipeline
X, y = make_classification(n_samples=100, n_features=5, chunksize=10)
pipe = make_pipeline(DoNothingTransformer(), LogisticRegression())
pipe.fit(X, y)


def test_gridsearch():
from sklearn.pipeline import make_pipeline
dcv = pytest.importorskip('dask_searchcv')

X, y = make_classification(n_samples=100, n_features=5, chunksize=10)
grid = {
'logisticregression__lamduh': [.001, .01, .1, .5]
}
pipe = make_pipeline(DoNothingTransformer(), LogisticRegression())
search = dcv.GridSearchCV(pipe, grid, cv=3)
search.fit(X, y)
assert len(w)
assert 'dask-ml' in str(w[-1])
8 changes: 0 additions & 8 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,6 @@
API Reference
-------------

.. _api.estimators:

Estimators
==========

.. automodule:: dask_glm.estimators
:members:

.. _api.families:

Families
Expand Down
37 changes: 0 additions & 37 deletions docs/estimators.rst

This file was deleted.

5 changes: 3 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ Dask-glm
*Dask-glm is a library for fitting Generalized Linear Models on large datasets*

Dask-glm builds on the `dask`_ project to fit `GLM`_'s on datasets in parallel.
It offers a `scikit-learn`_ compatible API for specifying your model.
It provides the optimizers and regularizers used by libraries like `dask-ml`_,
which builds scikit-learn-style APIs on top of those components.

.. toctree::
:maxdepth: 2
:caption: Contents:

estimators
examples
api

Expand All @@ -30,3 +30,4 @@ Indices and tables
.. _dask: http://dask.pydata.org/en/latest/
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.. _dask: http://dask.pydata.org/en/latest/
.. _dask: https://docs.dask.org/

.. _GLM: https://en.wikipedia.org/wiki/Generalized_linear_model
.. _scikit-learn: http://scikit-learn.org/
.. _dask-ml: http://dask-ml.readthedocs.org/
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.. _dask-ml: http://dask-ml.readthedocs.org/
.. _dask-ml: https://ml.dask.org/