Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vfalu: support vfredosum vfwredosum #82

Merged
merged 1 commit into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 25 additions & 22 deletions src/main/scala/yunsuan/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -426,31 +426,34 @@ package object yunsuan {
def vfredusum = LiteralCat(0.U(1.W), 0.U(1.W), 0.U(1.W), VfaddOpCode.fsum_ure)
def vfredmax = LiteralCat(0.U(1.W), 0.U(1.W), 0.U(1.W), VfaddOpCode.fmax_re)
def vfredmin = LiteralCat(0.U(1.W), 0.U(1.W), 0.U(1.W), VfaddOpCode.fmin_re)
def vfredosum = LiteralCat(0.U(1.W), 0.U(1.W), 0.U(1.W), VfaddOpCode.fsum_ore)
def vfwredosum= LiteralCat(0.U(1.W), 0.U(1.W), 1.U(1.W), VfaddOpCode.fsum_ore)
}

object VfaddOpCode {
def dummy = "b11111".U(5.W)
def fadd = "b00000".U(5.W)
def fsub = "b00001".U(5.W)
def fmin = "b00010".U(5.W)
def fmax = "b00011".U(5.W)
def fmerge = "b00100".U(5.W)
def fmove = "b00101".U(5.W)
def fsgnj = "b00110".U(5.W)
def fsgnjn = "b00111".U(5.W)
def fsgnjx = "b01000".U(5.W)
def feq = "b01001".U(5.W)
def fne = "b01010".U(5.W)
def flt = "b01011".U(5.W)
def fle = "b01100".U(5.W)
def fgt = "b01101".U(5.W)
def fge = "b01110".U(5.W)
def fclass = "b01111".U(5.W)
def fmv_f_s = "b10001".U(5.W)
def fmv_s_f = "b10010".U(5.W)
def fsum_ure = "b10011".U(5.W) // unorder
def fmin_re = "b10100".U(5.W)
def fmax_re = "b10101".U(5.W)
def dummy = "b11111".U(5.W)
def fadd = "b00000".U(5.W)
def fsub = "b00001".U(5.W)
def fmin = "b00010".U(5.W)
def fmax = "b00011".U(5.W)
def fmerge = "b00100".U(5.W)
def fmove = "b00101".U(5.W)
def fsgnj = "b00110".U(5.W)
def fsgnjn = "b00111".U(5.W)
def fsgnjx = "b01000".U(5.W)
def feq = "b01001".U(5.W)
def fne = "b01010".U(5.W)
def flt = "b01011".U(5.W)
def fle = "b01100".U(5.W)
def fgt = "b01101".U(5.W)
def fge = "b01110".U(5.W)
def fclass = "b01111".U(5.W)
def fmv_f_s = "b10001".U(5.W)
def fmv_s_f = "b10010".U(5.W)
def fsum_ure = "b10011".U(5.W) // unordered
def fmin_re = "b10100".U(5.W)
def fmax_re = "b10101".U(5.W)
def fsum_ore = "b10110".U(5.W) // ordered
}

object VfmaType{
Expand Down
49 changes: 40 additions & 9 deletions src/main/scala/yunsuan/vector/VectorFloatAdder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class VectorFloatAdder() extends Module {
val is_fsum_ure = io.op_code === VfaddOpCode.fsum_ure
val is_fmin_re = io.op_code === VfaddOpCode.fmin_re
val is_fmax_re = io.op_code === VfaddOpCode.fmax_re
val is_fsum_ore = io.op_code === VfaddOpCode.fsum_ore

val fast_is_sub = io.op_code(0)

Expand Down Expand Up @@ -415,6 +416,7 @@ private[vector] class FloatAdderF32WidenF16MixedPipeline(val is_print:Boolean =
val is_fsum_ure = io.op_code === VfaddOpCode.fsum_ure
val is_fmin_re = io.op_code === VfaddOpCode.fmin_re
val is_fmax_re = io.op_code === VfaddOpCode.fmax_re
val is_fsum_ore = io.op_code === VfaddOpCode.fsum_ore
val fp_a_sign = fp_a_to32.head(1)
val fp_b_sign = fp_b_to32.head(1)
val fp_b_sign_is_greater = fp_a_sign & !fp_b_sign
Expand Down Expand Up @@ -509,12 +511,19 @@ private[vector] class FloatAdderF32WidenF16MixedPipeline(val is_print:Boolean =
fp_a_is_NAN & !fp_a_is_SNAN
)))
val is_fsum_ure_notmasked = is_fsum_ure && io.maskForReduction.andR
val is_fsum_ure_masked = is_fsum_ure && io.maskForReduction.orR
val is_fsum_ure_masked = is_fsum_ure && !io.maskForReduction.andR
val is_fsum_ore_notmasked = is_fsum_ore && io.maskForReduction(0)
val is_fsum_ore_masked = is_fsum_ore && !io.maskForReduction(0)
val result_fsum_ure_masked = Mux(
io.maskForReduction === 0.U,
0.U(floatWidth.W),
Mux(io.maskForReduction(0), io.fp_a, io.fp_b)
)
val result_fsum_ore_masked = Mux(
io.maskForReduction(0) === 0.U,
0.U(floatWidth.W),
io.fp_b
)
val outInf = Mux(
res_is_f32,
Cat(is_fmax_re, Fill(8, 1.U), 0.U(23.W)),
Expand Down Expand Up @@ -554,6 +563,7 @@ private[vector] class FloatAdderF32WidenF16MixedPipeline(val is_print:Boolean =
is_fsum_ure_masked,
is_fmax_re,
is_fmin_re,
is_fsum_ore_masked,
),
Seq(
result_min,
Expand All @@ -573,14 +583,15 @@ private[vector] class FloatAdderF32WidenF16MixedPipeline(val is_print:Boolean =
result_fsum_ure_masked,
result_fmax_re,
result_fmin_re,
result_fsum_ore_masked,
)
)
val fflags_NV_stage0 = ((is_min | is_max) & (fp_a_is_SNAN | fp_b_is_SNAN)) |
((is_feq | is_fne) & (fp_a_is_SNAN | fp_b_is_SNAN)) |
((is_flt | is_fle | is_fgt | is_fge) & (fp_a_is_NAN | fp_b_is_NAN))
val fflags_stage0 = Cat(fflags_NV_stage0,0.U(4.W))
io.fp_c := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked),float_adder_result,RegNext(result_stage0))
io.fflags := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked),float_adder_fflags,RegNext(fflags_stage0))
io.fp_c := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked | is_fsum_ore_notmasked),float_adder_result,RegNext(result_stage0))
io.fflags := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked | is_fsum_ore_notmasked),float_adder_fflags,RegNext(fflags_stage0))
}
else {
io.fp_c := float_adder_result
Expand Down Expand Up @@ -1680,6 +1691,7 @@ private[vector] class FloatAdderF64WidenPipeline(val is_print:Boolean = false,va
val is_fsum_ure = io.op_code === VfaddOpCode.fsum_ure
val is_fmin_re = io.op_code === VfaddOpCode.fmin_re
val is_fmax_re = io.op_code === VfaddOpCode.fmax_re
val is_fsum_ore = io.op_code === VfaddOpCode.fsum_ore
val fp_a_sign = io.fp_a.head(1)
val fp_b_sign = io.fp_b.head(1)
val fp_b_sign_is_greater = fp_a_sign & !fp_b_sign
Expand Down Expand Up @@ -1760,12 +1772,19 @@ private[vector] class FloatAdderF64WidenPipeline(val is_print:Boolean = false,va
fp_a_is_NAN & !fp_a_is_SNAN
)))
val is_fsum_ure_notmasked = is_fsum_ure && io.maskForReduction.andR
val is_fsum_ure_masked = is_fsum_ure && io.maskForReduction.orR
val is_fsum_ure_masked = is_fsum_ure && !io.maskForReduction.andR
val is_fsum_ore_notmasked = is_fsum_ore && io.maskForReduction(0)
val is_fsum_ore_masked = is_fsum_ore && !io.maskForReduction(0)
val result_fsum_ure_masked = Mux(
io.maskForReduction === 0.U,
0.U(floatWidth.W),
Mux(io.maskForReduction(0), io.fp_a, io.fp_b)
)
val result_fsum_ore_masked = Mux(
io.maskForReduction(0) === 0.U,
0.U(floatWidth.W),
io.fp_b
)
val outInf = Cat(is_fmax_re, Fill(exponentWidth, 1.U), 0.U((significandWidth-1).W))
val re_masked_one_out = Mux(
io.maskForReduction(0),
Expand Down Expand Up @@ -1801,6 +1820,7 @@ private[vector] class FloatAdderF64WidenPipeline(val is_print:Boolean = false,va
is_fsum_ure_masked,
is_fmax_re,
is_fmin_re,
is_fsum_ore_masked,
),
Seq(
result_min,
Expand All @@ -1820,14 +1840,15 @@ private[vector] class FloatAdderF64WidenPipeline(val is_print:Boolean = false,va
result_fsum_ure_masked,
result_fmax_re,
result_fmin_re,
result_fsum_ore_masked,
)
)
val fflags_NV_stage0 = ((is_min | is_max) & (fp_a_is_SNAN | fp_b_is_SNAN)) |
((is_feq | is_fne) & (fp_a_is_SNAN | fp_b_is_SNAN)) |
((is_flt | is_fle | is_fgt | is_fge) & (fp_a_is_NAN | fp_b_is_NAN))
val fflags_stage0 = Cat(fflags_NV_stage0, 0.U(4.W))
io.fp_c := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked), float_adder_result, RegNext(result_stage0))
io.fflags := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked), float_adder_fflags, RegNext(fflags_stage0))
io.fp_c := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked | is_fsum_ore_notmasked), float_adder_result, RegNext(result_stage0))
io.fflags := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked | is_fsum_ore_notmasked), float_adder_fflags, RegNext(fflags_stage0))
}
else {
io.fp_c := float_adder_result
Expand Down Expand Up @@ -2325,6 +2346,7 @@ private[vector] class FloatAdderF16Pipeline(val is_print:Boolean = false,val has
val is_fsum_ure = io.op_code === VfaddOpCode.fsum_ure
val is_fmin_re = io.op_code === VfaddOpCode.fmin_re
val is_fmax_re = io.op_code === VfaddOpCode.fmax_re
val is_fsum_ore = io.op_code === VfaddOpCode.fsum_ore
val fp_a_sign = io.fp_a.head(1)
val fp_b_sign = io.fp_b.head(1)
val fp_b_sign_is_greater = fp_a_sign & !fp_b_sign
Expand Down Expand Up @@ -2405,12 +2427,19 @@ private[vector] class FloatAdderF16Pipeline(val is_print:Boolean = false,val has
fp_a_is_NAN & !fp_a_is_SNAN
)))
val is_fsum_ure_notmasked = is_fsum_ure && io.maskForReduction.andR
val is_fsum_ure_masked = is_fsum_ure && io.maskForReduction.orR
val is_fsum_ure_masked = is_fsum_ure && !io.maskForReduction.andR
val is_fsum_ore_notmasked = is_fsum_ore && io.maskForReduction(0)
val is_fsum_ore_masked = is_fsum_ore && !io.maskForReduction(0)
val result_fsum_ure_masked = Mux(
io.maskForReduction === 0.U,
0.U(floatWidth.W),
Mux(io.maskForReduction(0), io.fp_a, io.fp_b)
)
val result_fsum_ore_masked = Mux(
io.maskForReduction(0) === 0.U,
0.U(floatWidth.W),
io.fp_b
)
val outInf = Cat(is_fmax_re, Fill(exponentWidth, 1.U), 0.U((significandWidth-1).W))
val re_masked_one_out = Mux(
io.maskForReduction(0),
Expand Down Expand Up @@ -2446,6 +2475,7 @@ private[vector] class FloatAdderF16Pipeline(val is_print:Boolean = false,val has
is_fsum_ure_masked,
is_fmax_re,
is_fmin_re,
is_fsum_ore_masked,
),
Seq(
result_min,
Expand All @@ -2465,14 +2495,15 @@ private[vector] class FloatAdderF16Pipeline(val is_print:Boolean = false,val has
result_fsum_ure_masked,
result_fmax_re,
result_fmin_re,
result_fsum_ore_masked,
)
)
val fflags_NV_stage0 = ((is_min | is_max) & (fp_a_is_SNAN | fp_b_is_SNAN)) |
((is_feq | is_fne) & (fp_a_is_SNAN | fp_b_is_SNAN)) |
((is_flt | is_fle | is_fgt | is_fge) & (fp_a_is_NAN | fp_b_is_NAN))
val fflags_stage0 = Cat(fflags_NV_stage0, 0.U(4.W))
io.fp_c := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked), float_adder_result, RegNext(result_stage0))
io.fflags := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked), float_adder_fflags, RegNext(fflags_stage0))
io.fp_c := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked | is_fsum_ore_notmasked), float_adder_result, RegNext(result_stage0))
io.fflags := Mux(RegNext(is_add | is_sub | is_fsum_ure_notmasked | is_fsum_ore_notmasked), float_adder_fflags, RegNext(fflags_stage0))
}
else {
io.fp_c := float_adder_result
Expand Down
Loading