Skip to content

Commit

Permalink
Tighten assert condition in graph break tests (#458)
Browse files Browse the repository at this point in the history
Part of #452.
  • Loading branch information
akihironitta authored Sep 26, 2024
1 parent 9c6cc61 commit 6fc2761
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions test/nn/models/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch_frame.testing import withPackage


@withPackage("torch>=2.1.0")
@withPackage("torch>=2.5.0")
@pytest.mark.parametrize(
"model_cls, model_kwargs, stypes, expected_graph_breaks",
[
Expand All @@ -34,7 +34,7 @@
gamma=0.1,
),
None,
7,
2,
id="TabNet",
),
pytest.param(
Expand All @@ -47,21 +47,21 @@
ffn_dropout=0.5,
),
None,
4,
0,
id="TabTransformer",
),
pytest.param(
Trompt,
dict(channels=8, num_prompts=2),
None,
16,
4,
id="Trompt",
),
pytest.param(
ExcelFormer,
dict(in_channels=8, num_cols=3, num_heads=1),
[stype.numerical],
4,
1,
id="ExcelFormer",
),
],
Expand Down Expand Up @@ -89,4 +89,5 @@ def test_compile_graph_break(
**model_kwargs,
)
explanation = torch._dynamo.explain(model)(tf)
assert explanation.graph_break_count <= expected_graph_breaks
graph_breaks = explanation.graph_break_count
assert graph_breaks == expected_graph_breaks

0 comments on commit 6fc2761

Please sign in to comment.