Skip to content

Commit

Permalink
[tuner] Move constraint generation out of canddiate_gen. NFC.
Browse files Browse the repository at this point in the history
This is just code motion to make the code more modular.

Signed-off-by: Jakub Kuderski <[email protected]>
  • Loading branch information
kuhar committed Nov 15, 2024
1 parent f1bf282 commit bb024f9
Show file tree
Hide file tree
Showing 4 changed files with 368 additions and 325 deletions.
200 changes: 10 additions & 190 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import logging
import pickle
import re
import z3 # type: ignore
from dataclasses import dataclass
from os import path, makedirs
from typing import Optional
Expand All @@ -32,6 +31,7 @@
from iree.compiler import ir # type: ignore

from .common import *
from .dispatch_constraints import *
from .dispatch_parser import *

tune_logger = logging.getLogger("tune")
Expand Down Expand Up @@ -73,194 +73,6 @@ def apply_configuration(
return new_mlir


def get_mfma_intrinsic_constraints(
problem_size: ProblemSize,
intrinsic_m: z3.ArithRef,
intrinsic_n: z3.ArithRef,
intrinsic_k: z3.ArithRef,
) -> z3.BoolRef:
compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size)
assert len(compatible_intrinsics) > 0, "No compatible intrinsics found"
return z3.Or(
*(
z3.And(intrinsic_m == mfma.m, intrinsic_n == mfma.n, intrinsic_k == mfma.k)
for mfma in compatible_intrinsics
)
)


def get_dispatch_constraints(
problem_size: ProblemSize,
tile_m: z3.ArithRef,
tile_n: z3.ArithRef,
tile_k: z3.ArithRef,
) -> list[z3.BoolRef]:
if problem_size.dispatch_kind != DispatchKind.conv:
return []

dim_info = ConvDimInfo.from_problem_size(problem_size)
conv_constraints = []
# WARNING: This sometimes makes the constraints UNSAT for some reason.
conv_constraints += [tile_m <= dim_info.ow]
conv_constraints += [tile_n <= dim_info.oc]
conv_constraints += [tile_k <= dim_info.ic]
return conv_constraints


def calculate_shared_memory_usage_in_bytes(
problem_size: ProblemSize,
m: int | z3.ArithRef,
n: int | z3.ArithRef,
k: int | z3.ArithRef,
) -> int | z3.ArithRef:
lhs_memory = m * k * (problem_size.lhs_type.bitwidth // 8)
rhs_memory = k * n * (problem_size.rhs_type.bitwidth // 8)
return lhs_memory + rhs_memory


def generate_constraints(
problem_size: ProblemSize,
tile_sizes,
num_subgroups,
subgroup_size,
intrinsic_size,
workgroup_size,
subgroup_m_count,
subgroup_n_count,
waves_per_eu,
):
M, N, K = (
problem_size.matmul_size.M,
problem_size.matmul_size.N,
problem_size.matmul_size.K,
)
m, n, k = tile_sizes
intrinsic_mn, intrinsic_k = intrinsic_size
wg_x, wg_y, wg_z = workgroup_size
wg_threads = z3.Int("wg_threads")
constraints = [wg_threads == wg_x * wg_y * wg_z]
constraints += [subgroup_size == 64, wg_threads <= 1024]
constraints += [
get_mfma_intrinsic_constraints(
problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k
)
]
subgroup_k_count = 1
constraints += [
m >= intrinsic_mn,
m <= 512,
m <= M,
]
constraints += [n >= intrinsic_mn, n <= 512, n <= N, N % n == 0]
constraints += [k >= intrinsic_k, k <= 512, k <= K, K % k == 0]
for x in (subgroup_m_count, subgroup_n_count):
constraints += [x >= 1, x <= 32]

subgroup_m_tile_count = z3.Int("sg_m_tcnt")
subgroup_n_tile_count = z3.Int("sg_n_tcnt")
subgroup_k_tile_count = z3.Int("sg_k_tcnt")
for x in (subgroup_m_tile_count, subgroup_n_tile_count, subgroup_k_tile_count):
constraints += [x >= 1, x <= 32]

constraints += [m == subgroup_m_count * subgroup_m_tile_count * intrinsic_mn]
constraints += [n == subgroup_n_count * subgroup_n_tile_count * intrinsic_mn]
constraints += [k == subgroup_k_count * subgroup_k_tile_count * intrinsic_k]
constraints += [wg_x == subgroup_size * subgroup_n_count]
constraints += [wg_y == subgroup_m_count]
constraints += [wg_z == subgroup_k_count]
constraints += [z3.Or(wg_x <= n, wg_x <= m)]
constraints += [k % intrinsic_mn == 0]
constraints += [(k * n) % wg_threads == 0]
constraints += [(k * m) % wg_threads == 0]
subgroups = subgroup_m_count * subgroup_n_count
if num_subgroups > 0:
constraints += [subgroups == num_subgroups]
else:
constraints += [subgroups >= 1, subgroups <= 10]

constraints += [waves_per_eu == 2]
# constraints += [z3.Or(waves_per_eu == 2, waves_per_eu == 3, waves_per_eu == 4)]

shared_memory = calculate_shared_memory_usage_in_bytes(problem_size, m, n, k)
constraints += [shared_memory <= 65536]

constraints += get_dispatch_constraints(problem_size, m, n, k)

return constraints


def generate_solutions(problem_size: ProblemSize, num_subgrups: int):
M, N, K = problem_size.MNK
tune_logger.info(f"{M},{N},{K}")
m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k")
subgroup_size = z3.Int("subgroup_size")
intrinsic_mn = z3.Int("intrinsic_mn")
intrinsic_k = z3.Int("intrinsic_k")
wg_x, wg_y, wg_z = z3.Int("wg_x"), z3.Int("wg_y"), z3.Int("wg_z")
sg_m_cnt = z3.Int("sg_m_cnt")
sg_n_cnt = z3.Int("sg_n_cnt")
waves_per_eu = z3.Int("waves_per_eu")
all_vars = [
m,
n,
k,
subgroup_size,
intrinsic_mn,
intrinsic_k,
wg_x,
wg_y,
wg_z,
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
]

solver = z3.Solver()
constraints = generate_constraints(
problem_size,
[m, n, k],
num_subgrups,
subgroup_size,
[intrinsic_mn, intrinsic_k],
[wg_x, wg_y, wg_z],
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
)
solver.add(z3.simplify(z3.And(constraints)))
tune_logger.debug(f"Initial constraints: {solver}")
i = 0
while solver.check() == z3.sat:
model = solver.model()
lookup = lambda var: model[var].as_long()

config = Configuration(
lookup(subgroup_size),
[lookup(wg_x), lookup(wg_y), lookup(wg_z)],
MfmaIntrinsic(
problem_size.res_type.element_type,
lookup(intrinsic_mn),
lookup(intrinsic_mn),
lookup(intrinsic_k),
problem_size.lhs_type.element_type,
),
[lookup(m), lookup(n), lookup(k)],
lookup(sg_m_cnt),
lookup(sg_n_cnt),
GpuPipelineOptions(),
lookup(waves_per_eu),
)
solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars)))))
i += 1
yield config


def get_default_output_dir() -> str:
from datetime import datetime

return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M")


class DispatchTuner(DispatchParser):
# TODO(https://github.com/nod-ai/SHARK-Platform/issues/453): Remove this in favor of configuring using transform dialect.
@abstractmethod
Expand Down Expand Up @@ -675,6 +487,12 @@ def walk_mlir_op(
return walk_result


def get_default_output_dir() -> str:
from datetime import datetime

return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M")


def tune(
input: str, # Path to the mlir file to be tuned
output: str = "", # Path to the output directory, auto creates one if not given
Expand Down Expand Up @@ -722,7 +540,9 @@ def tune(
problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template)
tune_logger.debug(str(problem_size))
configs = []
for i, config in enumerate(generate_solutions(problem_size, num_subgroups)):
for i, config in enumerate(
generate_solutions(tuner_context, problem_size, num_subgroups)
):
if i >= limit:
break
tune_logger.info(f"Solution #{i+1}: {config}")
Expand Down
135 changes: 0 additions & 135 deletions tuner/tuner/candidate_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,141 +14,6 @@
from . import common


def test_generate_solutions() -> None:
matmul_size = common.MatmulSize(2048, 3840, 1280)
lhs_type = common.ShapedType([2048, 1280], common.ElementType.f16)
rhs_type = common.ShapedType([3840, 1280], common.ElementType.f16)
res_type = common.ShapedType([2048, 3840], common.ElementType.f32)
problem_size = common.ProblemSize(
matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt
)
configs = candidate_gen.generate_solutions(problem_size, 4)
assert configs is not None


def test_calculate_shared_memory_usage_in_bytes() -> None:
matmul_size = common.MatmulSize(1024, 1024, 1024)
lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16)
rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16)
res_type = common.ShapedType([1024, 1024], common.ElementType.f32)
problem_size = common.ProblemSize(
matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt
)
assert (
candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128)
== 147456
)

lhs_type = common.ShapedType([1024, 1024], common.ElementType.i8)
problem_size = common.ProblemSize(
matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt
)
assert (
candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128)
== 81920
)

rhs_type = common.ShapedType([1024, 1024], common.ElementType.i32)
problem_size = common.ProblemSize(
matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt
)
assert (
candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 128, 64, 32)
== 12288
)


def test_generate_constraints_valid_input() -> None:
matmul_size = common.MatmulSize(1024, 1024, 1024)
lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16)
rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16)
res_type = common.ShapedType([1024, 1024], common.ElementType.f32)
problem_size = common.ProblemSize(
matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt
)
# Define input parameters as z3 Ints
m, n, k = (
candidate_gen.z3.Int("m"),
candidate_gen.z3.Int("n"),
candidate_gen.z3.Int("k"),
)
subgroup_size = candidate_gen.z3.Int("subgroup_size")
intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn")
intrinsic_k = candidate_gen.z3.Int("intrinsic_k")
wg_x, wg_y, wg_z = (
candidate_gen.z3.Int("wg_x"),
candidate_gen.z3.Int("wg_y"),
candidate_gen.z3.Int("wg_z"),
)
sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt")
sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt")
waves_per_eu = candidate_gen.z3.Int("waves_per_eu")

constraints = candidate_gen.generate_constraints(
problem_size,
[m, n, k],
4,
subgroup_size,
[intrinsic_mn, intrinsic_k],
[wg_x, wg_y, wg_z],
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
)

solver = candidate_gen.z3.Solver()
solver.add(constraints)

# Check if the constraints are satisfiable
assert solver.check() == candidate_gen.z3.sat


def test_generate_constraints_invalid_input() -> None:
# Define input parameters that should lead to unsatisfiable constraints
matmul_size = common.MatmulSize(1024, 1024, 1024)
lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16)
rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16)
res_type = common.ShapedType([1024, 1024], common.ElementType.f32)
problem_size = common.ProblemSize(
matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt
)
m, n, k = (
candidate_gen.z3.Int("m"),
candidate_gen.z3.Int("n"),
candidate_gen.z3.Int("k"),
)
subgroup_size = candidate_gen.z3.Int("subgroup_size")
intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn")
intrinsic_k = candidate_gen.z3.Int("intrinsic_k")
wg_x, wg_y, wg_z = (
candidate_gen.z3.Int("wg_x"),
candidate_gen.z3.Int("wg_y"),
candidate_gen.z3.Int("wg_z"),
)
sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt")
sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt")
waves_per_eu = candidate_gen.z3.Int("waves_per_eu")

constraints = candidate_gen.generate_constraints(
problem_size,
[m, n, k],
4,
subgroup_size,
[intrinsic_mn, intrinsic_k],
[wg_x, wg_y, wg_z],
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
)
constraints.append(m > 1000) # Adding an additional unsatisfiable constraint

solver = candidate_gen.z3.Solver()
solver.add(constraints)

# Check if the constraints are unsatisfiable
assert solver.check() == candidate_gen.z3.unsat


def remove_comments(mlir: str) -> str:
return "\n".join(
filter(lambda x: not x.lstrip().startswith("//"), mlir.splitlines())
Expand Down
Loading

0 comments on commit bb024f9

Please sign in to comment.