From 84498920c848d2da38e1d6dc02a69acf6a65ad64 Mon Sep 17 00:00:00 2001 From: Yanqi Yang Date: Mon, 7 Aug 2023 16:16:30 +0800 Subject: [PATCH] [sqrtfloat] add exceptions in RoundingUnit --- arithmetic/src/float/RoundingUnit.scala | 56 +++++++++++++++---- arithmetic/src/float/SqrtFloat.scala | 42 ++++++++++++-- arithmetic/src/sqrt/SquareRoot.scala | 2 +- .../tests/src/float/SqrtFloatTester.scala | 2 +- 4 files changed, 83 insertions(+), 19 deletions(-) diff --git a/arithmetic/src/float/RoundingUnit.scala b/arithmetic/src/float/RoundingUnit.scala index 5a163d5..6df4606 100644 --- a/arithmetic/src/float/RoundingUnit.scala +++ b/arithmetic/src/float/RoundingUnit.scala @@ -1,20 +1,21 @@ package float import chisel3._ -import chiseltest._ -import utest._ +import chisel3.util._ -import scala.util.Random -import scala.math._ /** * input.rbits = 2bits + sticky bit * + * leave + * + * output is subnormal + * * */ class RoundingUnit extends Module{ val input = IO(Input(new Bundle{ -// val invalidExc = Bool() // overrides 'infiniteExc' and 'in' -// val infiniteExc = Bool() // overrides 'in' except for 'in.sign' + val invalidExc = Bool() // overrides 'infiniteExc' and 'in' + val infiniteExc = Bool() // overrides 'in' except for 'in.sign' val sig = UInt(23.W) val exp = UInt(8.W) val rBits = UInt(3.W) @@ -32,6 +33,21 @@ class RoundingUnit extends Module{ val roundingMode_max = (input.roundingMode === consts.round_max) val roundingMode_near_maxMag = (input.roundingMode === consts.round_near_maxMag) + + val common_case = !(input.infiniteExc || input.invalidExc) + val common_overflow = Wire(Bool()) + val common_inexact = Wire(Bool()) + + + // exception data with Spike + + val invalidOut = "h7FC00000".U + /** Inf with sign */ + val infiniteOut = Cat(input.sign,"h7F800000".U) + val outSele1H = common_case ## input.infiniteExc ## input.invalidExc + + + val sigPlus = Wire(UInt(23.W)) val expPlus = Wire(UInt(8.W)) val sigIncr = Wire(Bool()) @@ -51,18 +67,32 @@ class RoundingUnit extends Module{ expIncr := input.sig.andR && sigIncr expPlus := input.exp + expIncr - val expOverflow = input.exp.andR && expIncr + common_overflow := input.exp.andR && expIncr + common_inexact := input.rBits.orR + + val common_sigOut = Mux(sigIncr, sigPlus, input.sig) + val common_expOut = Mux(expIncr, expPlus, input.exp) + + val common_out = Mux(common_overflow, infiniteOut, input.sign ## common_expOut ## common_sigOut) + + output.data := Mux1H(Seq( + outSele1H(0) -> invalidOut, + outSele1H(1) -> infiniteOut, + outSele1H(2) -> common_out) + ) - val sigOut = Mux(sigIncr, sigPlus, input.sig) - val expOut = Mux(expIncr, expPlus, input.exp) + val invalidOpration = input.invalidExc + val divideByzero = false.B + val overflow = common_case && common_overflow + val underflow = false.B + val inexact = overflow || (common_case && common_inexact) - output.data := input.sign ## expOut ## sigOut - output.exceptionFlags := 0.U + output.exceptionFlags := invalidOpration ## divideByzero ## overflow ## underflow ## inexact } object RoundingUnit { - def apply(sign: Bool, exp:UInt, sig: UInt, rbits:UInt, rmode: UInt): UInt = { + def apply(sign: Bool, exp:UInt, sig: UInt, rbits:UInt, rmode: UInt,invalidExc:Bool, infiniteExc:Bool): UInt = { val rounder = Module(new RoundingUnit) rounder.input.sign := sign @@ -70,6 +100,8 @@ object RoundingUnit { rounder.input.exp := exp rounder.input.rBits := rbits rounder.input.roundingMode := rmode + rounder.input.invalidExc := invalidExc + rounder.input.infiniteExc := infiniteExc rounder.output.data } diff --git a/arithmetic/src/float/SqrtFloat.scala b/arithmetic/src/float/SqrtFloat.scala index 6adc35d..5b3f5be 100644 --- a/arithmetic/src/float/SqrtFloat.scala +++ b/arithmetic/src/float/SqrtFloat.scala @@ -4,15 +4,39 @@ import chisel3._ import chisel3.util._ import sqrt._ +/** + * + * @todo Opt for zero + * input is Subnormal! + * + * */ class SqrtFloat(expWidth: Int, sigWidth: Int) extends Module{ val input = IO(Flipped(DecoupledIO(new FloatSqrtInput(expWidth, sigWidth)))) val output = IO(DecoupledIO(new FloatSqrtOutput(expWidth, sigWidth))) val debug = IO(Output(new Bundle() { val fractIn = UInt(26.W) })) + val rawFloatIn = rawFloatFromFN(expWidth,sigWidth,input.bits.oprand) + + /** Control path */ + val isNegaZero = rawFloatIn.isZero && rawFloatIn.sign + val isPosiInf = rawFloatIn.isInf && rawFloatIn.sign + + val fastWorking = RegInit(false.B) + val fastCase = Wire(Bool()) + + /** negative or NaN*/ + val invalidExec = (rawFloatIn.sign && !isNegaZero) || rawFloatIn.isNaN + /** positive inf */ + val infinitExec = isPosiInf + + fastCase := invalidExec || infinitExec + fastWorking := input.fire && fastCase + + /** Data path */ - val rawFloatIn = rawFloatFromFN(expWidth,sigWidth,input.bits.oprand) + val adjustedExp = Cat(rawFloatIn.sExp(expWidth-1), rawFloatIn.sExp(expWidth-1, 0)) /** {{{ @@ -28,18 +52,26 @@ class SqrtFloat(expWidth: Int, sigWidth: Int) extends Module{ Cat(rawFloatIn.sig(sigWidth-1, 0),0.U(2.W))) val SqrtModule = Module(new SquareRoot(2, 2, 26, 26)) - SqrtModule.input.valid := input.valid + SqrtModule.input.valid := input.valid && !fastCase SqrtModule.input.bits.operand := fractIn SqrtModule.output.ready := output.ready val rbits = SqrtModule.output.bits.result(1,0) ## (!SqrtModule.output.bits.zeroRemainder) - val sigRound = SqrtModule.output.bits.result(24,2) + val sigforRound = SqrtModule.output.bits.result(24,2) + input.ready := SqrtModule.input.ready - output.bits.result := RoundingUnit(input.bits.oprand(expWidth + sigWidth-1) ,expOut,sigRound,rbits,consts.round_near_even) + output.bits.result := RoundingUnit( + input.bits.oprand(expWidth + sigWidth-1) , + expOut, + sigforRound, + rbits, + consts.round_near_even, + invalidExec, + infinitExec) output.bits.sig := SqrtModule.output.bits.result output.bits.exp := expOut - output.valid := SqrtModule.output.valid + output.valid := SqrtModule.output.valid || fastWorking debug.fractIn := fractIn diff --git a/arithmetic/src/sqrt/SquareRoot.scala b/arithmetic/src/sqrt/SquareRoot.scala index fb3af4e..d50a751 100644 --- a/arithmetic/src/sqrt/SquareRoot.scala +++ b/arithmetic/src/sqrt/SquareRoot.scala @@ -158,5 +158,5 @@ class SquareRoot( counterNext := Mux(input.fire, 0.U, counter + 1.U) output.bits.result := Mux(needCorrect, resultMinusOne, resultOrigin) - output.bits.zeroRemainder := remainderFinal.orR + output.bits.zeroRemainder := !remainderFinal.orR } diff --git a/arithmetic/tests/src/float/SqrtFloatTester.scala b/arithmetic/tests/src/float/SqrtFloatTester.scala index 9ac92f4..c0e619f 100644 --- a/arithmetic/tests/src/float/SqrtFloatTester.scala +++ b/arithmetic/tests/src/float/SqrtFloatTester.scala @@ -11,7 +11,7 @@ object SquareRootTester extends TestSuite with ChiselUtestTester { test("Sqrt Float FP32 should pass") { def testcase(): Unit = { def extendTofull(input:String, width:Int) =(Seq.fill(width - input.length)("0").mkString("") + input) - val oprandFloat: Float = (5.877471754111438e-39).toFloat + val oprandFloat: Float = Random.nextInt(1000000)+Random.nextFloat() val oprandDouble: Double = oprandFloat.toDouble val oprandString = extendTofull(java.lang.Float.floatToIntBits(oprandFloat).toBinaryString,32)