Skip to content

Commit

Permalink
Merge branch 'master' into wsad1-patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
wsad1 authored Sep 30, 2024
2 parents 700da76 + 6fc2761 commit dd63def
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
13 changes: 7 additions & 6 deletions test/nn/models/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch_frame.testing import withPackage


@withPackage("torch>=2.1.0")
@withPackage("torch>=2.5.0")
@pytest.mark.parametrize(
"model_cls, model_kwargs, stypes, expected_graph_breaks",
[
Expand All @@ -34,7 +34,7 @@
gamma=0.1,
),
None,
7,
2,
id="TabNet",
),
pytest.param(
Expand All @@ -47,21 +47,21 @@
ffn_dropout=0.5,
),
None,
4,
0,
id="TabTransformer",
),
pytest.param(
Trompt,
dict(channels=8, num_prompts=2),
None,
16,
4,
id="Trompt",
),
pytest.param(
ExcelFormer,
dict(in_channels=8, num_cols=3, num_heads=1),
[stype.numerical],
4,
1,
id="ExcelFormer",
),
],
Expand Down Expand Up @@ -89,4 +89,5 @@ def test_compile_graph_break(
**model_kwargs,
)
explanation = torch._dynamo.explain(model)(tf)
assert explanation.graph_break_count <= expected_graph_breaks
graph_breaks = explanation.graph_break_count
assert graph_breaks == expected_graph_breaks
13 changes: 4 additions & 9 deletions torch_frame/gbdt/tuned_lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def objective(
trial: Any, # optuna.trial.Trial
train_data: Any, # lightgbm.Dataset
eval_data: Any, # lightgbm.Dataset
cat_features: list[int],
num_boost_round: int,
) -> float:
r"""Objective function to be optimized.
Expand All @@ -112,8 +111,6 @@ def objective(
trial (optuna.trial.Trial): Optuna trial object.
train_data (lightgbm.Dataset): Train data.
eval_data (lightgbm.Dataset): Validation data.
cat_features (list[int]): Array containing indexes of
categorical features.
num_boost_round (int): Number of boosting round.
Returns:
Expand Down Expand Up @@ -169,8 +166,7 @@ def objective(

boost = lightgbm.train(
self.params, train_data, num_boost_round=num_boost_round,
categorical_feature=cat_features, valid_sets=[eval_data],
callbacks=[
valid_sets=[eval_data], callbacks=[
lightgbm.early_stopping(stopping_rounds=50, verbose=False),
lightgbm.log_evaluation(period=2000)
])
Expand Down Expand Up @@ -199,19 +195,18 @@ def _tune(
assert train_y is not None
assert val_y is not None
train_data = lightgbm.Dataset(train_x, label=train_y,
categorical_feature=cat_features,
free_raw_data=False)
eval_data = lightgbm.Dataset(val_x, label=val_y, free_raw_data=False)

study.optimize(
lambda trial: self.objective(trial, train_data, eval_data,
cat_features, num_boost_round),
num_trials)
num_boost_round), num_trials)
self.params.update(study.best_params)

self.model = lightgbm.train(
self.params, train_data, num_boost_round=num_boost_round,
categorical_feature=cat_features, valid_sets=[eval_data],
callbacks=[
valid_sets=[eval_data], callbacks=[
lightgbm.early_stopping(stopping_rounds=50, verbose=False),
lightgbm.log_evaluation(period=2000)
])
Expand Down

0 comments on commit dd63def

Please sign in to comment.