Skip to content

Commit

Permalink
[spirv-reader] Enable F16 support
Browse files Browse the repository at this point in the history
Add support for `f16` values in the AST SPIR-V reader.

Bug: 377728743
Change-Id: I66525d37c11f5521367404c91f8da3dee010fd0d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/214974
Reviewed-by: David Neto <[email protected]>
Commit-Queue: dan sinclair <[email protected]>
  • Loading branch information
dj2 authored and Dawn LUCI CQ committed Nov 18, 2024
1 parent 3154277 commit 3e934f8
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 10 deletions.
2 changes: 2 additions & 0 deletions src/tint/cmd/tint/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,8 @@ int main(int argc, const char** argv) {
#if TINT_BUILD_SPV_READER
opts.use_ir = options.use_ir_reader;
opts.spirv_reader_options = options.spirv_reader_options;
// Allow the shader-f16 extension
opts.spirv_reader_options.allowed_features = tint::wgsl::AllowedFeatures::Everything();
#endif

auto info = tint::cmd::LoadProgramInfo(opts);
Expand Down
24 changes: 22 additions & 2 deletions src/tint/lang/spirv/reader/ast_parser/ast_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,10 @@ const Type* ASTParser::ConvertType(const spvtools::opt::analysis::Float* float_t
if (float_ty->width() == 32) {
return ty_.F32();
}
if (float_ty->width() == 16) {
Enable(wgsl::Extension::kF16);
return ty_.F16();
}
Fail() << "unhandled float width: " << float_ty->width();
return nullptr;
}
Expand Down Expand Up @@ -2103,6 +2107,22 @@ TypedExpression ASTParser::MakeConstantExpressionForScalarSpirvConstant(
return TypedExpression{};
}
},
[&](const F16*) {
auto bits = spirv_const->AsScalarConstant()->GetU32BitValue();

// Section 2.2.1 of the SPIR-V spec guarantees that all integer types
// smaller than 32-bits are automatically zero or sign extended to 32-bits.
auto val = f16::FromBits(static_cast<uint16_t>(bits));

if (auto f = core::CheckedConvert<f16>(AFloat(val)); f == Success) {
return TypedExpression{ty_.F16(), create<ast::FloatLiteralExpression>(
source, static_cast<double>(val.value),
ast::FloatLiteralExpression::Suffix::kH)};
} else {
Fail() << "value cannot be represented as 'f16': " << spirv_const->GetFloat();
return TypedExpression{};
}
},
[&](const Bool*) {
const bool value =
spirv_const->AsNullConstant() ? false : spirv_const->AsBoolConstant()->value();
Expand Down Expand Up @@ -2259,7 +2279,7 @@ const Type* ASTParser::GetSignedIntMatchingShape(const Type* other) {
if (other == nullptr) {
Fail() << "no type provided";
}
if (other->Is<F32>() || other->Is<U32>() || other->Is<I32>()) {
if (other->IsAnyOf<F32, U32, I32>()) {
return ty_.I32();
}
if (auto* vec_ty = other->As<Vector>()) {
Expand All @@ -2274,7 +2294,7 @@ const Type* ASTParser::GetUnsignedIntMatchingShape(const Type* other) {
Fail() << "no type provided";
return nullptr;
}
if (other->Is<F32>() || other->Is<U32>() || other->Is<I32>()) {
if (other->IsAnyOf<F32, U32, I32>()) {
return ty_.U32();
}
if (auto* vec_ty = other->As<Vector>()) {
Expand Down
13 changes: 12 additions & 1 deletion src/tint/lang/spirv/reader/ast_parser/constant_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ std::string Preamble() {
OpCapability Image1D
OpCapability StorageImageExtendedFormats
OpCapability ImageQuery
OpCapability Float16
OpMemoryModel Logical Simple
)";
}
Expand Down Expand Up @@ -69,6 +70,7 @@ std::string CommonTypes() {
%bool = OpTypeBool
%float = OpTypeFloat 32
%half = OpTypeFloat 16
%uint = OpTypeInt 32 0
%int = OpTypeInt 32 1
Expand All @@ -81,6 +83,9 @@ std::string CommonTypes() {
%v2float = OpTypeVector %float 2
%v3float = OpTypeVector %float 3
%v4float = OpTypeVector %float 4
%v2half = OpTypeVector %half 2
%v3half = OpTypeVector %half 3
%v4half = OpTypeVector %half 4
%true = OpConstantTrue %bool
%false = OpConstantFalse %bool
Expand All @@ -95,6 +100,9 @@ std::string CommonTypes() {
%float_minus_5 = OpConstant %float -5
%float_half = OpConstant %float 0.5
%float_ten = OpConstant %float 10
%half_minus_5 = OpConstant %half -5
%half_half = OpConstant %half 0.5
%half_ten = OpConstant %half 10
)";
}

Expand Down Expand Up @@ -149,7 +157,10 @@ INSTANTIATE_TEST_SUITE_P(Scalars,
{"%uint", "%uint_max", "4294967295u"},
{"%float", "%float_minus_5", "-5.0f"},
{"%float", "%float_half", "0.5f"},
{"%float", "%float_ten", "10.0f"}}));
{"%float", "%float_ten", "10.0f"},
{"%half", "%half_minus_5", "-5.0h"},
{"%half", "%half_half", "0.5h"},
{"%half", "%half_ten", "10.0h"}}));

} // namespace
} // namespace tint::spirv::reader::ast_parser
114 changes: 114 additions & 0 deletions src/tint/lang/spirv/reader/ast_parser/convert_type_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,15 @@ TEST_F(SpirvASTParserTest, ConvertType_F32) {
EXPECT_TRUE(p->error().empty());
}

TEST_F(SpirvASTParserTest, ConvertType_F16) {
auto p = parser(test::Assemble(Preamble() + "%4 = OpTypeFloat 16" + MainBody()));
EXPECT_TRUE(p->BuildInternalModule());

auto* type = p->ConvertType(4);
EXPECT_TRUE(type->Is<F16>());
EXPECT_TRUE(p->error().empty());
}

TEST_F(SpirvASTParserTest, ConvertType_BadIntWidth) {
auto p = parser(test::Assemble(Preamble() + "%5 = OpTypeInt 17 1" + MainBody()));
EXPECT_TRUE(p->BuildInternalModule());
Expand Down Expand Up @@ -214,6 +223,33 @@ TEST_F(SpirvASTParserTest, ConvertType_VecOverF32) {
EXPECT_TRUE(p->error().empty());
}

TEST_F(SpirvASTParserTest, ConvertType_VecOverF16) {
auto p = parser(test::Assemble(Preamble() + R"(
%float = OpTypeFloat 16
%20 = OpTypeVector %float 2
%30 = OpTypeVector %float 3
%40 = OpTypeVector %float 4
)" + MainBody()));
EXPECT_TRUE(p->BuildInternalModule());

auto* v2xf16 = p->ConvertType(20);
EXPECT_TRUE(v2xf16->Is<Vector>());
EXPECT_TRUE(v2xf16->As<Vector>()->type->Is<F16>());
EXPECT_EQ(v2xf16->As<Vector>()->size, 2u);

auto* v3xf16 = p->ConvertType(30);
EXPECT_TRUE(v3xf16->Is<Vector>());
EXPECT_TRUE(v3xf16->As<Vector>()->type->Is<F16>());
EXPECT_EQ(v3xf16->As<Vector>()->size, 3u);

auto* v4xf16 = p->ConvertType(40);
EXPECT_TRUE(v4xf16->Is<Vector>());
EXPECT_TRUE(v4xf16->As<Vector>()->type->Is<F16>());
EXPECT_EQ(v4xf16->As<Vector>()->size, 4u);

EXPECT_TRUE(p->error().empty());
}

TEST_F(SpirvASTParserTest, ConvertType_VecOverI32) {
auto p = parser(test::Assemble(Preamble() + R"(
%int = OpTypeInt 32 1
Expand Down Expand Up @@ -359,6 +395,84 @@ TEST_F(SpirvASTParserTest, ConvertType_MatrixOverF32) {
EXPECT_TRUE(p->error().empty());
}

TEST_F(SpirvASTParserTest, ConvertType_MatrixOverF16) {
// Matrices are only defined over floats.
auto p = parser(test::Assemble(Preamble() + R"(
%float = OpTypeFloat 16
%v2 = OpTypeVector %float 2
%v3 = OpTypeVector %float 3
%v4 = OpTypeVector %float 4
; First digit is rows
; Second digit is columns
%22 = OpTypeMatrix %v2 2
%23 = OpTypeMatrix %v2 3
%24 = OpTypeMatrix %v2 4
%32 = OpTypeMatrix %v3 2
%33 = OpTypeMatrix %v3 3
%34 = OpTypeMatrix %v3 4
%42 = OpTypeMatrix %v4 2
%43 = OpTypeMatrix %v4 3
%44 = OpTypeMatrix %v4 4
)" + MainBody()));
EXPECT_TRUE(p->BuildInternalModule());

auto* m22 = p->ConvertType(22);
EXPECT_TRUE(m22->Is<Matrix>());
EXPECT_TRUE(m22->As<Matrix>()->type->Is<F16>());
EXPECT_EQ(m22->As<Matrix>()->rows, 2u);
EXPECT_EQ(m22->As<Matrix>()->columns, 2u);

auto* m23 = p->ConvertType(23);
EXPECT_TRUE(m23->Is<Matrix>());
EXPECT_TRUE(m23->As<Matrix>()->type->Is<F16>());
EXPECT_EQ(m23->As<Matrix>()->rows, 2u);
EXPECT_EQ(m23->As<Matrix>()->columns, 3u);

auto* m24 = p->ConvertType(24);
EXPECT_TRUE(m24->Is<Matrix>());
EXPECT_TRUE(m24->As<Matrix>()->type->Is<F16>());
EXPECT_EQ(m24->As<Matrix>()->rows, 2u);
EXPECT_EQ(m24->As<Matrix>()->columns, 4u);

auto* m32 = p->ConvertType(32);
EXPECT_TRUE(m32->Is<Matrix>());
EXPECT_TRUE(m32->As<Matrix>()->type->Is<F16>());
EXPECT_EQ(m32->As<Matrix>()->rows, 3u);
EXPECT_EQ(m32->As<Matrix>()->columns, 2u);

auto* m33 = p->ConvertType(33);
EXPECT_TRUE(m33->Is<Matrix>());
EXPECT_TRUE(m33->As<Matrix>()->type->Is<F16>());
EXPECT_EQ(m33->As<Matrix>()->rows, 3u);
EXPECT_EQ(m33->As<Matrix>()->columns, 3u);

auto* m34 = p->ConvertType(34);
EXPECT_TRUE(m34->Is<Matrix>());
EXPECT_TRUE(m34->As<Matrix>()->type->Is<F16>());
EXPECT_EQ(m34->As<Matrix>()->rows, 3u);
EXPECT_EQ(m34->As<Matrix>()->columns, 4u);

auto* m42 = p->ConvertType(42);
EXPECT_TRUE(m42->Is<Matrix>());
EXPECT_TRUE(m42->As<Matrix>()->type->Is<F16>());
EXPECT_EQ(m42->As<Matrix>()->rows, 4u);
EXPECT_EQ(m42->As<Matrix>()->columns, 2u);

auto* m43 = p->ConvertType(43);
EXPECT_TRUE(m43->Is<Matrix>());
EXPECT_TRUE(m43->As<Matrix>()->type->Is<F16>());
EXPECT_EQ(m43->As<Matrix>()->rows, 4u);
EXPECT_EQ(m43->As<Matrix>()->columns, 3u);

auto* m44 = p->ConvertType(44);
EXPECT_TRUE(m44->Is<Matrix>());
EXPECT_TRUE(m44->As<Matrix>()->type->Is<F16>());
EXPECT_EQ(m44->As<Matrix>()->rows, 4u);
EXPECT_EQ(m44->As<Matrix>()->columns, 4u);

EXPECT_TRUE(p->error().empty());
}

TEST_F(SpirvASTParserTest, ConvertType_RuntimeArray) {
auto p = parser(test::Assemble(Preamble() + R"(
%uint = OpTypeInt 32 0
Expand Down
11 changes: 9 additions & 2 deletions src/tint/lang/spirv/reader/ast_parser/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3962,7 +3962,7 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
}

if (op == spv::Op::OpConvertSToF || op == spv::Op::OpConvertUToF ||
op == spv::Op::OpConvertFToS || op == spv::Op::OpConvertFToU) {
op == spv::Op::OpConvertFToS || op == spv::Op::OpConvertFToU || op == spv::Op::OpFConvert) {
return MakeNumericConversion(inst);
}

Expand All @@ -3987,7 +3987,6 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
// OpSatConvertUToS // Only in Kernel (OpenCL), not in WebGPU
// OpUConvert // Only needed when multiple widths supported
// OpSConvert // Only needed when multiple widths supported
// OpFConvert // Only needed when multiple widths supported
// OpConvertPtrToU // Not in WebGPU
// OpConvertUToPtr // Not in WebGPU
// OpPtrCastToGeneric // Not in Vulkan
Expand Down Expand Up @@ -5227,6 +5226,14 @@ TypedExpression FunctionEmitter::MakeNumericConversion(const spvtools::opt::Inst
"point scalar or vector: "
<< inst.PrettyPrint();
}
} else if (op == spv::Op::OpFConvert) {
if (arg_expr.type->IsFloatScalarOrVector()) {
expr_type = requested_type;
} else {
Fail() << "operand for conversion to float 16 must be floating "
"point scalar or vector: "
<< inst.PrettyPrint();
}
}
if (expr_type == nullptr) {
// The diagnostic has already been emitted.
Expand Down
Loading

0 comments on commit 3e934f8

Please sign in to comment.