Skip to content

Commit

Permalink
Add cmf_10x_batch_fusion test to Sandcastle and AIBench
Browse files Browse the repository at this point in the history
Summary: As the title says. We want to add a test to cover batch fusion feature on cmf.

Reviewed By: jackiexu1992

Differential Revision: D48622074

fbshipit-source-id: 73c6cb52206805b9c6c3795789228bfacc35f80a
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Aug 24, 2023
1 parent cffa057 commit c63dbd6
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions torchbenchmark/util/backends/torchdynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def parse_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', dy
action='store_true',
help="enable batch fusion in Inductor"
)
parser.add_argument(
"--torchinductor_enable_split_cat_fx_pass",
action='store_true',
help="enable split_cat_fx_pass in Inductor"
)
parser.add_argument(
"--dynamo_disable_optimizer_step",
type=distutils.util.strtobool,
Expand Down Expand Up @@ -92,7 +97,10 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar
if args.torchinductor_enable_group_fusion:
torchinductor.config.group_fusion = True
if args.torchinductor_enable_batch_fusion:
torchinductor.config.pattern_matcher = True
torchinductor.config.batch_fusion = True
if args.torchinductor_enable_split_cat_fx_pass:
torchinductor.config.split_cat_fx_passes = True

# used for correctness checks, to avoid triton rand() behaving differently from torch rand().
torchinductor.config.fallback_random = bool(args.torchinductor_fallback_random)
Expand Down

0 comments on commit c63dbd6

Please sign in to comment.