From f8248b381b6696b311dd42e13f5f8706d5a50064 Mon Sep 17 00:00:00 2001 From: erman-gurses Date: Fri, 8 Nov 2024 17:52:20 -0600 Subject: [PATCH 1/9] Add Attention Test Suite Signed-off-by: erman-gurses --- linalg_ops/CMakeLists.txt | 21 + linalg_ops/attention/CMakeLists.txt | 113 ++++ .../attention/generate_e2e_attention_tests.py | 503 ++++++++++++++++++ .../attention/generate_test_mlir_files.sh | 68 +++ .../attention_f16_f16_f16_f16_large.mlir | 17 + ...attention_f16_f16_f16_f16_large_calls.mlir | 57 ++ .../attention_f16_f16_f16_f16_medium.mlir | 17 + ...ttention_f16_f16_f16_f16_medium_calls.mlir | 57 ++ .../attention_f16_f16_f16_f16_small.mlir | 17 + ...attention_f16_f16_f16_f16_small_calls.mlir | 57 ++ linalg_ops/iree-e2e-attention-test.cc | 491 +++++++++++++++++ 11 files changed, 1418 insertions(+) create mode 100644 linalg_ops/attention/CMakeLists.txt create mode 100644 linalg_ops/attention/generate_e2e_attention_tests.py create mode 100755 linalg_ops/attention/generate_test_mlir_files.sh create mode 100644 linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_large.mlir create mode 100644 linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_large_calls.mlir create mode 100644 linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_medium.mlir create mode 100644 linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_medium_calls.mlir create mode 100644 linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_small.mlir create mode 100644 linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_small_calls.mlir create mode 100644 linalg_ops/iree-e2e-attention-test.cc diff --git a/linalg_ops/CMakeLists.txt b/linalg_ops/CMakeLists.txt index 015b318..0a44398 100644 --- a/linalg_ops/CMakeLists.txt +++ b/linalg_ops/CMakeLists.txt @@ -134,6 +134,26 @@ iree_cc_binary( iree::vm::cc ) +iree_cc_binary( + NAME + iree-e2e-attention-test + SRCS + "iree-e2e-attention-test.cc" + DEPS + ::test_utils + iree::base + iree::base::internal + iree::base::internal::cpu + iree::base::internal::flags + iree::base::internal::path + iree::hal + iree::modules::hal + iree::tooling::context_util + iree::tooling::device_util + iree::vm + iree::vm::cc +) + #------------------------------------------------------------------------------- # Tests #------------------------------------------------------------------------------- @@ -144,3 +164,4 @@ include(iree_test_suites_runner_test) add_subdirectory(matmul) add_subdirectory(convolution) +add_subdirectory(attention) diff --git a/linalg_ops/attention/CMakeLists.txt b/linalg_ops/attention/CMakeLists.txt new file mode 100644 index 0000000..0851c89 --- /dev/null +++ b/linalg_ops/attention/CMakeLists.txt @@ -0,0 +1,113 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# TODO(scotttodd): add filtering here, in the helper functions, or in ctest to +# choose which tests to compile and run + +set(_SIZES) +list(APPEND _SIZES "large") +list(APPEND _SIZES "medium") +list(APPEND _SIZES "small") + + +set(_DTYPES_AND_LAYOUTS) +list(APPEND _DTYPES_AND_LAYOUTS "f16_f16_f16_f16") + +############################################################################### +# +# CPU - llvm-cpu on local-task, default flags. +# +############################################################################### + +foreach(_DTYPE_AND_LAYOUT IN LISTS _DTYPES_AND_LAYOUTS) + foreach(_SIZE IN LISTS _SIZES) + iree_test_suites_runner_test( + NAME + attention_llvm-cpu_local-task_${_DTYPE_AND_LAYOUT}_${_SIZE} + TESTS_SRC + "generated/${_DTYPE_AND_LAYOUT}/attention_${_DTYPE_AND_LAYOUT}_${_SIZE}.mlir" + CALLS_SRC + "generated/${_DTYPE_AND_LAYOUT}/attention_${_DTYPE_AND_LAYOUT}_${_SIZE}_calls.mlir" + TEST_RUNNER + iree-test-suites_iree-e2e-attention-test + TARGET_BACKEND + "llvm-cpu" + DRIVER + "local-task" + COMPILER_FLAGS + "--iree-llvmcpu-target-cpu=host" + RUNNER_FLAGS + LABELS + "hostonly" + "local" + ) + endforeach() +endforeach() + +############################################################################### +# +# GPU - ROCm/HIP, CDNA(gfx9). +# +############################################################################### + +# To distinguish between CDNA(gfx9) and RDNA3(gfx11) +if(IREE_HIP_TEST_TARGET_CHIP MATCHES "^gfx9") + +foreach(_DTYPE_AND_LAYOUT IN LISTS _DTYPES_AND_LAYOUTS) + foreach(_SIZE IN LISTS _SIZES) + iree_test_suites_runner_test( + NAME + attention_rocm_hip_${_DTYPE_AND_LAYOUT}_${_SIZE} + TESTS_SRC + "generated/${_DTYPE_AND_LAYOUT}/attention_${_DTYPE_AND_LAYOUT}_${_SIZE}.mlir" + CALLS_SRC + "generated/${_DTYPE_AND_LAYOUT}/attention_${_DTYPE_AND_LAYOUT}_${_SIZE}_calls.mlir" + TEST_RUNNER + iree-test-suites_iree-e2e-attention-test + TARGET_BACKEND + "rocm" + DRIVER + "hip" + COMPILER_FLAGS + "--iree-hip-target=${IREE_HIP_TEST_TARGET_CHIP}" + RUNNER_FLAGS + LABELS + ) + endforeach() +endforeach() + +############################################################################### +# +# GPU - ROCm/HIP, CDNA(gfx11) +# +############################################################################### + +elseif(IREE_HIP_TEST_TARGET_CHIP MATCHES "^gfx11") + +foreach(_DTYPE_AND_LAYOUT IN LISTS _DTYPES_AND_LAYOUTS) + foreach(_SIZE IN LISTS _SIZES) + iree_test_suites_runner_test( + NAME + attention_rocm_hip_${_DTYPE_AND_LAYOUT}_${_SIZE} + TESTS_SRC + "generated/${_DTYPE_AND_LAYOUT}/attention_${_DTYPE_AND_LAYOUT}_${_SIZE}.mlir" + CALLS_SRC + "generated/${_DTYPE_AND_LAYOUT}/attention_${_DTYPE_AND_LAYOUT}_${_SIZE}_calls.mlir" + TEST_RUNNER + iree-test-suites_iree-e2e-attention-test + TARGET_BACKEND + "rocm" + DRIVER + "hip" + COMPILER_FLAGS + "--iree-hip-target=${IREE_HIP_TEST_TARGET_CHIP}" + RUNNER_FLAGS + LABELS + ) + endforeach() +endforeach() + +endif() diff --git a/linalg_ops/attention/generate_e2e_attention_tests.py b/linalg_ops/attention/generate_e2e_attention_tests.py new file mode 100644 index 0000000..4027c88 --- /dev/null +++ b/linalg_ops/attention/generate_e2e_attention_tests.py @@ -0,0 +1,503 @@ +#!/usr/bin/env python3 +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Generator for e2e attention tests. +""" + +import argparse +import enum +import dataclasses +import typing +import math + + +# Data type of kernel entries. The string values must match MLIR data types. +@enum.unique +class QueryElemTypeId(enum.Enum): + NONE = "" + F16 = "f16" + + +# Data type of input entries. The string values must match MLIR data types. +@enum.unique +class KeyElemTypeId(enum.Enum): + NONE = "" + F16 = "f16" + + +# Data type of input entries. The string values must match MLIR data types. +@enum.unique +class ValueElemTypeId(enum.Enum): + NONE = "" + F16 = "f16" + + +# Data type of input entries. The string values must match MLIR data types. +@enum.unique +class ResultElemTypeId(enum.Enum): + NONE = "" + F16 = "f16" + + +# Enumerates of the collections of shapes that we can generate tests for. +# The values are the accepted values for the --shapes= flag. +@enum.unique +class ShapesId(enum.Enum): + SMALL = "small" + MEDIUM = "medium" + LARGE = "large" + + +# batch: Batch dimension +# m: M dimension of first and second matmul +# n: N dimension of second matmul +# k1: K dimension of first matmul +# k2: K dimension of second matmul +@dataclasses.dataclass +class TestShapeAndScale: + batch: int + m: int + k1: int + k2: int + n: int + scale: float + + +# Returns the list of TestShape's to use for the collection of shapes +# identified by shapes_id. +def get_test_shapes(shapes_id: ShapesId): + if shapes_id == ShapesId.SMALL: + return [ + TestShapeAndScale(batch=2, m=256, k1=64, k2=32, n=16, scale=1.0), + ] + if shapes_id == ShapesId.MEDIUM: + return [ + TestShapeAndScale(batch=2, m=512, k1=128, k2=64, n=32, scale=1.0), + ] + if shapes_id == ShapesId.LARGE: + return [ + TestShapeAndScale(batch=2, m=1024, k1=128, k2=128, n=64, scale=1.0), + ] + + raise ValueError(shapes_id) + + +# Determines the shape of input and kernel tensors. +@dataclasses.dataclass +class TestInputTensorShapes: + batch: int + m: int + k1: int + k2: int + n: int + scale: float + + +# Helper for generate_function. Generates TestInputTensorShapes, i.e. +# converts from the runtime shape dimensions in TestShape and given dynamicity to +# the set of shapes to be used in a test function's input tensors. +def generate_shapes_and_scale(shape: TestShapeAndScale): + batch = shape.batch + m = shape.m + k1 = shape.k1 + k2 = shape.k2 + n = shape.n + scale = shape.scale + + shapes_scale = TestInputTensorShapes( + batch=batch, + m=m, + k1=k1, + k2=k2, + n=n, + scale=scale, + ) + return shapes_scale + + +# Helper to return input, kernel and output shapes based on the layout and the Attention Params. +def get_tensor_shapes( + shapes_scale: TestShapeAndScale, +): + batch = shapes_scale.batch + m = shapes_scale.m + k1 = shapes_scale.k1 + k2 = shapes_scale.k2 + n = shapes_scale.n + scale = shapes_scale.scale + + query_tensor_shape = [batch, m, k1] + key_tensor_shape = [batch, k2, k1] + value_tensor_shape = [batch, k2, n] + result_tensor_shape = [batch, m, n] + + return query_tensor_shape, key_tensor_shape, value_tensor_shape, result_tensor_shape + + +# Helper for generate_function. +# Generates a name for a test function in the generated MLIR code. +def generate_function_name( + query_type: QueryElemTypeId, + key_type: KeyElemTypeId, + value_type: ValueElemTypeId, + shapes_scale: TestInputTensorShapes, +): + query_t = query_type.value + key_t = key_type.value + value_t = value_type.value + result_t = value_type.value + + batch = shapes_scale.batch + m = shapes_scale.m + k1 = shapes_scale.k1 + k2 = shapes_scale.k2 + n = shapes_scale.n + + attention = "attention" + return ( + f"{attention}_{batch}_{m}_{k1}_{k2}_{n}" + + f"_dtype_{query_t}_{key_t}_{value_t}_{result_t}" + ) + + +# Represents a generated test function. +@dataclasses.dataclass +class MLIRFunction: + name: str + signature: str + import_declaration: str + definition: str + + +# Generates a test function in the generated MLIR code. +# The generated function will take the same arguments as iree_linalg_ext.attention variants +# and will just call iree_linalg_ext.attention variants with them, returning its result. +def generate_function( + query_type: QueryElemTypeId, + key_type: KeyElemTypeId, + value_type: ValueElemTypeId, + shape_scale: TestShapeAndScale, +): + shapes_scale = generate_shapes_and_scale(shape_scale) + func_name = generate_function_name( + query_type, + key_type, + value_type, + shapes_scale, + ) + + query_shape, key_shape, value_shape, result_shape = get_tensor_shapes(shapes_scale) + query_tensor_type = ( + f"tensor<{query_shape[0]}x{query_shape[1]}x{query_shape[2]}x{query_type.value}>" + ) + key_tensor_type = ( + f"tensor<{key_shape[0]}x{key_shape[1]}x{key_shape[2]}x{key_type.value}>" + ) + value_tensor_type = ( + f"tensor<{value_shape[0]}x{value_shape[1]}x{value_shape[2]}x{value_type.value}>" + ) + result_tensor_type = f"tensor<{result_shape[0]}x{result_shape[1]}x{result_shape[2]}x{value_type.value}>" + F32 = "f32" + F16 = "f16" + op_name = "iree_linalg_ext.attention" + + # Compilation info is optional; prints empty string by default. + func_definition = "" + + signature = f"({query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {result_tensor_type}) -> {result_tensor_type}" + import_declaration = f"func.func private @module.{func_name}(%query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %scale: {F32}) -> !hal.buffer_view" + func_definition = func_definition + ( + f"func.func @{func_name}(%query: {query_tensor_type}, %key: {key_tensor_type}, %value: {value_tensor_type}, %scale: {F32}) -> {result_tensor_type} {{\n" + f" %result0 = tensor.empty(): {result_tensor_type}\n" + f" %scale_f16 = arith.truncf %scale : {F32} to {F16} \n" + f" %result1 = {op_name} {{\n" + f" indexing_maps = [affine_map<(batch, m, n, k1, k2) -> (batch, m, k1)>,\n" + f" affine_map<(batch, m, n, k1, k2) -> (batch, k2, k1)>,\n" + f" affine_map<(batch, m, n, k1, k2) -> (batch, k2, n)>,\n" + f" affine_map<(batch, m, n, k1, k2) -> ()>,\n" + f" affine_map<(batch, m, n, k1, k2) -> (batch, m, n)>]\n}}" + f" ins(%query, %key, %value, %scale_f16: {query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {F16})\n" + f" outs(%result0: {result_tensor_type}) {{\n" + f" ^bb0(%score: f32): \n" + f" iree_linalg_ext.yield %score : f32\n" + f" }} -> {result_tensor_type}\n" + f" return %result1: {result_tensor_type}\n" + f"}}\n" + ) + return MLIRFunction( + name=func_name, + signature=signature, + import_declaration=import_declaration, + definition=func_definition, + ) + + +# Represents a call to a generated test function. +@dataclasses.dataclass +class TestCall: + function: MLIRFunction + op: str + + +# Enumerates ways to initialize tensor buffer contents. +@enum.unique +class TensorGenerator(enum.Enum): + ZERO = "zero" # Fill with zeros + RANDOM = "random" # Fill with (deterministic) pseudorandom values. + + +# Intentionally fixed seed! We want full reproducibility here, both across runs +# and across machines. +# Intentionally not shared with local_pseudorandom_state to limit the ways +# in which shuffling testcases changes which random values are generated. +pseudorandom_generator_seed = 1 + + +def contents_generator_tag(generator: TensorGenerator): + if generator == TensorGenerator.ZERO: + return "" + elif generator == TensorGenerator.RANDOM: + global pseudorandom_generator_seed + pseudorandom_generator_seed = pseudorandom_generator_seed + 1 + return f"!tag:iree:fully_specified_pseudorandom {pseudorandom_generator_seed}" + else: + raise ValueError(generator) + + +# Generate a 3d tensor function argument of the given size as `%name`. +def generate_random_3d_tensor( + name: str, + tensor_shape: list, + element_type: typing.Union[QueryElemTypeId, ResultElemTypeId], +): + global pseudorandom_generator_seed + pseudorandom_generator_seed = pseudorandom_generator_seed + 1 + return ( + f" %{name}_dim0 = arith.constant {tensor_shape[0]} : i64\n" + f" %{name}_dim1 = arith.constant {tensor_shape[1]} : i64\n" + f" %{name}_dim2 = arith.constant {tensor_shape[2]} : i64\n" + f" %{name}_element_type = hal.element_type<{element_type.value}> : i32\n" + f" %{name}_seed = arith.constant {pseudorandom_generator_seed} : i32\n" + f" %{name} = call @attention_test.generate_random_tensor(%device, %{name}_dim0, %{name}_dim1, %{name}_dim2, %{name}_element_type, %{name}_seed) : (!hal.device, i64, i64, i64, i32, i32) -> !hal.buffer_view\n" + ) + + +call_id = 0 + + +def generate_call( + function: MLIRFunction, + query_type: QueryElemTypeId, + key_type: KeyElemTypeId, + value_type: ValueElemTypeId, + shapes_scale: TestShapeAndScale, +): + global call_id + func_name = f"{function.name}_{shapes_scale.batch}_{shapes_scale.m}_{shapes_scale.k1}_{shapes_scale.k2}_{shapes_scale.n}_{shapes_scale.k1}_{shapes_scale.scale}" + func_name = f"{func_name}_{call_id}" + call_id = call_id + 1 + + description = f"Attention shape (BATCHxMxK1xK2xN): {shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}x{shapes_scale.k2}x{shapes_scale.k1}x{shapes_scale.n}" + op = ( + f"func.func @{func_name}() attributes {{\n" + f' iree.reflection = {{description = "{description}"}}\n' + "} {\n" + " %device_index = arith.constant 0 : index\n" + " %device = hal.devices.get %device_index : !hal.device\n" + ) + + query_shape, key_shape, value_shape, result_shape = get_tensor_shapes( + shapes_scale, + ) + + op = op + generate_random_3d_tensor("query", query_shape, query_type) + op = op + generate_random_3d_tensor("key", key_shape, key_type) + op = op + generate_random_3d_tensor("value", value_shape, value_type) + + global pseudorandom_generator_seed + pseudorandom_generator_seed = pseudorandom_generator_seed - 1 + op = op + ( + f" %scale = arith.constant {shapes_scale.scale} : f32\n" + f" %result = call @module.{function.name}(%query, %key, %value, %scale) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, f32) -> !hal.buffer_view\n" + ) + + op = op + ( + f" %batch = arith.constant {shapes_scale.batch} : i64 \n" + f" %m = arith.constant {shapes_scale.m} : i64 \n" + f" %k1 = arith.constant {shapes_scale.k1} : i64 \n" + f" %k2 = arith.constant {shapes_scale.k2} : i64 \n" + f" %n = arith.constant {shapes_scale.n} : i64 \n" + f" %queryTensor = hal.tensor.import %query : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf16> \n" + f" %keyTensor = hal.tensor.import %key : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf16> \n" + f" %valueTensor = hal.tensor.import %value : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf16> \n" + f" %resultTensor = hal.tensor.import %result : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf16> \n" + f" %queryExt = arith.extf %queryTensor : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf32> \n" + f" %keyExt = arith.extf %keyTensor : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf32> \n" + f" %valueExt = arith.extf %valueTensor : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf32> \n" + f" %resultExt = arith.extf %resultTensor : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf32> \n" + f" %queryExtBufferView = hal.tensor.export %queryExt : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf32> -> !hal.buffer_view \n" + f" %keyExtBufferView = hal.tensor.export %keyExt : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf32> -> !hal.buffer_view \n" + f" %valueExtBufferView = hal.tensor.export %valueExt : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf32> -> !hal.buffer_view \n" + f" %resultExtBufferView = hal.tensor.export %resultExt : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf32> -> !hal.buffer_view \n" + f" call @attention_test.check_attention_results(%device, %batch, %m, %k1, %k2, %n, %queryExtBufferView, %keyExtBufferView, %valueExtBufferView, %resultExtBufferView) : (!hal.device, i64, i64, i64, i64, i64, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> ()\n" + ) + + op = op + " return\n" + op = op + "}\n" + + return TestCall(function=function, op=op) + + +# Generates all output files' contents as strings. +def generate( + query_type: QueryElemTypeId, + key_type: KeyElemTypeId, + value_type: ValueElemTypeId, + shapes_id: ShapesId, +): + functions = {} + calls = [] + + for shape in get_test_shapes(shapes_id): + function = generate_function( + query_type, + key_type, + value_type, + shape, + ) + if function.name not in functions: + functions[function.name] = function + calls.append( + generate_call( + function, + query_type, + key_type, + value_type, + shape, + ) + ) + + return (functions, calls) + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="Generator of e2e Attention tests") + parser.add_argument( + "--output_attention_mlir", + type=str, + help="Path of output .mlir file containing the generated Attention functions", + required=True, + ) + parser.add_argument( + "--output_calls_mlir", + type=str, + help="Path of output .mlir file containing the calls", + required=True, + ) + parser.add_argument( + "--query_type", + type=str, + choices=["f16"], + help="Numeric type of query tensors ", + required=True, + ) + parser.add_argument( + "--key_type", + type=str, + choices=["f16"], + help="Numeric type of key tensors ", + required=True, + ) + parser.add_argument( + "--value_type", + type=str, + choices=["f16"], + help="Numeric type of value tensors ", + required=True, + ) + parser.add_argument( + "--shapes_scale", + type=str, + choices=[s.value for s in ShapesId], + help="Collection of tensor shapes to test", + required=True, + ) + parser.add_argument( + "--requirements", + type=str, + help="Target requirements for this module. Comma-separated. As in -iree-llvmcpu-target-cpu-features. If the target device does not meet all of the requirements, the test will be skipped.", + required=False, + ) + return parser.parse_args() + + +def write_code_file(functions, filename): + with open(filename, "w") as file: + for function in functions.values(): + file.write(function.definition + "\n") + + +def write_calls_file(functions, calls, filename, requirements): + # Module-level reflection information used to control the test tool. + reflection = "" + if requirements: + reflection = ( + "iree.reflection = {" + 'target_features = "' + + ",".join([req.lstrip("+") for req in requirements.split(",")]) + + '"' + "}" + ) + module_definition = ( + f"builtin.module @calls attributes {{\n" f" {reflection}\n" f"}} {{\n\n" + ) + + # Declare the custom module that generates arguments. + module_definition = module_definition + ( + "func.func private @attention_test.generate_random_tensor(%device: !hal.device, %dim0: i64, %dim1: i64, %dim2: i64, %element_type: i32, %seed: i32) -> !hal.buffer_view\n" + "func.func private @attention_test.check_attention_results(%device: !hal.device, %batch: i64, %m: i64, %k1: i64, %k2: i64, %n: i64, %query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %result: !hal.buffer_view)\n" + "\n" + ) + + # Declare the functions that will be called. + for function in functions.values(): + module_definition = module_definition + function.import_declaration + "\n" + module_definition = module_definition + "\n" + + # Emit the test cases for each call. + for call in calls: + module_definition = module_definition + call.op + "\n" + + module_definition = module_definition + "\n}\n" + + with open(filename, "w") as file: + file.write(module_definition) + + +def main(args): + query_type = QueryElemTypeId(args.query_type) + key_type = KeyElemTypeId(args.key_type) + value_type = ValueElemTypeId(args.value_type) + shapes_id = ShapesId(args.shapes_scale) + + (functions, calls) = generate( + query_type, + key_type, + value_type, + shapes_id, + ) + + write_code_file(functions, args.output_attention_mlir) + write_calls_file( + functions, + calls, + args.output_calls_mlir, + args.requirements, + ) + + +if __name__ == "__main__": + main(parse_arguments()) diff --git a/linalg_ops/attention/generate_test_mlir_files.sh b/linalg_ops/attention/generate_test_mlir_files.sh new file mode 100755 index 0000000..5d71878 --- /dev/null +++ b/linalg_ops/attention/generate_test_mlir_files.sh @@ -0,0 +1,68 @@ +#!/bin/bash + +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# This script runs generate_e2e_conv2d_tests for all argument combinations that +# we are interested in testing. +# +# The output is a 'generated' folder with contents like this: +# linalg_ops/ +# convolution/ +# generated/ +# f16_f16_f16_f16/ +# attention_f16_f16_f16_f16_large_calls.mlir +# attention_f16_f16_f16_f16_large.mlir +# attention_f16_f16_f16_f16_medium_calls.mlir +# attention_f16_f16_f16_f16_medium.mlir +# attention_f16_f16_f16_f16_small_calls.mlir +# attention_f16_f16_f16_f16_small.mlir +# Usage: +# generate_test_mlir_files.sh + +set -euo pipefail + +this_dir="$(cd $(dirname $0) && pwd)" +generated_dir_root="${this_dir}/generated" + +# Reset generated directory. +rm -rf ${generated_dir_root?} +mkdir -p ${generated_dir_root?} + +shapes=( + "small" + "medium" + "large" +) + +# query_type;key_type;value_type; +type_and_layout_combinations=( + "f16;f16;f16;f16" +) + +for type_and_layout_combination in ${type_and_layout_combinations[@]}; do + IFS=";" read -r -a combination <<< "${type_and_layout_combination}" + query_type="${combination[0]}" + key_type="${combination[1]}" + value_type="${combination[2]}" + scale_type="${combination[3]}" + + type_layout_name="${query_type}_${key_type}_${value_type}_${scale_type}" + type_combination_dir="${generated_dir_root}/${type_layout_name}" + mkdir -p ${type_combination_dir} + + for shape in ${shapes[@]}; do + echo "Generating attention test files for ${type_layout_name}_${shape}" + name="attention_${type_layout_name}_${shape}" + python ${this_dir}/generate_e2e_attention_tests.py \ + --output_attention_mlir=${type_combination_dir}/${name}.mlir \ + --output_calls_mlir=${type_combination_dir}/${name}_calls.mlir \ + --query_type=${query_type} \ + --key_type=${key_type} \ + --value_type=${value_type} \ + --shapes=${shape} + done +done diff --git a/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_large.mlir b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_large.mlir new file mode 100644 index 0000000..e416b1a --- /dev/null +++ b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_large.mlir @@ -0,0 +1,17 @@ +func.func @attention_2_1024_128_128_64_dtype_f16_f16_f16_f16(%query: tensor<2x1024x128xf16>, %key: tensor<2x128x128xf16>, %value: tensor<2x128x64xf16>, %scale: f32) -> tensor<2x1024x64xf16> { + %result0 = tensor.empty(): tensor<2x1024x64xf16> + %scale_f16 = arith.truncf %scale : f32 to f16 + %result1 = iree_linalg_ext.attention { + indexing_maps = [affine_map<(batch, m, n, k1, k2) -> (batch, m, k1)>, + affine_map<(batch, m, n, k1, k2) -> (batch, k2, k1)>, + affine_map<(batch, m, n, k1, k2) -> (batch, k2, n)>, + affine_map<(batch, m, n, k1, k2) -> ()>, + affine_map<(batch, m, n, k1, k2) -> (batch, m, n)>] +} ins(%query, %key, %value, %scale_f16: tensor<2x1024x128xf16>, tensor<2x128x128xf16>, tensor<2x128x64xf16>, f16) + outs(%result0: tensor<2x1024x64xf16>) { + ^bb0(%score: f32): + iree_linalg_ext.yield %score : f32 + } -> tensor<2x1024x64xf16> + return %result1: tensor<2x1024x64xf16> +} + diff --git a/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_large_calls.mlir b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_large_calls.mlir new file mode 100644 index 0000000..544b0fa --- /dev/null +++ b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_large_calls.mlir @@ -0,0 +1,57 @@ +builtin.module @calls attributes { + +} { + +func.func private @attention_test.generate_random_tensor(%device: !hal.device, %dim0: i64, %dim1: i64, %dim2: i64, %element_type: i32, %seed: i32) -> !hal.buffer_view +func.func private @attention_test.check_attention_results(%device: !hal.device, %batch: i64, %m: i64, %k1: i64, %k2: i64, %n: i64, %query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %result: !hal.buffer_view) + +func.func private @module.attention_2_1024_128_128_64_dtype_f16_f16_f16_f16(%query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %scale: f32) -> !hal.buffer_view + +func.func @attention_2_1024_128_128_64_dtype_f16_f16_f16_f16_2_1024_128_128_64_128_1.0_0() attributes { + iree.reflection = {description = "Attention shape (BATCHxMxK1xK2xN): 2x1024x128x128x128x64"} +} { + %device_index = arith.constant 0 : index + %device = hal.devices.get %device_index : !hal.device + %query_dim0 = arith.constant 2 : i64 + %query_dim1 = arith.constant 1024 : i64 + %query_dim2 = arith.constant 128 : i64 + %query_element_type = hal.element_type : i32 + %query_seed = arith.constant 2 : i32 + %query = call @attention_test.generate_random_tensor(%device, %query_dim0, %query_dim1, %query_dim2, %query_element_type, %query_seed) : (!hal.device, i64, i64, i64, i32, i32) -> !hal.buffer_view + %key_dim0 = arith.constant 2 : i64 + %key_dim1 = arith.constant 128 : i64 + %key_dim2 = arith.constant 128 : i64 + %key_element_type = hal.element_type : i32 + %key_seed = arith.constant 3 : i32 + %key = call @attention_test.generate_random_tensor(%device, %key_dim0, %key_dim1, %key_dim2, %key_element_type, %key_seed) : (!hal.device, i64, i64, i64, i32, i32) -> !hal.buffer_view + %value_dim0 = arith.constant 2 : i64 + %value_dim1 = arith.constant 128 : i64 + %value_dim2 = arith.constant 64 : i64 + %value_element_type = hal.element_type : i32 + %value_seed = arith.constant 4 : i32 + %value = call @attention_test.generate_random_tensor(%device, %value_dim0, %value_dim1, %value_dim2, %value_element_type, %value_seed) : (!hal.device, i64, i64, i64, i32, i32) -> !hal.buffer_view + %scale = arith.constant 1.0 : f32 + %result = call @module.attention_2_1024_128_128_64_dtype_f16_f16_f16_f16(%query, %key, %value, %scale) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, f32) -> !hal.buffer_view + %batch = arith.constant 2 : i64 + %m = arith.constant 1024 : i64 + %k1 = arith.constant 128 : i64 + %k2 = arith.constant 128 : i64 + %n = arith.constant 64 : i64 + %queryTensor = hal.tensor.import %query : !hal.buffer_view -> tensor<2x1024x128xf16> + %keyTensor = hal.tensor.import %key : !hal.buffer_view -> tensor<2x128x128xf16> + %valueTensor = hal.tensor.import %value : !hal.buffer_view -> tensor<2x128x64xf16> + %resultTensor = hal.tensor.import %result : !hal.buffer_view -> tensor<2x1024x64xf16> + %queryExt = arith.extf %queryTensor : tensor<2x1024x128xf16> to tensor<2x1024x128xf32> + %keyExt = arith.extf %keyTensor : tensor<2x128x128xf16> to tensor<2x128x128xf32> + %valueExt = arith.extf %valueTensor : tensor<2x128x64xf16> to tensor<2x128x64xf32> + %resultExt = arith.extf %resultTensor : tensor<2x1024x64xf16> to tensor<2x1024x64xf32> + %queryExtBufferView = hal.tensor.export %queryExt : tensor<2x1024x128xf32> -> !hal.buffer_view + %keyExtBufferView = hal.tensor.export %keyExt : tensor<2x128x128xf32> -> !hal.buffer_view + %valueExtBufferView = hal.tensor.export %valueExt : tensor<2x128x64xf32> -> !hal.buffer_view + %resultExtBufferView = hal.tensor.export %resultExt : tensor<2x1024x64xf32> -> !hal.buffer_view + call @attention_test.check_attention_results(%device, %batch, %m, %k1, %k2, %n, %queryExtBufferView, %keyExtBufferView, %valueExtBufferView, %resultExtBufferView) : (!hal.device, i64, i64, i64, i64, i64, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> () + return +} + + +} diff --git a/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_medium.mlir b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_medium.mlir new file mode 100644 index 0000000..860595e --- /dev/null +++ b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_medium.mlir @@ -0,0 +1,17 @@ +func.func @attention_2_512_128_64_32_dtype_f16_f16_f16_f16(%query: tensor<2x512x128xf16>, %key: tensor<2x64x128xf16>, %value: tensor<2x64x32xf16>, %scale: f32) -> tensor<2x512x32xf16> { + %result0 = tensor.empty(): tensor<2x512x32xf16> + %scale_f16 = arith.truncf %scale : f32 to f16 + %result1 = iree_linalg_ext.attention { + indexing_maps = [affine_map<(batch, m, n, k1, k2) -> (batch, m, k1)>, + affine_map<(batch, m, n, k1, k2) -> (batch, k2, k1)>, + affine_map<(batch, m, n, k1, k2) -> (batch, k2, n)>, + affine_map<(batch, m, n, k1, k2) -> ()>, + affine_map<(batch, m, n, k1, k2) -> (batch, m, n)>] +} ins(%query, %key, %value, %scale_f16: tensor<2x512x128xf16>, tensor<2x64x128xf16>, tensor<2x64x32xf16>, f16) + outs(%result0: tensor<2x512x32xf16>) { + ^bb0(%score: f32): + iree_linalg_ext.yield %score : f32 + } -> tensor<2x512x32xf16> + return %result1: tensor<2x512x32xf16> +} + diff --git a/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_medium_calls.mlir b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_medium_calls.mlir new file mode 100644 index 0000000..6c77095 --- /dev/null +++ b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_medium_calls.mlir @@ -0,0 +1,57 @@ +builtin.module @calls attributes { + +} { + +func.func private @attention_test.generate_random_tensor(%device: !hal.device, %dim0: i64, %dim1: i64, %dim2: i64, %element_type: i32, %seed: i32) -> !hal.buffer_view +func.func private @attention_test.check_attention_results(%device: !hal.device, %batch: i64, %m: i64, %k1: i64, %k2: i64, %n: i64, %query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %result: !hal.buffer_view) + +func.func private @module.attention_2_512_128_64_32_dtype_f16_f16_f16_f16(%query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %scale: f32) -> !hal.buffer_view + +func.func @attention_2_512_128_64_32_dtype_f16_f16_f16_f16_2_512_128_64_32_128_1.0_0() attributes { + iree.reflection = {description = "Attention shape (BATCHxMxK1xK2xN): 2x512x128x64x128x32"} +} { + %device_index = arith.constant 0 : index + %device = hal.devices.get %device_index : !hal.device + %query_dim0 = arith.constant 2 : i64 + %query_dim1 = arith.constant 512 : i64 + %query_dim2 = arith.constant 128 : i64 + %query_element_type = hal.element_type : i32 + %query_seed = arith.constant 2 : i32 + %query = call @attention_test.generate_random_tensor(%device, %query_dim0, %query_dim1, %query_dim2, %query_element_type, %query_seed) : (!hal.device, i64, i64, i64, i32, i32) -> !hal.buffer_view + %key_dim0 = arith.constant 2 : i64 + %key_dim1 = arith.constant 64 : i64 + %key_dim2 = arith.constant 128 : i64 + %key_element_type = hal.element_type : i32 + %key_seed = arith.constant 3 : i32 + %key = call @attention_test.generate_random_tensor(%device, %key_dim0, %key_dim1, %key_dim2, %key_element_type, %key_seed) : (!hal.device, i64, i64, i64, i32, i32) -> !hal.buffer_view + %value_dim0 = arith.constant 2 : i64 + %value_dim1 = arith.constant 64 : i64 + %value_dim2 = arith.constant 32 : i64 + %value_element_type = hal.element_type : i32 + %value_seed = arith.constant 4 : i32 + %value = call @attention_test.generate_random_tensor(%device, %value_dim0, %value_dim1, %value_dim2, %value_element_type, %value_seed) : (!hal.device, i64, i64, i64, i32, i32) -> !hal.buffer_view + %scale = arith.constant 1.0 : f32 + %result = call @module.attention_2_512_128_64_32_dtype_f16_f16_f16_f16(%query, %key, %value, %scale) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, f32) -> !hal.buffer_view + %batch = arith.constant 2 : i64 + %m = arith.constant 512 : i64 + %k1 = arith.constant 128 : i64 + %k2 = arith.constant 64 : i64 + %n = arith.constant 32 : i64 + %queryTensor = hal.tensor.import %query : !hal.buffer_view -> tensor<2x512x128xf16> + %keyTensor = hal.tensor.import %key : !hal.buffer_view -> tensor<2x64x128xf16> + %valueTensor = hal.tensor.import %value : !hal.buffer_view -> tensor<2x64x32xf16> + %resultTensor = hal.tensor.import %result : !hal.buffer_view -> tensor<2x512x32xf16> + %queryExt = arith.extf %queryTensor : tensor<2x512x128xf16> to tensor<2x512x128xf32> + %keyExt = arith.extf %keyTensor : tensor<2x64x128xf16> to tensor<2x64x128xf32> + %valueExt = arith.extf %valueTensor : tensor<2x64x32xf16> to tensor<2x64x32xf32> + %resultExt = arith.extf %resultTensor : tensor<2x512x32xf16> to tensor<2x512x32xf32> + %queryExtBufferView = hal.tensor.export %queryExt : tensor<2x512x128xf32> -> !hal.buffer_view + %keyExtBufferView = hal.tensor.export %keyExt : tensor<2x64x128xf32> -> !hal.buffer_view + %valueExtBufferView = hal.tensor.export %valueExt : tensor<2x64x32xf32> -> !hal.buffer_view + %resultExtBufferView = hal.tensor.export %resultExt : tensor<2x512x32xf32> -> !hal.buffer_view + call @attention_test.check_attention_results(%device, %batch, %m, %k1, %k2, %n, %queryExtBufferView, %keyExtBufferView, %valueExtBufferView, %resultExtBufferView) : (!hal.device, i64, i64, i64, i64, i64, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> () + return +} + + +} diff --git a/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_small.mlir b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_small.mlir new file mode 100644 index 0000000..d0586d6 --- /dev/null +++ b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_small.mlir @@ -0,0 +1,17 @@ +func.func @attention_2_256_64_32_16_dtype_f16_f16_f16_f16(%query: tensor<2x256x64xf16>, %key: tensor<2x32x64xf16>, %value: tensor<2x32x16xf16>, %scale: f32) -> tensor<2x256x16xf16> { + %result0 = tensor.empty(): tensor<2x256x16xf16> + %scale_f16 = arith.truncf %scale : f32 to f16 + %result1 = iree_linalg_ext.attention { + indexing_maps = [affine_map<(batch, m, n, k1, k2) -> (batch, m, k1)>, + affine_map<(batch, m, n, k1, k2) -> (batch, k2, k1)>, + affine_map<(batch, m, n, k1, k2) -> (batch, k2, n)>, + affine_map<(batch, m, n, k1, k2) -> ()>, + affine_map<(batch, m, n, k1, k2) -> (batch, m, n)>] +} ins(%query, %key, %value, %scale_f16: tensor<2x256x64xf16>, tensor<2x32x64xf16>, tensor<2x32x16xf16>, f16) + outs(%result0: tensor<2x256x16xf16>) { + ^bb0(%score: f32): + iree_linalg_ext.yield %score : f32 + } -> tensor<2x256x16xf16> + return %result1: tensor<2x256x16xf16> +} + diff --git a/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_small_calls.mlir b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_small_calls.mlir new file mode 100644 index 0000000..2aed007 --- /dev/null +++ b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_small_calls.mlir @@ -0,0 +1,57 @@ +builtin.module @calls attributes { + +} { + +func.func private @attention_test.generate_random_tensor(%device: !hal.device, %dim0: i64, %dim1: i64, %dim2: i64, %element_type: i32, %seed: i32) -> !hal.buffer_view +func.func private @attention_test.check_attention_results(%device: !hal.device, %batch: i64, %m: i64, %k1: i64, %k2: i64, %n: i64, %query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %result: !hal.buffer_view) + +func.func private @module.attention_2_256_64_32_16_dtype_f16_f16_f16_f16(%query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %scale: f32) -> !hal.buffer_view + +func.func @attention_2_256_64_32_16_dtype_f16_f16_f16_f16_2_256_64_32_16_64_1.0_0() attributes { + iree.reflection = {description = "Attention shape (BATCHxMxK1xK2xN): 2x256x64x32x64x16"} +} { + %device_index = arith.constant 0 : index + %device = hal.devices.get %device_index : !hal.device + %query_dim0 = arith.constant 2 : i64 + %query_dim1 = arith.constant 256 : i64 + %query_dim2 = arith.constant 64 : i64 + %query_element_type = hal.element_type : i32 + %query_seed = arith.constant 2 : i32 + %query = call @attention_test.generate_random_tensor(%device, %query_dim0, %query_dim1, %query_dim2, %query_element_type, %query_seed) : (!hal.device, i64, i64, i64, i32, i32) -> !hal.buffer_view + %key_dim0 = arith.constant 2 : i64 + %key_dim1 = arith.constant 32 : i64 + %key_dim2 = arith.constant 64 : i64 + %key_element_type = hal.element_type : i32 + %key_seed = arith.constant 3 : i32 + %key = call @attention_test.generate_random_tensor(%device, %key_dim0, %key_dim1, %key_dim2, %key_element_type, %key_seed) : (!hal.device, i64, i64, i64, i32, i32) -> !hal.buffer_view + %value_dim0 = arith.constant 2 : i64 + %value_dim1 = arith.constant 32 : i64 + %value_dim2 = arith.constant 16 : i64 + %value_element_type = hal.element_type : i32 + %value_seed = arith.constant 4 : i32 + %value = call @attention_test.generate_random_tensor(%device, %value_dim0, %value_dim1, %value_dim2, %value_element_type, %value_seed) : (!hal.device, i64, i64, i64, i32, i32) -> !hal.buffer_view + %scale = arith.constant 1.0 : f32 + %result = call @module.attention_2_256_64_32_16_dtype_f16_f16_f16_f16(%query, %key, %value, %scale) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, f32) -> !hal.buffer_view + %batch = arith.constant 2 : i64 + %m = arith.constant 256 : i64 + %k1 = arith.constant 64 : i64 + %k2 = arith.constant 32 : i64 + %n = arith.constant 16 : i64 + %queryTensor = hal.tensor.import %query : !hal.buffer_view -> tensor<2x256x64xf16> + %keyTensor = hal.tensor.import %key : !hal.buffer_view -> tensor<2x32x64xf16> + %valueTensor = hal.tensor.import %value : !hal.buffer_view -> tensor<2x32x16xf16> + %resultTensor = hal.tensor.import %result : !hal.buffer_view -> tensor<2x256x16xf16> + %queryExt = arith.extf %queryTensor : tensor<2x256x64xf16> to tensor<2x256x64xf32> + %keyExt = arith.extf %keyTensor : tensor<2x32x64xf16> to tensor<2x32x64xf32> + %valueExt = arith.extf %valueTensor : tensor<2x32x16xf16> to tensor<2x32x16xf32> + %resultExt = arith.extf %resultTensor : tensor<2x256x16xf16> to tensor<2x256x16xf32> + %queryExtBufferView = hal.tensor.export %queryExt : tensor<2x256x64xf32> -> !hal.buffer_view + %keyExtBufferView = hal.tensor.export %keyExt : tensor<2x32x64xf32> -> !hal.buffer_view + %valueExtBufferView = hal.tensor.export %valueExt : tensor<2x32x16xf32> -> !hal.buffer_view + %resultExtBufferView = hal.tensor.export %resultExt : tensor<2x256x16xf32> -> !hal.buffer_view + call @attention_test.check_attention_results(%device, %batch, %m, %k1, %k2, %n, %queryExtBufferView, %keyExtBufferView, %valueExtBufferView, %resultExtBufferView) : (!hal.device, i64, i64, i64, i64, i64, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> () + return +} + + +} diff --git a/linalg_ops/iree-e2e-attention-test.cc b/linalg_ops/iree-e2e-attention-test.cc new file mode 100644 index 0000000..200a5b4 --- /dev/null +++ b/linalg_ops/iree-e2e-attention-test.cc @@ -0,0 +1,491 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include +#include +#include +#include + +#include "iree/base/api.h" +#include "iree/base/internal/cpu.h" +#include "iree/base/internal/flags.h" +#include "iree/base/internal/math.h" +#include "iree/base/internal/path.h" +#include "iree/hal/api.h" +#include "iree/modules/hal/module.h" +#include "iree/tooling/context_util.h" +#include "iree/tooling/device_util.h" +#include "iree/vm/api.h" +#include "iree/vm/native_module_cc.h" +#include "test_utils.h" + +//===----------------------------------------------------------------------===// +// Reference Attention +//===----------------------------------------------------------------------===// + +// Helper for reference_attention. +// Function to allocate and initialize tensors +float* allocate_tensor(int dim1, int dim2, int dim3) { + const int size = dim1 * dim2 * dim3; + float* tensor = (float*)malloc(size * sizeof(float)); + for (int i = 0; i < size; ++i) { + tensor[i] = 0.0f; + } + return tensor; +} + +// Function to free allocated tensors +void free_tensor(float* tensor) { + if (tensor != nullptr) free(tensor); +} + +// Function to calculate 1D index for a 3D array +int index_3d(int i, int j, int k, int dim2, int dim3) { + return i * dim2 * dim3 + j * dim3 + k; +} + +static void reference_attention_f32_f32_f32_f32( + iree_hal_dim_t M, iree_hal_dim_t K1, iree_hal_dim_t K2, iree_hal_dim_t N, + iree_hal_dim_t B, const float* query_data, const float* key_data, + const float* value_data, float* result_data, iree_hal_dim_t b, + float* Attention) { + // Compute Q * K^T + for (int m = 0; m < M; ++m) { + for (int k2 = 0; k2 < K2; ++k2) { + float sum = 0.0; + for (int k1 = 0; k1 < K1; ++k1) { + int q_idx = index_3d(b, m, k1, M, K1); + int k_idx = index_3d(b, k2, k1, K2, K1); + + sum += query_data[q_idx] * key_data[k_idx]; + } + int att_idx = index_3d(0, m, k2, M, K2); + Attention[att_idx] = sum / sqrt(K1); // Scale by sqrt(K1) + } + } + + // Compute softmax on Attention + for (int m = 0; m < M; ++m) { + // Find the maximum value for the current sequence + float max_val = -FLT_MAX; + for (int k2 = 0; k2 < K2; ++k2) { + int att_idx = index_3d(0, m, k2, M, K2); + max_val = iree_max(max_val, Attention[att_idx]); + } + + // Calculate the softmax denominator + float sum = 0.0f; + for (int k2 = 0; k2 < K2; ++k2) { + int att_idx = index_3d(0, m, k2, M, K2); + sum += exp(Attention[att_idx] - max_val); + } + + // Apply softmax + for (int k2 = 0; k2 < K2; ++k2) { + int att_idx = index_3d(0, m, k2, M, K2); + Attention[att_idx] = exp(Attention[att_idx]) / sum; + } + } + + // Compute Attention * V + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + float sum = 0.0; + for (int k2 = 0; k2 < K2; ++k2) { + int att_idx = index_3d(0, m, k2, M, K2); + int v_idx = index_3d(b, k2, n, K2, N); + sum += Attention[att_idx] * value_data[v_idx]; + } + int o_idx = index_3d(b, m, n, M, N); + result_data[o_idx] = sum; + } + } +} + +static iree_status_t reference_attention_element( + iree_hal_dim_t M, iree_hal_dim_t K1, iree_hal_dim_t K2, iree_hal_dim_t N, + iree_hal_dim_t B, iree_hal_element_type_t query_elem_type, + iree_hal_element_type_t key_elem_type, + iree_hal_element_type_t value_elem_type, void* query_data, void* key_data, + void* value_data, void* actual_data, void* result_data, iree_hal_dim_t b, + float* Attention) { + if (query_elem_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 && + key_elem_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 && + value_elem_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) { + reference_attention_f32_f32_f32_f32( + M, K1, K2, N, B, (const float*)query_data, (const float*)key_data, + (const float*)value_data, (float*)result_data, b, Attention); + + } else { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "unhandled combination of element types in attention"); + } + return iree_ok_status(); +} + +// Reference attention implementation, used to compare attention results +// against. +static iree_status_t reference_attention( + iree_hal_dim_t B, iree_hal_dim_t M, iree_hal_dim_t K1, iree_hal_dim_t K2, + iree_hal_dim_t N, iree_hal_element_type_t query_elem_type, + iree_hal_element_type_t key_elem_type, + iree_hal_element_type_t value_elem_type, iree_byte_span_t query_contents, + iree_byte_span_t key_contents, iree_byte_span_t value_contents, + iree_byte_span_t actual_contents, iree_byte_span_t result_contents, + int compute_every) { + IREE_TRACE_ZONE_BEGIN(z0); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, B); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, M); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, K1); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, K2); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, N); + + iree_host_size_t count = 0; + float* Attention = allocate_tensor(1, M, K2); + for (iree_hal_dim_t b = 0; b < B; ++b) { + if (++count < compute_every) continue; + count = 0; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + reference_attention_element( + M, K1, K2, N, B, query_elem_type, key_elem_type, value_elem_type, + query_contents.data, key_contents.data, value_contents.data, + actual_contents.data, result_contents.data, b, Attention)); + } + free_tensor(Attention); + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} +//===----------------------------------------------------------------------===// +// Attention comparison/logging +//===----------------------------------------------------------------------===// + +typedef struct { + iree_allocator_t host_allocator; + iree_hal_dim_t b; + iree_hal_dim_t m; + iree_hal_dim_t k1; + iree_hal_dim_t k2; + iree_hal_dim_t n; + iree_hal_element_type_t query_elem_type; + iree_hal_element_type_t key_elem_type; + iree_hal_element_type_t value_elem_type; + iree_hal_element_type_t result_elem_type; + iree_byte_span_t query_contents; + iree_byte_span_t key_contents; + iree_byte_span_t value_contents; + iree_byte_span_t actual_contents; + iree_byte_span_t expected_contents; +} attention_results_t; + +static void attention_results_deinitialize(attention_results_t* results); + +static iree_status_t attention_results_initialize( + iree_hal_device_t* device, iree_hal_dim_t b_size, iree_hal_dim_t m_size, + iree_hal_dim_t k1_size, iree_hal_dim_t k2_size, iree_hal_dim_t n_size, + iree_hal_buffer_view_t* query, iree_hal_buffer_view_t* key, + iree_hal_buffer_view_t* value, iree_hal_buffer_view_t* result, + iree_allocator_t host_allocator, attention_results_t* out_results) { + IREE_TRACE_ZONE_BEGIN(z0); + + memset(out_results, 0, sizeof(*out_results)); + out_results->host_allocator = host_allocator; + + out_results->b = b_size; + out_results->m = m_size; + out_results->k1 = k1_size; + out_results->k2 = k2_size; + out_results->n = n_size; + + out_results->query_elem_type = iree_hal_buffer_view_element_type(query); + out_results->key_elem_type = iree_hal_buffer_view_element_type(key); + out_results->value_elem_type = iree_hal_buffer_view_element_type(value); + out_results->result_elem_type = iree_hal_buffer_view_element_type(result); + + iree_hal_buffer_t* query_buffer = iree_hal_buffer_view_buffer(query); + iree_hal_buffer_t* key_buffer = iree_hal_buffer_view_buffer(key); + iree_hal_buffer_t* value_buffer = iree_hal_buffer_view_buffer(value); + iree_hal_buffer_t* result_buffer = iree_hal_buffer_view_buffer(result); + + iree_status_t status = iree_ok_status(); + + if (iree_status_is_ok(status)) { + out_results->query_contents.data_length = + iree_hal_buffer_byte_length(query_buffer); + status = iree_allocator_malloc(host_allocator, + out_results->query_contents.data_length, + (void**)&out_results->query_contents.data); + } + if (iree_status_is_ok(status)) { + status = iree_hal_device_transfer_d2h( + device, query_buffer, 0, out_results->query_contents.data, + out_results->query_contents.data_length, + IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()); + } + if (iree_status_is_ok(status)) { + out_results->key_contents.data_length = + iree_hal_buffer_byte_length(key_buffer); + status = iree_allocator_malloc(host_allocator, + out_results->key_contents.data_length, + (void**)&out_results->key_contents.data); + } + if (iree_status_is_ok(status)) { + status = iree_hal_device_transfer_d2h( + device, key_buffer, 0, out_results->key_contents.data, + out_results->key_contents.data_length, + IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()); + } + if (iree_status_is_ok(status)) { + out_results->value_contents.data_length = + iree_hal_buffer_byte_length(value_buffer); + status = iree_allocator_malloc(host_allocator, + out_results->value_contents.data_length, + (void**)&out_results->value_contents.data); + } + + if (iree_status_is_ok(status)) { + status = iree_hal_device_transfer_d2h( + device, value_buffer, 0, out_results->value_contents.data, + out_results->value_contents.data_length, + IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()); + } + if (iree_status_is_ok(status)) { + out_results->actual_contents.data_length = + iree_hal_buffer_byte_length(result_buffer); + status = iree_allocator_malloc(host_allocator, + out_results->actual_contents.data_length, + (void**)&out_results->actual_contents.data); + } + if (iree_status_is_ok(status)) { + status = iree_hal_device_transfer_d2h( + device, result_buffer, 0, out_results->actual_contents.data, + out_results->actual_contents.data_length, + IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()); + } + if (iree_status_is_ok(status)) { + out_results->expected_contents.data_length = + iree_hal_buffer_byte_length(result_buffer); + status = iree_allocator_malloc( + host_allocator, out_results->expected_contents.data_length, + (void**)&out_results->expected_contents.data); + } + if (!iree_status_is_ok(status)) { + attention_results_deinitialize(out_results); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void attention_results_deinitialize(attention_results_t* results) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_allocator_free(results->host_allocator, results->query_contents.data); + iree_allocator_free(results->host_allocator, results->key_contents.data); + iree_allocator_free(results->host_allocator, results->value_contents.data); + iree_allocator_free(results->host_allocator, results->actual_contents.data); + iree_allocator_free(results->host_allocator, results->expected_contents.data); + + IREE_TRACE_ZONE_END(z0); +} + +// Helper for check_attention_results: the actual interesting part once we've +// obtained and validated the {b,m,k1,k2,n}_size values. On error, detailed +// logging is written to |file| if it is not NULL. +static iree_status_t check_attention_results_impl( + FILE* file, const attention_results_t* results, int check_every) { + IREE_TRACE_ZONE_BEGIN(z0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, reference_attention(results->b, results->m, results->k1, results->k2, + results->n, results->query_elem_type, + results->key_elem_type, results->value_elem_type, + results->query_contents, results->key_contents, + results->value_contents, results->actual_contents, + results->expected_contents, check_every)); + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +// Given an actual attention's inputs and output (all host-local), uses a +// reference attention implementation on the same inputs to check if the output +// is correct. On error, detailed logging is written to |file| if it is not +// NULL. +static iree_status_t check_attention_results( + FILE* file, const attention_results_t* results) { + IREE_TRACE_ZONE_BEGIN(z0); + // TODO: Increase the check every param to reduce the number of comparisons. + int check_every = 1; + iree_status_t status = + check_attention_results_impl(file, results, check_every); + if (!iree_status_is_ok(status) && check_every > 1) { + // If we got a failure with check_every>1, that didn't log a useful + // numerical summary, as most of the reference matrix entries hadn't been + // computed. Rerun now with check_every=1 to get that numerical logging. + iree_status_ignore(status); + status = check_attention_results_impl(file, results, 1); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +//===----------------------------------------------------------------------===// +// `attention_test` custom module +//===----------------------------------------------------------------------===// +// This uses the C++ wrapper to keep things simple. Though easier to use it's +// got additional overhead/code-size bloat that doesn't matter in a test like +// this. Making a C module builder API that removes the boilerplate there is TBD +// so this file is written in C besides this module so that we can swap it back +// to being pure C in the future. + +namespace iree { + +class AttentionTestModuleState final { + public: + explicit AttentionTestModuleState(iree_allocator_t host_allocator) + : host_allocator_(host_allocator) {} + ~AttentionTestModuleState() = default; + + // Fills the destination span with pseudorandom values of the given + // |element_type|. The given |seed| is passed to the pseudorandom generator. + // The pseudorandom values are reproducible both across runs and across + // machines. + StatusOr> GenerateRandom3dTensor( + const vm::ref device, int64_t dim0, int64_t dim1, + int64_t dim2, iree_hal_element_type_t element_type, int32_t seed) { + iree_hal_dim_t dims[3] = { + (iree_hal_dim_t)dim0, + (iree_hal_dim_t)dim1, + (iree_hal_dim_t)dim2, + }; + iree_hal_buffer_params_t buffer_params = {0}; + buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT; + buffer_params.access = IREE_HAL_MEMORY_ACCESS_ALL; + buffer_params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE; + vm::ref result_view; + struct callback_state_t { + iree_hal_element_type_t element_type; + int32_t seed; + } callback_state = { + element_type, + seed, + }; + IREE_RETURN_IF_ERROR(iree_hal_buffer_view_generate_buffer( + device.get(), iree_hal_device_allocator(device.get()), + IREE_ARRAYSIZE(dims), dims, element_type, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params, + +[](iree_hal_buffer_mapping_t* mapping, void* user_data) { + callback_state_t callback_state = *(callback_state_t*)user_data; + iree_byte_span_t span = mapping->contents; + // Generate "uniform" integer-valued numbers in the range [min, max]. + int32_t min = 0; + int32_t max = 0; + iree_test_utils_get_min_max_for_element_type( + callback_state.element_type, &min, &max); + uint32_t range = (max - min + 1); + iree_host_size_t element_byte_count = + iree_hal_element_dense_byte_count(callback_state.element_type); + uint8_t* data_end = span.data + span.data_length; + uint32_t state = callback_state.seed; + for (uint8_t* data = span.data; data < data_end; + data += element_byte_count) { + int32_t value = + (int32_t)iree_test_utils_pseudorandom_range(&state, range) + + min; + iree_test_utils_write_element(callback_state.element_type, value, + data); + } + return iree_ok_status(); + }, + &callback_state, &result_view)); + return std::move(result_view); + } + + Status CheckAttentionResults( + const vm::ref device, int64_t b, int64_t m, int64_t k1, + int64_t k2, int64_t n, const vm::ref query, + const vm::ref key, + const vm::ref value, + const vm::ref actual_result) { + attention_results_t results = {}; + IREE_RETURN_IF_ERROR(attention_results_initialize( + device.get(), (iree_hal_dim_t)b, (iree_hal_dim_t)m, (iree_hal_dim_t)k1, + (iree_hal_dim_t)k2, (iree_hal_dim_t)n, query.get(), key.get(), + value.get(), actual_result.get(), host_allocator_, &results)); + iree_status_t status = check_attention_results(stderr, &results); + attention_results_deinitialize(&results); + return status; + } + + private: + iree_allocator_t host_allocator_; +}; + +static const vm::NativeFunction + kAttentionTestModuleFunctions[] = { + vm::MakeNativeFunction( + "generate_random_tensor", + &AttentionTestModuleState::GenerateRandom3dTensor), + vm::MakeNativeFunction( + "check_attention_results", + &AttentionTestModuleState::CheckAttentionResults), +}; + +struct AttentionTestModule final + : public vm::NativeModule { + using vm::NativeModule::NativeModule; + StatusOr> CreateState( + iree_allocator_t host_allocator) override { + return std::make_unique(host_allocator); + } + StatusOr> ForkState( + AttentionTestModuleState* parent_state, + iree_allocator_t host_allocator) { + return std::make_unique(host_allocator); + } +}; + +} // namespace iree + +static iree_status_t attention_test_module_create( + iree_vm_instance_t* instance, iree_allocator_t host_allocator, + iree_vm_module_t** out_module) { + IREE_ASSERT_ARGUMENT(out_module); + *out_module = NULL; + auto module = std::make_unique( + "attention_test", /*version=*/0, instance, host_allocator, + iree::span< + const iree::vm::NativeFunction>( + iree::kAttentionTestModuleFunctions)); + *out_module = module.release()->interface(); + return iree_ok_status(); +} + +int main(int argc, char** argv) { + IREE_TRACE_APP_ENTER(); + + iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv); + if (argc != 1) { + fprintf(stderr, "use --module= flags to specify the modules to run\n"); + IREE_TRACE_APP_EXIT(EXIT_FAILURE); + return EXIT_FAILURE; + } + + iree_status_t status = iree_test_utils_load_and_run_e2e_tests( + iree_allocator_system(), attention_test_module_create); + int exit_code = EXIT_SUCCESS; + if (!iree_status_is_ok(status)) { + iree_status_fprint(stderr, status); + bool is_unavailable = iree_status_is_unavailable(status); + iree_status_free(status); + exit_code = is_unavailable ? EXIT_SUCCESS : EXIT_FAILURE; + } + + IREE_TRACE_APP_EXIT(exit_code); + return exit_code; +} From d4b798f0e27902f636180bfead2ef1e7a9cc2e00 Mon Sep 17 00:00:00 2001 From: erman-gurses Date: Fri, 8 Nov 2024 18:01:11 -0600 Subject: [PATCH 2/9] Add commment for the input types Signed-off-by: erman-gurses --- linalg_ops/attention/generate_test_mlir_files.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linalg_ops/attention/generate_test_mlir_files.sh b/linalg_ops/attention/generate_test_mlir_files.sh index 5d71878..ddd4db4 100755 --- a/linalg_ops/attention/generate_test_mlir_files.sh +++ b/linalg_ops/attention/generate_test_mlir_files.sh @@ -38,7 +38,7 @@ shapes=( "large" ) -# query_type;key_type;value_type; +# query_type;key_type;value_type;scale_type type_and_layout_combinations=( "f16;f16;f16;f16" ) From 2075ea1c33a1c539b92f9d37e6c5ba2ef50d970d Mon Sep 17 00:00:00 2001 From: erman-gurses Date: Fri, 8 Nov 2024 22:28:42 -0600 Subject: [PATCH 3/9] Update comments Signed-off-by: erman-gurses --- linalg_ops/attention/generate_test_mlir_files.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linalg_ops/attention/generate_test_mlir_files.sh b/linalg_ops/attention/generate_test_mlir_files.sh index ddd4db4..8fdc9ef 100755 --- a/linalg_ops/attention/generate_test_mlir_files.sh +++ b/linalg_ops/attention/generate_test_mlir_files.sh @@ -6,7 +6,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# This script runs generate_e2e_conv2d_tests for all argument combinations that +# This script runs generate_e2e_attention_tests for all argument combinations that # we are interested in testing. # # The output is a 'generated' folder with contents like this: From d32d1b0eb5ebba457eb83e2a03bc7d8484999752 Mon Sep 17 00:00:00 2001 From: erman-gurses Date: Mon, 11 Nov 2024 17:41:14 -0600 Subject: [PATCH 4/9] Fix comment Signed-off-by: erman-gurses --- linalg_ops/attention/generate_test_mlir_files.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linalg_ops/attention/generate_test_mlir_files.sh b/linalg_ops/attention/generate_test_mlir_files.sh index 8fdc9ef..edaa650 100755 --- a/linalg_ops/attention/generate_test_mlir_files.sh +++ b/linalg_ops/attention/generate_test_mlir_files.sh @@ -11,7 +11,7 @@ # # The output is a 'generated' folder with contents like this: # linalg_ops/ -# convolution/ +# attention/ # generated/ # f16_f16_f16_f16/ # attention_f16_f16_f16_f16_large_calls.mlir From f703f2349ec40ed620c0db2602991cc230483a50 Mon Sep 17 00:00:00 2001 From: erman-gurses Date: Mon, 11 Nov 2024 17:53:08 -0600 Subject: [PATCH 5/9] Remove unnecessary newline from the IR generation Signed-off-by: erman-gurses --- linalg_ops/attention/generate_e2e_attention_tests.py | 2 +- .../f16_f16_f16_f16/attention_f16_f16_f16_f16_large.mlir | 1 - .../f16_f16_f16_f16/attention_f16_f16_f16_f16_medium.mlir | 1 - .../f16_f16_f16_f16/attention_f16_f16_f16_f16_small.mlir | 1 - 4 files changed, 1 insertion(+), 4 deletions(-) diff --git a/linalg_ops/attention/generate_e2e_attention_tests.py b/linalg_ops/attention/generate_e2e_attention_tests.py index 4027c88..8e79192 100644 --- a/linalg_ops/attention/generate_e2e_attention_tests.py +++ b/linalg_ops/attention/generate_e2e_attention_tests.py @@ -225,7 +225,7 @@ def generate_function( f" iree_linalg_ext.yield %score : f32\n" f" }} -> {result_tensor_type}\n" f" return %result1: {result_tensor_type}\n" - f"}}\n" + f"}}" ) return MLIRFunction( name=func_name, diff --git a/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_large.mlir b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_large.mlir index e416b1a..bd9dafa 100644 --- a/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_large.mlir +++ b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_large.mlir @@ -14,4 +14,3 @@ func.func @attention_2_1024_128_128_64_dtype_f16_f16_f16_f16(%query: tensor<2x10 } -> tensor<2x1024x64xf16> return %result1: tensor<2x1024x64xf16> } - diff --git a/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_medium.mlir b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_medium.mlir index 860595e..8d94920 100644 --- a/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_medium.mlir +++ b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_medium.mlir @@ -14,4 +14,3 @@ func.func @attention_2_512_128_64_32_dtype_f16_f16_f16_f16(%query: tensor<2x512x } -> tensor<2x512x32xf16> return %result1: tensor<2x512x32xf16> } - diff --git a/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_small.mlir b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_small.mlir index d0586d6..cb6f3e6 100644 --- a/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_small.mlir +++ b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_small.mlir @@ -14,4 +14,3 @@ func.func @attention_2_256_64_32_16_dtype_f16_f16_f16_f16(%query: tensor<2x256x6 } -> tensor<2x256x16xf16> return %result1: tensor<2x256x16xf16> } - From 1b86fc8fc3a4556d46fac2dc8b665254daca7292 Mon Sep 17 00:00:00 2001 From: erman-gurses Date: Mon, 11 Nov 2024 19:18:18 -0600 Subject: [PATCH 6/9] Remove extra newlines for the generated calls Signed-off-by: erman-gurses --- linalg_ops/attention/generate_e2e_attention_tests.py | 4 ++-- .../attention_f16_f16_f16_f16_large_calls.mlir | 2 -- .../attention_f16_f16_f16_f16_medium_calls.mlir | 2 -- .../attention_f16_f16_f16_f16_small_calls.mlir | 2 -- 4 files changed, 2 insertions(+), 8 deletions(-) diff --git a/linalg_ops/attention/generate_e2e_attention_tests.py b/linalg_ops/attention/generate_e2e_attention_tests.py index 8e79192..c6ed70b 100644 --- a/linalg_ops/attention/generate_e2e_attention_tests.py +++ b/linalg_ops/attention/generate_e2e_attention_tests.py @@ -469,9 +469,9 @@ def write_calls_file(functions, calls, filename, requirements): # Emit the test cases for each call. for call in calls: - module_definition = module_definition + call.op + "\n" + module_definition = module_definition + call.op + "" - module_definition = module_definition + "\n}\n" + module_definition = module_definition + "}\n" with open(filename, "w") as file: file.write(module_definition) diff --git a/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_large_calls.mlir b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_large_calls.mlir index 544b0fa..38206bb 100644 --- a/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_large_calls.mlir +++ b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_large_calls.mlir @@ -52,6 +52,4 @@ func.func @attention_2_1024_128_128_64_dtype_f16_f16_f16_f16_2_1024_128_128_64_1 call @attention_test.check_attention_results(%device, %batch, %m, %k1, %k2, %n, %queryExtBufferView, %keyExtBufferView, %valueExtBufferView, %resultExtBufferView) : (!hal.device, i64, i64, i64, i64, i64, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> () return } - - } diff --git a/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_medium_calls.mlir b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_medium_calls.mlir index 6c77095..256eda2 100644 --- a/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_medium_calls.mlir +++ b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_medium_calls.mlir @@ -52,6 +52,4 @@ func.func @attention_2_512_128_64_32_dtype_f16_f16_f16_f16_2_512_128_64_32_128_1 call @attention_test.check_attention_results(%device, %batch, %m, %k1, %k2, %n, %queryExtBufferView, %keyExtBufferView, %valueExtBufferView, %resultExtBufferView) : (!hal.device, i64, i64, i64, i64, i64, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> () return } - - } diff --git a/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_small_calls.mlir b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_small_calls.mlir index 2aed007..a1b85a4 100644 --- a/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_small_calls.mlir +++ b/linalg_ops/attention/generated/f16_f16_f16_f16/attention_f16_f16_f16_f16_small_calls.mlir @@ -52,6 +52,4 @@ func.func @attention_2_256_64_32_16_dtype_f16_f16_f16_f16_2_256_64_32_16_64_1.0_ call @attention_test.check_attention_results(%device, %batch, %m, %k1, %k2, %n, %queryExtBufferView, %keyExtBufferView, %valueExtBufferView, %resultExtBufferView) : (!hal.device, i64, i64, i64, i64, i64, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> () return } - - } From 84ee3dd7005ba97151eb0de846aeb9d87ff201c2 Mon Sep 17 00:00:00 2001 From: erman-gurses Date: Mon, 11 Nov 2024 22:04:05 -0600 Subject: [PATCH 7/9] Refactor tensor name and status check Signed-off-by: erman-gurses --- linalg_ops/iree-e2e-attention-test.cc | 153 ++++++++++++++------------ 1 file changed, 81 insertions(+), 72 deletions(-) diff --git a/linalg_ops/iree-e2e-attention-test.cc b/linalg_ops/iree-e2e-attention-test.cc index 200a5b4..49ef7f1 100644 --- a/linalg_ops/iree-e2e-attention-test.cc +++ b/linalg_ops/iree-e2e-attention-test.cc @@ -29,9 +29,9 @@ // Helper for reference_attention. // Function to allocate and initialize tensors -float* allocate_tensor(int dim1, int dim2, int dim3) { +float *allocate_tensor(int dim1, int dim2, int dim3) { const int size = dim1 * dim2 * dim3; - float* tensor = (float*)malloc(size * sizeof(float)); + float *tensor = (float *)malloc(size * sizeof(float)); for (int i = 0; i < size; ++i) { tensor[i] = 0.0f; } @@ -39,8 +39,9 @@ float* allocate_tensor(int dim1, int dim2, int dim3) { } // Function to free allocated tensors -void free_tensor(float* tensor) { - if (tensor != nullptr) free(tensor); +void free_tensor(float *tensor) { + if (tensor != nullptr) + free(tensor); } // Function to calculate 1D index for a 3D array @@ -50,9 +51,9 @@ int index_3d(int i, int j, int k, int dim2, int dim3) { static void reference_attention_f32_f32_f32_f32( iree_hal_dim_t M, iree_hal_dim_t K1, iree_hal_dim_t K2, iree_hal_dim_t N, - iree_hal_dim_t B, const float* query_data, const float* key_data, - const float* value_data, float* result_data, iree_hal_dim_t b, - float* Attention) { + iree_hal_dim_t B, const float *query_data, const float *key_data, + const float *value_data, float *result_data, iree_hal_dim_t b, + float *qk_cache) { // Compute Q * K^T for (int m = 0; m < M; ++m) { for (int k2 = 0; k2 < K2; ++k2) { @@ -64,30 +65,30 @@ static void reference_attention_f32_f32_f32_f32( sum += query_data[q_idx] * key_data[k_idx]; } int att_idx = index_3d(0, m, k2, M, K2); - Attention[att_idx] = sum / sqrt(K1); // Scale by sqrt(K1) + qk_cache[att_idx] = sum / sqrt(K1); // Scale by sqrt(K1) } } - // Compute softmax on Attention + // Compute softmax on qk_cache for (int m = 0; m < M; ++m) { // Find the maximum value for the current sequence float max_val = -FLT_MAX; for (int k2 = 0; k2 < K2; ++k2) { int att_idx = index_3d(0, m, k2, M, K2); - max_val = iree_max(max_val, Attention[att_idx]); + max_val = iree_max(max_val, qk_cache[att_idx]); } // Calculate the softmax denominator float sum = 0.0f; for (int k2 = 0; k2 < K2; ++k2) { int att_idx = index_3d(0, m, k2, M, K2); - sum += exp(Attention[att_idx] - max_val); + sum += exp(qk_cache[att_idx] - max_val); } // Apply softmax for (int k2 = 0; k2 < K2; ++k2) { int att_idx = index_3d(0, m, k2, M, K2); - Attention[att_idx] = exp(Attention[att_idx]) / sum; + qk_cache[att_idx] = exp(qk_cache[att_idx]) / sum; } } @@ -98,7 +99,7 @@ static void reference_attention_f32_f32_f32_f32( for (int k2 = 0; k2 < K2; ++k2) { int att_idx = index_3d(0, m, k2, M, K2); int v_idx = index_3d(b, k2, n, K2, N); - sum += Attention[att_idx] * value_data[v_idx]; + sum += qk_cache[att_idx] * value_data[v_idx]; } int o_idx = index_3d(b, m, n, M, N); result_data[o_idx] = sum; @@ -110,15 +111,15 @@ static iree_status_t reference_attention_element( iree_hal_dim_t M, iree_hal_dim_t K1, iree_hal_dim_t K2, iree_hal_dim_t N, iree_hal_dim_t B, iree_hal_element_type_t query_elem_type, iree_hal_element_type_t key_elem_type, - iree_hal_element_type_t value_elem_type, void* query_data, void* key_data, - void* value_data, void* actual_data, void* result_data, iree_hal_dim_t b, - float* Attention) { + iree_hal_element_type_t value_elem_type, void *query_data, void *key_data, + void *value_data, void *actual_data, void *result_data, iree_hal_dim_t b, + float *qk_cache) { if (query_elem_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 && key_elem_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 && value_elem_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) { reference_attention_f32_f32_f32_f32( - M, K1, K2, N, B, (const float*)query_data, (const float*)key_data, - (const float*)value_data, (float*)result_data, b, Attention); + M, K1, K2, N, B, (const float *)query_data, (const float *)key_data, + (const float *)value_data, (float *)result_data, b, qk_cache); } else { return iree_make_status( @@ -146,21 +147,25 @@ static iree_status_t reference_attention( IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, N); iree_host_size_t count = 0; - float* Attention = allocate_tensor(1, M, K2); + float *qk_cache = allocate_tensor(1, M, K2); + iree_status_t status = iree_ok_status(); for (iree_hal_dim_t b = 0; b < B; ++b) { - if (++count < compute_every) continue; + if (++count < compute_every) + continue; count = 0; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, - reference_attention_element( - M, K1, K2, N, B, query_elem_type, key_elem_type, value_elem_type, - query_contents.data, key_contents.data, value_contents.data, - actual_contents.data, result_contents.data, b, Attention)); + if (iree_status_is_ok(status)) { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + status = reference_attention_element( + M, K1, K2, N, B, query_elem_type, key_elem_type, value_elem_type, + query_contents.data, key_contents.data, value_contents.data, + actual_contents.data, result_contents.data, b, qk_cache)); + } } - free_tensor(Attention); + free_tensor(qk_cache); IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); + return status; } //===----------------------------------------------------------------------===// // Attention comparison/logging @@ -184,14 +189,14 @@ typedef struct { iree_byte_span_t expected_contents; } attention_results_t; -static void attention_results_deinitialize(attention_results_t* results); +static void attention_results_deinitialize(attention_results_t *results); static iree_status_t attention_results_initialize( - iree_hal_device_t* device, iree_hal_dim_t b_size, iree_hal_dim_t m_size, + iree_hal_device_t *device, iree_hal_dim_t b_size, iree_hal_dim_t m_size, iree_hal_dim_t k1_size, iree_hal_dim_t k2_size, iree_hal_dim_t n_size, - iree_hal_buffer_view_t* query, iree_hal_buffer_view_t* key, - iree_hal_buffer_view_t* value, iree_hal_buffer_view_t* result, - iree_allocator_t host_allocator, attention_results_t* out_results) { + iree_hal_buffer_view_t *query, iree_hal_buffer_view_t *key, + iree_hal_buffer_view_t *value, iree_hal_buffer_view_t *result, + iree_allocator_t host_allocator, attention_results_t *out_results) { IREE_TRACE_ZONE_BEGIN(z0); memset(out_results, 0, sizeof(*out_results)); @@ -208,10 +213,10 @@ static iree_status_t attention_results_initialize( out_results->value_elem_type = iree_hal_buffer_view_element_type(value); out_results->result_elem_type = iree_hal_buffer_view_element_type(result); - iree_hal_buffer_t* query_buffer = iree_hal_buffer_view_buffer(query); - iree_hal_buffer_t* key_buffer = iree_hal_buffer_view_buffer(key); - iree_hal_buffer_t* value_buffer = iree_hal_buffer_view_buffer(value); - iree_hal_buffer_t* result_buffer = iree_hal_buffer_view_buffer(result); + iree_hal_buffer_t *query_buffer = iree_hal_buffer_view_buffer(query); + iree_hal_buffer_t *key_buffer = iree_hal_buffer_view_buffer(key); + iree_hal_buffer_t *value_buffer = iree_hal_buffer_view_buffer(value); + iree_hal_buffer_t *result_buffer = iree_hal_buffer_view_buffer(result); iree_status_t status = iree_ok_status(); @@ -220,7 +225,7 @@ static iree_status_t attention_results_initialize( iree_hal_buffer_byte_length(query_buffer); status = iree_allocator_malloc(host_allocator, out_results->query_contents.data_length, - (void**)&out_results->query_contents.data); + (void **)&out_results->query_contents.data); } if (iree_status_is_ok(status)) { status = iree_hal_device_transfer_d2h( @@ -233,7 +238,7 @@ static iree_status_t attention_results_initialize( iree_hal_buffer_byte_length(key_buffer); status = iree_allocator_malloc(host_allocator, out_results->key_contents.data_length, - (void**)&out_results->key_contents.data); + (void **)&out_results->key_contents.data); } if (iree_status_is_ok(status)) { status = iree_hal_device_transfer_d2h( @@ -246,7 +251,7 @@ static iree_status_t attention_results_initialize( iree_hal_buffer_byte_length(value_buffer); status = iree_allocator_malloc(host_allocator, out_results->value_contents.data_length, - (void**)&out_results->value_contents.data); + (void **)&out_results->value_contents.data); } if (iree_status_is_ok(status)) { @@ -260,7 +265,7 @@ static iree_status_t attention_results_initialize( iree_hal_buffer_byte_length(result_buffer); status = iree_allocator_malloc(host_allocator, out_results->actual_contents.data_length, - (void**)&out_results->actual_contents.data); + (void **)&out_results->actual_contents.data); } if (iree_status_is_ok(status)) { status = iree_hal_device_transfer_d2h( @@ -273,7 +278,7 @@ static iree_status_t attention_results_initialize( iree_hal_buffer_byte_length(result_buffer); status = iree_allocator_malloc( host_allocator, out_results->expected_contents.data_length, - (void**)&out_results->expected_contents.data); + (void **)&out_results->expected_contents.data); } if (!iree_status_is_ok(status)) { attention_results_deinitialize(out_results); @@ -282,7 +287,7 @@ static iree_status_t attention_results_initialize( return status; } -static void attention_results_deinitialize(attention_results_t* results) { +static void attention_results_deinitialize(attention_results_t *results) { IREE_TRACE_ZONE_BEGIN(z0); iree_allocator_free(results->host_allocator, results->query_contents.data); iree_allocator_free(results->host_allocator, results->key_contents.data); @@ -296,8 +301,9 @@ static void attention_results_deinitialize(attention_results_t* results) { // Helper for check_attention_results: the actual interesting part once we've // obtained and validated the {b,m,k1,k2,n}_size values. On error, detailed // logging is written to |file| if it is not NULL. -static iree_status_t check_attention_results_impl( - FILE* file, const attention_results_t* results, int check_every) { +static iree_status_t +check_attention_results_impl(FILE *file, const attention_results_t *results, + int check_every) { IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( @@ -316,8 +322,8 @@ static iree_status_t check_attention_results_impl( // reference attention implementation on the same inputs to check if the output // is correct. On error, detailed logging is written to |file| if it is not // NULL. -static iree_status_t check_attention_results( - FILE* file, const attention_results_t* results) { +static iree_status_t +check_attention_results(FILE *file, const attention_results_t *results) { IREE_TRACE_ZONE_BEGIN(z0); // TODO: Increase the check every param to reduce the number of comparisons. int check_every = 1; @@ -346,7 +352,7 @@ static iree_status_t check_attention_results( namespace iree { class AttentionTestModuleState final { - public: +public: explicit AttentionTestModuleState(iree_allocator_t host_allocator) : host_allocator_(host_allocator) {} ~AttentionTestModuleState() = default; @@ -355,9 +361,10 @@ class AttentionTestModuleState final { // |element_type|. The given |seed| is passed to the pseudorandom generator. // The pseudorandom values are reproducible both across runs and across // machines. - StatusOr> GenerateRandom3dTensor( - const vm::ref device, int64_t dim0, int64_t dim1, - int64_t dim2, iree_hal_element_type_t element_type, int32_t seed) { + StatusOr> + GenerateRandom3dTensor(const vm::ref device, int64_t dim0, + int64_t dim1, int64_t dim2, + iree_hal_element_type_t element_type, int32_t seed) { iree_hal_dim_t dims[3] = { (iree_hal_dim_t)dim0, (iree_hal_dim_t)dim1, @@ -379,8 +386,8 @@ class AttentionTestModuleState final { device.get(), iree_hal_device_allocator(device.get()), IREE_ARRAYSIZE(dims), dims, element_type, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params, - +[](iree_hal_buffer_mapping_t* mapping, void* user_data) { - callback_state_t callback_state = *(callback_state_t*)user_data; + +[](iree_hal_buffer_mapping_t *mapping, void *user_data) { + callback_state_t callback_state = *(callback_state_t *)user_data; iree_byte_span_t span = mapping->contents; // Generate "uniform" integer-valued numbers in the range [min, max]. int32_t min = 0; @@ -390,9 +397,9 @@ class AttentionTestModuleState final { uint32_t range = (max - min + 1); iree_host_size_t element_byte_count = iree_hal_element_dense_byte_count(callback_state.element_type); - uint8_t* data_end = span.data + span.data_length; + uint8_t *data_end = span.data + span.data_length; uint32_t state = callback_state.seed; - for (uint8_t* data = span.data; data < data_end; + for (uint8_t *data = span.data; data < data_end; data += element_byte_count) { int32_t value = (int32_t)iree_test_utils_pseudorandom_range(&state, range) + @@ -406,12 +413,13 @@ class AttentionTestModuleState final { return std::move(result_view); } - Status CheckAttentionResults( - const vm::ref device, int64_t b, int64_t m, int64_t k1, - int64_t k2, int64_t n, const vm::ref query, - const vm::ref key, - const vm::ref value, - const vm::ref actual_result) { + Status + CheckAttentionResults(const vm::ref device, int64_t b, + int64_t m, int64_t k1, int64_t k2, int64_t n, + const vm::ref query, + const vm::ref key, + const vm::ref value, + const vm::ref actual_result) { attention_results_t results = {}; IREE_RETURN_IF_ERROR(attention_results_initialize( device.get(), (iree_hal_dim_t)b, (iree_hal_dim_t)m, (iree_hal_dim_t)k1, @@ -422,7 +430,7 @@ class AttentionTestModuleState final { return status; } - private: +private: iree_allocator_t host_allocator_; }; @@ -439,22 +447,23 @@ static const vm::NativeFunction struct AttentionTestModule final : public vm::NativeModule { using vm::NativeModule::NativeModule; - StatusOr> CreateState( - iree_allocator_t host_allocator) override { + StatusOr> + CreateState(iree_allocator_t host_allocator) override { return std::make_unique(host_allocator); } - StatusOr> ForkState( - AttentionTestModuleState* parent_state, - iree_allocator_t host_allocator) { + StatusOr> + ForkState(AttentionTestModuleState *parent_state, + iree_allocator_t host_allocator) { return std::make_unique(host_allocator); } }; -} // namespace iree +} // namespace iree -static iree_status_t attention_test_module_create( - iree_vm_instance_t* instance, iree_allocator_t host_allocator, - iree_vm_module_t** out_module) { +static iree_status_t +attention_test_module_create(iree_vm_instance_t *instance, + iree_allocator_t host_allocator, + iree_vm_module_t **out_module) { IREE_ASSERT_ARGUMENT(out_module); *out_module = NULL; auto module = std::make_unique( @@ -466,7 +475,7 @@ static iree_status_t attention_test_module_create( return iree_ok_status(); } -int main(int argc, char** argv) { +int main(int argc, char **argv) { IREE_TRACE_APP_ENTER(); iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv); From cd045b743491278ec0a835fc746c0f56cc35ba42 Mon Sep 17 00:00:00 2001 From: erman-gurses Date: Mon, 11 Nov 2024 22:16:34 -0600 Subject: [PATCH 8/9] Removing traling spaces Signed-off-by: erman-gurses --- linalg_ops/attention/generate_e2e_attention_tests.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/linalg_ops/attention/generate_e2e_attention_tests.py b/linalg_ops/attention/generate_e2e_attention_tests.py index c6ed70b..03367ef 100644 --- a/linalg_ops/attention/generate_e2e_attention_tests.py +++ b/linalg_ops/attention/generate_e2e_attention_tests.py @@ -401,21 +401,21 @@ def parse_arguments(): "--query_type", type=str, choices=["f16"], - help="Numeric type of query tensors ", + help="Numeric type of query tensors", required=True, ) parser.add_argument( "--key_type", type=str, choices=["f16"], - help="Numeric type of key tensors ", + help="Numeric type of key tensors", required=True, ) parser.add_argument( "--value_type", type=str, choices=["f16"], - help="Numeric type of value tensors ", + help="Numeric type of value tensors", required=True, ) parser.add_argument( From 55fda1f736725a7a0aaf9a757bdb7a49f64a6e77 Mon Sep 17 00:00:00 2001 From: erman-gurses Date: Fri, 15 Nov 2024 17:37:50 -0600 Subject: [PATCH 9/9] Format pointer type variables Signed-off-by: erman-gurses --- linalg_ops/iree-e2e-attention-test.cc | 74 +++++++++++++-------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/linalg_ops/iree-e2e-attention-test.cc b/linalg_ops/iree-e2e-attention-test.cc index 49ef7f1..c27d11f 100644 --- a/linalg_ops/iree-e2e-attention-test.cc +++ b/linalg_ops/iree-e2e-attention-test.cc @@ -29,9 +29,9 @@ // Helper for reference_attention. // Function to allocate and initialize tensors -float *allocate_tensor(int dim1, int dim2, int dim3) { +float* allocate_tensor(int dim1, int dim2, int dim3) { const int size = dim1 * dim2 * dim3; - float *tensor = (float *)malloc(size * sizeof(float)); + float* tensor = (float*) malloc(size * sizeof(float)); for (int i = 0; i < size; ++i) { tensor[i] = 0.0f; } @@ -39,7 +39,7 @@ float *allocate_tensor(int dim1, int dim2, int dim3) { } // Function to free allocated tensors -void free_tensor(float *tensor) { +void free_tensor(float* tensor) { if (tensor != nullptr) free(tensor); } @@ -51,9 +51,9 @@ int index_3d(int i, int j, int k, int dim2, int dim3) { static void reference_attention_f32_f32_f32_f32( iree_hal_dim_t M, iree_hal_dim_t K1, iree_hal_dim_t K2, iree_hal_dim_t N, - iree_hal_dim_t B, const float *query_data, const float *key_data, - const float *value_data, float *result_data, iree_hal_dim_t b, - float *qk_cache) { + iree_hal_dim_t B, const float* query_data, const float* key_data, + const float* value_data, float* result_data, iree_hal_dim_t b, + float* qk_cache) { // Compute Q * K^T for (int m = 0; m < M; ++m) { for (int k2 = 0; k2 < K2; ++k2) { @@ -111,15 +111,15 @@ static iree_status_t reference_attention_element( iree_hal_dim_t M, iree_hal_dim_t K1, iree_hal_dim_t K2, iree_hal_dim_t N, iree_hal_dim_t B, iree_hal_element_type_t query_elem_type, iree_hal_element_type_t key_elem_type, - iree_hal_element_type_t value_elem_type, void *query_data, void *key_data, - void *value_data, void *actual_data, void *result_data, iree_hal_dim_t b, - float *qk_cache) { + iree_hal_element_type_t value_elem_type, void* query_data, void* key_data, + void* value_data, void* actual_data, void* result_data, iree_hal_dim_t b, + float* qk_cache) { if (query_elem_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 && key_elem_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 && value_elem_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) { reference_attention_f32_f32_f32_f32( - M, K1, K2, N, B, (const float *)query_data, (const float *)key_data, - (const float *)value_data, (float *)result_data, b, qk_cache); + M, K1, K2, N, B, (const float*)query_data, (const float*)key_data, + (const float*)value_data, (float*)result_data, b, qk_cache); } else { return iree_make_status( @@ -147,7 +147,7 @@ static iree_status_t reference_attention( IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, N); iree_host_size_t count = 0; - float *qk_cache = allocate_tensor(1, M, K2); + float* qk_cache = allocate_tensor(1, M, K2); iree_status_t status = iree_ok_status(); for (iree_hal_dim_t b = 0; b < B; ++b) { if (++count < compute_every) @@ -189,14 +189,14 @@ typedef struct { iree_byte_span_t expected_contents; } attention_results_t; -static void attention_results_deinitialize(attention_results_t *results); +static void attention_results_deinitialize(attention_results_t* results); static iree_status_t attention_results_initialize( - iree_hal_device_t *device, iree_hal_dim_t b_size, iree_hal_dim_t m_size, + iree_hal_device_t* device, iree_hal_dim_t b_size, iree_hal_dim_t m_size, iree_hal_dim_t k1_size, iree_hal_dim_t k2_size, iree_hal_dim_t n_size, - iree_hal_buffer_view_t *query, iree_hal_buffer_view_t *key, - iree_hal_buffer_view_t *value, iree_hal_buffer_view_t *result, - iree_allocator_t host_allocator, attention_results_t *out_results) { + iree_hal_buffer_view_t* query, iree_hal_buffer_view_t* key, + iree_hal_buffer_view_t* value, iree_hal_buffer_view_t* result, + iree_allocator_t host_allocator, attention_results_t* out_results) { IREE_TRACE_ZONE_BEGIN(z0); memset(out_results, 0, sizeof(*out_results)); @@ -213,10 +213,10 @@ static iree_status_t attention_results_initialize( out_results->value_elem_type = iree_hal_buffer_view_element_type(value); out_results->result_elem_type = iree_hal_buffer_view_element_type(result); - iree_hal_buffer_t *query_buffer = iree_hal_buffer_view_buffer(query); - iree_hal_buffer_t *key_buffer = iree_hal_buffer_view_buffer(key); - iree_hal_buffer_t *value_buffer = iree_hal_buffer_view_buffer(value); - iree_hal_buffer_t *result_buffer = iree_hal_buffer_view_buffer(result); + iree_hal_buffer_t* query_buffer = iree_hal_buffer_view_buffer(query); + iree_hal_buffer_t* key_buffer = iree_hal_buffer_view_buffer(key); + iree_hal_buffer_t* value_buffer = iree_hal_buffer_view_buffer(value); + iree_hal_buffer_t* result_buffer = iree_hal_buffer_view_buffer(result); iree_status_t status = iree_ok_status(); @@ -225,7 +225,7 @@ static iree_status_t attention_results_initialize( iree_hal_buffer_byte_length(query_buffer); status = iree_allocator_malloc(host_allocator, out_results->query_contents.data_length, - (void **)&out_results->query_contents.data); + (void**)&out_results->query_contents.data); } if (iree_status_is_ok(status)) { status = iree_hal_device_transfer_d2h( @@ -238,7 +238,7 @@ static iree_status_t attention_results_initialize( iree_hal_buffer_byte_length(key_buffer); status = iree_allocator_malloc(host_allocator, out_results->key_contents.data_length, - (void **)&out_results->key_contents.data); + (void**)&out_results->key_contents.data); } if (iree_status_is_ok(status)) { status = iree_hal_device_transfer_d2h( @@ -251,7 +251,7 @@ static iree_status_t attention_results_initialize( iree_hal_buffer_byte_length(value_buffer); status = iree_allocator_malloc(host_allocator, out_results->value_contents.data_length, - (void **)&out_results->value_contents.data); + (void**)&out_results->value_contents.data); } if (iree_status_is_ok(status)) { @@ -265,7 +265,7 @@ static iree_status_t attention_results_initialize( iree_hal_buffer_byte_length(result_buffer); status = iree_allocator_malloc(host_allocator, out_results->actual_contents.data_length, - (void **)&out_results->actual_contents.data); + (void**)&out_results->actual_contents.data); } if (iree_status_is_ok(status)) { status = iree_hal_device_transfer_d2h( @@ -278,7 +278,7 @@ static iree_status_t attention_results_initialize( iree_hal_buffer_byte_length(result_buffer); status = iree_allocator_malloc( host_allocator, out_results->expected_contents.data_length, - (void **)&out_results->expected_contents.data); + (void**)&out_results->expected_contents.data); } if (!iree_status_is_ok(status)) { attention_results_deinitialize(out_results); @@ -287,7 +287,7 @@ static iree_status_t attention_results_initialize( return status; } -static void attention_results_deinitialize(attention_results_t *results) { +static void attention_results_deinitialize(attention_results_t* results) { IREE_TRACE_ZONE_BEGIN(z0); iree_allocator_free(results->host_allocator, results->query_contents.data); iree_allocator_free(results->host_allocator, results->key_contents.data); @@ -302,7 +302,7 @@ static void attention_results_deinitialize(attention_results_t *results) { // obtained and validated the {b,m,k1,k2,n}_size values. On error, detailed // logging is written to |file| if it is not NULL. static iree_status_t -check_attention_results_impl(FILE *file, const attention_results_t *results, +check_attention_results_impl(FILE* file, const attention_results_t* results, int check_every) { IREE_TRACE_ZONE_BEGIN(z0); @@ -323,7 +323,7 @@ check_attention_results_impl(FILE *file, const attention_results_t *results, // is correct. On error, detailed logging is written to |file| if it is not // NULL. static iree_status_t -check_attention_results(FILE *file, const attention_results_t *results) { +check_attention_results(FILE* file, const attention_results_t* results) { IREE_TRACE_ZONE_BEGIN(z0); // TODO: Increase the check every param to reduce the number of comparisons. int check_every = 1; @@ -386,8 +386,8 @@ class AttentionTestModuleState final { device.get(), iree_hal_device_allocator(device.get()), IREE_ARRAYSIZE(dims), dims, element_type, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params, - +[](iree_hal_buffer_mapping_t *mapping, void *user_data) { - callback_state_t callback_state = *(callback_state_t *)user_data; + +[](iree_hal_buffer_mapping_t* mapping, void* user_data) { + callback_state_t callback_state = *(callback_state_t*)user_data; iree_byte_span_t span = mapping->contents; // Generate "uniform" integer-valued numbers in the range [min, max]. int32_t min = 0; @@ -397,9 +397,9 @@ class AttentionTestModuleState final { uint32_t range = (max - min + 1); iree_host_size_t element_byte_count = iree_hal_element_dense_byte_count(callback_state.element_type); - uint8_t *data_end = span.data + span.data_length; + uint8_t* data_end = span.data + span.data_length; uint32_t state = callback_state.seed; - for (uint8_t *data = span.data; data < data_end; + for (uint8_t* data = span.data; data < data_end; data += element_byte_count) { int32_t value = (int32_t)iree_test_utils_pseudorandom_range(&state, range) + @@ -452,7 +452,7 @@ struct AttentionTestModule final return std::make_unique(host_allocator); } StatusOr> - ForkState(AttentionTestModuleState *parent_state, + ForkState(AttentionTestModuleState* parent_state, iree_allocator_t host_allocator) { return std::make_unique(host_allocator); } @@ -461,9 +461,9 @@ struct AttentionTestModule final } // namespace iree static iree_status_t -attention_test_module_create(iree_vm_instance_t *instance, +attention_test_module_create(iree_vm_instance_t* instance, iree_allocator_t host_allocator, - iree_vm_module_t **out_module) { + iree_vm_module_t** out_module) { IREE_ASSERT_ARGUMENT(out_module); *out_module = NULL; auto module = std::make_unique( @@ -475,7 +475,7 @@ attention_test_module_create(iree_vm_instance_t *instance, return iree_ok_status(); } -int main(int argc, char **argv) { +int main(int argc, char** argv) { IREE_TRACE_APP_ENTER(); iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);