Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Can't access underlying models in multi-quantile regression #2596

Open
bmwilly opened this issue Nov 15, 2024 · 3 comments
Open

[BUG] Can't access underlying models in multi-quantile regression #2596

bmwilly opened this issue Nov 15, 2024 · 3 comments
Labels
bug Something isn't working feature request Use this label to request a new feature good first issue Good for newcomers

Comments

@bmwilly
Copy link

bmwilly commented Nov 15, 2024

Describe the bug

I'm using a multi-quantile forecaster on multivariate target data. E.g, a CatBoostModel(likelihood='quantile', quantile=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], ...).

Darts fits a separate GBT model for each quantile level and each target component. However, these aren't accessible to me.

Suppose my target series has two components and my CatBoostModel is called model. Then model.model.estimators_ returns only 2 models (corresponding to quantile 0.99 for each component).

This means that model.get_multioutput_estimator and model.get_estimator are incapable of returning estimators for any quantile other than 0.99.

Trying to access the models using model.model or model._model_container or model._model_container[0.5] all give a similar error:

{
	"name": "RuntimeError",
	"message": "scikit-learn estimators should always specify their parameters in the signature of their __init__ (no varargs). <class 'darts.utils.multioutput.MultiOutputRegressor'> with constructor (self, *args, eval_set_name: Optional[str] = None, eval_weight_name: Optional[str] = None, **kwargs) doesn't  follow this convention.",
	"stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File ~/projects/equilibrium/helios/.venv/lib/python3.11/site-packages/IPython/core/formatters.py:347, in BaseFormatter.__call__(self, obj)
    345     method = get_real_method(obj, self.print_method)
    346     if method is not None:
--> 347         return method()
    348     return None
    349 else:

File ~/projects/equilibrium/helios/.venv/lib/python3.11/site-packages/sklearn/base.py:693, in BaseEstimator._repr_html_inner(self)
    688 def _repr_html_inner(self):
    689     \"\"\"This function is returned by the @property `_repr_html_` to make
    690     `hasattr(estimator, \"_repr_html_\") return `True` or `False` depending
    691     on `get_config()[\"display\"]`.
    692     \"\"\"
--> 693     return estimator_html_repr(self)

File ~/projects/equilibrium/helios/.venv/lib/python3.11/site-packages/sklearn/utils/_estimator_html_repr.py:363, in estimator_html_repr(estimator)
    361 style_template = Template(_CSS_STYLE)
    362 style_with_id = style_template.substitute(id=container_id)
--> 363 estimator_str = str(estimator)
    365 # The fallback message is shown by default and loading the CSS sets
    366 # div.sk-text-repr-fallback to display: none to hide the fallback message.
    367 #
   (...)
    372 # The reverse logic applies to HTML repr div.sk-container.
    373 # div.sk-container is hidden by default and the loading the CSS displays it.
    374 fallback_msg = (
    375     \"In a Jupyter environment, please rerun this cell to show the HTML\"
    376     \" representation or trust the notebook. <br />On GitHub, the\"
    377     \" HTML representation is unable to render, please try loading this page\"
    378     \" with nbviewer.org.\"
    379 )

File ~/projects/equilibrium/helios/.venv/lib/python3.11/site-packages/sklearn/base.py:315, in BaseEstimator.__repr__(self, N_CHAR_MAX)
    307 # use ellipsis for sequences with a lot of elements
    308 pp = _EstimatorPrettyPrinter(
    309     compact=True,
    310     indent=1,
    311     indent_at_name=True,
    312     n_max_elements_to_show=N_MAX_ELEMENTS_TO_SHOW,
    313 )
--> 315 repr_ = pp.pformat(self)
    317 # Use bruteforce ellipsis when there are a lot of non-blank characters
    318 n_nonblank = len(\"\".join(repr_.split()))

File ~/.pyenv/versions/3.11.10/lib/python3.11/pprint.py:161, in PrettyPrinter.pformat(self, object)
    159 def pformat(self, object):
    160     sio = _StringIO()
--> 161     self._format(object, sio, 0, 0, {}, 0)
    162     return sio.getvalue()

File ~/.pyenv/versions/3.11.10/lib/python3.11/pprint.py:178, in PrettyPrinter._format(self, object, stream, indent, allowance, context, level)
    176     self._readable = False
    177     return
--> 178 rep = self._repr(object, context, level)
    179 max_width = self._width - indent - allowance
    180 if len(rep) > max_width:

File ~/.pyenv/versions/3.11.10/lib/python3.11/pprint.py:458, in PrettyPrinter._repr(self, object, context, level)
    457 def _repr(self, object, context, level):
--> 458     repr, readable, recursive = self.format(object, context.copy(),
    459                                             self._depth, level)
    460     if not readable:
    461         self._readable = False

File ~/projects/equilibrium/helios/.venv/lib/python3.11/site-packages/sklearn/utils/_pprint.py:189, in _EstimatorPrettyPrinter.format(self, object, context, maxlevels, level)
    188 def format(self, object, context, maxlevels, level):
--> 189     return _safe_repr(
    190         object, context, maxlevels, level, changed_only=self._changed_only
    191     )

File ~/projects/equilibrium/helios/.venv/lib/python3.11/site-packages/sklearn/utils/_pprint.py:440, in _safe_repr(object, context, maxlevels, level, changed_only)
    438 recursive = False
    439 if changed_only:
--> 440     params = _changed_params(object)
    441 else:
    442     params = object.get_params(deep=False)

File ~/projects/equilibrium/helios/.venv/lib/python3.11/site-packages/sklearn/utils/_pprint.py:93, in _changed_params(estimator)
     89 def _changed_params(estimator):
     90     \"\"\"Return dict (param_name: value) of parameters that were given to
     91     estimator with non-default values.\"\"\"
---> 93     params = estimator.get_params(deep=False)
     94     init_func = getattr(estimator.__init__, \"deprecated_original\", estimator.__init__)
     95     init_params = inspect.signature(init_func).parameters

File ~/projects/equilibrium/helios/.venv/lib/python3.11/site-packages/sklearn/base.py:243, in BaseEstimator.get_params(self, deep)
    228 \"\"\"
    229 Get parameters for this estimator.
    230 
   (...)
    240     Parameter names mapped to their values.
    241 \"\"\"
    242 out = dict()
--> 243 for key in self._get_param_names():
    244     value = getattr(self, key)
    245     if deep and hasattr(value, \"get_params\") and not isinstance(value, type):

File ~/projects/equilibrium/helios/.venv/lib/python3.11/site-packages/sklearn/base.py:217, in BaseEstimator._get_param_names(cls)
    215 for p in parameters:
    216     if p.kind == p.VAR_POSITIONAL:
--> 217         raise RuntimeError(
    218             \"scikit-learn estimators should always \"
    219             \"specify their parameters in the signature\"
    220             \" of their __init__ (no varargs).\"
    221             \" %s with constructor %s doesn't \"
    222             \" follow this convention.\" % (cls, init_signature)
    223         )
    224 # Extract and sort argument names excluding 'self'
    225 return sorted([p.name for p in parameters])

RuntimeError: scikit-learn estimators should always specify their parameters in the signature of their __init__ (no varargs). <class 'darts.utils.multioutput.MultiOutputRegressor'> with constructor (self, *args, eval_set_name: Optional[str] = None, eval_weight_name: Optional[str] = None, **kwargs) doesn't  follow this convention."
}

To Reproduce

import numpy as np
import pandas as pd
from darts import TimeSeries
from darts.models import CatBoostModel
from darts.utils.timeseries_generation import linear_timeseries

# Generate a synthetic multivariate time series
np.random.seed(42)
series1 = linear_timeseries(length=100, start_value=0, end_value=10)
series2 = linear_timeseries(length=100, start_value=10, end_value=0)
multivariate_series = series1.stack(series2)

# Define future covariates (optional, here just using a simple linear trend)
future_covariates = linear_timeseries(length=100, start_value=0, end_value=5)

# Initialize the CatBoostModel with quantile regression
model = CatBoostModel(
    lags=12,
    lags_future_covariates=[0],
    likelihood='quantile',
    quantiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99],
    random_state=42
)

# Fit the model
model.fit(multivariate_series, future_covariates=future_covariates)

# This is only 2 instead of 2 * 9
len(model.model.estimators_)

# Both of these give the above error
model.model
model._model_container

Expected behavior

Ability to access all the underlying estimators.

System (please complete the following information):

  • Python version: 3.11.10
  • darts version: 0.31.0
@bmwilly bmwilly added bug Something isn't working triage Issue waiting for triaging labels Nov 15, 2024
@dennisbader
Copy link
Collaborator

dennisbader commented Nov 15, 2024

Hi @bmwilly, you are right that get_estimators currently does not support quantile regression.

Thanks for the MWE. For me the snippet runs fine. I do however receive the error when trying to print(model._model_container) (so that might why you see it). There is probably a bug in the __str__ or __repr__ methods of our MultiOutputRegressor that then gets triggered by sklearn.

Debugging with PyCharm gives the expected results:

image

image

image

In short, you can access the underlying estimator as follows (this should not throw an error):

q_val = 0.01  # the quantile value of interest
target_i = 0  # the target component index of interest
model._model_container[q_val].estimators_[target_i])

@dennisbader dennisbader added bug Something isn't working good first issue Good for newcomers and removed bug Something isn't working triage Issue waiting for triaging labels Nov 15, 2024
@bmwilly
Copy link
Author

bmwilly commented Nov 15, 2024

@dennisbader yes, I guess the error is triggered by __repr__ which is called since I'm working in a Jupyter notebook environment.

But OK, thank you for the quick workaround! I will use model._model_container[q_val].estimators_[target_i]) instead of get_estimators.

Please feel free to close this, or leave it open if you want to use it to track the required updates to get_estimators for QR.

@github-project-automation github-project-automation bot moved this to To do in darts Nov 15, 2024
@dennisbader
Copy link
Collaborator

Great @bmwilly, I'll leave it open and added it to our backlog

@dennisbader dennisbader added the feature request Use this label to request a new feature label Nov 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working feature request Use this label to request a new feature good first issue Good for newcomers
Projects
Status: To do
Development

No branches or pull requests

2 participants