From 9c6cc61091c428bfb9f6b81d826254fb6b4b7651 Mon Sep 17 00:00:00 2001 From: Andrei Ivanov <32910461+drivanov@users.noreply.github.com> Date: Wed, 25 Sep 2024 13:34:09 -0700 Subject: [PATCH 1/2] Removing the deprecated `categorical_feature` parameter from `lightgbm.train(...)` function calls. (#454) Using `categorical_feature` parameter in `lightgbm.Dataset()` instead of `lightgbm.train(...)` eliminates the following warnings: ``` test/gbdt/test_gbdt.py: 60 warnings /usr/local/lib/python3.10/dist-packages/lightgbm/engine.py:187: LGBMDeprecationWarning: Argument 'categorical_feature' to train() is deprecated and will be removed in a future release. Set 'categorical_feature' when calling lightgbm.Dataset() instead. See https://github.com/microsoft/LightGBM/issues/6435. _emit_dataset_kwarg_warning("train", "categorical_feature") ``` --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- torch_frame/gbdt/tuned_lightgbm.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/torch_frame/gbdt/tuned_lightgbm.py b/torch_frame/gbdt/tuned_lightgbm.py index 732ad741..39187322 100644 --- a/torch_frame/gbdt/tuned_lightgbm.py +++ b/torch_frame/gbdt/tuned_lightgbm.py @@ -103,7 +103,6 @@ def objective( trial: Any, # optuna.trial.Trial train_data: Any, # lightgbm.Dataset eval_data: Any, # lightgbm.Dataset - cat_features: list[int], num_boost_round: int, ) -> float: r"""Objective function to be optimized. @@ -112,8 +111,6 @@ def objective( trial (optuna.trial.Trial): Optuna trial object. train_data (lightgbm.Dataset): Train data. eval_data (lightgbm.Dataset): Validation data. - cat_features (list[int]): Array containing indexes of - categorical features. num_boost_round (int): Number of boosting round. Returns: @@ -169,8 +166,7 @@ def objective( boost = lightgbm.train( self.params, train_data, num_boost_round=num_boost_round, - categorical_feature=cat_features, valid_sets=[eval_data], - callbacks=[ + valid_sets=[eval_data], callbacks=[ lightgbm.early_stopping(stopping_rounds=50, verbose=False), lightgbm.log_evaluation(period=2000) ]) @@ -199,19 +195,18 @@ def _tune( assert train_y is not None assert val_y is not None train_data = lightgbm.Dataset(train_x, label=train_y, + categorical_feature=cat_features, free_raw_data=False) eval_data = lightgbm.Dataset(val_x, label=val_y, free_raw_data=False) study.optimize( lambda trial: self.objective(trial, train_data, eval_data, - cat_features, num_boost_round), - num_trials) + num_boost_round), num_trials) self.params.update(study.best_params) self.model = lightgbm.train( self.params, train_data, num_boost_round=num_boost_round, - categorical_feature=cat_features, valid_sets=[eval_data], - callbacks=[ + valid_sets=[eval_data], callbacks=[ lightgbm.early_stopping(stopping_rounds=50, verbose=False), lightgbm.log_evaluation(period=2000) ]) From 6fc2761ae76e74f0d6cf9af948923691913b3ba7 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Sep 2024 17:56:59 +0900 Subject: [PATCH 2/2] Tighten assert condition in graph break tests (#458) Part of #452. --- test/nn/models/test_compile.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/test/nn/models/test_compile.py b/test/nn/models/test_compile.py index ec53c0d7..dc22527c 100644 --- a/test/nn/models/test_compile.py +++ b/test/nn/models/test_compile.py @@ -14,7 +14,7 @@ from torch_frame.testing import withPackage -@withPackage("torch>=2.1.0") +@withPackage("torch>=2.5.0") @pytest.mark.parametrize( "model_cls, model_kwargs, stypes, expected_graph_breaks", [ @@ -34,7 +34,7 @@ gamma=0.1, ), None, - 7, + 2, id="TabNet", ), pytest.param( @@ -47,21 +47,21 @@ ffn_dropout=0.5, ), None, - 4, + 0, id="TabTransformer", ), pytest.param( Trompt, dict(channels=8, num_prompts=2), None, - 16, + 4, id="Trompt", ), pytest.param( ExcelFormer, dict(in_channels=8, num_cols=3, num_heads=1), [stype.numerical], - 4, + 1, id="ExcelFormer", ), ], @@ -89,4 +89,5 @@ def test_compile_graph_break( **model_kwargs, ) explanation = torch._dynamo.explain(model)(tf) - assert explanation.graph_break_count <= expected_graph_breaks + graph_breaks = explanation.graph_break_count + assert graph_breaks == expected_graph_breaks