Skip to content

Commit

Permalink
[sqrtfloat] add exceptions in RoundingUnit
Browse files Browse the repository at this point in the history
  • Loading branch information
midnighter95 committed Aug 7, 2023
1 parent e611c42 commit 8449892
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 19 deletions.
56 changes: 44 additions & 12 deletions arithmetic/src/float/RoundingUnit.scala
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
package float

import chisel3._
import chiseltest._
import utest._
import chisel3.util._

import scala.util.Random
import scala.math._

/**
* input.rbits = 2bits + sticky bit
*
* leave
*
* output is subnormal
*
* */
class RoundingUnit extends Module{
val input = IO(Input(new Bundle{
// val invalidExc = Bool() // overrides 'infiniteExc' and 'in'
// val infiniteExc = Bool() // overrides 'in' except for 'in.sign'
val invalidExc = Bool() // overrides 'infiniteExc' and 'in'
val infiniteExc = Bool() // overrides 'in' except for 'in.sign'
val sig = UInt(23.W)
val exp = UInt(8.W)
val rBits = UInt(3.W)
Expand All @@ -32,6 +33,21 @@ class RoundingUnit extends Module{
val roundingMode_max = (input.roundingMode === consts.round_max)
val roundingMode_near_maxMag = (input.roundingMode === consts.round_near_maxMag)


val common_case = !(input.infiniteExc || input.invalidExc)
val common_overflow = Wire(Bool())
val common_inexact = Wire(Bool())


// exception data with Spike

val invalidOut = "h7FC00000".U
/** Inf with sign */
val infiniteOut = Cat(input.sign,"h7F800000".U)
val outSele1H = common_case ## input.infiniteExc ## input.invalidExc



val sigPlus = Wire(UInt(23.W))
val expPlus = Wire(UInt(8.W))
val sigIncr = Wire(Bool())
Expand All @@ -51,25 +67,41 @@ class RoundingUnit extends Module{
expIncr := input.sig.andR && sigIncr
expPlus := input.exp + expIncr

val expOverflow = input.exp.andR && expIncr
common_overflow := input.exp.andR && expIncr
common_inexact := input.rBits.orR

val common_sigOut = Mux(sigIncr, sigPlus, input.sig)
val common_expOut = Mux(expIncr, expPlus, input.exp)

val common_out = Mux(common_overflow, infiniteOut, input.sign ## common_expOut ## common_sigOut)

output.data := Mux1H(Seq(
outSele1H(0) -> invalidOut,
outSele1H(1) -> infiniteOut,
outSele1H(2) -> common_out)
)

val sigOut = Mux(sigIncr, sigPlus, input.sig)
val expOut = Mux(expIncr, expPlus, input.exp)
val invalidOpration = input.invalidExc
val divideByzero = false.B
val overflow = common_case && common_overflow
val underflow = false.B
val inexact = overflow || (common_case && common_inexact)

output.data := input.sign ## expOut ## sigOut
output.exceptionFlags := 0.U
output.exceptionFlags := invalidOpration ## divideByzero ## overflow ## underflow ## inexact

}

object RoundingUnit {
def apply(sign: Bool, exp:UInt, sig: UInt, rbits:UInt, rmode: UInt): UInt = {
def apply(sign: Bool, exp:UInt, sig: UInt, rbits:UInt, rmode: UInt,invalidExc:Bool, infiniteExc:Bool): UInt = {

val rounder = Module(new RoundingUnit)
rounder.input.sign := sign
rounder.input.sig := sig
rounder.input.exp := exp
rounder.input.rBits := rbits
rounder.input.roundingMode := rmode
rounder.input.invalidExc := invalidExc
rounder.input.infiniteExc := infiniteExc
rounder.output.data
}

Expand Down
42 changes: 37 additions & 5 deletions arithmetic/src/float/SqrtFloat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,39 @@ import chisel3._
import chisel3.util._
import sqrt._

/**
*
* @todo Opt for zero
* input is Subnormal!
*
* */
class SqrtFloat(expWidth: Int, sigWidth: Int) extends Module{
val input = IO(Flipped(DecoupledIO(new FloatSqrtInput(expWidth, sigWidth))))
val output = IO(DecoupledIO(new FloatSqrtOutput(expWidth, sigWidth)))
val debug = IO(Output(new Bundle() {
val fractIn = UInt(26.W)
}))
val rawFloatIn = rawFloatFromFN(expWidth,sigWidth,input.bits.oprand)

/** Control path */
val isNegaZero = rawFloatIn.isZero && rawFloatIn.sign
val isPosiInf = rawFloatIn.isInf && rawFloatIn.sign

val fastWorking = RegInit(false.B)
val fastCase = Wire(Bool())

/** negative or NaN*/
val invalidExec = (rawFloatIn.sign && !isNegaZero) || rawFloatIn.isNaN
/** positive inf */
val infinitExec = isPosiInf

fastCase := invalidExec || infinitExec
fastWorking := input.fire && fastCase



/** Data path */
val rawFloatIn = rawFloatFromFN(expWidth,sigWidth,input.bits.oprand)

val adjustedExp = Cat(rawFloatIn.sExp(expWidth-1), rawFloatIn.sExp(expWidth-1, 0))

/** {{{
Expand All @@ -28,18 +52,26 @@ class SqrtFloat(expWidth: Int, sigWidth: Int) extends Module{
Cat(rawFloatIn.sig(sigWidth-1, 0),0.U(2.W)))

val SqrtModule = Module(new SquareRoot(2, 2, 26, 26))
SqrtModule.input.valid := input.valid
SqrtModule.input.valid := input.valid && !fastCase
SqrtModule.input.bits.operand := fractIn
SqrtModule.output.ready := output.ready

val rbits = SqrtModule.output.bits.result(1,0) ## (!SqrtModule.output.bits.zeroRemainder)
val sigRound = SqrtModule.output.bits.result(24,2)
val sigforRound = SqrtModule.output.bits.result(24,2)


input.ready := SqrtModule.input.ready
output.bits.result := RoundingUnit(input.bits.oprand(expWidth + sigWidth-1) ,expOut,sigRound,rbits,consts.round_near_even)
output.bits.result := RoundingUnit(
input.bits.oprand(expWidth + sigWidth-1) ,
expOut,
sigforRound,
rbits,
consts.round_near_even,
invalidExec,
infinitExec)
output.bits.sig := SqrtModule.output.bits.result
output.bits.exp := expOut
output.valid := SqrtModule.output.valid
output.valid := SqrtModule.output.valid || fastWorking

debug.fractIn := fractIn

Expand Down
2 changes: 1 addition & 1 deletion arithmetic/src/sqrt/SquareRoot.scala
Original file line number Diff line number Diff line change
Expand Up @@ -158,5 +158,5 @@ class SquareRoot(
counterNext := Mux(input.fire, 0.U, counter + 1.U)

output.bits.result := Mux(needCorrect, resultMinusOne, resultOrigin)
output.bits.zeroRemainder := remainderFinal.orR
output.bits.zeroRemainder := !remainderFinal.orR
}
2 changes: 1 addition & 1 deletion arithmetic/tests/src/float/SqrtFloatTester.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ object SquareRootTester extends TestSuite with ChiselUtestTester {
test("Sqrt Float FP32 should pass") {
def testcase(): Unit = {
def extendTofull(input:String, width:Int) =(Seq.fill(width - input.length)("0").mkString("") + input)
val oprandFloat: Float = (5.877471754111438e-39).toFloat
val oprandFloat: Float = Random.nextInt(1000000)+Random.nextFloat()
val oprandDouble: Double = oprandFloat.toDouble

val oprandString = extendTofull(java.lang.Float.floatToIntBits(oprandFloat).toBinaryString,32)
Expand Down

0 comments on commit 8449892

Please sign in to comment.