diff --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h index 37653631cc238..cb5a4c14b364c 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h @@ -522,6 +522,13 @@ std::optional isConstantOrConstantSplatVector(MachineInstr &MI, const MachineRegisterInfo &MRI); +/// Determines if \p MI defines a float constant integer or a splat vector of +/// float constant integers. +/// \returns the float constant or std::nullopt. +std::optional +isConstantOrConstantSplatVectorFP(MachineInstr &MI, + const MachineRegisterInfo &MRI); + /// Attempt to match a unary predicate against a scalar/splat constant or every /// element of a constant G_BUILD_VECTOR. If \p ConstVal is null, the source /// value was undef. diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp index 8c1e41ea106ec..79382933a1f42 100644 --- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp +++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp @@ -1517,6 +1517,18 @@ llvm::isConstantOrConstantSplatVector(MachineInstr &MI, return APInt(ScalarSize, *MaybeCst, true); } +std::optional +llvm::isConstantOrConstantSplatVectorFP(MachineInstr &MI, + const MachineRegisterInfo &MRI) { + Register Def = MI.getOperand(0).getReg(); + if (auto FpConst = getFConstantVRegValWithLookThrough(Def, MRI)) + return FpConst->Value; + auto MaybeCstFP = getFConstantSplat(Def, MRI, /*allowUndef=*/false); + if (!MaybeCstFP) + return std::nullopt; + return MaybeCstFP->Value; +} + bool llvm::isNullOrNullSplat(const MachineInstr &MI, const MachineRegisterInfo &MRI, bool AllowUndefs) { switch (MI.getOpcode()) { diff --git a/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp b/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp index 1ff7fd956d015..9163663c2b776 100644 --- a/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp +++ b/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp @@ -77,6 +77,15 @@ static const LLT NXV3P0 = LLT::scalable_vector(3, P0); static const LLT NXV4P0 = LLT::scalable_vector(4, P0); static const LLT NXV12P0 = LLT::scalable_vector(12, P0); +static void collectNonCopyMI(SmallVectorImpl &MIList, + MachineFunction *MF) { + for (auto &MBB : *MF) + for (MachineInstr &MI : MBB) { + if (MI.getOpcode() != TargetOpcode::COPY) + MIList.push_back(&MI); + } +} + TEST(GISelUtilsTest, getGCDType) { EXPECT_EQ(S1, getGCDType(S1, S1)); EXPECT_EQ(S32, getGCDType(S32, S32)); @@ -408,4 +417,90 @@ TEST_F(AArch64GISelMITest, ConstFalseTest) { } } } + +TEST_F(AMDGPUGISelMITest, isConstantOrConstantSplatVectorFP) { + StringRef MIRString = + " %cst0:_(s32) = G_FCONSTANT float 2.000000e+00\n" + " %cst1:_(s32) = G_FCONSTANT float 0.0\n" + " %cst2:_(s64) = G_FCONSTANT double 3.000000e-02\n" + " %cst3:_(s32) = G_CONSTANT i32 2\n" + " %cst4:_(<2 x s32>) = G_BUILD_VECTOR %cst0(s32), %cst0(s32)\n" + " %cst5:_(<2 x s32>) = G_BUILD_VECTOR %cst1(s32), %cst0(s32)\n" + " %cst6:_(<2 x s64>) = G_BUILD_VECTOR %cst2(s64), %cst2(s64)\n" + " %cst7:_(<2 x s32>) = G_BUILD_VECTOR %cst3(s32), %cst3:_(s32)\n" + " %cst8:_(<4 x s32>) = G_CONCAT_VECTORS %cst4:_(<2 x s32>), %cst4:_(<2 " + "x s32>)\n" + " %cst9:_(<4 x s64>) = G_CONCAT_VECTORS %cst6:_(<2 x s64>), %cst6:_(<2 " + "x s64>)\n" + " %cst10:_(<4 x s32>) = G_CONCAT_VECTORS %cst4:_(<2 x s32>), %cst5:_(<2 " + "x s32>)\n" + " %cst11:_(<4 x s32>) = G_CONCAT_VECTORS %cst7:_(<2 x s32>), %cst7:_(<2 " + "x s32>)\n" + " %cst12:_(s32) = G_IMPLICIT_DEF \n" + " %cst13:_(<2 x s32>) = G_BUILD_VECTOR %cst12(s32), %cst12(s32)\n" + " %cst14:_(<2 x s32>) = G_BUILD_VECTOR %cst0(s32), %cst12(s32)\n" + " %cst15:_(<4 x s32>) = G_CONCAT_VECTORS %cst4:_(<2 x s32>), " + "%cst14:_(<2 " + "x s32>)\n"; + + SmallVector MIList; + + setUp(MIRString); + if (!TM) + GTEST_SKIP(); + + collectNonCopyMI(MIList, MF); + + EXPECT_TRUE(isConstantOrConstantSplatVectorFP(*MIList[0], *MRI).has_value()); + auto val = isConstantOrConstantSplatVectorFP(*MIList[0], *MRI).value(); + EXPECT_EQ(2.0, val.convertToFloat()); + + EXPECT_TRUE(isConstantOrConstantSplatVectorFP(*MIList[1], *MRI).has_value()); + val = isConstantOrConstantSplatVectorFP(*MIList[1], *MRI).value(); + EXPECT_EQ(0.0, val.convertToFloat()); + + EXPECT_TRUE(isConstantOrConstantSplatVectorFP(*MIList[2], *MRI).has_value()); + val = isConstantOrConstantSplatVectorFP(*MIList[2], *MRI).value(); + EXPECT_EQ(0.03, val.convertToDouble()); + + EXPECT_FALSE(isConstantOrConstantSplatVectorFP(*MIList[3], *MRI).has_value()); + + EXPECT_TRUE(isConstantOrConstantSplatVectorFP(*MIList[4], *MRI).has_value()); + val = isConstantOrConstantSplatVectorFP(*MIList[4], *MRI).value(); + EXPECT_EQ(2.0, val.convertToFloat()); + + EXPECT_FALSE(isConstantOrConstantSplatVectorFP(*MIList[5], *MRI).has_value()); + + EXPECT_TRUE(isConstantOrConstantSplatVectorFP(*MIList[6], *MRI).has_value()); + val = isConstantOrConstantSplatVectorFP(*MIList[6], *MRI).value(); + EXPECT_EQ(0.03, val.convertToDouble()); + + EXPECT_FALSE(isConstantOrConstantSplatVectorFP(*MIList[7], *MRI).has_value()); + + EXPECT_TRUE(isConstantOrConstantSplatVectorFP(*MIList[8], *MRI).has_value()); + val = isConstantOrConstantSplatVectorFP(*MIList[8], *MRI).value(); + EXPECT_EQ(2.0, val.convertToFloat()); + + EXPECT_TRUE(isConstantOrConstantSplatVectorFP(*MIList[9], *MRI).has_value()); + val = isConstantOrConstantSplatVectorFP(*MIList[9], *MRI).value(); + EXPECT_EQ(0.03, val.convertToDouble()); + + EXPECT_FALSE( + isConstantOrConstantSplatVectorFP(*MIList[10], *MRI).has_value()); + + EXPECT_FALSE( + isConstantOrConstantSplatVectorFP(*MIList[11], *MRI).has_value()); + + EXPECT_FALSE( + isConstantOrConstantSplatVectorFP(*MIList[12], *MRI).has_value()); + + EXPECT_FALSE( + isConstantOrConstantSplatVectorFP(*MIList[13], *MRI).has_value()); + + EXPECT_FALSE( + isConstantOrConstantSplatVectorFP(*MIList[14], *MRI).has_value()); + + EXPECT_FALSE( + isConstantOrConstantSplatVectorFP(*MIList[15], *MRI).has_value()); +} }