Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Introducing test file for polars support for estimators #370

Merged
merged 27 commits into from
Jun 22, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
539fd51
create test_polars.py file
julian-fong May 26, 2024
4712da9
updates
julian-fong May 28, 2024
fe5333b
initial commit
julian-fong May 30, 2024
5f578fd
added polars eager table to allowed mtypes in base regressor
julian-fong May 30, 2024
cf8a0d5
added draft version of testing fit and predict in polars dataframe
julian-fong May 30, 2024
9357486
fixed to use skpro check soft dependencies
julian-fong May 30, 2024
1a23ee0
updated tests
julian-fong Jun 2, 2024
89079f6
added test for predict_quantiles
julian-fong Jun 2, 2024
02f699f
fixed naming of pandas datafarmes
julian-fong Jun 2, 2024
c49ed0e
Merge branch 'sktime:main' into polars_support
julian-fong Jun 2, 2024
be084ef
added test for check_polars_table
julian-fong Jun 3, 2024
5c3697e
updates to pr
julian-fong Jun 7, 2024
32e700a
updated estimator to be a pytest fixture for one estimator
julian-fong Jun 10, 2024
0470817
Merge branch 'sktime:main' into polars_support
julian-fong Jun 11, 2024
497e1ef
bug fix
julian-fong Jun 11, 2024
8d3b541
update
julian-fong Jun 11, 2024
782e714
update
julian-fong Jun 11, 2024
39590f7
updates
julian-fong Jun 11, 2024
20643c5
updates
julian-fong Jun 11, 2024
05e96bf
updates
julian-fong Jun 11, 2024
ad697a3
updates
julian-fong Jun 11, 2024
00ac2bf
updates
julian-fong Jun 11, 2024
78d5d46
Merge branch 'sktime:main' into polars_support
julian-fong Jun 13, 2024
f464b7f
Merge branch 'sktime:main' into polars_support
julian-fong Jun 13, 2024
5eba103
Merge branch 'sktime:main' into polars_support
julian-fong Jun 14, 2024
227d623
updates to remove unnecessary skipifs and changed the estimator used …
julian-fong Jun 14, 2024
0b51616
Merge branch 'sktime:main' into polars_support
julian-fong Jun 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions skpro/regression/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"pd_Series_Table",
"numpy1D",
"numpy2D",
"polars_eager_table",
]


Expand Down
37 changes: 37 additions & 0 deletions skpro/tests/test_polars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Test file for polars dataframes"""

import pytest
from sktime.utils.validation._dependencies import _check_soft_dependencies
julian-fong marked this conversation as resolved.
Show resolved Hide resolved

from skpro.datatypes._table._convert import convert_pandas_to_polars_eager


@pytest.mark.skipif(
not _check_soft_dependencies(["polars", "pyarrow"], severity="none"),
reason="skip test if polars/pyarrow is not installed in environment",
)
def test_polars_dataframe_in_fit(estimator):
import pandas as pd
import polars as pl
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split

X, y = load_diabetes(return_X_y=True, as_frame=True)
X = X.iloc[:75]
y = y.iloc[:75]
y = pd.DataFrame(y)

X_train, X_test, y_train, _ = train_test_split(
X, y, test_size=0.33, random_state=42
)
X_train_pl = convert_pandas_to_polars_eager(X_train)
julian-fong marked this conversation as resolved.
Show resolved Hide resolved
X_test_pl = convert_pandas_to_polars_eager(X_test)
y_train_pl = convert_pandas_to_polars_eager(y_train)

estimator.fit(X_train_pl, y_train_pl)

# test predict output contract
y_pred = estimator.predict(X_test_pl)

assert isinstance(y_pred, pl.DataFrame)
assert (y_pred.columns == y_train_pl.columns).all()
Loading