Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT add parameter validation using BaseEstimator #958

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions skrub/_agg_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""

import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.base import BaseEstimator, TransformerMixin, _fit_context
from sklearn.utils.multiclass import type_of_target
from sklearn.utils.validation import check_is_fitted

Expand Down Expand Up @@ -67,7 +67,7 @@ class AggJoiner(TransformerMixin, BaseEstimator):
The placeholder string "X" can be provided to perform
self-aggregation on the input data.

key : str, default=None
key : str or iterable of str, default=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a side note, we use "iterable" everywhere but I wonder if we should say "sequence" (or "list"?) because it is more understandable for users who are less familiar with the python/computer programming jargon. also it is arguably a bit more accurate because we iterate over these parameters several times and sometimes index them so some iterables would not be appropriate

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to discuss this
I had the same question on the Joiner PR

Copy link
Member Author

@glemaitre glemaitre Jun 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me, this should be a list (and you can accept loosely tuple) but this is more friendly than stating sequence that is only meaningful for Python developer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with "list" -- and we already use that term in a bunch of places

The column name to use for both `main_key` and `aux_key` when they
are the same. Provide either `key` or both `main_key` and `aux_key`.
If `key` is an iterable, we will perform a multi-column join.
Expand Down Expand Up @@ -138,6 +138,16 @@ class AggJoiner(TransformerMixin, BaseEstimator):
1 2 NY JFK DL 80.00...
"""

_parameter_constraints = {
"aux_table": "no_validation", # we should have a DataFrameLike constraint
"key": [str, "array-like", None],
"main_key": [str, "array-like", None],
"aux_key": [str, "array-like", None],
"cols": [str, "array-like", None],
"operations": [str, "array-like", None],
"suffix": [str],
}

def __init__(
self,
aux_table,
Expand Down Expand Up @@ -244,6 +254,7 @@ def _check_inputs(self, X):
if not isinstance(self.suffix, str):
raise ValueError(f"'suffix' must be a string. Got {self.suffix}")

@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X, y=None):
"""Aggregate auxiliary table based on the main keys.

Expand Down Expand Up @@ -318,7 +329,7 @@ class AggTarget(TransformerMixin, BaseEstimator):
aggregated using each key separately, then each aggregation of
the target will be joined on the main table.

operation : str or iterable of str, optional
operation : str or iterable of str, default=None
Aggregation operations to perform on the target.

numerical : {"sum", "mean", "std", "min", "max", "hist", "value_counts"}
Expand All @@ -329,7 +340,7 @@ class AggTarget(TransformerMixin, BaseEstimator):

If set to None (the default), ["mean", "mode"] will be used.

suffix : str, optional
suffix : str, default=None
The suffix to append to the columns of the target table if the join
results in duplicates columns.
If set to None, "_target" is used.
Expand Down Expand Up @@ -370,6 +381,12 @@ class AggTarget(TransformerMixin, BaseEstimator):
5 6 2 ... 1 1.000000
"""

_parameter_constraints = {
"main_key": [str, "array-like", None],
"operations": [str, "array-like", None],
"suffix": [str, None],
}

def __init__(
self,
main_key,
Expand All @@ -380,6 +397,7 @@ def __init__(
self.operation = operation
self.suffix = suffix

@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X, y):
"""Aggregate the target ``y`` based on keys from ``X``.

Expand Down
17 changes: 9 additions & 8 deletions skrub/_datetime_encoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from datetime import datetime, timezone

import pandas as pd
from sklearn.base import _fit_context
from sklearn.utils.validation import check_is_fitted
from sklearn.utils._param_validation import StrOptions

try:
import polars as pl
Expand Down Expand Up @@ -255,11 +257,18 @@ class DatetimeEncoder(SingleColumnTransformer):
timezone used during ``fit`` and that we get the same result for "hour".
""" # noqa: E501

_parameter_constraints = {
"resolution": [StrOptions(set(_TIME_LEVELS)), None],
"add_weekday": ["boolean"],
"add_total_seconds": ["boolean"],
}

def __init__(self, resolution="hour", add_weekday=False, add_total_seconds=True):
self.resolution = resolution
self.add_weekday = add_weekday
self.add_total_seconds = add_total_seconds

@_fit_context(prefer_skip_nested_validation=True)
def fit_transform(self, column, y=None):
"""Fit the encoder and transform a column.

Expand All @@ -277,7 +286,6 @@ def fit_transform(self, column, y=None):
The extracted features.
"""
del y
self._check_params()
if not sbd.is_any_date(column):
raise RejectColumn(
f"Column {sbd.name(column)!r} does not have Date or Datetime dtype."
Expand Down Expand Up @@ -316,10 +324,3 @@ def transform(self, column):
extracted = sbd.to_float32(extracted)
all_extracted.append(extracted)
return sbd.make_dataframe_like(column, all_extracted)

def _check_params(self):
allowed = _TIME_LEVELS + [None]
if self.resolution not in allowed:
raise ValueError(
f"'resolution' options are {allowed}, got {self.resolution!r}."
)
9 changes: 6 additions & 3 deletions skrub/tests/test_agg_joiner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pandas as pd
import pytest

from sklearn.utils._param_validation import InvalidParameterError

from skrub import _dataframe as sbd
from skrub._agg_joiner import AggJoiner, AggTarget, split_num_categ_operations
from skrub._dataframe._testing_utils import assert_frame_equal
Expand Down Expand Up @@ -216,7 +218,7 @@ def test_too_many_suffixes(df_module, main_table):
cols="rating",
suffix=["_user", "_movie", "_tag"],
)
with pytest.raises(ValueError, match=r"(?='suffix' must be a string.*)"):
with pytest.raises(InvalidParameterError):
agg_joiner.fit(main_table)


Expand Down Expand Up @@ -447,7 +449,7 @@ def test_no_aggregation_exception(main_table):
main_key="userId",
operation=[],
)
with pytest.raises(ValueError, match=r"(?=.*No aggregation)"):
with pytest.raises(ValueError, match="No aggregation to perform"):
agg_target.fit(main_table, y)


Expand All @@ -456,5 +458,6 @@ def test_wrong_args_ops(main_table):
main_key="userId",
operation="mean(2)",
)
with pytest.raises(ValueError, match=r"(?=.*'mean')(?=.*argument)"):
err_msg = "Operator 'mean' doesn't take any argument, got 2"
with pytest.raises(ValueError, match=err_msg):
agg_target.fit(main_table, y)
4 changes: 3 additions & 1 deletion skrub/tests/test_datetime_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import pytest

from sklearn.utils._param_validation import InvalidParameterError

from skrub import DatetimeEncoder
from skrub import _dataframe as sbd
from skrub import _selectors as s
Expand Down Expand Up @@ -153,7 +155,7 @@ def test_time_not_extracted_from_date_col(datetime_cols):


def test_invalid_resolution(datetime_cols):
with pytest.raises(ValueError, match=r".*'resolution' options are"):
with pytest.raises(InvalidParameterError):
DatetimeEncoder(resolution="hello").fit(datetime_cols.datetime)


Expand Down
Loading