Skip to content

Commit

Permalink
vfalu: support vfredosum vfwredosum
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaofeibao-xjtu authored and Ziyue-Zhang committed Sep 13, 2023
1 parent c51acf8 commit b60860c
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 31 deletions.
47 changes: 25 additions & 22 deletions src/main/scala/yunsuan/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -426,31 +426,34 @@ package object yunsuan {
def vfredusum = LiteralCat(0.U(1.W), 0.U(1.W), 0.U(1.W), VfaddOpCode.fsum_ure)
def vfredmax = LiteralCat(0.U(1.W), 0.U(1.W), 0.U(1.W), VfaddOpCode.fmax_re)
def vfredmin = LiteralCat(0.U(1.W), 0.U(1.W), 0.U(1.W), VfaddOpCode.fmin_re)
def vfredosum = LiteralCat(0.U(1.W), 0.U(1.W), 0.U(1.W), VfaddOpCode.fsum_ore)
def vfwredosum= LiteralCat(0.U(1.W), 0.U(1.W), 1.U(1.W), VfaddOpCode.fsum_ore)
}

object VfaddOpCode {
def dummy = "b11111".U(5.W)
def fadd = "b00000".U(5.W)
def fsub = "b00001".U(5.W)
def fmin = "b00010".U(5.W)
def fmax = "b00011".U(5.W)
def fmerge = "b00100".U(5.W)
def fmove = "b00101".U(5.W)
def fsgnj = "b00110".U(5.W)
def fsgnjn = "b00111".U(5.W)
def fsgnjx = "b01000".U(5.W)
def feq = "b01001".U(5.W)
def fne = "b01010".U(5.W)
def flt = "b01011".U(5.W)
def fle = "b01100".U(5.W)
def fgt = "b01101".U(5.W)
def fge = "b01110".U(5.W)
def fclass = "b01111".U(5.W)
def fmv_f_s = "b10001".U(5.W)
def fmv_s_f = "b10010".U(5.W)
def fsum_ure = "b10011".U(5.W) // unorder
def fmin_re = "b10100".U(5.W)
def fmax_re = "b10101".U(5.W)
def dummy = "b11111".U(5.W)
def fadd = "b00000".U(5.W)
def fsub = "b00001".U(5.W)
def fmin = "b00010".U(5.W)
def fmax = "b00011".U(5.W)
def fmerge = "b00100".U(5.W)
def fmove = "b00101".U(5.W)
def fsgnj = "b00110".U(5.W)
def fsgnjn = "b00111".U(5.W)
def fsgnjx = "b01000".U(5.W)
def feq = "b01001".U(5.W)
def fne = "b01010".U(5.W)
def flt = "b01011".U(5.W)
def fle = "b01100".U(5.W)
def fgt = "b01101".U(5.W)
def fge = "b01110".U(5.W)
def fclass = "b01111".U(5.W)
def fmv_f_s = "b10001".U(5.W)
def fmv_s_f = "b10010".U(5.W)
def fsum_ure = "b10011".U(5.W) // unordered
def fmin_re = "b10100".U(5.W)
def fmax_re = "b10101".U(5.W)
def fsum_ore = "b10110".U(5.W) // ordered
}

object VfmaType{
Expand Down
49 changes: 40 additions & 9 deletions src/main/scala/yunsuan/vector/VectorFloatAdder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class VectorFloatAdder() extends Module {
val is_fsum_ure = io.op_code === VfaddOpCode.fsum_ure
val is_fmin_re = io.op_code === VfaddOpCode.fmin_re
val is_fmax_re = io.op_code === VfaddOpCode.fmax_re
val is_fsum_ore = io.op_code === VfaddOpCode.fsum_ore

val fast_is_sub = io.op_code(0)

Expand Down Expand Up @@ -415,6 +416,7 @@ private[vector] class FloatAdderF32WidenF16MixedPipeline(val is_print:Boolean =
val is_fsum_ure = io.op_code === VfaddOpCode.fsum_ure
val is_fmin_re = io.op_code === VfaddOpCode.fmin_re
val is_fmax_re = io.op_code === VfaddOpCode.fmax_re
val is_fsum_ore = io.op_code === VfaddOpCode.fsum_ore
val fp_a_sign = fp_a_to32.head(1)
val fp_b_sign = fp_b_to32.head(1)
val fp_b_sign_is_greater = fp_a_sign & !fp_b_sign
Expand Down Expand Up @@ -509,12 +511,19 @@ private[vector] class FloatAdderF32WidenF16MixedPipeline(val is_print:Boolean =
fp_a_is_NAN & !fp_a_is_SNAN
)))
val is_fsum_ure_notmasked = is_fsum_ure && io.maskForReduction.andR
val is_fsum_ure_masked = is_fsum_ure && io.maskForReduction.orR
val is_fsum_ure_masked = is_fsum_ure && !io.maskForReduction.andR
val is_fsum_ore_notmasked = is_fsum_ore && io.maskForReduction(0)
val is_fsum_ore_masked = is_fsum_ore && !io.maskForReduction(0)
val result_fsum_ure_masked = Mux(
io.maskForReduction === 0.U,
0.U(floatWidth.W),
Mux(io.maskForReduction(0), io.fp_a, io.fp_b)
)
val result_fsum_ore_masked = Mux(
io.maskForReduction(0) === 0.U,
0.U(floatWidth.W),
io.fp_b
)
val outInf = Mux(
res_is_f32,
Cat(is_fmax_re, Fill(8, 1.U), 0.U(23.W)),
Expand Down Expand Up @@ -554,6 +563,7 @@ private[vector] class FloatAdderF32WidenF16MixedPipeline(val is_print:Boolean =
is_fsum_ure_masked,
is_fmax_re,
is_fmin_re,
is_fsum_ore_masked,
),
Seq(
result_min,
Expand All @@ -573,14 +583,15 @@ private[vector] class FloatAdderF32WidenF16MixedPipeline(val is_print:Boolean =
result_fsum_ure_masked,
result_fmax_re,
result_fmin_re,
result_fsum_ore_masked,
)
)
val fflags_NV_stage0 = ((is_min | is_max) & (fp_a_is_SNAN | fp_b_is_SNAN)) |
((is_feq | is_fne) & (fp_a_is_SNAN | fp_b_is_SNAN)) |
((is_flt | is_fle | is_fgt | is_fge) & (fp_a_is_NAN | fp_b_is_NAN))
val fflags_stage0 = Cat(fflags_NV_stage0,0.U(4.W))
io.fp_c := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked),float_adder_result,RegNext(result_stage0))
io.fflags := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked),float_adder_fflags,RegNext(fflags_stage0))
io.fp_c := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked | is_fsum_ore_notmasked),float_adder_result,RegNext(result_stage0))
io.fflags := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked | is_fsum_ore_notmasked),float_adder_fflags,RegNext(fflags_stage0))
}
else {
io.fp_c := float_adder_result
Expand Down Expand Up @@ -1680,6 +1691,7 @@ private[vector] class FloatAdderF64WidenPipeline(val is_print:Boolean = false,va
val is_fsum_ure = io.op_code === VfaddOpCode.fsum_ure
val is_fmin_re = io.op_code === VfaddOpCode.fmin_re
val is_fmax_re = io.op_code === VfaddOpCode.fmax_re
val is_fsum_ore = io.op_code === VfaddOpCode.fsum_ore
val fp_a_sign = io.fp_a.head(1)
val fp_b_sign = io.fp_b.head(1)
val fp_b_sign_is_greater = fp_a_sign & !fp_b_sign
Expand Down Expand Up @@ -1760,12 +1772,19 @@ private[vector] class FloatAdderF64WidenPipeline(val is_print:Boolean = false,va
fp_a_is_NAN & !fp_a_is_SNAN
)))
val is_fsum_ure_notmasked = is_fsum_ure && io.maskForReduction.andR
val is_fsum_ure_masked = is_fsum_ure && io.maskForReduction.orR
val is_fsum_ure_masked = is_fsum_ure && !io.maskForReduction.andR
val is_fsum_ore_notmasked = is_fsum_ore && io.maskForReduction(0)
val is_fsum_ore_masked = is_fsum_ore && !io.maskForReduction(0)
val result_fsum_ure_masked = Mux(
io.maskForReduction === 0.U,
0.U(floatWidth.W),
Mux(io.maskForReduction(0), io.fp_a, io.fp_b)
)
val result_fsum_ore_masked = Mux(
io.maskForReduction(0) === 0.U,
0.U(floatWidth.W),
io.fp_b
)
val outInf = Cat(is_fmax_re, Fill(exponentWidth, 1.U), 0.U((significandWidth-1).W))
val re_masked_one_out = Mux(
io.maskForReduction(0),
Expand Down Expand Up @@ -1801,6 +1820,7 @@ private[vector] class FloatAdderF64WidenPipeline(val is_print:Boolean = false,va
is_fsum_ure_masked,
is_fmax_re,
is_fmin_re,
is_fsum_ore_masked,
),
Seq(
result_min,
Expand All @@ -1820,14 +1840,15 @@ private[vector] class FloatAdderF64WidenPipeline(val is_print:Boolean = false,va
result_fsum_ure_masked,
result_fmax_re,
result_fmin_re,
result_fsum_ore_masked,
)
)
val fflags_NV_stage0 = ((is_min | is_max) & (fp_a_is_SNAN | fp_b_is_SNAN)) |
((is_feq | is_fne) & (fp_a_is_SNAN | fp_b_is_SNAN)) |
((is_flt | is_fle | is_fgt | is_fge) & (fp_a_is_NAN | fp_b_is_NAN))
val fflags_stage0 = Cat(fflags_NV_stage0, 0.U(4.W))
io.fp_c := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked), float_adder_result, RegNext(result_stage0))
io.fflags := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked), float_adder_fflags, RegNext(fflags_stage0))
io.fp_c := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked | is_fsum_ore_notmasked), float_adder_result, RegNext(result_stage0))
io.fflags := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked | is_fsum_ore_notmasked), float_adder_fflags, RegNext(fflags_stage0))
}
else {
io.fp_c := float_adder_result
Expand Down Expand Up @@ -2325,6 +2346,7 @@ private[vector] class FloatAdderF16Pipeline(val is_print:Boolean = false,val has
val is_fsum_ure = io.op_code === VfaddOpCode.fsum_ure
val is_fmin_re = io.op_code === VfaddOpCode.fmin_re
val is_fmax_re = io.op_code === VfaddOpCode.fmax_re
val is_fsum_ore = io.op_code === VfaddOpCode.fsum_ore
val fp_a_sign = io.fp_a.head(1)
val fp_b_sign = io.fp_b.head(1)
val fp_b_sign_is_greater = fp_a_sign & !fp_b_sign
Expand Down Expand Up @@ -2405,12 +2427,19 @@ private[vector] class FloatAdderF16Pipeline(val is_print:Boolean = false,val has
fp_a_is_NAN & !fp_a_is_SNAN
)))
val is_fsum_ure_notmasked = is_fsum_ure && io.maskForReduction.andR
val is_fsum_ure_masked = is_fsum_ure && io.maskForReduction.orR
val is_fsum_ure_masked = is_fsum_ure && !io.maskForReduction.andR
val is_fsum_ore_notmasked = is_fsum_ore && io.maskForReduction(0)
val is_fsum_ore_masked = is_fsum_ore && !io.maskForReduction(0)
val result_fsum_ure_masked = Mux(
io.maskForReduction === 0.U,
0.U(floatWidth.W),
Mux(io.maskForReduction(0), io.fp_a, io.fp_b)
)
val result_fsum_ore_masked = Mux(
io.maskForReduction(0) === 0.U,
0.U(floatWidth.W),
io.fp_b
)
val outInf = Cat(is_fmax_re, Fill(exponentWidth, 1.U), 0.U((significandWidth-1).W))
val re_masked_one_out = Mux(
io.maskForReduction(0),
Expand Down Expand Up @@ -2446,6 +2475,7 @@ private[vector] class FloatAdderF16Pipeline(val is_print:Boolean = false,val has
is_fsum_ure_masked,
is_fmax_re,
is_fmin_re,
is_fsum_ore_masked,
),
Seq(
result_min,
Expand All @@ -2465,14 +2495,15 @@ private[vector] class FloatAdderF16Pipeline(val is_print:Boolean = false,val has
result_fsum_ure_masked,
result_fmax_re,
result_fmin_re,
result_fsum_ore_masked,
)
)
val fflags_NV_stage0 = ((is_min | is_max) & (fp_a_is_SNAN | fp_b_is_SNAN)) |
((is_feq | is_fne) & (fp_a_is_SNAN | fp_b_is_SNAN)) |
((is_flt | is_fle | is_fgt | is_fge) & (fp_a_is_NAN | fp_b_is_NAN))
val fflags_stage0 = Cat(fflags_NV_stage0, 0.U(4.W))
io.fp_c := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked), float_adder_result, RegNext(result_stage0))
io.fflags := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked), float_adder_fflags, RegNext(fflags_stage0))
io.fp_c := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked | is_fsum_ore_notmasked), float_adder_result, RegNext(result_stage0))
io.fflags := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked | is_fsum_ore_notmasked), float_adder_fflags, RegNext(fflags_stage0))
}
else {
io.fp_c := float_adder_result
Expand Down

0 comments on commit b60860c

Please sign in to comment.