Skip to content
This repository has been archived by the owner on Jan 20, 2024. It is now read-only.

Commit

Permalink
[RISCV] Allow swapped operands in reduction formation (#68634)
Browse files Browse the repository at this point in the history
Very straight forward, but worth landing on it's own in advance of a
more complicated generalization.
  • Loading branch information
preames authored Oct 23, 2023
1 parent aab0626 commit 25da9bb
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 27 deletions.
51 changes: 28 additions & 23 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11379,16 +11379,20 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
const unsigned ReduceOpc = getVecReduceOpcode(Opc);
assert(Opc == ISD::getVecReduceBaseOpcode(ReduceOpc) &&
"Inconsistent mappings");
const SDValue LHS = N->getOperand(0);
const SDValue RHS = N->getOperand(1);
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);

if (!LHS.hasOneUse() || !RHS.hasOneUse())
return SDValue();

if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
std::swap(LHS, RHS);

if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
!isa<ConstantSDNode>(RHS.getOperand(1)))
return SDValue();

uint64_t RHSIdx = cast<ConstantSDNode>(RHS.getOperand(1))->getLimitedValue();
SDValue SrcVec = RHS.getOperand(0);
EVT SrcVecVT = SrcVec.getValueType();
assert(SrcVecVT.getVectorElementType() == VT);
Expand All @@ -11401,14 +11405,17 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
// match binop (extract_vector_elt V, 0), (extract_vector_elt V, 1) to
// reduce_op (extract_subvector [2 x VT] from V). This will form the
// root of our reduction tree. TODO: We could extend this to any two
// adjacent constant indices if desired.
// adjacent aligned constant indices if desired.
if (LHS.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
LHS.getOperand(0) == SrcVec && isNullConstant(LHS.getOperand(1)) &&
isOneConstant(RHS.getOperand(1))) {
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
DAG.getVectorIdxConstant(0, DL));
return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
LHS.getOperand(0) == SrcVec && isa<ConstantSDNode>(LHS.getOperand(1))) {
uint64_t LHSIdx =
cast<ConstantSDNode>(LHS.getOperand(1))->getLimitedValue();
if (0 == std::min(LHSIdx, RHSIdx) && 1 == std::max(LHSIdx, RHSIdx)) {
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
DAG.getVectorIdxConstant(0, DL));
return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
}
}

// Match (binop (reduce (extract_subvector V, 0),
Expand All @@ -11420,20 +11427,18 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
SDValue ReduceVec = LHS.getOperand(0);
if (ReduceVec.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
ReduceVec.hasOneUse() && ReduceVec.getOperand(0) == RHS.getOperand(0) &&
isNullConstant(ReduceVec.getOperand(1))) {
uint64_t Idx = cast<ConstantSDNode>(RHS.getOperand(1))->getLimitedValue();
if (ReduceVec.getValueType().getVectorNumElements() == Idx) {
// For illegal types (e.g. 3xi32), most will be combined again into a
// wider (hopefully legal) type. If this is a terminal state, we are
// relying on type legalization here to produce something reasonable
// and this lowering quality could probably be improved. (TODO)
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, Idx + 1);
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
DAG.getVectorIdxConstant(0, DL));
auto Flags = ReduceVec->getFlags();
Flags.intersectWith(N->getFlags());
return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
}
isNullConstant(ReduceVec.getOperand(1)) &&
ReduceVec.getValueType().getVectorNumElements() == RHSIdx) {
// For illegal types (e.g. 3xi32), most will be combined again into a
// wider (hopefully legal) type. If this is a terminal state, we are
// relying on type legalization here to produce something reasonable
// and this lowering quality could probably be improved. (TODO)
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, RHSIdx + 1);
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
DAG.getVectorIdxConstant(0, DL));
auto Flags = ReduceVec->getFlags();
Flags.intersectWith(N->getFlags());
return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
}

return SDValue();
Expand Down
66 changes: 62 additions & 4 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ define i32 @reduce_sum_4xi32(<4 x i32> %v) {
ret i32 %add2
}


define i32 @reduce_sum_8xi32(<8 x i32> %v) {
; CHECK-LABEL: reduce_sum_8xi32:
; CHECK: # %bb.0:
Expand Down Expand Up @@ -449,6 +448,68 @@ define i32 @reduce_sum_16xi32_prefix15(ptr %p) {
ret i32 %add13
}

; Check that we can match with the operand ordered reversed, but the
; reduction order unchanged.
define i32 @reduce_sum_4xi32_op_order(<4 x i32> %v) {
; CHECK-LABEL: reduce_sum_4xi32_op_order:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma
; CHECK-NEXT: vmv.s.x v9, zero
; CHECK-NEXT: vredsum.vs v8, v8, v9
; CHECK-NEXT: vmv.x.s a0, v8
; CHECK-NEXT: ret
%e0 = extractelement <4 x i32> %v, i32 0
%e1 = extractelement <4 x i32> %v, i32 1
%e2 = extractelement <4 x i32> %v, i32 2
%e3 = extractelement <4 x i32> %v, i32 3
%add0 = add i32 %e1, %e0
%add1 = add i32 %e2, %add0
%add2 = add i32 %add1, %e3
ret i32 %add2
}

; Negative test - Reduction order isn't compatibile with current
; incremental matching scheme.
define i32 @reduce_sum_4xi32_reduce_order(<4 x i32> %v) {
; RV32-LABEL: reduce_sum_4xi32_reduce_order:
; RV32: # %bb.0:
; RV32-NEXT: vsetivli zero, 1, e32, m1, ta, ma
; RV32-NEXT: vmv.x.s a0, v8
; RV32-NEXT: vslidedown.vi v9, v8, 1
; RV32-NEXT: vmv.x.s a1, v9
; RV32-NEXT: vslidedown.vi v9, v8, 2
; RV32-NEXT: vmv.x.s a2, v9
; RV32-NEXT: vslidedown.vi v8, v8, 3
; RV32-NEXT: vmv.x.s a3, v8
; RV32-NEXT: add a1, a1, a2
; RV32-NEXT: add a0, a0, a3
; RV32-NEXT: add a0, a0, a1
; RV32-NEXT: ret
;
; RV64-LABEL: reduce_sum_4xi32_reduce_order:
; RV64: # %bb.0:
; RV64-NEXT: vsetivli zero, 1, e32, m1, ta, ma
; RV64-NEXT: vmv.x.s a0, v8
; RV64-NEXT: vslidedown.vi v9, v8, 1
; RV64-NEXT: vmv.x.s a1, v9
; RV64-NEXT: vslidedown.vi v9, v8, 2
; RV64-NEXT: vmv.x.s a2, v9
; RV64-NEXT: vslidedown.vi v8, v8, 3
; RV64-NEXT: vmv.x.s a3, v8
; RV64-NEXT: add a1, a1, a2
; RV64-NEXT: add a0, a0, a3
; RV64-NEXT: addw a0, a0, a1
; RV64-NEXT: ret
%e0 = extractelement <4 x i32> %v, i32 0
%e1 = extractelement <4 x i32> %v, i32 1
%e2 = extractelement <4 x i32> %v, i32 2
%e3 = extractelement <4 x i32> %v, i32 3
%add0 = add i32 %e1, %e2
%add1 = add i32 %e0, %add0
%add2 = add i32 %add1, %e3
ret i32 %add2
}

;; Most of the cornercases are exercised above, the following just
;; makes sure that other opcodes work as expected.

Expand Down Expand Up @@ -923,6 +984,3 @@ define float @reduce_fadd_4xi32_non_associative2(ptr %p) {
}


;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
; RV32: {{.*}}
; RV64: {{.*}}

0 comments on commit 25da9bb

Please sign in to comment.