From 3e934f8f50a1d307ba4ae1b96927fd91497d00c7 Mon Sep 17 00:00:00 2001 From: dan sinclair Date: Mon, 18 Nov 2024 17:24:14 +0000 Subject: [PATCH] [spirv-reader] Enable F16 support Add support for `f16` values in the AST SPIR-V reader. Bug: 377728743 Change-Id: I66525d37c11f5521367404c91f8da3dee010fd0d Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/214974 Reviewed-by: David Neto Commit-Queue: dan sinclair --- src/tint/cmd/tint/main.cc | 2 + .../spirv/reader/ast_parser/ast_parser.cc | 24 +++- .../spirv/reader/ast_parser/constant_test.cc | 13 +- .../reader/ast_parser/convert_type_test.cc | 114 ++++++++++++++++++ .../lang/spirv/reader/ast_parser/function.cc | 11 +- .../ast_parser/function_conversion_test.cc | 92 +++++++++++++- src/tint/lang/spirv/reader/ast_parser/type.cc | 32 ++++- src/tint/lang/spirv/reader/ast_parser/type.h | 14 +++ .../lang/spirv/reader/ast_parser/type_test.cc | 1 + 9 files changed, 293 insertions(+), 10 deletions(-) diff --git a/src/tint/cmd/tint/main.cc b/src/tint/cmd/tint/main.cc index 7925c0272c7..48483cb83bd 100644 --- a/src/tint/cmd/tint/main.cc +++ b/src/tint/cmd/tint/main.cc @@ -1288,6 +1288,8 @@ int main(int argc, const char** argv) { #if TINT_BUILD_SPV_READER opts.use_ir = options.use_ir_reader; opts.spirv_reader_options = options.spirv_reader_options; + // Allow the shader-f16 extension + opts.spirv_reader_options.allowed_features = tint::wgsl::AllowedFeatures::Everything(); #endif auto info = tint::cmd::LoadProgramInfo(opts); diff --git a/src/tint/lang/spirv/reader/ast_parser/ast_parser.cc b/src/tint/lang/spirv/reader/ast_parser/ast_parser.cc index de9615a6184..930a6c721b9 100644 --- a/src/tint/lang/spirv/reader/ast_parser/ast_parser.cc +++ b/src/tint/lang/spirv/reader/ast_parser/ast_parser.cc @@ -987,6 +987,10 @@ const Type* ASTParser::ConvertType(const spvtools::opt::analysis::Float* float_t if (float_ty->width() == 32) { return ty_.F32(); } + if (float_ty->width() == 16) { + Enable(wgsl::Extension::kF16); + return ty_.F16(); + } Fail() << "unhandled float width: " << float_ty->width(); return nullptr; } @@ -2103,6 +2107,22 @@ TypedExpression ASTParser::MakeConstantExpressionForScalarSpirvConstant( return TypedExpression{}; } }, + [&](const F16*) { + auto bits = spirv_const->AsScalarConstant()->GetU32BitValue(); + + // Section 2.2.1 of the SPIR-V spec guarantees that all integer types + // smaller than 32-bits are automatically zero or sign extended to 32-bits. + auto val = f16::FromBits(static_cast(bits)); + + if (auto f = core::CheckedConvert(AFloat(val)); f == Success) { + return TypedExpression{ty_.F16(), create( + source, static_cast(val.value), + ast::FloatLiteralExpression::Suffix::kH)}; + } else { + Fail() << "value cannot be represented as 'f16': " << spirv_const->GetFloat(); + return TypedExpression{}; + } + }, [&](const Bool*) { const bool value = spirv_const->AsNullConstant() ? false : spirv_const->AsBoolConstant()->value(); @@ -2259,7 +2279,7 @@ const Type* ASTParser::GetSignedIntMatchingShape(const Type* other) { if (other == nullptr) { Fail() << "no type provided"; } - if (other->Is() || other->Is() || other->Is()) { + if (other->IsAnyOf()) { return ty_.I32(); } if (auto* vec_ty = other->As()) { @@ -2274,7 +2294,7 @@ const Type* ASTParser::GetUnsignedIntMatchingShape(const Type* other) { Fail() << "no type provided"; return nullptr; } - if (other->Is() || other->Is() || other->Is()) { + if (other->IsAnyOf()) { return ty_.U32(); } if (auto* vec_ty = other->As()) { diff --git a/src/tint/lang/spirv/reader/ast_parser/constant_test.cc b/src/tint/lang/spirv/reader/ast_parser/constant_test.cc index 6d17707c2f3..188c6d7362d 100644 --- a/src/tint/lang/spirv/reader/ast_parser/constant_test.cc +++ b/src/tint/lang/spirv/reader/ast_parser/constant_test.cc @@ -42,6 +42,7 @@ std::string Preamble() { OpCapability Image1D OpCapability StorageImageExtendedFormats OpCapability ImageQuery + OpCapability Float16 OpMemoryModel Logical Simple )"; } @@ -69,6 +70,7 @@ std::string CommonTypes() { %bool = OpTypeBool %float = OpTypeFloat 32 + %half = OpTypeFloat 16 %uint = OpTypeInt 32 0 %int = OpTypeInt 32 1 @@ -81,6 +83,9 @@ std::string CommonTypes() { %v2float = OpTypeVector %float 2 %v3float = OpTypeVector %float 3 %v4float = OpTypeVector %float 4 + %v2half = OpTypeVector %half 2 + %v3half = OpTypeVector %half 3 + %v4half = OpTypeVector %half 4 %true = OpConstantTrue %bool %false = OpConstantFalse %bool @@ -95,6 +100,9 @@ std::string CommonTypes() { %float_minus_5 = OpConstant %float -5 %float_half = OpConstant %float 0.5 %float_ten = OpConstant %float 10 + %half_minus_5 = OpConstant %half -5 + %half_half = OpConstant %half 0.5 + %half_ten = OpConstant %half 10 )"; } @@ -149,7 +157,10 @@ INSTANTIATE_TEST_SUITE_P(Scalars, {"%uint", "%uint_max", "4294967295u"}, {"%float", "%float_minus_5", "-5.0f"}, {"%float", "%float_half", "0.5f"}, - {"%float", "%float_ten", "10.0f"}})); + {"%float", "%float_ten", "10.0f"}, + {"%half", "%half_minus_5", "-5.0h"}, + {"%half", "%half_half", "0.5h"}, + {"%half", "%half_ten", "10.0h"}})); } // namespace } // namespace tint::spirv::reader::ast_parser diff --git a/src/tint/lang/spirv/reader/ast_parser/convert_type_test.cc b/src/tint/lang/spirv/reader/ast_parser/convert_type_test.cc index 8e43f1273e7..f56862af99e 100644 --- a/src/tint/lang/spirv/reader/ast_parser/convert_type_test.cc +++ b/src/tint/lang/spirv/reader/ast_parser/convert_type_test.cc @@ -157,6 +157,15 @@ TEST_F(SpirvASTParserTest, ConvertType_F32) { EXPECT_TRUE(p->error().empty()); } +TEST_F(SpirvASTParserTest, ConvertType_F16) { + auto p = parser(test::Assemble(Preamble() + "%4 = OpTypeFloat 16" + MainBody())); + EXPECT_TRUE(p->BuildInternalModule()); + + auto* type = p->ConvertType(4); + EXPECT_TRUE(type->Is()); + EXPECT_TRUE(p->error().empty()); +} + TEST_F(SpirvASTParserTest, ConvertType_BadIntWidth) { auto p = parser(test::Assemble(Preamble() + "%5 = OpTypeInt 17 1" + MainBody())); EXPECT_TRUE(p->BuildInternalModule()); @@ -214,6 +223,33 @@ TEST_F(SpirvASTParserTest, ConvertType_VecOverF32) { EXPECT_TRUE(p->error().empty()); } +TEST_F(SpirvASTParserTest, ConvertType_VecOverF16) { + auto p = parser(test::Assemble(Preamble() + R"( + %float = OpTypeFloat 16 + %20 = OpTypeVector %float 2 + %30 = OpTypeVector %float 3 + %40 = OpTypeVector %float 4 + )" + MainBody())); + EXPECT_TRUE(p->BuildInternalModule()); + + auto* v2xf16 = p->ConvertType(20); + EXPECT_TRUE(v2xf16->Is()); + EXPECT_TRUE(v2xf16->As()->type->Is()); + EXPECT_EQ(v2xf16->As()->size, 2u); + + auto* v3xf16 = p->ConvertType(30); + EXPECT_TRUE(v3xf16->Is()); + EXPECT_TRUE(v3xf16->As()->type->Is()); + EXPECT_EQ(v3xf16->As()->size, 3u); + + auto* v4xf16 = p->ConvertType(40); + EXPECT_TRUE(v4xf16->Is()); + EXPECT_TRUE(v4xf16->As()->type->Is()); + EXPECT_EQ(v4xf16->As()->size, 4u); + + EXPECT_TRUE(p->error().empty()); +} + TEST_F(SpirvASTParserTest, ConvertType_VecOverI32) { auto p = parser(test::Assemble(Preamble() + R"( %int = OpTypeInt 32 1 @@ -359,6 +395,84 @@ TEST_F(SpirvASTParserTest, ConvertType_MatrixOverF32) { EXPECT_TRUE(p->error().empty()); } +TEST_F(SpirvASTParserTest, ConvertType_MatrixOverF16) { + // Matrices are only defined over floats. + auto p = parser(test::Assemble(Preamble() + R"( + %float = OpTypeFloat 16 + %v2 = OpTypeVector %float 2 + %v3 = OpTypeVector %float 3 + %v4 = OpTypeVector %float 4 + ; First digit is rows + ; Second digit is columns + %22 = OpTypeMatrix %v2 2 + %23 = OpTypeMatrix %v2 3 + %24 = OpTypeMatrix %v2 4 + %32 = OpTypeMatrix %v3 2 + %33 = OpTypeMatrix %v3 3 + %34 = OpTypeMatrix %v3 4 + %42 = OpTypeMatrix %v4 2 + %43 = OpTypeMatrix %v4 3 + %44 = OpTypeMatrix %v4 4 + )" + MainBody())); + EXPECT_TRUE(p->BuildInternalModule()); + + auto* m22 = p->ConvertType(22); + EXPECT_TRUE(m22->Is()); + EXPECT_TRUE(m22->As()->type->Is()); + EXPECT_EQ(m22->As()->rows, 2u); + EXPECT_EQ(m22->As()->columns, 2u); + + auto* m23 = p->ConvertType(23); + EXPECT_TRUE(m23->Is()); + EXPECT_TRUE(m23->As()->type->Is()); + EXPECT_EQ(m23->As()->rows, 2u); + EXPECT_EQ(m23->As()->columns, 3u); + + auto* m24 = p->ConvertType(24); + EXPECT_TRUE(m24->Is()); + EXPECT_TRUE(m24->As()->type->Is()); + EXPECT_EQ(m24->As()->rows, 2u); + EXPECT_EQ(m24->As()->columns, 4u); + + auto* m32 = p->ConvertType(32); + EXPECT_TRUE(m32->Is()); + EXPECT_TRUE(m32->As()->type->Is()); + EXPECT_EQ(m32->As()->rows, 3u); + EXPECT_EQ(m32->As()->columns, 2u); + + auto* m33 = p->ConvertType(33); + EXPECT_TRUE(m33->Is()); + EXPECT_TRUE(m33->As()->type->Is()); + EXPECT_EQ(m33->As()->rows, 3u); + EXPECT_EQ(m33->As()->columns, 3u); + + auto* m34 = p->ConvertType(34); + EXPECT_TRUE(m34->Is()); + EXPECT_TRUE(m34->As()->type->Is()); + EXPECT_EQ(m34->As()->rows, 3u); + EXPECT_EQ(m34->As()->columns, 4u); + + auto* m42 = p->ConvertType(42); + EXPECT_TRUE(m42->Is()); + EXPECT_TRUE(m42->As()->type->Is()); + EXPECT_EQ(m42->As()->rows, 4u); + EXPECT_EQ(m42->As()->columns, 2u); + + auto* m43 = p->ConvertType(43); + EXPECT_TRUE(m43->Is()); + EXPECT_TRUE(m43->As()->type->Is()); + EXPECT_EQ(m43->As()->rows, 4u); + EXPECT_EQ(m43->As()->columns, 3u); + + auto* m44 = p->ConvertType(44); + EXPECT_TRUE(m44->Is()); + EXPECT_TRUE(m44->As()->type->Is()); + EXPECT_EQ(m44->As()->rows, 4u); + EXPECT_EQ(m44->As()->columns, 4u); + + EXPECT_TRUE(p->error().empty()); +} + TEST_F(SpirvASTParserTest, ConvertType_RuntimeArray) { auto p = parser(test::Assemble(Preamble() + R"( %uint = OpTypeInt 32 0 diff --git a/src/tint/lang/spirv/reader/ast_parser/function.cc b/src/tint/lang/spirv/reader/ast_parser/function.cc index 475e3393ae0..cbc8af00b14 100644 --- a/src/tint/lang/spirv/reader/ast_parser/function.cc +++ b/src/tint/lang/spirv/reader/ast_parser/function.cc @@ -3962,7 +3962,7 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue( } if (op == spv::Op::OpConvertSToF || op == spv::Op::OpConvertUToF || - op == spv::Op::OpConvertFToS || op == spv::Op::OpConvertFToU) { + op == spv::Op::OpConvertFToS || op == spv::Op::OpConvertFToU || op == spv::Op::OpFConvert) { return MakeNumericConversion(inst); } @@ -3987,7 +3987,6 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue( // OpSatConvertUToS // Only in Kernel (OpenCL), not in WebGPU // OpUConvert // Only needed when multiple widths supported // OpSConvert // Only needed when multiple widths supported - // OpFConvert // Only needed when multiple widths supported // OpConvertPtrToU // Not in WebGPU // OpConvertUToPtr // Not in WebGPU // OpPtrCastToGeneric // Not in Vulkan @@ -5227,6 +5226,14 @@ TypedExpression FunctionEmitter::MakeNumericConversion(const spvtools::opt::Inst "point scalar or vector: " << inst.PrettyPrint(); } + } else if (op == spv::Op::OpFConvert) { + if (arg_expr.type->IsFloatScalarOrVector()) { + expr_type = requested_type; + } else { + Fail() << "operand for conversion to float 16 must be floating " + "point scalar or vector: " + << inst.PrettyPrint(); + } } if (expr_type == nullptr) { // The diagnostic has already been emitted. diff --git a/src/tint/lang/spirv/reader/ast_parser/function_conversion_test.cc b/src/tint/lang/spirv/reader/ast_parser/function_conversion_test.cc index 2402b072965..b186bee0b6a 100644 --- a/src/tint/lang/spirv/reader/ast_parser/function_conversion_test.cc +++ b/src/tint/lang/spirv/reader/ast_parser/function_conversion_test.cc @@ -39,6 +39,7 @@ using ::testing::HasSubstr; std::string Preamble() { return R"( OpCapability Shader + OpCapability Float16 OpMemoryModel Logical Simple OpEntryPoint Fragment %100 "main" OpExecutionMode %100 OriginUpperLeft @@ -50,6 +51,7 @@ std::string Preamble() { %uint = OpTypeInt 32 0 %int = OpTypeInt 32 1 %float = OpTypeFloat 32 + %half = OpTypeFloat 16 %true = OpConstantTrue %bool %false = OpConstantFalse %bool @@ -62,14 +64,18 @@ std::string Preamble() { %int_40 = OpConstant %int 40 %float_50 = OpConstant %float 50 %float_60 = OpConstant %float 60 + %half_50 = OpConstant %half 50 + %half_60 = OpConstant %half 60 %ptr_uint = OpTypePointer Function %uint %ptr_int = OpTypePointer Function %int %ptr_float = OpTypePointer Function %float + %ptr_half = OpTypePointer Function %half %v2uint = OpTypeVector %uint 2 %v2int = OpTypeVector %int 2 %v2float = OpTypeVector %float 2 + %v2half = OpTypeVector %half 2 %v2uint_10_20 = OpConstantComposite %v2uint %uint_10 %uint_20 %v2uint_20_10 = OpConstantComposite %v2uint %uint_20 %uint_10 @@ -77,6 +83,8 @@ std::string Preamble() { %v2int_40_30 = OpConstantComposite %v2int %int_40 %int_30 %v2float_50_60 = OpConstantComposite %v2float %float_50 %float_60 %v2float_60_50 = OpConstantComposite %v2float %float_60 %float_50 + %v2half_50_60 = OpConstantComposite %v2half %half_50 %half_60 + %v2half_60_50 = OpConstantComposite %v2half %half_60 %half_50 )"; } @@ -621,9 +629,91 @@ OpFunctionEnd EXPECT_THAT(test::ToString(p->program(), ast_body), HasSubstr("let x_82 = u32(x_600);")); } +TEST_F(SpvUnaryConversionTest, FConvert_BadArg) { + const auto assembly = Preamble() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpFConvert %float %void + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + auto fe = p->function_emitter(100); + EXPECT_FALSE(fe.EmitBody()); + EXPECT_THAT(p->error(), HasSubstr("unhandled expression for ID 2\n%2 = OpTypeVoid")); +} + +TEST_F(SpvUnaryConversionTest, FConvertFToH_Scalar) { + const auto assembly = Preamble() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %30 = OpCopyObject %float %float_50 + %1 = OpFConvert %half %30 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + auto fe = p->function_emitter(100); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + auto ast_body = fe.ast_body(); + EXPECT_THAT(test::ToString(p->program(), ast_body), HasSubstr("let x_1 = f16(x_30);")); +} + +TEST_F(SpvUnaryConversionTest, FConvertHToF_Scalar) { + const auto assembly = Preamble() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %30 = OpCopyObject %half %half_50 + %1 = OpFConvert %float %30 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + auto fe = p->function_emitter(100); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + auto ast_body = fe.ast_body(); + EXPECT_THAT(test::ToString(p->program(), ast_body), HasSubstr("let x_1 = f32(x_30);")); +} + +TEST_F(SpvUnaryConversionTest, FConvertFToH_Vector) { + const auto assembly = Preamble() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %30 = OpCopyObject %v2float %v2float_50_60 + %1 = OpFConvert %v2half %30 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + auto fe = p->function_emitter(100); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + auto ast_body = fe.ast_body(); + EXPECT_THAT(test::ToString(p->program(), ast_body), HasSubstr("let x_1 = vec2h(x_30);")); +} + +TEST_F(SpvUnaryConversionTest, FConvertHToF_Vector) { + const auto assembly = Preamble() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %30 = OpCopyObject %v2half %v2half_50_60 + %1 = OpFConvert %v2float %30 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + auto fe = p->function_emitter(100); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + auto ast_body = fe.ast_body(); + EXPECT_THAT(test::ToString(p->program(), ast_body), HasSubstr("let x_1 = vec2f(x_30);")); +} + // TODO(dneto): OpSConvert // only if multiple widths // TODO(dneto): OpUConvert // only if multiple widths -// TODO(dneto): OpFConvert // only if multiple widths // TODO(dneto): OpSatConvertSToU // Kernel (OpenCL), not in WebGPU // TODO(dneto): OpSatConvertUToS // Kernel (OpenCL), not in WebGPU diff --git a/src/tint/lang/spirv/reader/ast_parser/type.cc b/src/tint/lang/spirv/reader/ast_parser/type.cc index 5c62f29fc42..a001c86a25b 100644 --- a/src/tint/lang/spirv/reader/ast_parser/type.cc +++ b/src/tint/lang/spirv/reader/ast_parser/type.cc @@ -47,6 +47,7 @@ TINT_INSTANTIATE_TYPEINFO(tint::spirv::reader::ast_parser::Void); TINT_INSTANTIATE_TYPEINFO(tint::spirv::reader::ast_parser::Bool); TINT_INSTANTIATE_TYPEINFO(tint::spirv::reader::ast_parser::U32); TINT_INSTANTIATE_TYPEINFO(tint::spirv::reader::ast_parser::F32); +TINT_INSTANTIATE_TYPEINFO(tint::spirv::reader::ast_parser::F16); TINT_INSTANTIATE_TYPEINFO(tint::spirv::reader::ast_parser::I32); TINT_INSTANTIATE_TYPEINFO(tint::spirv::reader::ast_parser::Pointer); TINT_INSTANTIATE_TYPEINFO(tint::spirv::reader::ast_parser::Reference); @@ -180,6 +181,10 @@ ast::Type F32::Build(ProgramBuilder& b) const { return b.ty.f32(); } +ast::Type F16::Build(ProgramBuilder& b) const { + return b.ty.f16(); +} + ast::Type I32::Build(ProgramBuilder& b) const { return b.ty.i32(); } @@ -222,6 +227,7 @@ ast::Type Vector::Build(ProgramBuilder& b) const { [&](const I32*) { return b.ty(prefix + "i"); }, [&](const U32*) { return b.ty(prefix + "u"); }, [&](const F32*) { return b.ty(prefix + "f"); }, + [&](const F16*) { return b.ty(prefix + "h"); }, [&](Default) { return b.ty.vec(type->Build(b), size); }); } @@ -229,9 +235,14 @@ Matrix::Matrix(const Type* t, uint32_t c, uint32_t r) : type(t), columns(c), row Matrix::Matrix(const Matrix&) = default; ast::Type Matrix::Build(ProgramBuilder& b) const { - if (type->Is()) { + if (type->IsAnyOf()) { std::ostringstream ss; - ss << "mat" << columns << "x" << rows << "f"; + ss << "mat" << columns << "x" << rows; + if (type->Is()) { + ss << "f"; + } else { + ss << "h"; + } return b.ty(ss.str()); } return b.ty.mat(type->Build(b), columns, rows); @@ -334,6 +345,8 @@ struct TypeManager::State { ast_parser::U32 const* u32_ = nullptr; /// The lazily-created F32 type ast_parser::F32 const* f32_ = nullptr; + /// The lazily-created F16 type + ast_parser::F16 const* f16_ = nullptr; /// The lazily-created I32 type ast_parser::I32 const* i32_ = nullptr; /// Unique Pointer instances @@ -405,7 +418,7 @@ const Type* Type::UnwrapAll() const { } bool Type::IsFloatScalar() const { - return Is(); + return IsAnyOf(); } bool Type::IsFloatScalarOrVector() const { @@ -425,7 +438,7 @@ bool Type::IsIntegerScalarOrVector() const { } bool Type::IsScalar() const { - return IsAnyOf(); + return IsAnyOf(); } bool Type::IsSignedIntegerVector() const { @@ -478,6 +491,13 @@ const ast_parser::F32* TypeManager::F32() { return state->f32_; } +const ast_parser::F16* TypeManager::F16() { + if (!state->f16_) { + state->f16_ = state->allocator_.Create(); + } + return state->f16_; +} + const ast_parser::I32* TypeManager::I32() { if (!state->i32_) { state->i32_ = state->allocator_.Create(); @@ -579,6 +599,10 @@ std::string F32::String() const { return "f32"; } +std::string F16::String() const { + return "f16"; +} + std::string I32::String() const { return "i32"; } diff --git a/src/tint/lang/spirv/reader/ast_parser/type.h b/src/tint/lang/spirv/reader/ast_parser/type.h index 07d331893e6..4f2456eb636 100644 --- a/src/tint/lang/spirv/reader/ast_parser/type.h +++ b/src/tint/lang/spirv/reader/ast_parser/type.h @@ -157,6 +157,18 @@ struct F32 final : public Castable { #endif // NDEBUG }; +/// `f16` type +struct F16 final : public Castable { + /// @param b the ProgramBuilder used to construct the AST types + /// @returns the constructed ast::Type node for the given type + ast::Type Build(ProgramBuilder& b) const override; + +#ifndef NDEBUG + /// @returns a string representation of the type, for debug purposes only + std::string String() const override; +#endif // NDEBUG +}; + /// `i32` type struct I32 final : public Castable { /// @param b the ProgramBuilder used to construct the AST types @@ -549,6 +561,8 @@ class TypeManager { const ast_parser::U32* U32(); /// @return a F32 type. Repeated calls will return the same pointer. const ast_parser::F32* F32(); + /// @return a F16 type. Repeated calls will return the same pointer. + const ast_parser::F16* F16(); /// @return a I32 type. Repeated calls will return the same pointer. const ast_parser::I32* I32(); /// @param ty the input type. diff --git a/src/tint/lang/spirv/reader/ast_parser/type_test.cc b/src/tint/lang/spirv/reader/ast_parser/type_test.cc index aa171b790ca..caeb43f5287 100644 --- a/src/tint/lang/spirv/reader/ast_parser/type_test.cc +++ b/src/tint/lang/spirv/reader/ast_parser/type_test.cc @@ -41,6 +41,7 @@ TEST(SpirvASTParserTypeTest, SameArgumentsGivesSamePointer) { EXPECT_EQ(ty.Bool(), ty.Bool()); EXPECT_EQ(ty.U32(), ty.U32()); EXPECT_EQ(ty.F32(), ty.F32()); + EXPECT_EQ(ty.F16(), ty.F16()); EXPECT_EQ(ty.I32(), ty.I32()); EXPECT_EQ(ty.Pointer(core::AddressSpace::kUndefined, ty.I32()), ty.Pointer(core::AddressSpace::kUndefined, ty.I32()));