Skip to content

Commit

Permalink
vfalu: support unorder vfreduction
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaofeibao-xjtu authored and Ziyue-Zhang committed Sep 2, 2023
1 parent 5060703 commit c51acf8
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 20 deletions.
9 changes: 6 additions & 3 deletions src/main/scala/yunsuan/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,9 @@ package object yunsuan {
def vfclass = LiteralCat(0.U(1.W), 0.U(1.W), 0.U(1.W), VfaddOpCode.fclass)
def vfmv_f_s = LiteralCat(0.U(1.W), 0.U(1.W), 0.U(1.W), VfaddOpCode.fmv_f_s)
def vfmv_s_f = LiteralCat(0.U(1.W), 0.U(1.W), 0.U(1.W), VfaddOpCode.fmv_s_f)
def vfredusum = LiteralCat(0.U(1.W), 0.U(1.W), 0.U(1.W), VfaddOpCode.fsum_ure)
def vfredmax = LiteralCat(0.U(1.W), 0.U(1.W), 0.U(1.W), VfaddOpCode.fmax_re)
def vfredmin = LiteralCat(0.U(1.W), 0.U(1.W), 0.U(1.W), VfaddOpCode.fmin_re)
}

object VfaddOpCode {
Expand All @@ -445,9 +448,9 @@ package object yunsuan {
def fclass = "b01111".U(5.W)
def fmv_f_s = "b10001".U(5.W)
def fmv_s_f = "b10010".U(5.W)
// def fsum_re = "b10000".U(5.W) // unorder
// def fmin_re = "b10001".U(5.W)
// def fmax_re = "b10010".U(5.W)
def fsum_ure = "b10011".U(5.W) // unorder
def fmin_re = "b10100".U(5.W)
def fmax_re = "b10101".U(5.W)
}

object VfmaType{
Expand Down
150 changes: 133 additions & 17 deletions src/main/scala/yunsuan/vector/VectorFloatAdder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class VectorFloatAdder() extends Module {
val op_code = Input (UInt(5.W))
val fp_aIsFpCanonicalNAN = Input (Bool())
val fp_bIsFpCanonicalNAN = Input (Bool())
val maskForReduction = Input(UInt(8.W))

val fp_result = Output(UInt(floatWidth.W))
val fflags = Output(UInt(20.W))
Expand All @@ -69,6 +70,9 @@ class VectorFloatAdder() extends Module {
val is_fgt = io.op_code === VfaddOpCode.fgt
val is_fge = io.op_code === VfaddOpCode.fge
val is_fclass = io.op_code === VfaddOpCode.fclass
val is_fsum_ure = io.op_code === VfaddOpCode.fsum_ure
val is_fmin_re = io.op_code === VfaddOpCode.fmin_re
val is_fmax_re = io.op_code === VfaddOpCode.fmax_re

val fast_is_sub = io.op_code(0)

Expand All @@ -93,6 +97,7 @@ class VectorFloatAdder() extends Module {
U_F32_Mixed_0.io.op_code := io.op_code
U_F32_Mixed_0.io.fp_aIsFpCanonicalNAN := io.fp_aIsFpCanonicalNAN
U_F32_Mixed_0.io.fp_bIsFpCanonicalNAN := io.fp_bIsFpCanonicalNAN
U_F32_Mixed_0.io.maskForReduction := Cat(io.maskForReduction(4), io.maskForReduction(0))
val U_F32_0_result = U_F32_Mixed_0.io.fp_c
val U_F32_0_fflags = U_F32_Mixed_0.io.fflags
val U_F16_0_result = U_F32_Mixed_0.io.fp_c(15,0)
Expand All @@ -113,6 +118,11 @@ class VectorFloatAdder() extends Module {
U_F32_Mixed_1.io.op_code := io.op_code
U_F32_Mixed_1.io.fp_aIsFpCanonicalNAN := io.fp_aIsFpCanonicalNAN
U_F32_Mixed_1.io.fp_bIsFpCanonicalNAN := io.fp_bIsFpCanonicalNAN
U_F32_Mixed_1.io.maskForReduction := Mux(
fp_format === 1.U,
Cat(io.maskForReduction(5), io.maskForReduction(1)),
Cat(io.maskForReduction(6), io.maskForReduction(2))
)
val U_F32_1_result = U_F32_Mixed_1.io.fp_c
val U_F32_1_fflags = U_F32_Mixed_1.io.fflags
val U_F16_2_result = U_F32_Mixed_1.io.fp_c(15,0)
Expand All @@ -132,6 +142,7 @@ class VectorFloatAdder() extends Module {
U_F64_Widen_0.io.op_code := io.op_code
U_F64_Widen_0.io.fp_aIsFpCanonicalNAN := io.fp_aIsFpCanonicalNAN
U_F64_Widen_0.io.fp_bIsFpCanonicalNAN := io.fp_bIsFpCanonicalNAN
U_F64_Widen_0.io.maskForReduction := Cat(io.maskForReduction(4), io.maskForReduction(0))
val U_F64_Widen_0_result = U_F64_Widen_0.io.fp_c
val U_F64_Widen_0_fflags = U_F64_Widen_0.io.fflags

Expand All @@ -144,6 +155,7 @@ class VectorFloatAdder() extends Module {
U_F16_1.io.op_code := io.op_code
U_F16_1.io.fp_aIsFpCanonicalNAN := io.fp_aIsFpCanonicalNAN
U_F16_1.io.fp_bIsFpCanonicalNAN := io.fp_bIsFpCanonicalNAN
U_F16_1.io.maskForReduction := Cat(io.maskForReduction(5), io.maskForReduction(1))
val U_F16_1_result = U_F16_1.io.fp_c
val U_F16_1_fflags = U_F16_1.io.fflags

Expand All @@ -156,6 +168,7 @@ class VectorFloatAdder() extends Module {
U_F16_3.io.op_code := io.op_code
U_F16_3.io.fp_aIsFpCanonicalNAN := io.fp_aIsFpCanonicalNAN
U_F16_3.io.fp_bIsFpCanonicalNAN := io.fp_bIsFpCanonicalNAN
U_F16_3.io.maskForReduction := Cat(io.maskForReduction(7), io.maskForReduction(3))
val U_F16_3_result = U_F16_3.io.fp_c
val U_F16_3_fflags = U_F16_3.io.fflags

Expand Down Expand Up @@ -273,6 +286,7 @@ private[vector] class FloatAdderF32WidenF16MixedPipeline(val is_print:Boolean =
val op_code = if (hasMinMaxCompare) Input(UInt(5.W)) else Input(UInt(0.W))
val fp_aIsFpCanonicalNAN = Input (Bool())
val fp_bIsFpCanonicalNAN = Input (Bool())
val maskForReduction = Input(UInt(2.W))
})
val res_is_f32 = io.fp_format(0).asBool
val fp_a_16as32 = Cat(io.fp_a(15), Cat(0.U(3.W),io.fp_a(14,10)), Cat(io.fp_a(9,0),0.U(13.W)))
Expand Down Expand Up @@ -398,6 +412,9 @@ private[vector] class FloatAdderF32WidenF16MixedPipeline(val is_print:Boolean =
val is_fclass = io.op_code === VfaddOpCode.fclass
val is_fmerge = io.op_code === VfaddOpCode.fmerge
val is_fmove = (io.op_code === VfaddOpCode.fmove) || (io.op_code === VfaddOpCode.fmv_f_s) || (io.op_code === VfaddOpCode.fmv_s_f)
val is_fsum_ure = io.op_code === VfaddOpCode.fsum_ure
val is_fmin_re = io.op_code === VfaddOpCode.fmin_re
val is_fmax_re = io.op_code === VfaddOpCode.fmax_re
val fp_a_sign = fp_a_to32.head(1)
val fp_b_sign = fp_b_to32.head(1)
val fp_b_sign_is_greater = fp_a_sign & !fp_b_sign
Expand Down Expand Up @@ -491,6 +508,33 @@ private[vector] class FloatAdderF32WidenF16MixedPipeline(val is_print:Boolean =
fp_a_is_SNAN,
fp_a_is_NAN & !fp_a_is_SNAN
)))
val is_fsum_ure_notmasked = is_fsum_ure && io.maskForReduction.andR
val is_fsum_ure_masked = is_fsum_ure && io.maskForReduction.orR
val result_fsum_ure_masked = Mux(
io.maskForReduction === 0.U,
0.U(floatWidth.W),
Mux(io.maskForReduction(0), io.fp_a, io.fp_b)
)
val outInf = Mux(
res_is_f32,
Cat(is_fmax_re, Fill(8, 1.U), 0.U(23.W)),
Cat(0.U(16.W), is_fmax_re, Fill(5, 1.U), 0.U(10.W))
)
val re_masked_one_out = Mux(
io.maskForReduction(0),
Mux(fp_a_is_NAN, out_NAN, io.fp_a),
Mux(fp_b_is_NAN, out_NAN, io.fp_b)
)
val result_fmax_re = Mux(
io.maskForReduction === 0.U,
outInf,
Mux(io.maskForReduction.andR, result_max, re_masked_one_out)
)
val result_fmin_re = Mux(
io.maskForReduction === 0.U,
outInf,
Mux(io.maskForReduction.andR, result_min, re_masked_one_out)
)
val result_stage0 = Mux1H(
Seq(
is_min,
Expand All @@ -501,12 +545,15 @@ private[vector] class FloatAdderF32WidenF16MixedPipeline(val is_print:Boolean =
is_fle,
is_fgt,
is_fge,
is_fsgnj,
is_fsgnj,
is_fsgnjn,
is_fsgnjx,
is_fclass,
is_fmerge,
is_fmove
is_fmove,
is_fsum_ure_masked,
is_fmax_re,
is_fmin_re,
),
Seq(
result_min,
Expand All @@ -522,15 +569,18 @@ private[vector] class FloatAdderF32WidenF16MixedPipeline(val is_print:Boolean =
result_fsgnjx,
result_fclass,
result_fmerge,
result_fmove
result_fmove,
result_fsum_ure_masked,
result_fmax_re,
result_fmin_re,
)
)
val fflags_NV_stage0 = ((is_min | is_max) & (fp_a_is_SNAN | fp_b_is_SNAN)) |
((is_feq | is_fne) & (fp_a_is_SNAN | fp_b_is_SNAN)) |
((is_flt | is_fle | is_fgt | is_fge) & (fp_a_is_NAN | fp_b_is_NAN))
val fflags_stage0 = Cat(fflags_NV_stage0,0.U(4.W))
io.fp_c := Mux(RegNext(is_add | is_sub),float_adder_result,RegNext(result_stage0))
io.fflags := Mux(RegNext(is_add | is_sub),float_adder_fflags,RegNext(fflags_stage0))
io.fp_c := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked),float_adder_result,RegNext(result_stage0))
io.fflags := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked),float_adder_fflags,RegNext(fflags_stage0))
}
else {
io.fp_c := float_adder_result
Expand Down Expand Up @@ -1520,6 +1570,7 @@ private[vector] class FloatAdderF64WidenPipeline(val is_print:Boolean = false,va
val op_code = if (hasMinMaxCompare) Input(UInt(5.W)) else Input(UInt(0.W))
val fp_aIsFpCanonicalNAN = Input(Bool())
val fp_bIsFpCanonicalNAN = Input(Bool())
val maskForReduction = Input(UInt(2.W))
})
// val fp_a_to64_is_denormal = !io.widen_a(30,23).orR
// val fp_a_lshift = Wire(UInt(23.W))
Expand Down Expand Up @@ -1626,6 +1677,9 @@ private[vector] class FloatAdderF64WidenPipeline(val is_print:Boolean = false,va
val is_fclass = io.op_code === VfaddOpCode.fclass
val is_fmerge = io.op_code === VfaddOpCode.fmerge
val is_fmove = (io.op_code === VfaddOpCode.fmove) || (io.op_code === VfaddOpCode.fmv_f_s) || (io.op_code === VfaddOpCode.fmv_s_f)
val is_fsum_ure = io.op_code === VfaddOpCode.fsum_ure
val is_fmin_re = io.op_code === VfaddOpCode.fmin_re
val is_fmax_re = io.op_code === VfaddOpCode.fmax_re
val fp_a_sign = io.fp_a.head(1)
val fp_b_sign = io.fp_b.head(1)
val fp_b_sign_is_greater = fp_a_sign & !fp_b_sign
Expand Down Expand Up @@ -1705,6 +1759,29 @@ private[vector] class FloatAdderF64WidenPipeline(val is_print:Boolean = false,va
fp_a_is_SNAN,
fp_a_is_NAN & !fp_a_is_SNAN
)))
val is_fsum_ure_notmasked = is_fsum_ure && io.maskForReduction.andR
val is_fsum_ure_masked = is_fsum_ure && io.maskForReduction.orR
val result_fsum_ure_masked = Mux(
io.maskForReduction === 0.U,
0.U(floatWidth.W),
Mux(io.maskForReduction(0), io.fp_a, io.fp_b)
)
val outInf = Cat(is_fmax_re, Fill(exponentWidth, 1.U), 0.U((significandWidth-1).W))
val re_masked_one_out = Mux(
io.maskForReduction(0),
Mux(fp_a_is_NAN, out_NAN, io.fp_a),
Mux(fp_b_is_NAN, out_NAN, io.fp_b)
)
val result_fmax_re = Mux(
io.maskForReduction === 0.U,
outInf,
Mux(io.maskForReduction.andR, result_max, re_masked_one_out)
)
val result_fmin_re = Mux(
io.maskForReduction === 0.U,
outInf,
Mux(io.maskForReduction.andR, result_min, re_masked_one_out)
)
val result_stage0 = Mux1H(
Seq(
is_min,
Expand All @@ -1715,12 +1792,15 @@ private[vector] class FloatAdderF64WidenPipeline(val is_print:Boolean = false,va
is_fle,
is_fgt,
is_fge,
is_fsgnj,
is_fsgnj,
is_fsgnjn,
is_fsgnjx,
is_fclass,
is_fmerge,
is_fmove
is_fmove,
is_fsum_ure_masked,
is_fmax_re,
is_fmin_re,
),
Seq(
result_min,
Expand All @@ -1736,15 +1816,18 @@ private[vector] class FloatAdderF64WidenPipeline(val is_print:Boolean = false,va
result_fsgnjx,
result_fclass,
result_fmerge,
result_fmove
result_fmove,
result_fsum_ure_masked,
result_fmax_re,
result_fmin_re,
)
)
val fflags_NV_stage0 = ((is_min | is_max) & (fp_a_is_SNAN | fp_b_is_SNAN)) |
((is_feq | is_fne) & (fp_a_is_SNAN | fp_b_is_SNAN)) |
((is_flt | is_fle | is_fgt | is_fge) & (fp_a_is_NAN | fp_b_is_NAN))
val fflags_stage0 = Cat(fflags_NV_stage0,0.U(4.W))
io.fp_c := Mux(RegNext(is_add | is_sub),float_adder_result,RegNext(result_stage0))
io.fflags := Mux(RegNext(is_add | is_sub),float_adder_fflags,RegNext(fflags_stage0))
val fflags_stage0 = Cat(fflags_NV_stage0, 0.U(4.W))
io.fp_c := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked), float_adder_result, RegNext(result_stage0))
io.fflags := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked), float_adder_fflags, RegNext(fflags_stage0))
}
else {
io.fp_c := float_adder_result
Expand Down Expand Up @@ -2170,6 +2253,7 @@ private[vector] class FloatAdderF16Pipeline(val is_print:Boolean = false,val has
val op_code = if (hasMinMaxCompare) Input(UInt(5.W)) else Input(UInt(0.W))
val fp_aIsFpCanonicalNAN = Input(Bool())
val fp_bIsFpCanonicalNAN = Input(Bool())
val maskForReduction = Input(UInt(2.W))
})
val EOP = (io.fp_a.head(1) ^ io.is_sub ^ io.fp_b.head(1)).asBool
val U_far_path = Module(new FarPathF16Pipeline(exponentWidth = exponentWidth,significandWidth = significandWidth, is_print = is_print, hasMinMaxCompare=hasMinMaxCompare))
Expand Down Expand Up @@ -2238,6 +2322,9 @@ private[vector] class FloatAdderF16Pipeline(val is_print:Boolean = false,val has
val is_fclass = io.op_code === VfaddOpCode.fclass
val is_fmerge = io.op_code === VfaddOpCode.fmerge
val is_fmove = (io.op_code === VfaddOpCode.fmove) || (io.op_code === VfaddOpCode.fmv_f_s) || (io.op_code === VfaddOpCode.fmv_s_f)
val is_fsum_ure = io.op_code === VfaddOpCode.fsum_ure
val is_fmin_re = io.op_code === VfaddOpCode.fmin_re
val is_fmax_re = io.op_code === VfaddOpCode.fmax_re
val fp_a_sign = io.fp_a.head(1)
val fp_b_sign = io.fp_b.head(1)
val fp_b_sign_is_greater = fp_a_sign & !fp_b_sign
Expand Down Expand Up @@ -2317,6 +2404,29 @@ private[vector] class FloatAdderF16Pipeline(val is_print:Boolean = false,val has
fp_a_is_SNAN,
fp_a_is_NAN & !fp_a_is_SNAN
)))
val is_fsum_ure_notmasked = is_fsum_ure && io.maskForReduction.andR
val is_fsum_ure_masked = is_fsum_ure && io.maskForReduction.orR
val result_fsum_ure_masked = Mux(
io.maskForReduction === 0.U,
0.U(floatWidth.W),
Mux(io.maskForReduction(0), io.fp_a, io.fp_b)
)
val outInf = Cat(is_fmax_re, Fill(exponentWidth, 1.U), 0.U((significandWidth-1).W))
val re_masked_one_out = Mux(
io.maskForReduction(0),
Mux(fp_a_is_NAN, out_NAN, io.fp_a),
Mux(fp_b_is_NAN, out_NAN, io.fp_b)
)
val result_fmax_re = Mux(
io.maskForReduction === 0.U,
outInf,
Mux(io.maskForReduction.andR, result_max, re_masked_one_out)
)
val result_fmin_re = Mux(
io.maskForReduction === 0.U,
outInf,
Mux(io.maskForReduction.andR, result_min, re_masked_one_out)
)
val result_stage0 = Mux1H(
Seq(
is_min,
Expand All @@ -2327,12 +2437,15 @@ private[vector] class FloatAdderF16Pipeline(val is_print:Boolean = false,val has
is_fle,
is_fgt,
is_fge,
is_fsgnj,
is_fsgnj,
is_fsgnjn,
is_fsgnjx,
is_fclass,
is_fmerge,
is_fmove
is_fmove,
is_fsum_ure_masked,
is_fmax_re,
is_fmin_re,
),
Seq(
result_min,
Expand All @@ -2348,15 +2461,18 @@ private[vector] class FloatAdderF16Pipeline(val is_print:Boolean = false,val has
result_fsgnjx,
result_fclass,
result_fmerge,
result_fmove
result_fmove,
result_fsum_ure_masked,
result_fmax_re,
result_fmin_re,
)
)
val fflags_NV_stage0 = ((is_min | is_max) & (fp_a_is_SNAN | fp_b_is_SNAN)) |
((is_feq | is_fne) & (fp_a_is_SNAN | fp_b_is_SNAN)) |
((is_flt | is_fle | is_fgt | is_fge) & (fp_a_is_NAN | fp_b_is_NAN))
val fflags_stage0 = Cat(fflags_NV_stage0,0.U(4.W))
io.fp_c := Mux(RegNext(is_add | is_sub),float_adder_result,RegNext(result_stage0))
io.fflags := Mux(RegNext(is_add | is_sub),float_adder_fflags,RegNext(fflags_stage0))
val fflags_stage0 = Cat(fflags_NV_stage0, 0.U(4.W))
io.fp_c := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked), float_adder_result, RegNext(result_stage0))
io.fflags := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked), float_adder_fflags, RegNext(fflags_stage0))
}
else {
io.fp_c := float_adder_result
Expand Down
1 change: 1 addition & 0 deletions src/test/scala/top/VectorSimTop.scala
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ class SimTop() extends VPUTestModule {
vfa.io.is_vec := true.B // TODO: check it
vfa.io.fp_aIsFpCanonicalNAN := false.B
vfa.io.fp_bIsFpCanonicalNAN := false.B
vfa.io.maskForReduction := 0.U
vfa_result.result(i) := vfa.io.fp_result
vfa_result.fflags(i) := vfa.io.fflags
vfa_result.vxsat := 0.U // DontCare
Expand Down

0 comments on commit c51acf8

Please sign in to comment.