Skip to content

Commit

Permalink
Refine MarsDMatrix & support more parameters for XGB classifier and r…
Browse files Browse the repository at this point in the history
…egressor (#2498) (#2501)
  • Loading branch information
Xuye (Chris) Qin authored Oct 9, 2021
1 parent 164944e commit 5e1fdf8
Show file tree
Hide file tree
Showing 10 changed files with 372 additions and 236 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/core-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
source ./ci/reload-env.sh
export DEFAULT_VENV=$VIRTUAL_ENV
if [[ ! "$PYTHON" =~ "3.9" ]]; then
if [[ ! "$PYTHON" =~ "3.6" ]]; then
conda install -n test --quiet --yes -c conda-forge python=$PYTHON numba
fi
Expand Down
13 changes: 6 additions & 7 deletions mars/learn/contrib/xgboost/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
from xgboost.sklearn import XGBClassifierBase

from .... import tensor as mt
from .dmatrix import MarsDMatrix
from .core import evaluation_matrices
from .core import wrap_evaluation_matrices
from .train import train
from .predict import predict

Expand All @@ -31,14 +30,16 @@ class XGBClassifier(XGBScikitLearnBase, XGBClassifierBase):
Implementation of the scikit-learn API for XGBoost classification.
"""

def fit(self, X, y, sample_weights=None, eval_set=None, sample_weight_eval_set=None, **kw):
def fit(self, X, y, sample_weight=None, base_margin=None,
eval_set=None, sample_weight_eval_set=None, base_margin_eval_set=None, **kw):
session = kw.pop('session', None)
run_kwargs = kw.pop('run_kwargs', dict())
if kw:
raise TypeError(f"fit got an unexpected keyword argument '{next(iter(kw))}'")

dtrain = MarsDMatrix(X, label=y, weight=sample_weights,
session=session, run_kwargs=run_kwargs)
dtrain, evals = wrap_evaluation_matrices(
None, X, y, sample_weight, base_margin, eval_set,
sample_weight_eval_set, base_margin_eval_set)
params = self.get_xgb_params()

self.classes_ = mt.unique(y, aggregate_size=1).to_numpy(session=session, **run_kwargs)
Expand All @@ -50,8 +51,6 @@ def fit(self, X, y, sample_weights=None, eval_set=None, sample_weight_eval_set=N
else:
params['objective'] = 'binary:logistic'

evals = evaluation_matrices(eval_set, sample_weight_eval_set,
session=session, run_kwargs=run_kwargs)
self.evals_result_ = dict()
result = train(params, dtrain, num_boost_round=self.get_num_boosting_rounds(),
evals=evals, evals_result=self.evals_result_,
Expand Down
108 changes: 79 additions & 29 deletions mars/learn/contrib/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, List, Optional, Tuple

try:
import xgboost
except ImportError:
Expand Down Expand Up @@ -61,34 +63,82 @@ def predict(self, data, **kw):
"""
raise NotImplementedError

def evaluation_matrices(validation_set, sample_weights, session=None, run_kwargs=None):
"""
Parameters
----------
validation_set: list of tuples
Each tuple contains a validation dataset including input X and label y.
E.g.:
.. code-block:: python
[(X_0, y_0), (X_1, y_1), ... ]
sample_weights: list of arrays
The weight vector for validation data.
session:
Session to run
run_kwargs:
kwargs for session.run
Returns
-------
evals: list of validation MarsDMatrix
def wrap_evaluation_matrices(
missing: float,
X: Any,
y: Any,
sample_weight: Optional[Any],
base_margin: Optional[Any],
eval_set: Optional[List[Tuple[Any, Any]]],
sample_weight_eval_set: Optional[List[Any]],
base_margin_eval_set: Optional[List[Any]],
label_transform: Callable = lambda x: x,
) -> Tuple[Any, Optional[List[Tuple[Any, str]]]]:
"""Convert array_like evaluation matrices into DMatrix. Perform validation on the way.
"""
evals = []
if validation_set is not None:
assert isinstance(validation_set, list)
for i, e in enumerate(validation_set):
w = (sample_weights[i]
if sample_weights is not None else None)
dmat = MarsDMatrix(e[0], label=e[1], weight=w,
session=session, run_kwargs=run_kwargs)
evals.append((dmat, f'validation_{i}'))
train_dmatrix = MarsDMatrix(
data=X,
label=label_transform(y),
weight=sample_weight,
base_margin=base_margin,
missing=missing,
)

n_validation = 0 if eval_set is None else len(eval_set)

def validate_or_none(meta: Optional[List], name: str) -> List:
if meta is None:
return [None] * n_validation
if len(meta) != n_validation:
raise ValueError(
f"{name}'s length does not equal `eval_set`'s length, " +
f"expecting {n_validation}, got {len(meta)}"
)
return meta

if eval_set is not None:
sample_weight_eval_set = validate_or_none(
sample_weight_eval_set, "sample_weight_eval_set"
)
base_margin_eval_set = validate_or_none(
base_margin_eval_set, "base_margin_eval_set"
)

evals = []
for i, (valid_X, valid_y) in enumerate(eval_set):
# Skip the duplicated entry.
if all(
(
valid_X is X, valid_y is y,
sample_weight_eval_set[i] is sample_weight,
base_margin_eval_set[i] is base_margin,
)
):
evals.append(train_dmatrix)
else:
m = MarsDMatrix(
data=valid_X,
label=label_transform(valid_y),
weight=sample_weight_eval_set[i],
base_margin=base_margin_eval_set[i],
missing=missing,
)
evals.append(m)
nevals = len(evals)
eval_names = [f"validation_{i}" for i in range(nevals)]
evals = list(zip(evals, eval_names))
else:
evals = None
return evals
if any(
meta is not None
for meta in [
sample_weight_eval_set,
base_margin_eval_set,
]
):
raise ValueError(
"`eval_set` is not set but one of the other evaluation meta info is "
"not None."
)
evals = []

return train_dmatrix, evals
Loading

0 comments on commit 5e1fdf8

Please sign in to comment.