From bb024f9707584aa4d01b220d3b6f664df61158c5 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Fri, 15 Nov 2024 11:05:20 -0500 Subject: [PATCH] [tuner] Move constraint generation out of canddiate_gen. NFC. This is just code motion to make the code more modular. Signed-off-by: Jakub Kuderski --- tuner/tuner/candidate_gen.py | 200 ++--------------------- tuner/tuner/candidate_gen_test.py | 135 --------------- tuner/tuner/dispatch_constraints.py | 197 ++++++++++++++++++++++ tuner/tuner/dispatch_constraints_test.py | 161 ++++++++++++++++++ 4 files changed, 368 insertions(+), 325 deletions(-) create mode 100644 tuner/tuner/dispatch_constraints.py create mode 100644 tuner/tuner/dispatch_constraints_test.py diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 06ccae0e..2f21520f 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -22,7 +22,6 @@ import logging import pickle import re -import z3 # type: ignore from dataclasses import dataclass from os import path, makedirs from typing import Optional @@ -32,6 +31,7 @@ from iree.compiler import ir # type: ignore from .common import * +from .dispatch_constraints import * from .dispatch_parser import * tune_logger = logging.getLogger("tune") @@ -73,194 +73,6 @@ def apply_configuration( return new_mlir -def get_mfma_intrinsic_constraints( - problem_size: ProblemSize, - intrinsic_m: z3.ArithRef, - intrinsic_n: z3.ArithRef, - intrinsic_k: z3.ArithRef, -) -> z3.BoolRef: - compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size) - assert len(compatible_intrinsics) > 0, "No compatible intrinsics found" - return z3.Or( - *( - z3.And(intrinsic_m == mfma.m, intrinsic_n == mfma.n, intrinsic_k == mfma.k) - for mfma in compatible_intrinsics - ) - ) - - -def get_dispatch_constraints( - problem_size: ProblemSize, - tile_m: z3.ArithRef, - tile_n: z3.ArithRef, - tile_k: z3.ArithRef, -) -> list[z3.BoolRef]: - if problem_size.dispatch_kind != DispatchKind.conv: - return [] - - dim_info = ConvDimInfo.from_problem_size(problem_size) - conv_constraints = [] - # WARNING: This sometimes makes the constraints UNSAT for some reason. - conv_constraints += [tile_m <= dim_info.ow] - conv_constraints += [tile_n <= dim_info.oc] - conv_constraints += [tile_k <= dim_info.ic] - return conv_constraints - - -def calculate_shared_memory_usage_in_bytes( - problem_size: ProblemSize, - m: int | z3.ArithRef, - n: int | z3.ArithRef, - k: int | z3.ArithRef, -) -> int | z3.ArithRef: - lhs_memory = m * k * (problem_size.lhs_type.bitwidth // 8) - rhs_memory = k * n * (problem_size.rhs_type.bitwidth // 8) - return lhs_memory + rhs_memory - - -def generate_constraints( - problem_size: ProblemSize, - tile_sizes, - num_subgroups, - subgroup_size, - intrinsic_size, - workgroup_size, - subgroup_m_count, - subgroup_n_count, - waves_per_eu, -): - M, N, K = ( - problem_size.matmul_size.M, - problem_size.matmul_size.N, - problem_size.matmul_size.K, - ) - m, n, k = tile_sizes - intrinsic_mn, intrinsic_k = intrinsic_size - wg_x, wg_y, wg_z = workgroup_size - wg_threads = z3.Int("wg_threads") - constraints = [wg_threads == wg_x * wg_y * wg_z] - constraints += [subgroup_size == 64, wg_threads <= 1024] - constraints += [ - get_mfma_intrinsic_constraints( - problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k - ) - ] - subgroup_k_count = 1 - constraints += [ - m >= intrinsic_mn, - m <= 512, - m <= M, - ] - constraints += [n >= intrinsic_mn, n <= 512, n <= N, N % n == 0] - constraints += [k >= intrinsic_k, k <= 512, k <= K, K % k == 0] - for x in (subgroup_m_count, subgroup_n_count): - constraints += [x >= 1, x <= 32] - - subgroup_m_tile_count = z3.Int("sg_m_tcnt") - subgroup_n_tile_count = z3.Int("sg_n_tcnt") - subgroup_k_tile_count = z3.Int("sg_k_tcnt") - for x in (subgroup_m_tile_count, subgroup_n_tile_count, subgroup_k_tile_count): - constraints += [x >= 1, x <= 32] - - constraints += [m == subgroup_m_count * subgroup_m_tile_count * intrinsic_mn] - constraints += [n == subgroup_n_count * subgroup_n_tile_count * intrinsic_mn] - constraints += [k == subgroup_k_count * subgroup_k_tile_count * intrinsic_k] - constraints += [wg_x == subgroup_size * subgroup_n_count] - constraints += [wg_y == subgroup_m_count] - constraints += [wg_z == subgroup_k_count] - constraints += [z3.Or(wg_x <= n, wg_x <= m)] - constraints += [k % intrinsic_mn == 0] - constraints += [(k * n) % wg_threads == 0] - constraints += [(k * m) % wg_threads == 0] - subgroups = subgroup_m_count * subgroup_n_count - if num_subgroups > 0: - constraints += [subgroups == num_subgroups] - else: - constraints += [subgroups >= 1, subgroups <= 10] - - constraints += [waves_per_eu == 2] - # constraints += [z3.Or(waves_per_eu == 2, waves_per_eu == 3, waves_per_eu == 4)] - - shared_memory = calculate_shared_memory_usage_in_bytes(problem_size, m, n, k) - constraints += [shared_memory <= 65536] - - constraints += get_dispatch_constraints(problem_size, m, n, k) - - return constraints - - -def generate_solutions(problem_size: ProblemSize, num_subgrups: int): - M, N, K = problem_size.MNK - tune_logger.info(f"{M},{N},{K}") - m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k") - subgroup_size = z3.Int("subgroup_size") - intrinsic_mn = z3.Int("intrinsic_mn") - intrinsic_k = z3.Int("intrinsic_k") - wg_x, wg_y, wg_z = z3.Int("wg_x"), z3.Int("wg_y"), z3.Int("wg_z") - sg_m_cnt = z3.Int("sg_m_cnt") - sg_n_cnt = z3.Int("sg_n_cnt") - waves_per_eu = z3.Int("waves_per_eu") - all_vars = [ - m, - n, - k, - subgroup_size, - intrinsic_mn, - intrinsic_k, - wg_x, - wg_y, - wg_z, - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ] - - solver = z3.Solver() - constraints = generate_constraints( - problem_size, - [m, n, k], - num_subgrups, - subgroup_size, - [intrinsic_mn, intrinsic_k], - [wg_x, wg_y, wg_z], - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ) - solver.add(z3.simplify(z3.And(constraints))) - tune_logger.debug(f"Initial constraints: {solver}") - i = 0 - while solver.check() == z3.sat: - model = solver.model() - lookup = lambda var: model[var].as_long() - - config = Configuration( - lookup(subgroup_size), - [lookup(wg_x), lookup(wg_y), lookup(wg_z)], - MfmaIntrinsic( - problem_size.res_type.element_type, - lookup(intrinsic_mn), - lookup(intrinsic_mn), - lookup(intrinsic_k), - problem_size.lhs_type.element_type, - ), - [lookup(m), lookup(n), lookup(k)], - lookup(sg_m_cnt), - lookup(sg_n_cnt), - GpuPipelineOptions(), - lookup(waves_per_eu), - ) - solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars))))) - i += 1 - yield config - - -def get_default_output_dir() -> str: - from datetime import datetime - - return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M") - - class DispatchTuner(DispatchParser): # TODO(https://github.com/nod-ai/SHARK-Platform/issues/453): Remove this in favor of configuring using transform dialect. @abstractmethod @@ -675,6 +487,12 @@ def walk_mlir_op( return walk_result +def get_default_output_dir() -> str: + from datetime import datetime + + return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M") + + def tune( input: str, # Path to the mlir file to be tuned output: str = "", # Path to the output directory, auto creates one if not given @@ -722,7 +540,9 @@ def tune( problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template) tune_logger.debug(str(problem_size)) configs = [] - for i, config in enumerate(generate_solutions(problem_size, num_subgroups)): + for i, config in enumerate( + generate_solutions(tuner_context, problem_size, num_subgroups) + ): if i >= limit: break tune_logger.info(f"Solution #{i+1}: {config}") diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 63e8441d..47e351fc 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -14,141 +14,6 @@ from . import common -def test_generate_solutions() -> None: - matmul_size = common.MatmulSize(2048, 3840, 1280) - lhs_type = common.ShapedType([2048, 1280], common.ElementType.f16) - rhs_type = common.ShapedType([3840, 1280], common.ElementType.f16) - res_type = common.ShapedType([2048, 3840], common.ElementType.f32) - problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt - ) - configs = candidate_gen.generate_solutions(problem_size, 4) - assert configs is not None - - -def test_calculate_shared_memory_usage_in_bytes() -> None: - matmul_size = common.MatmulSize(1024, 1024, 1024) - lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - res_type = common.ShapedType([1024, 1024], common.ElementType.f32) - problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt - ) - assert ( - candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) - == 147456 - ) - - lhs_type = common.ShapedType([1024, 1024], common.ElementType.i8) - problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt - ) - assert ( - candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) - == 81920 - ) - - rhs_type = common.ShapedType([1024, 1024], common.ElementType.i32) - problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt - ) - assert ( - candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 128, 64, 32) - == 12288 - ) - - -def test_generate_constraints_valid_input() -> None: - matmul_size = common.MatmulSize(1024, 1024, 1024) - lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - res_type = common.ShapedType([1024, 1024], common.ElementType.f32) - problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt - ) - # Define input parameters as z3 Ints - m, n, k = ( - candidate_gen.z3.Int("m"), - candidate_gen.z3.Int("n"), - candidate_gen.z3.Int("k"), - ) - subgroup_size = candidate_gen.z3.Int("subgroup_size") - intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") - intrinsic_k = candidate_gen.z3.Int("intrinsic_k") - wg_x, wg_y, wg_z = ( - candidate_gen.z3.Int("wg_x"), - candidate_gen.z3.Int("wg_y"), - candidate_gen.z3.Int("wg_z"), - ) - sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") - sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") - waves_per_eu = candidate_gen.z3.Int("waves_per_eu") - - constraints = candidate_gen.generate_constraints( - problem_size, - [m, n, k], - 4, - subgroup_size, - [intrinsic_mn, intrinsic_k], - [wg_x, wg_y, wg_z], - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ) - - solver = candidate_gen.z3.Solver() - solver.add(constraints) - - # Check if the constraints are satisfiable - assert solver.check() == candidate_gen.z3.sat - - -def test_generate_constraints_invalid_input() -> None: - # Define input parameters that should lead to unsatisfiable constraints - matmul_size = common.MatmulSize(1024, 1024, 1024) - lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - res_type = common.ShapedType([1024, 1024], common.ElementType.f32) - problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt - ) - m, n, k = ( - candidate_gen.z3.Int("m"), - candidate_gen.z3.Int("n"), - candidate_gen.z3.Int("k"), - ) - subgroup_size = candidate_gen.z3.Int("subgroup_size") - intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") - intrinsic_k = candidate_gen.z3.Int("intrinsic_k") - wg_x, wg_y, wg_z = ( - candidate_gen.z3.Int("wg_x"), - candidate_gen.z3.Int("wg_y"), - candidate_gen.z3.Int("wg_z"), - ) - sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") - sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") - waves_per_eu = candidate_gen.z3.Int("waves_per_eu") - - constraints = candidate_gen.generate_constraints( - problem_size, - [m, n, k], - 4, - subgroup_size, - [intrinsic_mn, intrinsic_k], - [wg_x, wg_y, wg_z], - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ) - constraints.append(m > 1000) # Adding an additional unsatisfiable constraint - - solver = candidate_gen.z3.Solver() - solver.add(constraints) - - # Check if the constraints are unsatisfiable - assert solver.check() == candidate_gen.z3.unsat - - def remove_comments(mlir: str) -> str: return "\n".join( filter(lambda x: not x.lstrip().startswith("//"), mlir.splitlines()) diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py new file mode 100644 index 00000000..ac46d8ed --- /dev/null +++ b/tuner/tuner/dispatch_constraints.py @@ -0,0 +1,197 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Given an input dispatch, this code modifies the hyperparameters +# in the code and runs it. + +import z3 # type: ignore +from typing import Iterator + +from .common import * + + +def get_mfma_intrinsic_constraints( + problem_size: ProblemSize, + intrinsic_m: z3.ArithRef, + intrinsic_n: z3.ArithRef, + intrinsic_k: z3.ArithRef, +) -> z3.BoolRef: + compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size) + assert len(compatible_intrinsics) > 0, "No compatible intrinsics found" + return z3.Or( + *( + z3.And(intrinsic_m == mfma.m, intrinsic_n == mfma.n, intrinsic_k == mfma.k) + for mfma in compatible_intrinsics + ) + ) + + +def get_dispatch_constraints( + problem_size: ProblemSize, + tile_m: z3.ArithRef, + tile_n: z3.ArithRef, + tile_k: z3.ArithRef, +) -> list[z3.BoolRef]: + if problem_size.dispatch_kind != DispatchKind.conv: + return [] + + dim_info = ConvDimInfo.from_problem_size(problem_size) + conv_constraints = [] + # WARNING: This sometimes makes the constraints UNSAT for some reason. + conv_constraints += [tile_m <= dim_info.ow] + conv_constraints += [tile_n <= dim_info.oc] + conv_constraints += [tile_k <= dim_info.ic] + return conv_constraints + + +def calculate_shared_memory_usage_in_bytes( + problem_size: ProblemSize, + m: int | z3.ArithRef, + n: int | z3.ArithRef, + k: int | z3.ArithRef, +) -> int | z3.ArithRef: + lhs_memory = m * k * (problem_size.lhs_type.bitwidth // 8) + rhs_memory = k * n * (problem_size.rhs_type.bitwidth // 8) + return lhs_memory + rhs_memory + + +def generate_constraints( + problem_size: ProblemSize, + tile_sizes, + num_subgroups, + subgroup_size, + intrinsic_size, + workgroup_size, + subgroup_m_count, + subgroup_n_count, + waves_per_eu, +): + M, N, K = ( + problem_size.matmul_size.M, + problem_size.matmul_size.N, + problem_size.matmul_size.K, + ) + m, n, k = tile_sizes + intrinsic_mn, intrinsic_k = intrinsic_size + wg_x, wg_y, wg_z = workgroup_size + wg_threads = z3.Int("wg_threads") + constraints = [wg_threads == wg_x * wg_y * wg_z] + constraints += [subgroup_size == 64, wg_threads <= 1024] + constraints += [ + get_mfma_intrinsic_constraints( + problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k + ) + ] + subgroup_k_count = 1 + constraints += [ + m >= intrinsic_mn, + m <= 512, + m <= M, + ] + constraints += [n >= intrinsic_mn, n <= 512, n <= N, N % n == 0] + constraints += [k >= intrinsic_k, k <= 512, k <= K, K % k == 0] + for x in (subgroup_m_count, subgroup_n_count): + constraints += [x >= 1, x <= 32] + + subgroup_m_tile_count = z3.Int("sg_m_tcnt") + subgroup_n_tile_count = z3.Int("sg_n_tcnt") + subgroup_k_tile_count = z3.Int("sg_k_tcnt") + for x in (subgroup_m_tile_count, subgroup_n_tile_count, subgroup_k_tile_count): + constraints += [x >= 1, x <= 32] + + constraints += [m == subgroup_m_count * subgroup_m_tile_count * intrinsic_mn] + constraints += [n == subgroup_n_count * subgroup_n_tile_count * intrinsic_mn] + constraints += [k == subgroup_k_count * subgroup_k_tile_count * intrinsic_k] + constraints += [wg_x == subgroup_size * subgroup_n_count] + constraints += [wg_y == subgroup_m_count] + constraints += [wg_z == subgroup_k_count] + constraints += [z3.Or(wg_x <= n, wg_x <= m)] + constraints += [k % intrinsic_mn == 0] + constraints += [(k * n) % wg_threads == 0] + constraints += [(k * m) % wg_threads == 0] + subgroups = subgroup_m_count * subgroup_n_count + if num_subgroups > 0: + constraints += [subgroups == num_subgroups] + else: + constraints += [subgroups >= 1, subgroups <= 10] + + constraints += [waves_per_eu == 2] + # constraints += [z3.Or(waves_per_eu == 2, waves_per_eu == 3, waves_per_eu == 4)] + + shared_memory = calculate_shared_memory_usage_in_bytes(problem_size, m, n, k) + constraints += [shared_memory <= 65536] + + constraints += get_dispatch_constraints(problem_size, m, n, k) + + return constraints + + +def generate_solutions( + ctx: TunerContext, problem_size: ProblemSize, num_subgrups: int +) -> Iterator[Configuration]: + M, N, K = problem_size.MNK + ctx.logger.info(f"{M},{N},{K}") + m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k") + subgroup_size = z3.Int("subgroup_size") + intrinsic_mn = z3.Int("intrinsic_mn") + intrinsic_k = z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = z3.Int("wg_x"), z3.Int("wg_y"), z3.Int("wg_z") + sg_m_cnt = z3.Int("sg_m_cnt") + sg_n_cnt = z3.Int("sg_n_cnt") + waves_per_eu = z3.Int("waves_per_eu") + all_vars = [ + m, + n, + k, + subgroup_size, + intrinsic_mn, + intrinsic_k, + wg_x, + wg_y, + wg_z, + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ] + + solver = z3.Solver() + constraints = generate_constraints( + problem_size, + [m, n, k], + num_subgrups, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ) + solver.add(z3.simplify(z3.And(constraints))) + ctx.logger.debug(f"Initial constraints: {solver}") + i = 0 + while solver.check() == z3.sat: + model = solver.model() + lookup = lambda var: model[var].as_long() + + config = Configuration( + lookup(subgroup_size), + [lookup(wg_x), lookup(wg_y), lookup(wg_z)], + MfmaIntrinsic( + problem_size.res_type.element_type, + lookup(intrinsic_mn), + lookup(intrinsic_mn), + lookup(intrinsic_k), + problem_size.lhs_type.element_type, + ), + [lookup(m), lookup(n), lookup(k)], + lookup(sg_m_cnt), + lookup(sg_n_cnt), + GpuPipelineOptions(), + lookup(waves_per_eu), + ) + solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars))))) + i += 1 + yield config diff --git a/tuner/tuner/dispatch_constraints_test.py b/tuner/tuner/dispatch_constraints_test.py new file mode 100644 index 00000000..55f3a8c4 --- /dev/null +++ b/tuner/tuner/dispatch_constraints_test.py @@ -0,0 +1,161 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Usage: python -m pytest candidate_gen_test.py +""" + +import pytest +import z3 # type: ignore + +from logging import Logger +from unittest.mock import MagicMock + +from . import common +from . import dispatch_constraints + + +def test_generate_solutions() -> None: + matmul_size = common.MatmulSize(2048, 3840, 1280) + lhs_type = common.ShapedType([2048, 1280], common.ElementType.f16) + rhs_type = common.ShapedType([3840, 1280], common.ElementType.f16) + res_type = common.ShapedType([2048, 3840], common.ElementType.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + ) + logger: Logger = MagicMock(spec=Logger) + ctx = common.TunerContext(None, logger) + configs = dispatch_constraints.generate_solutions(ctx, problem_size, 4) + assert configs is not None + + +def test_calculate_shared_memory_usage_in_bytes() -> None: + matmul_size = common.MatmulSize(1024, 1024, 1024) + lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + res_type = common.ShapedType([1024, 1024], common.ElementType.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + ) + assert ( + dispatch_constraints.calculate_shared_memory_usage_in_bytes( + problem_size, 512, 64, 128 + ) + == 147456 + ) + + lhs_type = common.ShapedType([1024, 1024], common.ElementType.i8) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + ) + assert ( + dispatch_constraints.calculate_shared_memory_usage_in_bytes( + problem_size, 512, 64, 128 + ) + == 81920 + ) + + rhs_type = common.ShapedType([1024, 1024], common.ElementType.i32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + ) + assert ( + dispatch_constraints.calculate_shared_memory_usage_in_bytes( + problem_size, 128, 64, 32 + ) + == 12288 + ) + + +def test_generate_constraints_valid_input() -> None: + matmul_size = common.MatmulSize(1024, 1024, 1024) + lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + res_type = common.ShapedType([1024, 1024], common.ElementType.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + ) + # Define input parameters as z3 Ints + m, n, k = ( + dispatch_constraints.z3.Int("m"), + z3.Int("n"), + z3.Int("k"), + ) + subgroup_size = z3.Int("subgroup_size") + intrinsic_mn = z3.Int("intrinsic_mn") + intrinsic_k = z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = ( + z3.Int("wg_x"), + z3.Int("wg_y"), + z3.Int("wg_z"), + ) + sg_m_cnt = z3.Int("sg_m_cnt") + sg_n_cnt = z3.Int("sg_n_cnt") + waves_per_eu = z3.Int("waves_per_eu") + + constraints = dispatch_constraints.generate_constraints( + problem_size, + [m, n, k], + 4, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ) + + solver = z3.Solver() + solver.add(constraints) + + # Check if the constraints are satisfiable + assert solver.check() == z3.sat + + +def test_generate_constraints_invalid_input() -> None: + # Define input parameters that should lead to unsatisfiable constraints + matmul_size = common.MatmulSize(1024, 1024, 1024) + lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + res_type = common.ShapedType([1024, 1024], common.ElementType.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + ) + m, n, k = ( + z3.Int("m"), + z3.Int("n"), + z3.Int("k"), + ) + subgroup_size = z3.Int("subgroup_size") + intrinsic_mn = z3.Int("intrinsic_mn") + intrinsic_k = z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = ( + z3.Int("wg_x"), + z3.Int("wg_y"), + z3.Int("wg_z"), + ) + sg_m_cnt = z3.Int("sg_m_cnt") + sg_n_cnt = z3.Int("sg_n_cnt") + waves_per_eu = z3.Int("waves_per_eu") + + constraints = dispatch_constraints.generate_constraints( + problem_size, + [m, n, k], + 4, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ) + constraints.append(m > 1000) # Adding an additional unsatisfiable constraint + + solver = z3.Solver() + solver.add(constraints) + + # Check if the constraints are unsatisfiable + assert solver.check() == z3.unsat