From 0cb382bc73ca6e646952e0761f0b00fb4985551b Mon Sep 17 00:00:00 2001 From: Avi Shinnar Date: Wed, 31 Jan 2024 16:30:05 -0500 Subject: [PATCH] Replace uses of if_delegate_has_method with available_if Signed-off-by: Avi Shinnar --- lale/lib/sklearn/pipeline.py | 19 ++++++------- lale/operators.py | 54 +++++++++++++++++------------------- 2 files changed, 33 insertions(+), 40 deletions(-) diff --git a/lale/lib/sklearn/pipeline.py b/lale/lib/sklearn/pipeline.py index a891c1e48..50851578b 100644 --- a/lale/lib/sklearn/pipeline.py +++ b/lale/lib/sklearn/pipeline.py @@ -16,23 +16,20 @@ import typing from packaging import version +from sklearn.utils.metaestimators import available_if import lale.docstrings import lale.helpers import lale.operators from lale.schemas import Bool -try: - from sklearn.pipeline import if_delegate_has_method -except ImportError as e: - if lale.operators.sklearn_version >= version.Version("1.0"): - from sklearn.utils.metaestimators import if_delegate_has_method - else: - raise e - logger = logging.getLogger(__name__) +def _pipeline_has(attr): + return lambda self: (hasattr(self._pipeline, attr)) + + class _PipelineImpl: def __init__(self, **hyperparams): if hyperparams.get("memory", None): @@ -61,17 +58,17 @@ def fit(self, X, y=None, **fit_params): self._final_estimator = self._pipeline.get_last() return self - @if_delegate_has_method(delegate="_final_estimator") + @available_if(_pipeline_has("predict")) def predict(self, X, **predict_params): result = self._pipeline.predict(X, **predict_params) return result - @if_delegate_has_method(delegate="_final_estimator") + @available_if(_pipeline_has("predict_proba")) def predict_proba(self, X): result = self._pipeline.predict_proba(X) return result - @if_delegate_has_method(delegate="_final_estimator") + @available_if(_pipeline_has("transform")) def transform(self, X, y=None): if y is None: result = self._pipeline.transform(X) diff --git a/lale/operators.py b/lale/operators.py index 1caf4ffe9..4c45be0b4 100644 --- a/lale/operators.py +++ b/lale/operators.py @@ -182,6 +182,7 @@ import sklearn.base from packaging import version from sklearn.base import clone +from sklearn.utils.metaestimators import available_if import lale.datasets.data_schemas import lale.json_operator @@ -230,20 +231,15 @@ sklearn_version = version.parse(getattr(sklearn, "__version__")) -try: - from sklearn.pipeline import ( # pylint:disable=ungrouped-imports - if_delegate_has_method, - ) -except ImportError as imp_exc: - if sklearn_version >= version.Version("1.0"): - from sklearn.utils.metaestimators import if_delegate_has_method - else: - raise imp_exc - logger = logging.getLogger(__name__) _LALE_SKL_PIPELINE = "lale.lib.sklearn.pipeline._PipelineImpl" + +def _impl_has(attr): + return lambda self: (hasattr(self._impl, attr)) + + _combinators_docstrings = """ Methods ------- @@ -2919,7 +2915,7 @@ def __repr__(self): hyp_string = lale.pretty_print.hyperparams_to_string(hps) return name + "(" + hyp_string + ")" - @if_delegate_has_method(delegate="_impl") + @available_if(_impl_has("get_pipeline")) def get_pipeline( self, pipeline_name: Optional[str] = None, astype: astype_type = "lale" ) -> Optional[TrainableOperator]: @@ -2937,7 +2933,7 @@ def get_pipeline( except AttributeError as exc: raise ValueError("Must call `fit` before `get_pipeline`.") from exc - @if_delegate_has_method(delegate="_impl") + @available_if(_impl_has("summary")) def summary(self) -> pd.DataFrame: """ .. deprecated:: 0.0.0 @@ -2953,7 +2949,7 @@ def summary(self) -> pd.DataFrame: except AttributeError as exc: raise ValueError("Must call `fit` before `summary`.") from exc - @if_delegate_has_method(delegate="_impl") + @available_if(_impl_has("transform")) def transform(self, X: Any, y: Any = None) -> Any: """ .. deprecated:: 0.0.0 @@ -2969,7 +2965,7 @@ def transform(self, X: Any, y: Any = None) -> Any: except AttributeError as exc: raise ValueError("Must call `fit` before `transform`.") from exc - @if_delegate_has_method(delegate="_impl") + @available_if(_impl_has("predict")) def predict(self, X=None, **predict_params) -> Any: """ .. deprecated:: 0.0.0 @@ -2985,7 +2981,7 @@ def predict(self, X=None, **predict_params) -> Any: except AttributeError as exc: raise ValueError("Must call `fit` before `predict`.") from exc - @if_delegate_has_method(delegate="_impl") + @available_if(_impl_has("predict_proba")) def predict_proba(self, X=None): """ .. deprecated:: 0.0.0 @@ -3001,7 +2997,7 @@ def predict_proba(self, X=None): except AttributeError as exc: raise ValueError("Must call `fit` before `predict_proba`.") from exc - @if_delegate_has_method(delegate="_impl") + @available_if(_impl_has("decision_function")) def decision_function(self, X=None): """ .. deprecated:: 0.0.0 @@ -3017,7 +3013,7 @@ def decision_function(self, X=None): except AttributeError as exc: raise ValueError("Must call `fit` before `decision_function`.") from exc - @if_delegate_has_method(delegate="_impl") + @available_if(_impl_has("score")) def score(self, X, y, **score_params) -> Any: """ .. deprecated:: 0.0.0 @@ -3036,7 +3032,7 @@ def score(self, X, y, **score_params) -> Any: except AttributeError as exc: raise ValueError("Must call `fit` before `score`.") from exc - @if_delegate_has_method(delegate="_impl") + @available_if(_impl_has("score_samples")) def score_samples(self, X=None): """ .. deprecated:: 0.0.0 @@ -3052,7 +3048,7 @@ def score_samples(self, X=None): except AttributeError as exc: raise ValueError("Must call `fit` before `score_samples`.") from exc - @if_delegate_has_method(delegate="_impl") + @available_if(_impl_has("predict_log_proba")) def predict_log_proba(self, X=None): """ .. deprecated:: 0.0.0 @@ -3219,7 +3215,7 @@ def fit(self, X: Any, y: Any = None, **fit_params) -> "TrainedIndividualOp": else: return self - @if_delegate_has_method(delegate="_impl") + @available_if(_impl_has("transform")) def transform(self, X: Any, y: Any = None) -> Any: """Transform the data. @@ -3249,7 +3245,7 @@ def transform(self, X: Any, y: Any = None) -> Any: # logger.info("%s exit transform %s", time.asctime(), self.name()) return result - @if_delegate_has_method(delegate="_impl") + @available_if(_impl_has("transform_X_y")) def transform_X_y(self, X: Any, y: Any) -> Any: """Transform the data and target. @@ -3281,7 +3277,7 @@ def _predict(self, X, **predict_params): result = self._validate_output_schema(raw_result, "predict") return result - @if_delegate_has_method(delegate="_impl") + @available_if(_impl_has("predict")) def predict(self, X: Any = None, **predict_params) -> Any: """Make predictions. @@ -3304,7 +3300,7 @@ def predict(self, X: Any = None, **predict_params) -> Any: return strip_schema(result) # otherwise scorers return zero-dim array return result - @if_delegate_has_method(delegate="_impl") + @available_if(_impl_has("predict_proba")) def predict_proba(self, X: Any = None): """Probability estimates for all classes. @@ -3325,7 +3321,7 @@ def predict_proba(self, X: Any = None): # logger.info("%s exit predict_proba %s", time.asctime(), self.name()) return result - @if_delegate_has_method(delegate="_impl") + @available_if(_impl_has("decision_function")) def decision_function(self, X: Any = None): """Confidence scores for all classes. @@ -3346,7 +3342,7 @@ def decision_function(self, X: Any = None): # logger.info("%s exit decision_function %s", time.asctime(), self.name()) return result - @if_delegate_has_method(delegate="_impl") + @available_if(_impl_has("score")) def score(self, X: Any, y: Any, **score_params) -> Any: """Performance evaluation with a default metric. @@ -3373,7 +3369,7 @@ def score(self, X: Any, y: Any, **score_params) -> Any: # We skip output validation for score for now return result - @if_delegate_has_method(delegate="_impl") + @available_if(_impl_has("score_samples")) def score_samples(self, X: Any = None): """Scores for each sample in X. The type of scores depends on the operator. @@ -3392,7 +3388,7 @@ def score_samples(self, X: Any = None): result = self._validate_output_schema(raw_result, "score_samples") return result - @if_delegate_has_method(delegate="_impl") + @available_if(_impl_has("predict_log_proba")) def predict_log_proba(self, X: Any = None): """Predicted class log-probabilities for X. @@ -3443,12 +3439,12 @@ def get_pipeline( # pylint:disable=signature-differs ) -> Optional[TrainableOperator]: ... - @if_delegate_has_method(delegate="_impl") + @available_if(_impl_has("get_pipeline")) def get_pipeline(self, pipeline_name=None, astype: astype_type = "lale"): result = self._impl_instance().get_pipeline(pipeline_name, astype) return result - @if_delegate_has_method(delegate="_impl") + @available_if(_impl_has("summary")) def summary(self) -> pd.DataFrame: return self._impl_instance().summary()