Skip to content

Commit

Permalink
Replace uses of if_delegate_has_method with available_if
Browse files Browse the repository at this point in the history
Signed-off-by: Avi Shinnar <[email protected]>
  • Loading branch information
shinnar committed Feb 1, 2024
1 parent c66a423 commit d3a736a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 40 deletions.
19 changes: 8 additions & 11 deletions lale/lib/sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,20 @@
import typing

from packaging import version
from sklearn.utils.metaestimators import available_if

import lale.docstrings
import lale.helpers
import lale.operators
from lale.schemas import Bool

try:
from sklearn.pipeline import if_delegate_has_method
except ImportError as e:
if lale.operators.sklearn_version >= version.Version("1.0"):
from sklearn.utils.metaestimators import if_delegate_has_method
else:
raise e

logger = logging.getLogger(__name__)


def _pipeline_has(attr):
return lambda self: (hasattr(self._pipeline, attr))


class _PipelineImpl:
def __init__(self, **hyperparams):
if hyperparams.get("memory", None):
Expand Down Expand Up @@ -61,17 +58,17 @@ def fit(self, X, y=None, **fit_params):
self._final_estimator = self._pipeline.get_last()
return self

@if_delegate_has_method(delegate="_final_estimator")
@available_if(_pipeline_has("predict"))
def predict(self, X, **predict_params):
result = self._pipeline.predict(X, **predict_params)
return result

@if_delegate_has_method(delegate="_final_estimator")
@available_if(_pipeline_has("predict_proba"))
def predict_proba(self, X):
result = self._pipeline.predict_proba(X)
return result

@if_delegate_has_method(delegate="_final_estimator")
@available_if(_pipeline_has("transform"))
def transform(self, X, y=None):
if y is None:
result = self._pipeline.transform(X)
Expand Down
54 changes: 25 additions & 29 deletions lale/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
import sklearn.base
from packaging import version
from sklearn.base import clone
from sklearn.utils.metaestimators import available_if

import lale.datasets.data_schemas
import lale.json_operator
Expand Down Expand Up @@ -230,20 +231,15 @@

sklearn_version = version.parse(getattr(sklearn, "__version__"))

try:
from sklearn.pipeline import ( # pylint:disable=ungrouped-imports
if_delegate_has_method,
)
except ImportError as imp_exc:
if sklearn_version >= version.Version("1.0"):
from sklearn.utils.metaestimators import if_delegate_has_method
else:
raise imp_exc

logger = logging.getLogger(__name__)

_LALE_SKL_PIPELINE = "lale.lib.sklearn.pipeline._PipelineImpl"


def _impl_has(attr):
return lambda self: (hasattr(self._impl, attr))


_combinators_docstrings = """
Methods
-------
Expand Down Expand Up @@ -2919,7 +2915,7 @@ def __repr__(self):
hyp_string = lale.pretty_print.hyperparams_to_string(hps)
return name + "(" + hyp_string + ")"

@if_delegate_has_method(delegate="_impl")
@available_if(_impl_has("get_pipeline"))
def get_pipeline(
self, pipeline_name: Optional[str] = None, astype: astype_type = "lale"
) -> Optional[TrainableOperator]:
Expand All @@ -2937,7 +2933,7 @@ def get_pipeline(
except AttributeError as exc:
raise ValueError("Must call `fit` before `get_pipeline`.") from exc

@if_delegate_has_method(delegate="_impl")
@available_if(_impl_has("summary"))
def summary(self) -> pd.DataFrame:
"""
.. deprecated:: 0.0.0
Expand All @@ -2953,7 +2949,7 @@ def summary(self) -> pd.DataFrame:
except AttributeError as exc:
raise ValueError("Must call `fit` before `summary`.") from exc

@if_delegate_has_method(delegate="_impl")
@available_if(_impl_has("transform"))
def transform(self, X: Any, y: Any = None) -> Any:
"""
.. deprecated:: 0.0.0
Expand All @@ -2969,7 +2965,7 @@ def transform(self, X: Any, y: Any = None) -> Any:
except AttributeError as exc:
raise ValueError("Must call `fit` before `transform`.") from exc

@if_delegate_has_method(delegate="_impl")
@available_if(_impl_has("predict"))
def predict(self, X=None, **predict_params) -> Any:
"""
.. deprecated:: 0.0.0
Expand All @@ -2985,7 +2981,7 @@ def predict(self, X=None, **predict_params) -> Any:
except AttributeError as exc:
raise ValueError("Must call `fit` before `predict`.") from exc

@if_delegate_has_method(delegate="_impl")
@available_if(_impl_has("predict_proba"))
def predict_proba(self, X=None):
"""
.. deprecated:: 0.0.0
Expand All @@ -3001,7 +2997,7 @@ def predict_proba(self, X=None):
except AttributeError as exc:
raise ValueError("Must call `fit` before `predict_proba`.") from exc

@if_delegate_has_method(delegate="_impl")
@available_if(_impl_has("decision_function"))
def decision_function(self, X=None):
"""
.. deprecated:: 0.0.0
Expand All @@ -3017,7 +3013,7 @@ def decision_function(self, X=None):
except AttributeError as exc:
raise ValueError("Must call `fit` before `decision_function`.") from exc

@if_delegate_has_method(delegate="_impl")
@available_if(_impl_has("score"))
def score(self, X, y, **score_params) -> Any:
"""
.. deprecated:: 0.0.0
Expand All @@ -3036,7 +3032,7 @@ def score(self, X, y, **score_params) -> Any:
except AttributeError as exc:
raise ValueError("Must call `fit` before `score`.") from exc

@if_delegate_has_method(delegate="_impl")
@available_if(_impl_has("score_samples"))
def score_samples(self, X=None):
"""
.. deprecated:: 0.0.0
Expand All @@ -3052,7 +3048,7 @@ def score_samples(self, X=None):
except AttributeError as exc:
raise ValueError("Must call `fit` before `score_samples`.") from exc

@if_delegate_has_method(delegate="_impl")
@available_if(_impl_has("predict_log_proba"))
def predict_log_proba(self, X=None):
"""
.. deprecated:: 0.0.0
Expand Down Expand Up @@ -3219,7 +3215,7 @@ def fit(self, X: Any, y: Any = None, **fit_params) -> "TrainedIndividualOp":
else:
return self

@if_delegate_has_method(delegate="_impl")
@available_if(_impl_has("transform"))
def transform(self, X: Any, y: Any = None) -> Any:
"""Transform the data.
Expand Down Expand Up @@ -3249,7 +3245,7 @@ def transform(self, X: Any, y: Any = None) -> Any:
# logger.info("%s exit transform %s", time.asctime(), self.name())
return result

@if_delegate_has_method(delegate="_impl")
@available_if(_impl_has("transform_X_y"))
def transform_X_y(self, X: Any, y: Any) -> Any:
"""Transform the data and target.
Expand Down Expand Up @@ -3281,7 +3277,7 @@ def _predict(self, X, **predict_params):
result = self._validate_output_schema(raw_result, "predict")
return result

@if_delegate_has_method(delegate="_impl")
@available_if(_impl_has("predict"))
def predict(self, X: Any = None, **predict_params) -> Any:
"""Make predictions.
Expand All @@ -3304,7 +3300,7 @@ def predict(self, X: Any = None, **predict_params) -> Any:
return strip_schema(result) # otherwise scorers return zero-dim array
return result

@if_delegate_has_method(delegate="_impl")
@available_if(_impl_has("predict_proba"))
def predict_proba(self, X: Any = None):
"""Probability estimates for all classes.
Expand All @@ -3325,7 +3321,7 @@ def predict_proba(self, X: Any = None):
# logger.info("%s exit predict_proba %s", time.asctime(), self.name())
return result

@if_delegate_has_method(delegate="_impl")
@available_if(_impl_has("decision_function"))
def decision_function(self, X: Any = None):
"""Confidence scores for all classes.
Expand All @@ -3346,7 +3342,7 @@ def decision_function(self, X: Any = None):
# logger.info("%s exit decision_function %s", time.asctime(), self.name())
return result

@if_delegate_has_method(delegate="_impl")
@available_if(_impl_has("score"))
def score(self, X: Any, y: Any, **score_params) -> Any:
"""Performance evaluation with a default metric.
Expand All @@ -3373,7 +3369,7 @@ def score(self, X: Any, y: Any, **score_params) -> Any:
# We skip output validation for score for now
return result

@if_delegate_has_method(delegate="_impl")
@available_if(_impl_has("score_samples"))
def score_samples(self, X: Any = None):
"""Scores for each sample in X. The type of scores depends on the operator.
Expand All @@ -3392,7 +3388,7 @@ def score_samples(self, X: Any = None):
result = self._validate_output_schema(raw_result, "score_samples")
return result

@if_delegate_has_method(delegate="_impl")
@available_if(_impl_has("predict_log_proba"))
def predict_log_proba(self, X: Any = None):
"""Predicted class log-probabilities for X.
Expand Down Expand Up @@ -3443,12 +3439,12 @@ def get_pipeline( # pylint:disable=signature-differs
) -> Optional[TrainableOperator]:
...

@if_delegate_has_method(delegate="_impl")
@available_if(_impl_has("get_pipeline"))
def get_pipeline(self, pipeline_name=None, astype: astype_type = "lale"):
result = self._impl_instance().get_pipeline(pipeline_name, astype)
return result

@if_delegate_has_method(delegate="_impl")
@available_if(_impl_has("summary"))
def summary(self) -> pd.DataFrame:
return self._impl_instance().summary()

Expand Down

0 comments on commit d3a736a

Please sign in to comment.