diff --git a/torchbenchmark/util/backends/torchdynamo.py b/torchbenchmark/util/backends/torchdynamo.py index 2eedc89dc7..68ef6ea5bb 100644 --- a/torchbenchmark/util/backends/torchdynamo.py +++ b/torchbenchmark/util/backends/torchdynamo.py @@ -58,6 +58,11 @@ def parse_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', dy action='store_true', help="enable batch fusion in Inductor" ) + parser.add_argument( + "--torchinductor_enable_split_cat_fx_pass", + action='store_true', + help="enable split_cat_fx_pass in Inductor" + ) parser.add_argument( "--dynamo_disable_optimizer_step", type=distutils.util.strtobool, @@ -92,7 +97,10 @@ def apply_torchdynamo_args(model: 'torchbenchmark.util.model.BenchmarkModel', ar if args.torchinductor_enable_group_fusion: torchinductor.config.group_fusion = True if args.torchinductor_enable_batch_fusion: + torchinductor.config.pattern_matcher = True torchinductor.config.batch_fusion = True + if args.torchinductor_enable_split_cat_fx_pass: + torchinductor.config.split_cat_fx_passes = True # used for correctness checks, to avoid triton rand() behaving differently from torch rand(). torchinductor.config.fallback_random = bool(args.torchinductor_fallback_random)