Skip to content

Commit

Permalink
BigInt in EInt to allow large literals in let (#409)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark1626 authored Jan 17, 2024
1 parent b72bb3c commit ffb3866
Show file tree
Hide file tree
Showing 13 changed files with 52 additions and 28 deletions.
6 changes: 3 additions & 3 deletions src/main/scala/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,18 @@ case class Parser(input: String) {
positioned(
P(
"0" | "-".? ~ (CharIn("1-9") ~ CharsWhileIn("0-9").?)
).!.map((n: String) => EInt(n.toInt)).opaque("integer")
).!.map((n: String) => EInt(BigInt(n))).opaque("integer")
)
def hex[K: P]: P[Expr] =
positioned(
P("0x" ~/ CharIn("0-9a-fA-F").rep(1)).!.map((n: String) =>
EInt(Integer.parseInt(n.substring(2), 16), 16)
EInt(BigInt(n.substring(2), 16), 16)
).opaque("hexademical")
)
def octal[K: P]: P[Expr] =
positioned(
P("0" ~ CharsWhileIn("0-7")).!.map((n: String) =>
EInt(Integer.parseInt(n.substring(1), 8), 8)
EInt(BigInt(n.substring(1), 8), 8)
).opaque("ocatal")
)
def rational[K: P]: P[Expr] =
Expand Down
16 changes: 14 additions & 2 deletions src/main/scala/Utils.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package fuselang
import scala.{PartialFunction => PF}
import scala.math.{log10, ceil, abs}
import scala.math.{log10, ceil}

object Utils {

Expand All @@ -11,10 +11,22 @@ object Utils {
}
}

// https://codereview.stackexchange.com/questions/14561/matching-bigints-in-scala
// TODO: This can overflow and result in an runtime exception
object Big {
def unapply(n: BigInt) = Some(n.toInt)
}

def bitsNeeded(n: Int): Int = n match {
case 0 => 1
case n if n > 0 => ceil(log10(n + 1) / log10(2)).toInt
case n if n < 0 => bitsNeeded(abs(n)) + 1
case n if n < 0 => bitsNeeded(n.abs) + 1
}

def bitsNeeded(n: BigInt): Int = n match {
case Big(0) => 1
case n if n > 0 => ceil(log10((n + 1).toDouble) / log10(2)).toInt
case n if n < 0 => bitsNeeded(n.abs) + 1
}

def cartesianProduct[T](llst: Seq[Seq[T]]): Seq[Seq[T]] = {
Expand Down
6 changes: 3 additions & 3 deletions src/main/scala/backends/CppLike.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ object Cpp {

implicit def IdToString(id: Id): Doc = value(id.v)

def emitBaseInt(v: Int, base: Int): String = base match {
case 8 => s"0${Integer.toString(v, 8)}"
def emitBaseInt(v: BigInt, base: Int): String = base match {
case 8 => s"0${v.toString(8)}"
case 10 => v.toString
case 16 => s"0x${Integer.toString(v, 16)}"
case 16 => s"0x${v.toString(16)}"
}

implicit def emitExpr(e: Expr): Doc = e match {
Expand Down
22 changes: 12 additions & 10 deletions src/main/scala/common/CodeGenHelpers.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package fuselang.common

import fuselang.Utils.Big

import scala.math.log10

object CodeGenHelpers {
Expand Down Expand Up @@ -35,10 +37,10 @@ object CodeGenHelpers {
}

// Using the trick defined here: https://www.geeksforgeeks.org/program-to-find-whether-a-no-is-power-of-two/
def isPowerOfTwo(x: Int) =
def isPowerOfTwo(x: BigInt) =
x != 0 && ((x & (x - 1)) == 0)

def log2(n: Int) = log10(n) / log10(2)
def log2(n: BigInt) = log10(n.toDouble) / log10(2)

def fastDiv(l: Expr, r: Expr) = (l, r) match {
case (EInt(n, b), EInt(m, _)) => EInt(n / m, b)
Expand Down Expand Up @@ -68,14 +70,14 @@ object CodeGenHelpers {

// Simple peephole optimization to turn: 1 * x => x, 0 + x => x, 0 * x => 0
def binop(op: BOp, l: Expr, r: Expr) = (op, l, r) match {
case (NumOp("*", _), EInt(1, _), r) => r
case (NumOp("*", _), l, EInt(1, _)) => l
case (NumOp("*", _), EInt(0, b), _) => EInt(0, b)
case (NumOp("*", _), _, EInt(0, b)) => EInt(0, b)
case (NumOp("+", _), l, EInt(0, _)) => l
case (NumOp("+", _), EInt(0, _), r) => r
case (BitOp("<<"), l, EInt(0, _)) => l
case (BitOp(">>"), l, EInt(0, _)) => l
case (NumOp("*", _), EInt(Big(1), _), r) => r
case (NumOp("*", _), l, EInt(Big(1), _)) => l
case (NumOp("*", _), EInt(Big(0), b), _) => EInt(0, b)
case (NumOp("*", _), _, EInt(Big(0), b)) => EInt(0, b)
case (NumOp("+", _), l, EInt(Big(0), _)) => l
case (NumOp("+", _), EInt(Big(0), _), r) => r
case (BitOp("<<"), l, EInt(Big(0), _)) => l
case (BitOp(">>"), l, EInt(Big(0), _)) => l
case _ => EBinop(op, l, r)
}

Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/common/Errors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ object Errors {
pos
)

case class IndexOutOfBounds(id: Id, size: Int, mv: Int, pos: Position)
case class IndexOutOfBounds(id: Id, size: BigInt, mv: BigInt, pos: Position)
extends TypeError(
s"Index out of bounds for `$id'. Memory size is $size, iterator max val is $mv",
pos
Expand Down
6 changes: 3 additions & 3 deletions src/main/scala/common/Pretty.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ object Pretty {

def emitTyp(t: Type): Doc = text(t.toString)

def emitBaseInt(v: Int, base: Int): String = base match {
case 8 => s"0${Integer.toString(v, 8)}"
def emitBaseInt(v: BigInt, base: Int): String = base match {
case 8 => s"0${v.toString(8)}"
case 10 => v.toString
case 16 => s"0x${Integer.toString(v, 16)}"
case 16 => s"0x${v.toString(16)}"
}

implicit def emitExpr(e: Expr)(implicit debug: Boolean): Doc = e match {
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/common/Syntax.scala
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ object Syntax {
// Types that can be upcast to Ints
sealed trait IntType
case class TSizedInt(len: Int, unsigned: Boolean) extends Type with IntType
case class TStaticInt(v: Int) extends Type with IntType
case class TStaticInt(v: BigInt) extends Type with IntType
case class TIndex(static: (Int, Int), dynamic: (Int, Int))
extends Type
with IntType {
Expand Down Expand Up @@ -135,7 +135,7 @@ object Syntax {
case _ => false
}
}
case class EInt(v: Int, base: Int = 10) extends Expr
case class EInt(v: BigInt, base: Int = 10) extends Expr
case class ERational(d: String) extends Expr
case class EBool(v: Boolean) extends Expr
case class EBinop(op: BOp, e1: Expr, e2: Expr) extends Expr
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/passes/BoundsCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ object BoundsChecker {
case Rotation(e) => (e, 1)
}

val maxVal: Int =
val maxVal: BigInt =
sufExpr.typ
.getOrThrow(Impossible(s"$sufExpr is missing type"))
.matchOrError(viewId.pos, "view", "Integer Type") {
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/passes/LowerUnroll.scala
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ object LowerUnroll extends PartialTransformer {
.map({
case ((_, bank), idx) =>
idx match {
case EInt(n, 10) => Some(n % bank)
case EInt(n, 10) => Some((n % bank).toInt)
case EInt(_, _) =>
throw NotImplemented(
"Indexing using non decimal integers",
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/typechecker/AffineCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ object AffineChecker {
(bres * (e - s), Range(s, e) +: consume)
// Index is a statically known number.
case TStaticInt(v) =>
(bres * 1, Seq(v % dims(dim)._2) +: consume)
(bres * 1, Seq((v % dims(dim)._2).toInt) +: consume)
// Index is a dynamic number.
case _: TSizedInt =>
if (dims(dim)._2 != 1)
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/typechecker/Subtyping.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ object Subtyping {
Some(TSizedInt(max(s, bitsNeeded(v)), un))
case (st: TStaticInt, idx: TIndex) =>
// Infer unsigned
Some(TSizedInt(bitsNeeded(max(idx.maxVal, st.v)), false))
Some(TSizedInt(bitsNeeded(st.v.max(idx.maxVal)), false))
case (t2: TSizedInt, _: TIndex) => Some(t2)
case (_: TFloat, _: TDouble) => Some(TDouble())
case (_: TRational, _: TFloat) => Some(TFloat())
Expand Down
1 change: 1 addition & 0 deletions src/test/scala/ParsingPositive.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class ParsingTests extends org.scalatest.FunSuite {
parseAst("0.25;")
parseAst("0x19;")
parseAst("014;")
parseAst("0x9e3779b9;")
}

test("atoms") {
Expand Down
9 changes: 9 additions & 0 deletions src/test/scala/TypeCheckerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ class TypeCheckerSpec extends FunSpec {
typeCheck("let x: fix<2,2> = 2+2.1;")
}
}
it("should allow large unsigned literals") {
typeCheck("let x: ubit<32> = 0x9e3779b9;")
typeCheck("let x: ubit<64> = 0xffffffffffffffff;")
typeCheck("let x: bit<64> = 0x7fffffffffffffff;")
typeCheck("let x: ubit<128> = 0xffffffffffffffffffffffffffffffff;")
}
}

describe("with explicit type and without initializer") {
Expand Down Expand Up @@ -231,6 +237,9 @@ class TypeCheckerSpec extends FunSpec {
it("result of fix type addition upcast to subtype join") {
typeCheck("decl x: fix<32,16>; decl y: fix<16,8>; let z = x + y;")
}
it("result of peephole optimization on large unsigned int should be valid") {
typeCheck("let x: ubit<32> = 0x9e3779b9; let y = x + 0;")
}
}

describe("Reassign") {
Expand Down

0 comments on commit ffb3866

Please sign in to comment.