diff --git a/.github/configs/mergify_config.yml b/.github/configs/mergify_config.yml new file mode 100644 index 0000000000..90bf4ad54c --- /dev/null +++ b/.github/configs/mergify_config.yml @@ -0,0 +1,7 @@ +# Configuration for generating .mergify.yml +conditions: + - status-success=all tests passed +branches: + - 1.2.x + - 1.3.x + - 1.4.x diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000000..d04f7e05c6 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,128 @@ +name: Continuous Integration + +on: + pull_request: + push: + branches: + - master + - 1.4.x + - 1.3.x + - 1.2.x + +jobs: + test: + name: sbt test + runs-on: ubuntu-latest + strategy: + matrix: + scala: [2.13.4, 2.12.13] + container: + image: ucbbar/chisel3-tools + options: --user github --entrypoint /bin/bash + env: + CONTAINER_HOME: /home/github + + steps: + - name: Checkout + uses: actions/checkout@v2 + - name: Setup Scala + uses: olafurpg/setup-scala@v10 + with: + java-version: adopt@1.8 + - name: Cache Scala + uses: coursier/cache-action@v6 + - name: Check Formatting (Scala 2.12 only) + if: startsWith(matrix.scala, '2.12') + run: sbt ++${{ matrix.scala }} scalafmtCheckAll + - name: Unidoc + run: sbt ++${{ matrix.scala }} unidoc + - name: Sanity check benchmarking scripts (Scala 2.13 only) + if: startsWith(matrix.scala, '2.13') + run: | + benchmark/scripts/benchmark_cold_compile.py -N 2 --designs regress/ICache.fir --versions HEAD + benchmark/scripts/find_heap_bound.py -- -cp firrtl*jar firrtl.stage.FirrtlMain -i regress/ICache.fir -o out -X verilog + - name: Test + run: sbt ++${{ matrix.scala }} test + - name: Binary compatibility + run: sbt ++${{ matrix.scala }} mimaReportBinaryIssues + + mill: + name: Mill Sanity Check + runs-on: ubuntu-latest + strategy: + matrix: + scala: [2.13.4, 2.12.13] + steps: + - name: Checkout + uses: actions/checkout@v2 + - name: Setup Scala + uses: olafurpg/setup-scala@v10 + - name: Cache + uses: coursier/cache-action@v5 + - name: Setup Mill + uses: jodersky/setup-mill@v0.2.3 + - name: Mill sanity check + run: mill _[${{ matrix.scala }}].compile + + # TODO find better way to express Ops and AddNot as single test + equiv: + name: formal equivalence + runs-on: ubuntu-latest + strategy: + matrix: + design: [RocketCore, FPU, ICache, Ops, AddNot] + container: + image: ucbbar/chisel3-tools + options: --user github --entrypoint /bin/bash + env: + CONTAINER_HOME: /home/github + + steps: + - name: Checkout + uses: actions/checkout@v2 + - name: Setup Scala + uses: olafurpg/setup-scala@v10 + with: + java-version: adopt@1.8 + - name: Cache Scala + uses: coursier/cache-action@v5 + - name: Run Formal Equivalence + run: ./.run_formal_checks.sh ${{ matrix.design }} + + # Sentinel job to simplify how we specify which checks need to pass in branch + # protection and in Mergify + # + # When adding new jobs, please add them to `needs` below + all_tests_passed: + name: "all tests passed" + needs: [test, mill, equiv] + runs-on: ubuntu-latest + steps: + - run: echo Success! + + # sbt ci-release publishes all cross versions so this job needs to be + # separate from a Scala versions build matrix to avoid duplicate publishing + publish: + needs: [all_tests_passed] + runs-on: ubuntu-latest + if: github.event_name == 'push' + + steps: + - name: Checkout + uses: actions/checkout@v2 + - name: Setup Scala + uses: olafurpg/setup-scala@v10 + with: + java-version: adopt@1.8 + - name: Cache Scala + uses: coursier/cache-action@v5 + - name: Setup GPG (for Publish) + uses: olafurpg/setup-gpg@v3 + - name: Publish + run: sbt ci-release + env: + PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }} + PGP_SECRET: ${{ secrets.PGP_SECRET }} + SONATYPE_PASSWORD: ${{ secrets.SONATYPE_PASSWORD }} + SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }} + diff --git a/.github/workflows/update_mergify.yml b/.github/workflows/update_mergify.yml new file mode 100644 index 0000000000..fbcdebd743 --- /dev/null +++ b/.github/workflows/update_mergify.yml @@ -0,0 +1,45 @@ +name: Update .mergify.yml +on: + schedule: + # Runs once an hour + - cron: "0 * * * *" + +jobs: + update-mergify: + name: Update .mergify.yml + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v2 + with: + ref: master + - name: Checkout Chisel Repo Tools + uses: actions/checkout@v2 + with: + repository: ucb-bar/chisel-repo-tools + path: tools + - name: Setup Scala + uses: olafurpg/setup-scala@v10 + - name: Install Ammonite + run: sudo sh -c '(echo "#!/usr/bin/env sh" && curl -L https://github.com/com-lihaoyi/Ammonite/releases/download/2.3.8/2.13-2.3.8) > /usr/local/bin/amm && chmod +x /usr/local/bin/amm' + - name: Cache Scala + uses: coursier/cache-action@v6 + - name: Generate .mergify.yml + run: | + ./tools/scripts/mergify.sc .github/configs/mergify_config.yml > .mergify.yml + # Delete tools so they don't get included in commit + rm -rf tools/ + - name: Create Pull Request + uses: peter-evans/create-pull-request@v3.8.2 + with: + # TODO Default GITHUB_TOKEN cannot trigger PR CI + token: ${{ secrets.GITHUB_TOKEN }} + commit-message: Update .mergify.yml + branch: actions/workflows/update-mergify + delete-branch: true + title: Update .mergify.yml + body: | + This is an automated pull request by "Update .mergify.yml" workflow created with [create-pull-request][1]. + + [1]: https://github.com/peter-evans/create-pull-request + diff --git a/.gitignore b/.gitignore index 8b3d51a642..dd2f998ba5 100644 --- a/.gitignore +++ b/.gitignore @@ -66,3 +66,4 @@ project/metals.sbt *~ *#*# +.vscode diff --git a/.mergify.yml b/.mergify.yml index b09c861ae4..5e77c8d945 100644 --- a/.mergify.yml +++ b/.mergify.yml @@ -1,116 +1,106 @@ pull_request_rules: - - name: remove outdated reviews - conditions: - - base=master - actions: - dismiss_reviews: +- name: automatic squash-and-merge on CI success and review + conditions: + - status-success=all tests passed + - '#approved-reviews-by>=1' + - '#changes-requested-reviews-by=0' + - base=master + - label="Please Merge" + - label!="DO NOT MERGE" + - label!="bp-conflict" + actions: + merge: + method: squash + strict: smart + strict_method: merge +- name: backport to 1.4.x + conditions: + - merged + - base=master + - milestone=1.4.x + actions: + backport: + branches: + - 1.4.x + labels: + - Backport + ignore_conflicts: true + label_conflicts: bp-conflict + label: + add: + - Backported +- name: backport to 1.3.x, 1.4.x + conditions: + - merged + - base=master + - milestone=1.3.x + actions: + backport: + branches: + - 1.3.x + - 1.4.x + labels: + - Backport + ignore_conflicts: true + label_conflicts: bp-conflict + label: + add: + - Backported +- name: backport to 1.2.x, 1.3.x, 1.4.x + conditions: + - merged + - base=master + - milestone=1.2.x + actions: + backport: + branches: + - 1.2.x + - 1.3.x + - 1.4.x + labels: + - Backport + ignore_conflicts: true + label_conflicts: bp-conflict + label: + add: + - Backported +- name: automatic squash-and-mege of 1.2.x backport PRs + conditions: + - status-success=all tests passed + - '#changes-requested-reviews-by=0' + - base=1.2.x + - label="Backport" + - label!="DO NOT MERGE" + - label!="bp-conflict" + actions: + merge: + method: squash + strict: smart + strict_method: merge +- name: automatic squash-and-mege of 1.3.x backport PRs + conditions: + - status-success=all tests passed + - '#changes-requested-reviews-by=0' + - base=1.3.x + - label="Backport" + - label!="DO NOT MERGE" + - label!="bp-conflict" + actions: + merge: + method: squash + strict: smart + strict_method: merge +- name: automatic squash-and-mege of 1.4.x backport PRs + conditions: + - status-success=all tests passed + - '#changes-requested-reviews-by=0' + - base=1.4.x + - label="Backport" + - label!="DO NOT MERGE" + - label!="bp-conflict" + actions: + merge: + method: squash + strict: smart + strict_method: merge -pull_request_rules: - - name: automatic squash-and-merge on CI success and review - conditions: - - status-success=Travis CI - Pull Request - - "#approved-reviews-by>=1" - - "#changes-requested-reviews-by=0" - - base=master - - label="Please Merge" - - label!="DO NOT MERGE" - - label!="bp-conflict" - actions: - merge: - method: squash - strict: smart - strict_method: merge - - - name: backport to 1.4.x - conditions: - - merged - - base=master - - milestone=1.4.x - actions: - backport: - branches: - - 1.4.x - ignore_conflicts: True - label_conflicts: "bp-conflict" - label: - add: [Backported] - - - name: backport to 1.3.x and 1.4.x - conditions: - - merged - - base=master - - milestone=1.3.x - actions: - backport: - branches: - - 1.3.x - - 1.4.x - ignore_conflicts: True - label_conflicts: "bp-conflict" - label: - add: [Backported] - - - name: backport to 1.2.x, 1.3.x, and 1.4.x - conditions: - - merged - - base=master - - milestone=1.2.X - actions: - backport: - branches: - - 1.2.x - - 1.3.x - - 1.4.x - ignore_conflicts: True - label_conflicts: "bp-conflict" - label: - add: [Backported] - - - name: label Mergify backport PR - conditions: - - body~=This is an automated backport of pull request \#\d+ done by Mergify - actions: - label: - add: [Backport] - - - name: automatic squash-and-merge of 1.4.x backport PRs - conditions: - - status-success=Travis CI - Pull Request - - "#changes-requested-reviews-by=0" - - base=1.4.x - - label="Backport" - - label!="DO NOT MERGE" - - label!="bp-conflict" - actions: - merge: - method: squash - strict: smart - strict_method: merge - - - name: automatic squash-and-merge of 1.3.x backport PRs - conditions: - - status-success=Travis CI - Pull Request - - "#changes-requested-reviews-by=0" - - base=1.3.x - - label="Backport" - - label!="DO NOT MERGE" - - label!="bp-conflict" - actions: - merge: - method: squash - strict: smart - strict_method: merge - - - name: automatic squash-and-merge of 1.2.x backport PRs - conditions: - - status-success=Travis CI - Pull Request - - "#changes-requested-reviews-by=0" - - base=1.2.x - - label="Backport" - - label!="DO NOT MERGE" - - label!="bp-conflict" - actions: - merge: - method: squash - strict: smart - strict_method: merge diff --git a/.run_chisel_tests.sh b/.run_chisel_tests.sh deleted file mode 100644 index dbec894e6e..0000000000 --- a/.run_chisel_tests.sh +++ /dev/null @@ -1,20 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -set -e - -# Use appropriate branches. -# Each stable branch of FIRRTL should have a fixed value for these branches. -CHISEL_BRANCH="master" -TREADLE_BRANCH="master" - -# Skip chisel tests if the commit message says to -# Replace ... with .. in TRAVIS_COMMIT_RANGE, see https://github.com/travis-ci/travis-ci/issues/4596 -if git log --format=%B --no-merges ${TRAVIS_COMMIT_RANGE/.../..} | grep '\[skip chisel tests\]'; then - exit 0 -else - sbt $SBT_ARGS publishLocal - git clone https://github.com/freechipsproject/treadle.git --single-branch -b ${TREADLE_BRANCH} --depth 10 - (cd treadle && sbt $SBT_ARGS publishLocal) - git clone https://github.com/ucb-bar/chisel3.git --single-branch -b ${CHISEL_BRANCH} - cd chisel3 - sbt $SBT_ARGS test -fi diff --git a/.run_formal_checks.sh b/.run_formal_checks.sh index f5ffec25c9..0a28a1d16a 100755 --- a/.run_formal_checks.sh +++ b/.run_formal_checks.sh @@ -11,23 +11,25 @@ fi DUT=$1 -# Run formal check only for PRs -if [ $TRAVIS_PULL_REQUEST = "false" ]; then - echo "Not a pull request, no formal check" - exit 0 -else - # $TRAVIS_BRANCH is branch targeted by PR - # Travis does a shallow clone, checkout PR target so that we have it - # THen return to previous branch so HEAD points to the commit we're testing - git remote set-branches origin $TRAVIS_BRANCH && git fetch - git checkout $TRAVIS_BRANCH +# See https://docs.github.com/en/actions/reference/environment-variables +# for info about these variables + +# Run formal check only for PRs, GITHUB_BASE_REF is only set for PRs +if [ ! -z "$GITHUB_BASE_REF" ]; then + # Github Actions does a shallow clone, checkout PR target so that we have it + # Then return to previous branch so HEAD points to the commit we're testing + git remote set-branches origin $GITHUB_BASE_REF && git fetch + git checkout $GITHUB_BASE_REF git checkout - # Skip if '[skip formal checks]' shows up in any of the commit messages in the PR - if git log --format=%B --no-merges $TRAVIS_BRANCH..HEAD | grep '\[skip formal checks\]'; then + if git log --format=%B --no-merges $GITHUB_BASE_REF..HEAD | grep '\[skip formal checks\]'; then echo "Commit message says to skip formal checks" exit 0 else cp regress/$DUT.fir $DUT.fir - ./scripts/formal_equiv.sh HEAD $TRAVIS_BRANCH $DUT + ./scripts/formal_equiv.sh HEAD $GITHUB_BASE_REF $DUT fi +else + echo "Not a pull request, no formal check" + exit 0 fi diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index efcb6434a5..0000000000 --- a/.travis.yml +++ /dev/null @@ -1,151 +0,0 @@ -language: scala -sudo: false - -jdk: openjdk8 - -branches: - only: - - master - - 1.4.x - - 1.3.x - - 1.2.x - -cache: - directories: - $HOME/.cache/coursier - $HOME/.sbt - $INSTALL_DIR - -git: - depth: 10 - -env: - global: - INSTALL_DIR=$TRAVIS_BUILD_DIR/install - VERILATOR_ROOT=$INSTALL_DIR - PATH=$PATH:$VERILATOR_ROOT/bin:$TRAVIS_BUILD_DIR/utils/bin - SBT_ARGS="-Dsbt.log.noformat=true" - -before_script: - - OLDEST_SHARED=`git log --format=%H $TRAVIS_COMMIT_RANGE | tail -n1` - - OLDEST_COMMIT=`git log --format=%H | tail -n1` - - if [ $OLDEST_SHARED == $OLDEST_COMMIT ]; then git fetch --unshallow; fi - -stages: - - name: prepare - - name: test - - name: release - if: (branch IN (master, 1.4.x, 1.3.x, 1.2.x)) AND (type = push) AND (NOT fork) - -# We do not use the built-in tests as generated by using multiple Scala -# versions because the cache is not shared between stages with any -# environmental differences. -# Instead, we specify the version of Scala manually for each test (or leave it -# as the default as defined in FIRRTL's build.sbt). -jobs: - include: - # Because these write to the same install directory, they must run in the - # same script - - stage: prepare - name: "Install: [Verilator, Yosys, Z3]" - script: - - bash .install_verilator.sh - - verilator --version - - bash .install_yosys.sh - - yosys -V - - bash .install_z3.sh - - z3 -version - - stage: prepare - name: "Compile FIRRTL to share with subsequent stages" - script: - - sbt $SBT_ARGS assembly - workspaces: - create: - name: firrtl_build - paths: - - target/ - - utils/bin/firrtl.jar - - project/project/ - - project/target/ - - stage: test - name: "All files must be formatted with scalafmt" - workspaces: - use: firrtl_build - script: - - sbt scalafmtCheckAll - - stage: test - name: "Unidoc builds (no warnings)" - workspaces: - use: firrtl_build - script: - - sbt $SBT_ARGS +unidoc - - stage: test - name: "Tests: FIRRTL (2.13)" - workspaces: - use: firrtl_build - script: - - verilator --version - - sbt ++2.13.2 $SBT_ARGS test - - stage: test - name: "Tests: FIRRTL (2.12)" - workspaces: - use: firrtl_build - script: - - verilator --version - - sbt $SBT_ARGS test - - stage: test - name: "Tests: FIRRTL (2.11)" - script: - - verilator --version - - sbt ++2.11.12 $SBT_ARGS test - - stage: test - name: "Tests: chisel3 (2.12)" - workspaces: - use: firrtl_build - script: - - verilator --version - - bash .run_chisel_tests.sh - - stage: test - name: "Formal equivalence: RocketCore" - workspaces: - use: firrtl_build - script: - - yosys -V - - "travis_wait 30 sleep 1800 &" - - ./.run_formal_checks.sh RocketCore - - stage: test - name: "Formal equivalence: FPU" - workspaces: - use: firrtl_build - script: - - yosys -V - - "travis_wait 30 sleep 1800 &" - - ./.run_formal_checks.sh FPU - - stage: test - name: "Formal equivalence: ICache" - workspaces: - use: firrtl_build - script: - - yosys -V - - "travis_wait 30 sleep 1800 &" - - ./.run_formal_checks.sh ICache - - stage: test - name: "Formal equivalence: small expression-tree stress tests" - workspaces: - use: firrtl_build - script: - - yosys -V - - "travis_wait 30 sleep 1800 &" - - ./.run_formal_checks.sh Ops - - ./.run_formal_checks.sh AddNot - - stage: test - name: "Sanity check benchmarking scripts" - workspaces: - use: firrtl_build - script: - - benchmark/scripts/benchmark_cold_compile.py -N 2 --designs regress/ICache.fir --versions HEAD - - benchmark/scripts/find_heap_bound.py -- -cp firrtl*jar firrtl.stage.FirrtlMain -i regress/ICache.fir -o out -X verilog - # run ci-release only if previous stages passed - - stage: release - name: "Publish SNAPSHOT release" - script: sbt ci-release diff --git a/README.md b/README.md index 70921ed521..5efe975e1f 100644 --- a/README.md +++ b/README.md @@ -3,11 +3,11 @@ --- [![Join the chat at https://gitter.im/freechipsproject/firrtl](https://badges.gitter.im/freechipsproject/firrtl.svg)](https://gitter.im/freechipsproject/firrtl?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) -[![Build Status](https://travis-ci.org/freechipsproject/firrtl.svg?branch=master)](https://travis-ci.org/freechipsproject/firrtl) +![Build Status](https://github.com/chipsalliance/firrtl/workflows/Continuous%20Integration/badge.svg) [![Mergify Status][mergify-status]][mergify] [mergify]: https://mergify.io -[mergify-status]: https://gh.mergify.io/badges/:freechipsproject/:firrtl.png?style=cut +[mergify-status]: https://img.shields.io/endpoint.svg?url=https://gh.mergify.io/badges/chipsalliance/firrtl&style=flat #### Flexible Internal Representation for RTL diff --git a/build.sbt b/build.sbt index d39f99dc29..395c18861c 100644 --- a/build.sbt +++ b/build.sbt @@ -2,27 +2,20 @@ enablePlugins(SiteScaladocPlugin) -def javacOptionsVersion(scalaVersion: String): Seq[String] = { - Seq() ++ { - // Scala 2.12 requires Java 8, but we continue to generate - // Java 7 compatible code until we need Java 8 features - // for compatibility with old clients. - CrossVersion.partialVersion(scalaVersion) match { - case Some((2, scalaMajor: Long)) if scalaMajor < 12 => - Seq("-source", "1.7", "-target", "1.7") - case _ => - Seq("-source", "1.8", "-target", "1.8") - } - } -} - - lazy val commonSettings = Seq( organization := "edu.berkeley.cs", + scalaVersion := "2.12.13", + crossScalaVersions := Seq("2.13.4", "2.12.13") +) + +lazy val isAtLeastScala213 = Def.setting { + import Ordering.Implicits._ + CrossVersion.partialVersion(scalaVersion.value).exists(_ >= (2, 13)) +} + +lazy val firrtlSettings = Seq( name := "firrtl", version := "1.5-SNAPSHOT", - scalaVersion := "2.12.12", - crossScalaVersions := Seq("2.13.2", "2.12.12", "2.11.12"), addCompilerPlugin(scalafixSemanticdb), scalacOptions := Seq( "-deprecation", @@ -32,21 +25,32 @@ lazy val commonSettings = Seq( "-language:implicitConversions", "-Yrangepos", // required by SemanticDB compiler plugin ), - javacOptions ++= javacOptionsVersion(scalaVersion.value), + // Always target Java8 for maximum compatibility + javacOptions ++= Seq("-source", "1.8", "-target", "1.8"), libraryDependencies ++= Seq( "org.scala-lang" % "scala-reflect" % scalaVersion.value, - "org.scalatest" %% "scalatest" % "3.2.0" % "test", + "org.scalatest" %% "scalatest" % "3.2.8" % "test", "org.scalatestplus" %% "scalacheck-1-14" % "3.1.3.0" % "test", "com.github.scopt" %% "scopt" % "3.7.1", "net.jcazevedo" %% "moultingyaml" % "0.4.2", - "org.json4s" %% "json4s-native" % "3.6.9", - "org.apache.commons" % "commons-text" % "1.8" + "org.json4s" %% "json4s-native" % "3.6.11", + "org.apache.commons" % "commons-text" % "1.8", + "io.github.alexarchambault" %% "data-class" % "0.2.5", ), + // macros for the data-class library + libraryDependencies ++= { + if (isAtLeastScala213.value) Nil + else Seq(compilerPlugin("org.scalamacros" % "paradise" % "2.1.1" cross CrossVersion.full)) + }, + scalacOptions ++= { + if (isAtLeastScala213.value) Seq("-Ymacro-annotations") + else Nil + }, // starting with scala 2.13 the parallel collections are separate from the standard library libraryDependencies ++= { CrossVersion.partialVersion(scalaVersion.value) match { case Some((2, major)) if major <= 12 => Seq() - case _ => Seq("org.scala-lang.modules" %% "scala-parallel-collections" % "0.2.0") + case _ => Seq("org.scala-lang.modules" %% "scala-parallel-collections" % "1.0.2") } }, resolvers ++= Seq( @@ -55,6 +59,10 @@ lazy val commonSettings = Seq( ) ) +lazy val mimaSettings = Seq( + mimaPreviousArtifacts := Set() +) + lazy val protobufSettings = Seq( sourceDirectory in ProtobufConfig := baseDirectory.value / "src" / "main" / "proto", protobufRunProtoc in ProtobufConfig := (args => @@ -85,16 +93,15 @@ lazy val antlrSettings = Seq( antlr4GenVisitor in Antlr4 := true, antlr4GenListener in Antlr4 := false, antlr4PackageName in Antlr4 := Option("firrtl.antlr"), - antlr4Version in Antlr4 := "4.8", + antlr4Version in Antlr4 := "4.9.2", javaSource in Antlr4 := (sourceManaged in Compile).value ) lazy val publishSettings = Seq( - // publishMavenStyle and publishTo handled by sbt-ci-release + publishMavenStyle := true, publishArtifact in Test := false, pomIncludeRepository := { x => false }, - // Don't add 'scm' elements if we have a git.remoteRepo definition, - // but since we don't (with the removal of ghpages), add them in below. + // scm is set by sbt-ci-release pomExtra := http://chisel.eecs.berkeley.edu/ @@ -103,10 +110,6 @@ lazy val publishSettings = Seq( repo - - https://github.com/freechipsproject/firrtl.git - scm:git:github.com/freechipsproject/firrtl.git - jackbackrack @@ -114,25 +117,23 @@ lazy val publishSettings = Seq( http://www.eecs.berkeley.edu/~jrb/ , + publishTo := { + val v = version.value + val nexus = "https://oss.sonatype.org/" + if (v.trim.endsWith("SNAPSHOT")) { + Some("snapshots" at nexus + "content/repositories/snapshots") + } else { + Some("releases" at nexus + "service/local/staging/deploy/maven2") + } + } ) -def scalacDocOptionsVersion(scalaVersion: String): Seq[String] = { - Seq() ++ { - // If we're building with Scala > 2.11, enable the compile option - // to flag warnings as errors. This must be disabled for 2.11 since - // references to the Java class library from Java 9 on generate warnings. - // https://github.com/scala/bug/issues/10675 - CrossVersion.partialVersion(scalaVersion) match { - case Some((2, scalaMajor: Long)) if scalaMajor < 12 => Seq() - case _ => Seq("-Xfatal-warnings") - } - } -} lazy val docSettings = Seq( doc in Compile := (doc in ScalaUnidoc).value, autoAPIMappings := true, scalacOptions in Compile in doc ++= Seq( + "-Xfatal-warnings", "-feature", "-diagrams", "-diagrams-max-classes", "25", @@ -148,9 +149,9 @@ lazy val docSettings = Seq( } else { s"v${version.value}" } - s"https://github.com/freechipsproject/firrtl/tree/$branch€{FILE_PATH}.scala" + s"https://github.com/chipsalliance/firrtl/tree/$branch€{FILE_PATH_EXT}#L€{FILE_LINE}" } - ) ++ scalacDocOptionsVersion(scalaVersion.value) + ) ) lazy val firrtl = (project in file(".")) @@ -162,6 +163,7 @@ lazy val firrtl = (project in file(".")) Test / testForkedParallel := true ) .settings(commonSettings) + .settings(firrtlSettings) .settings(protobufSettings) .settings(antlrSettings) .settings(assemblySettings) @@ -175,9 +177,11 @@ lazy val firrtl = (project in file(".")) buildInfoUsePackageAsPath := true, buildInfoKeys := Seq[BuildInfoKey](buildInfoPackage, version, scalaVersion, sbtVersion) ) + .settings(mimaSettings) lazy val benchmark = (project in file("benchmark")) .dependsOn(firrtl) + .settings(commonSettings) .settings( assemblyJarName in assembly := "firrtl-benchmark.jar", test in assembly := {}, @@ -187,6 +191,7 @@ lazy val benchmark = (project in file("benchmark")) val JQF_VERSION = "1.5" lazy val jqf = (project in file("jqf")) + .settings(commonSettings) .settings( libraryDependencies ++= Seq( "edu.berkeley.cs.jqf" % "jqf-fuzz" % JQF_VERSION, @@ -212,6 +217,7 @@ lazy val testClassAndMethodParser = { lazy val fuzzer = (project in file("fuzzer")) .dependsOn(firrtl) + .settings(commonSettings) .settings( libraryDependencies ++= Seq( "com.pholser" % "junit-quickcheck-core" % "0.8", diff --git a/build.sc b/build.sc index ebef2b3d09..36412024f3 100644 --- a/build.sc +++ b/build.sc @@ -3,13 +3,14 @@ import mill._ import mill.scalalib._ import mill.scalalib.publish._ +import mill.scalalib.scalafmt._ import mill.modules.Util import $ivy.`com.lihaoyi::mill-contrib-buildinfo:$MILL_VERSION` import mill.contrib.buildinfo.BuildInfo -object firrtl extends mill.Cross[firrtlCrossModule]("2.11.12", "2.12.12", "2.13.2") +object firrtl extends mill.Cross[firrtlCrossModule]("2.12.13", "2.13.4") -class firrtlCrossModule(val crossScalaVersion: String) extends CrossSbtModule with PublishModule with BuildInfo { +class firrtlCrossModule(val crossScalaVersion: String) extends CrossSbtModule with ScalafmtModule with PublishModule with BuildInfo { override def millSourcePath = super.millSourcePath / os.up // 2.12.12 -> Array("2", "12", "12") -> "12" -> 12 @@ -21,17 +22,14 @@ class firrtlCrossModule(val crossScalaVersion: String) extends CrossSbtModule wi Some("firrtl.stage.FirrtlMain") } - private def javacCrossOptions = majorVersion match { - case i if i < 12 => Seq("-source", "1.7", "-target", "1.7") - case _ => Seq("-source", "1.8", "-target", "1.8") - } + private def javacCrossOptions = Seq("-source", "1.8", "-target", "1.8") override def scalacOptions = T { super.scalacOptions() ++ Seq( "-deprecation", "-unchecked", "-Yrangepos" // required by SemanticDB compiler plugin - ) + ) ++ (if (majorVersion == 13) Seq("-Ymacro-annotations") else Nil) } override def javacOptions = T { @@ -43,29 +41,27 @@ class firrtlCrossModule(val crossScalaVersion: String) extends CrossSbtModule wi ivy"${scalaOrganization()}:scala-reflect:${scalaVersion()}", ivy"com.github.scopt::scopt:3.7.1", ivy"net.jcazevedo::moultingyaml:0.4.2", - ivy"org.json4s::json4s-native:3.6.9", + ivy"org.json4s::json4s-native:3.6.11", ivy"org.apache.commons:commons-text:1.8", + ivy"io.github.alexarchambault::data-class:0.2.5", ivy"org.antlr:antlr4-runtime:$antlr4Version", ivy"com.google.protobuf:protobuf-java:$protocVersion" ) ++ { - if (majorVersion > 12) - Agg(ivy"org.scala-lang.modules::scala-parallel-collections:0.2.0") + if (majorVersion == 13) + Agg(ivy"org.scala-lang.modules::scala-parallel-collections:1.0.2") else Agg() } } - object test extends Tests { - private def ivyCrossDeps = majorVersion match { - case i if i < 12 => Agg(ivy"junit:junit:4.13.1") - case _ => Agg() - } + override def scalacPluginIvyDeps = if (majorVersion == 12) Agg(ivy"org.scalamacros:::paradise:2.1.1") else super.scalacPluginIvyDeps + object test extends Tests { override def ivyDeps = T { Agg( - ivy"org.scalatest::scalatest:3.2.0", + ivy"org.scalatest::scalatest:3.2.8", ivy"org.scalatestplus::scalacheck-1-14:3.1.3.0" - ) ++ ivyCrossDeps + ) } def testFrameworks = T { @@ -94,14 +90,16 @@ class firrtlCrossModule(val crossScalaVersion: String) extends CrossSbtModule wi } /* antlr4 */ - def antlr4Version = "4.8" + def antlr4Version = "4.9.2" def antlrSource = T.source { millSourcePath / "src" / "main" / "antlr4" / "FIRRTL.g4" } def downloadAntlr4Jar = T.persistent { - Util.download(s"https://www.antlr.org/download/antlr-$antlr4Version-complete.jar") + if (!os.isFile( T.ctx.dest / "antlr4" )) + Util.download(s"https://www.antlr.org/download/antlr-$antlr4Version-complete.jar", os.rel / "antlr4") + PathRef(T.ctx.dest / "antlr4") } def generatedAntlr4Source = T.sources { @@ -117,19 +115,68 @@ class firrtlCrossModule(val crossScalaVersion: String) extends CrossSbtModule wi } /* protoc */ - def protocVersion = "3.5.1" + def protocVersion = "3.15.6" def protobufSource = T.source { millSourcePath / "src" / "main" / "proto" / "firrtl.proto" } - def downloadProtocJar = T.persistent { - Util.download(s"https://repo.maven.apache.org/maven2/com/github/os72/protoc-jar/$protocVersion/protoc-jar-$protocVersion.jar") + def architecture = T { + System.getProperty("os.arch") + } + def operationSystem = T { + System.getProperty("os.name") + } + + def downloadProtoc = T.persistent { + val isMac = operationSystem().toLowerCase.startsWith("mac") + val isLinux = operationSystem().toLowerCase.startsWith("linux") + val isWindows = operationSystem().toLowerCase.startsWith("win") + + val aarch_64 = architecture().equals("aarch64") | architecture().startsWith("armv8") + val ppcle_64 = architecture().equals("ppc64le") + val s390x = architecture().equals("s390x") + val x86_32 = architecture().matches("^(x8632|x86|i[3-6]86|ia32|x32)$") + val x86_64 = architecture().matches("^(x8664|amd64|ia32e|em64t|x64|x86_64)$") + + val protocBinary = + if (isMac) + // MacOS ARM 64-bit still supports x86_64 binaries via Rosetta 2 + if (aarch_64 || x86_64) "osx-x86_64" + else throw new Exception("mill cannot detect your architecture of your Mac") + else if (isLinux) + if (aarch_64) "linux-aarch_64" + else if (ppcle_64) "linux-ppcle_64" + else if (s390x) "linux-s390x" + else if (x86_32) "linux-x86_32" + else if (x86_64) "linux-x86_64" + else throw new Exception("mill cannot detect your architecture of your Linux") + else if (isWindows) + if (x86_32) "win32" + else if (x86_64) "win64" + else throw new Exception("mill cannot detect your architecture of your Windows") + else throw new Exception("mill cannot detect your operation system.") + + val unpackPath = os.rel / "unpacked" + + val bin = if(isWindows) + T.ctx.dest / unpackPath / "bin" / "protoc.exe" + else + T.ctx.dest / unpackPath / "bin" / "protoc" + + if (!os.exists(bin)) + Util.downloadUnpackZip( + s"https://github.com/protocolbuffers/protobuf/releases/download/v$protocVersion/protoc-$protocVersion-$protocBinary.zip", + unpackPath + ) + // Download Linux/Mac binary doesn't have x. + if (!isWindows) os.perms.set(bin, "rwx------") + PathRef(bin) } def generatedProtoSources = T.sources { - os.proc("java", - "-jar", downloadProtocJar().path.toString, + os.proc( + downloadProtoc().path.toString, "-I", protobufSource().path / os.up, s"--java_out=${T.ctx.dest.toString}", protobufSource().path.toString() diff --git a/project/build.properties b/project/build.properties index 797e7ccfdb..f0be67b9f7 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=1.3.10 +sbt.version=1.5.1 diff --git a/project/plugins.sbt b/project/plugins.sbt index bb342af2bd..7e19a7b2a8 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -10,20 +10,22 @@ addSbtPlugin("com.eed3si9n" % "sbt-buildinfo" % "0.10.0") addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.4.3") -addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.6.1") +addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.7.0") addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.10") addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.8.2") -addSbtPlugin("com.github.gseitz" % "sbt-protobuf" % "0.6.5") +addSbtPlugin("com.github.sbt" % "sbt-protobuf" % "0.7.0") -addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.9.19") +addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.9.27") addSbtPlugin("com.thoughtworks.sbt-api-mappings" % "sbt-api-mappings" % "3.0.0") -addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.4.0") +addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.4.2") -addSbtPlugin("com.geirsson" % "sbt-ci-release" % "1.5.4") +addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.8.1") + +addSbtPlugin("com.geirsson" % "sbt-ci-release" % "1.5.7") libraryDependencies += "com.github.os72" % "protoc-jar" % "3.11.4" diff --git a/scripts/formal_equiv.sh b/scripts/formal_equiv.sh index 2304b74e17..c3d4535750 100755 --- a/scripts/formal_equiv.sh +++ b/scripts/formal_equiv.sh @@ -28,7 +28,7 @@ make_verilog () { git checkout $1 local filename="$DUT.$1.v" - sbt "runMain firrtl.Driver -i $DUT.fir -o $filename -X verilog" + sbt "runMain firrtl.stage.FirrtlMain -i $DUT.fir -o $filename -X verilog" RET=$filename } diff --git a/spec/spec.pdf b/spec/spec.pdf index 45dd4879d4..e392dec686 100644 Binary files a/spec/spec.pdf and b/spec/spec.pdf differ diff --git a/spec/spec.tex b/spec/spec.tex index 41de5d5e1c..8201b32ed6 100644 --- a/spec/spec.tex +++ b/spec/spec.tex @@ -989,10 +989,15 @@ \subsection{Stops} \label{stop_stmt} side-effect-having statements in different modules or with different clocks that trigger concurrently is undefined. +The stop statement has an optional name attribute which can be used to +attach metadata to the statement. The name is part of the module level +namespace. However it can never be used in a reference since it is not of +any valid type. + \begin{lstlisting} wire clk:Clock wire halt:UInt<1> -stop(clk,halt,42) +stop(clk,halt,42) : optional_name ... \end{lstlisting} @@ -1004,12 +1009,17 @@ \subsection{Formatted Prints} For information about execution ordering of clocked statements with observable environmental side effects, see section \ref{stop_stmt}. +The printf statement has an optional name attribute which can be used to +attach metadata to the statement. The name is part of the module level +namespace. However it can never be used in a reference since it is not of +any valid type. + \begin{lstlisting} wire clk:Clock wire condition:UInt<1> wire a:UInt wire b:UInt -printf(clk, condition, "a in hex: %x, b in decimal:%d.\n", a, b) +printf(clk, condition, "a in hex: %x, b in decimal:%d.\n", a, b) : optional_name ... \end{lstlisting} @@ -1043,6 +1053,11 @@ \subsection{Verification} For information about execution ordering of clocked statements with observable environmental side effects, see section \ref{stop_stmt}. +Any verification statement has an optional name attribute which can be used to +attach metadata to the statement. The name is part of the module level +namespace. However it can never be used in a reference since it is not of +any valid type. + \subsubsection{Assert} The assert statement verifies that the predicate is true on the rising edge of any clock cycle when the enable is true. In other words, it verifies that enable implies predicate. @@ -1053,7 +1068,7 @@ \subsubsection{Assert} wire en:UInt<1> pred <= eq(X, Y) en <= Z_valid -assert(clk, pred, en, "X equals Y when Z is valid") +assert(clk, pred, en, "X equals Y when Z is valid") : optional_name \end{lstlisting} \subsubsection{Assume} @@ -1066,7 +1081,7 @@ \subsubsection{Assume} wire en:UInt<1> pred <= eq(X, Y) en <= Z_valid -assume(clk, pred, en, "X equals Y when Z is valid") +assume(clk, pred, en, "X equals Y when Z is valid") : optional_name \end{lstlisting} \subsubsection{Cover} @@ -1079,7 +1094,7 @@ \subsubsection{Cover} wire en:UInt<1> pred <= eq(X, Y) en <= Z_valid -cover(clk, pred, en, "X equals Y when Z is valid") +cover(clk, pred, en, "X equals Y when Z is valid") : optional_name \end{lstlisting} \section{Expressions} @@ -2108,8 +2123,8 @@ \subsection{Concrete Syntax Tree} \label{syntax_tree} &\pipe &\pd{exp} \vv{is invalid} \opt{\pd{info}} &\text{Invalidate}\\ &\pipe &\vv{attach}\vv{(}\rpt{\pd{exp}}\vv{)} \opt{\pd{info}} &\text{Attach}\\ &\pipe &\vv{when} \pd{exp} \vv{:} \opt{\pd{info}} \pd{stmt} \opt{\vv{else} \vv{:} \pd{stmt}} &\text{Conditional}\\ - &\pipe &\vv{stop(}\pd{exp}, \pd{exp}, \pd{int})\opt{\pd{info}} &\text{Stop}\\ - &\pipe &\vv{printf(}\pd{exp}, \pd{exp}, \pd{string}, \rpt{\pd{exp}}\vv{)} \opt{\pd{info}} &\text{Printf}\\ + &\pipe &\vv{stop(}\pd{exp}, \pd{exp}, \pd{int}) \opt{\vv{:}\pd{id}} \opt{\pd{info}} &\text{Stop}\\ + &\pipe &\vv{printf(}\pd{exp}, \pd{exp}, \pd{string}, \rpt{\pd{exp}}\vv{)} \opt{\vv{:}\pd{id}} \opt{\pd{info}} &\text{Printf}\\ &\pipe &\vv{skip} \opt{\pd{info}} &\text{Empty}\\ &\pipe &\vv{(}\rpt{\pd{stmt}}\vv{)} &\text{Statement Group}\\ \pd{ruw} &= &\vv{old} \pipe \vv{ new} \pipe \vv{ undefined} &\text{Read Under Write Flag}\\ diff --git a/src/main/antlr4/FIRRTL.g4 b/src/main/antlr4/FIRRTL.g4 index 54ad8d0eec..aa53f2f0b1 100644 --- a/src/main/antlr4/FIRRTL.g4 +++ b/src/main/antlr4/FIRRTL.g4 @@ -103,13 +103,17 @@ stmt | exp '<-' exp info? | exp 'is' 'invalid' info? | when - | 'stop(' exp exp intLit ')' info? - | 'printf(' exp exp StringLit ( exp)* ')' info? + | 'stop(' exp exp intLit ')' stmtName? info? + | 'printf(' exp exp StringLit ( exp)* ')' stmtName? info? | 'skip' info? | 'attach' '(' exp+ ')' info? - | 'assert' '(' exp exp exp StringLit ')' info? - | 'assume' '(' exp exp exp StringLit ')' info? - | 'cover' '(' exp exp exp StringLit ')' info? + | 'assert' '(' exp exp exp StringLit ')' stmtName? info? + | 'assume' '(' exp exp exp StringLit ')' stmtName? info? + | 'cover' '(' exp exp exp StringLit ')' stmtName? info? + ; + +stmtName + : ':' id ; memField diff --git a/src/main/proto/firrtl.proto b/src/main/proto/firrtl.proto index e8451d7a25..6ce1c10822 100644 --- a/src/main/proto/firrtl.proto +++ b/src/main/proto/firrtl.proto @@ -276,6 +276,7 @@ message Firrtl { IsInvalid is_invalid = 17; MemoryPort memory_port = 18; Attach attach = 20; + Verification verification = 21; } SourceInfo source_info = 19; diff --git a/src/main/resources/META-INF/services/firrtl.options.RegisteredTransform b/src/main/resources/META-INF/services/firrtl.options.RegisteredTransform index bb72d45ccb..9304c39aaf 100644 --- a/src/main/resources/META-INF/services/firrtl.options.RegisteredTransform +++ b/src/main/resources/META-INF/services/firrtl.options.RegisteredTransform @@ -3,3 +3,4 @@ firrtl.transforms.CheckCombLoops firrtl.passes.InlineInstances firrtl.passes.clocklist.ClockListTransform firrtl.transforms.formal.AssertSubmoduleAssumptions +firrtl.transforms.ConstantPropagation diff --git a/src/main/scala/firrtl/AddDescriptionNodes.scala b/src/main/scala/firrtl/AddDescriptionNodes.scala index 9424d4a76d..123ae6e364 100644 --- a/src/main/scala/firrtl/AddDescriptionNodes.scala +++ b/src/main/scala/firrtl/AddDescriptionNodes.scala @@ -136,7 +136,7 @@ class AddDescriptionNodes extends Transform with DependencyAPIMigration { Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], Dependency[firrtl.transforms.InlineBitExtractionsTransform], Dependency[firrtl.transforms.PropagatePresetAnnotations], - Dependency[firrtl.transforms.InlineCastsTransform], + Dependency[firrtl.transforms.InlineAcrossCastsTransform], Dependency[firrtl.transforms.LegalizeClocksTransform], Dependency[firrtl.transforms.FlattenRegUpdate], Dependency(passes.VerilogModulusCleanup), diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala index b4629a2a99..2998af3c6e 100644 --- a/src/main/scala/firrtl/Compiler.scala +++ b/src/main/scala/firrtl/Compiler.scala @@ -116,6 +116,14 @@ sealed abstract class CircuitForm(private val value: Int) extends Ordered[Circui /** Defines a suffix to use if this form is written to a file */ def outputSuffix: String } +private[firrtl] object CircuitForm { + // Private internal utils to reduce number of deprecation warnings + val ChirrtlForm = firrtl.ChirrtlForm + val HighForm = firrtl.HighForm + val MidForm = firrtl.MidForm + val LowForm = firrtl.LowForm + val UnknownForm = firrtl.UnknownForm +} // These magic numbers give an ordering to CircuitForm /** Chirrtl Form @@ -202,22 +210,9 @@ final case object UnknownForm extends CircuitForm(-1) { // Internal utilities to keep code DRY, not a clean interface private[firrtl] object Transform { - // Run transform with logging - def runTransform(name: String, mk: => CircuitState, logger: Logger): CircuitState = { - logger.info(s"======== Starting Transform $name ========") - - val (timeMillis, result) = Utils.time(mk) - - logger.info(s"""----------------------------${"-" * name.size}---------\n""") - logger.info(f"Time: $timeMillis%.1f ms") - - result - } - def remapAnnotations(name: String, before: CircuitState, after: CircuitState, logger: Logger): CircuitState = { val remappedAnnotations = propagateAnnotations(name, logger, before.annotations, after.annotations, after.renames) - logger.info(s"Form: ${after.form}") logger.trace(s"Annotations:") logger.trace { JsonProtocol @@ -232,7 +227,6 @@ private[firrtl] object Transform { } logger.trace(s"Circuit:\n${after.circuit.serialize}") - logger.info(s"======== Finished Transform $name ========\n") CircuitState(after.circuit, after.form, remappedAnnotations, None) } @@ -310,7 +304,7 @@ trait Transform extends TransformLike[CircuitState] with DependencyAPI[Transform def transform(state: CircuitState): CircuitState = execute(state) - import firrtl.{ChirrtlForm => C, HighForm => H, MidForm => M, LowForm => L, UnknownForm => U} + import firrtl.CircuitForm.{ChirrtlForm => C, HighForm => H, MidForm => M, LowForm => L, UnknownForm => U} override def prerequisites: Seq[Dependency[Transform]] = inputForm match { case C => Nil @@ -377,7 +371,7 @@ trait Transform extends TransformLike[CircuitState] with DependencyAPI[Transform * @return A transformed Firrtl AST */ final def runTransform(state: CircuitState): CircuitState = { - val result = Transform.runTransform(name, execute(prepare(state)), logger) + val result = execute(prepare(state)) Transform.remapAnnotations(name, state, result, logger) } diff --git a/src/main/scala/firrtl/EmissionOption.scala b/src/main/scala/firrtl/EmissionOption.scala index 074d9c2886..90eb585186 100644 --- a/src/main/scala/firrtl/EmissionOption.scala +++ b/src/main/scala/firrtl/EmissionOption.scala @@ -2,6 +2,8 @@ package firrtl +import firrtl.annotations.MemoryLoadFileType + /** * Base type for emission customization options * NOTE: all the following traits must be mixed with SingleTargetAnnotation[T <: Named] @@ -18,6 +20,7 @@ sealed trait MemoryInitValue case object MemoryRandomInit extends MemoryInitValue case class MemoryScalarInit(value: BigInt) extends MemoryInitValue case class MemoryArrayInit(values: Seq[BigInt]) extends MemoryInitValue +case class MemoryFileInlineInit(filename: String, hexOrBinary: MemoryLoadFileType.FileType) extends MemoryInitValue /** default Emitter behavior for memories */ case object MemoryEmissionOptionDefault extends MemoryEmissionOption diff --git a/src/main/scala/firrtl/FileUtils.scala b/src/main/scala/firrtl/FileUtils.scala index 2cc2961e87..f92d50cc1a 100644 --- a/src/main/scala/firrtl/FileUtils.scala +++ b/src/main/scala/firrtl/FileUtils.scala @@ -71,7 +71,7 @@ object FileUtils { val ioToDevNull = BasicIO(withIn = false, ProcessLogger(line => sb.append(line))) try { - cmd.run(ioToDevNull).exitValue == 0 + cmd.run(ioToDevNull).exitValue() == 0 } catch { case _: Throwable => false } diff --git a/src/main/scala/firrtl/LexerHelper.scala b/src/main/scala/firrtl/LexerHelper.scala index e1327cd75e..b755790a26 100644 --- a/src/main/scala/firrtl/LexerHelper.scala +++ b/src/main/scala/firrtl/LexerHelper.scala @@ -98,7 +98,7 @@ abstract class LexerHelper { if (tokenBuffer.isEmpty) pullToken() else - tokenBuffer.dequeue + tokenBuffer.dequeue() if (reachedEof) t @@ -157,6 +157,6 @@ abstract class LexerHelper { doPop() indentations.push(targetIndent) - tokenBuffer.dequeue + tokenBuffer.dequeue() } } diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index e27cafb7c4..261465e26f 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -31,7 +31,11 @@ class ChirrtlToHighFirrtl extends CoreTransform { class IRToWorkingIR extends CoreTransform { def inputForm = HighForm def outputForm = HighForm - def transforms = new TransformManager(Forms.WorkingIR, Forms.MinimalHighForm).flattenedTransformOrder + def transforms = Seq( + new Transform with firrtl.options.IdentityLike[CircuitState] with DependencyAPIMigration { + override def execute(a: CircuitState) = transform(a) + } + ) } /** Resolves types, kinds, and flows, and checks the circuit legality. @@ -44,7 +48,7 @@ class IRToWorkingIR extends CoreTransform { class ResolveAndCheck extends CoreTransform { def inputForm = HighForm def outputForm = HighForm - def transforms = new TransformManager(Forms.Resolved, Forms.WorkingIR).flattenedTransformOrder + def transforms = new TransformManager(Forms.Resolved, Forms.MinimalHighForm).flattenedTransformOrder } /** Expands aggregate connects, removes dynamic accesses, and when @@ -123,7 +127,7 @@ class NoneCompiler extends Compiler { ) class HighFirrtlCompiler extends Compiler { val emitter = new HighFirrtlEmitter - def transforms: Seq[Transform] = Forms.HighForm.map(_.getObject) + def transforms: Seq[Transform] = Forms.HighForm.map(_.getObject()) } /** Emits middle Firrtl input circuit */ @@ -133,7 +137,7 @@ class HighFirrtlCompiler extends Compiler { ) class MiddleFirrtlCompiler extends Compiler { val emitter = new MiddleFirrtlEmitter - def transforms: Seq[Transform] = Forms.MidForm.map(_.getObject) + def transforms: Seq[Transform] = Forms.MidForm.map(_.getObject()) } /** Emits lowered input circuit */ @@ -143,7 +147,7 @@ class MiddleFirrtlCompiler extends Compiler { ) class LowFirrtlCompiler extends Compiler { val emitter = new LowFirrtlEmitter - def transforms: Seq[Transform] = Forms.LowForm.map(_.getObject) + def transforms: Seq[Transform] = Forms.LowForm.map(_.getObject()) } /** Emits Verilog */ @@ -153,7 +157,7 @@ class LowFirrtlCompiler extends Compiler { ) class VerilogCompiler extends Compiler { val emitter = new VerilogEmitter - def transforms: Seq[Transform] = Forms.LowFormOptimized.map(_.getObject) + def transforms: Seq[Transform] = Forms.LowFormOptimized.map(_.getObject()) } /** Emits Verilog without optimizations */ @@ -163,7 +167,7 @@ class VerilogCompiler extends Compiler { ) class MinimumVerilogCompiler extends Compiler { val emitter = new MinimumVerilogEmitter - def transforms: Seq[Transform] = Forms.LowFormMinimumOptimized.map(_.getObject) + def transforms: Seq[Transform] = Forms.LowFormMinimumOptimized.map(_.getObject()) } /** Currently just an alias for the [[VerilogCompiler]] */ diff --git a/src/main/scala/firrtl/Namespace.scala b/src/main/scala/firrtl/Namespace.scala index 25f4a805f5..a4b7bc7a52 100644 --- a/src/main/scala/firrtl/Namespace.scala +++ b/src/main/scala/firrtl/Namespace.scala @@ -53,9 +53,11 @@ object Namespace { val namespace = new Namespace def buildNamespaceStmt(s: Statement): Seq[String] = s match { - case s: IsDeclaration => Seq(s.name) - case s: Conditionally => buildNamespaceStmt(s.conseq) ++ buildNamespaceStmt(s.alt) - case s: Block => s.stmts.flatMap(buildNamespaceStmt) + // Empty names are allowed for backwards compatibility reasons and + // indicate that the entity has essentially no name. + case s: IsDeclaration if s.name.nonEmpty => Seq(s.name) + case s: Conditionally => buildNamespaceStmt(s.conseq) ++ buildNamespaceStmt(s.alt) + case s: Block => s.stmts.flatMap(buildNamespaceStmt) case _ => Nil } namespace.namespace ++= m.ports.map(_.name) diff --git a/src/main/scala/firrtl/RenameMap.scala b/src/main/scala/firrtl/RenameMap.scala index df98f72f2a..82c00ca5c7 100644 --- a/src/main/scala/firrtl/RenameMap.scala +++ b/src/main/scala/firrtl/RenameMap.scala @@ -4,7 +4,9 @@ package firrtl import annotations._ import firrtl.RenameMap.IllegalRenameException +import firrtl.analyses.InstanceKeyGraph import firrtl.annotations.TargetToken.{Field, Index, Instance, OfModule} +import TargetUtils.{instKeyPathToTarget, unfoldInstanceTargets} import scala.collection.mutable @@ -21,6 +23,58 @@ object RenameMap { rm } + /** RenameMap factory for simple renaming of instances + * + * @param graph [[InstanceKeyGraph]] from *before* renaming + * @param renames Mapping of old instance name to new within Modules + */ + private[firrtl] def fromInstanceRenames( + graph: InstanceKeyGraph, + renames: Map[OfModule, Map[Instance, Instance]] + ): RenameMap = { + def renameAll(it: InstanceTarget): InstanceTarget = { + var prevMod = OfModule(it.module) + val pathx = it.path.map { + case (inst, of) => + val instx = renames + .get(prevMod) + .flatMap(_.get(inst)) + .getOrElse(inst) + prevMod = of + instx -> of + } + // Sanity check, the last one should always be a rename (or we wouldn't be calling this method) + val instx = renames(prevMod)(Instance(it.instance)) + it.copy(path = pathx, instance = instx.value) + } + val underlying = new mutable.HashMap[CompleteTarget, Seq[CompleteTarget]] + val instOf: String => Map[String, String] = + graph.getChildInstances.toMap + // Laziness here is desirable, we only access each key once, some we don't access + .mapValues(_.map(k => k.name -> k.module).toMap) + for ((OfModule(module), instMapping) <- renames) { + val modLookup = instOf(module) + val parentInstances = graph.findInstancesInHierarchy(module) + for { + // For every instance of the Module where the renamed instance resides + parent <- parentInstances + parentTarget = instKeyPathToTarget(parent) + // Create the absolute InstanceTarget to be renamed + (Instance(from), _) <- instMapping // The to is given by renameAll + instMod = modLookup(from) + fromTarget = parentTarget.instOf(from, instMod) + // Ensure all renames apply to the InstanceTarget + toTarget = renameAll(fromTarget) + // RenameMap only allows 1 hit when looking up InstanceTargets, so rename all possible + // paths to this instance + (fromx, tox) <- unfoldInstanceTargets(fromTarget).zip(unfoldInstanceTargets(toTarget)) + } yield { + underlying(fromx) = List(tox) + } + } + new RenameMap(underlying) + } + /** Initialize a new RenameMap */ def apply(): RenameMap = new RenameMap diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 921ec60b33..a58b69975d 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -5,6 +5,7 @@ package firrtl import firrtl.ir._ import firrtl.PrimOps._ import firrtl.Mappers._ +import firrtl.traversals.Foreachers._ import firrtl.WrappedExpression._ import scala.collection.mutable @@ -50,7 +51,18 @@ object getWidth { def apply(e: Expression): Width = apply(e.tpe) } +/** + * Helper object for computing the width of a firrtl type. + */ object bitWidth { + + /** + * Compute the width of a firrtl type. + * For example, a Vec of 4 UInts of width 8 should have a width of 32. + * + * @param dt firrtl type + * @return Width of the given type + */ def apply(dt: Type): BigInt = widthOf(dt) private def widthOf(dt: Type): BigInt = dt match { case t: VectorType => t.size * bitWidth(t.tpe) @@ -210,6 +222,24 @@ object Utils extends LazyLogging { case _ => false } + /** Selects all the elements of this list ignoring the duplicates as determined by == after + * applying the transforming function f + * + * @note In Scala Standard Library starting in 2.13 + */ + def distinctBy[A, B](xs: List[A])(f: A => B): List[A] = { + val buf = new mutable.ListBuffer[A] + val seen = new mutable.HashSet[B] + for (x <- xs) { + val y = f(x) + if (!seen(y)) { + buf += x + seen += y + } + } + buf.toList + } + /** Provide a nice name to create a temporary * */ def niceName(e: Expression): String = niceName(1)(e) def niceName(depth: Int)(e: Expression): String = { @@ -632,7 +662,6 @@ object Utils extends LazyLogging { def get_flow(s: Statement): Flow = s match { case sx: DefWire => DuplexFlow case sx: DefRegister => DuplexFlow - case sx: WDefInstance => SourceFlow case sx: DefNode => SourceFlow case sx: DefInstance => SourceFlow case sx: DefMemory => SourceFlow @@ -650,6 +679,19 @@ object Utils extends LazyLogging { case _ => NoInfo } + /** Finds all root References in a nested Expression */ + def getAllRefs(expr: Expression): Seq[Reference] = { + val refs = mutable.ListBuffer.empty[Reference] + def rec(e: Expression): Unit = { + e match { + case ref: Reference => refs += ref + case other => other.foreach(rec) + } + } + rec(expr) + refs.toList + } + /** Splits an Expression into root Ref and tail * * @example @@ -660,19 +702,24 @@ object Utils extends LazyLogging { * Given: SubField(SubIndex(Ref("b"), 2), "c") * Returns: (Ref("b"), SubField(SubIndex(EmptyExpression, 2), "c")) * b[2].c -> (b, EMPTY[2].c) - * @note This function only supports WRef, WSubField, and WSubIndex + * @note This function only supports [[firrtl.ir.RefLikeExpression RefLikeExpression]]s: [[firrtl.ir.Reference + * Reference]], [[firrtl.ir.SubField SubField]], [[firrtl.ir.SubIndex SubIndex]], and [[firrtl.ir.SubAccess + * SubAccess]] */ def splitRef(e: Expression): (WRef, Expression) = e match { - case e: WRef => (e, EmptyExpression) - case e: WSubIndex => + case e: Reference => (e, EmptyExpression) + case e: SubIndex => val (root, tail) = splitRef(e.expr) - (root, WSubIndex(tail, e.value, e.tpe, e.flow)) - case e: WSubField => + (root, SubIndex(tail, e.value, e.tpe, e.flow)) + case e: SubField => val (root, tail) = splitRef(e.expr) tail match { - case EmptyExpression => (root, WRef(e.name, e.tpe, root.kind, e.flow)) - case exp => (root, WSubField(tail, e.name, e.tpe, e.flow)) + case EmptyExpression => (root, Reference(e.name, e.tpe, root.kind, e.flow)) + case exp => (root, SubField(tail, e.name, e.tpe, e.flow)) } + case e: SubAccess => + val (root, tail) = splitRef(e.expr) + (root, SubAccess(tail, e.index, e.tpe, e.flow)) } /** Adds a root reference to some SubField/SubIndex chain */ @@ -819,7 +866,7 @@ object Utils extends LazyLogging { "scalared", "sequence", "shortint", "shortreal", "showcancelled", "signed", "small", "solve", "specify", "specparam", "static", - "strength", "string", "strong0", "strong1", "struct", "super", + "strength", "string", "strong", "strong0", "strong1", "struct", "super", "supply0", "supply1", "table", "tagged", "task", "this", "throughout", "time", "timeprecision", @@ -830,7 +877,7 @@ object Utils extends LazyLogging { "var", "vectored", "virtual", "void", - "wait", "wait_order", "wand", "weak0", "weak1", "while", + "wait", "wait_order", "wand", "weak", "weak0", "weak1", "while", "wildcard", "wire", "with", "within", "wor", "xnor", "xor", @@ -864,6 +911,79 @@ object Utils extends LazyLogging { def maskBigInt(value: BigInt, width: Int): BigInt = { value & ((BigInt(1) << width) - 1) } + + /** Returns true iff the expression is a Literal or a Literal cast to a different type. */ + def isLiteral(e: Expression): Boolean = e match { + case _: Literal => true + case DoPrim(op, args, _, _) if isCast(op) => args.exists(isLiteral) + case _ => false + } + + /** Applies the firrtl And primop. Automatically constant propagates when one of the expressions is True or False. */ + def and(e1: Expression, e2: Expression): Expression = { + assert(e1.tpe == e2.tpe) + (e1, e2) match { + case (a: UIntLiteral, b: UIntLiteral) => UIntLiteral(a.value | b.value, a.width) + case (True(), b) => b + case (a, True()) => a + case (False(), _) => False() + case (_, False()) => False() + case (a, b) if a == b => a + case (a, b) => DoPrim(PrimOps.And, Seq(a, b), Nil, BoolType) + } + } + + /** Applies the firrtl Eq primop. */ + def eq(e1: Expression, e2: Expression): Expression = DoPrim(PrimOps.Eq, Seq(e1, e2), Nil, BoolType) + + /** Applies the firrtl Or primop. Automatically constant propagates when one of the expressions is True or False. */ + def or(e1: Expression, e2: Expression): Expression = { + assert(e1.tpe == e2.tpe) + (e1, e2) match { + case (a: UIntLiteral, b: UIntLiteral) => UIntLiteral(a.value | b.value, a.width) + case (True(), _) => True() + case (_, True()) => True() + case (False(), b) => b + case (a, False()) => a + case (a, b) if a == b => a + case (a, b) => DoPrim(PrimOps.Or, Seq(a, b), Nil, BoolType) + } + } + + /** Applies the firrtl Not primop. Automatically constant propagates when the expressions is True or False. */ + def not(e: Expression): Expression = e match { + case True() => False() + case False() => True() + case a => DoPrim(PrimOps.Not, Seq(a), Nil, BoolType) + } + + /** implies(e1, e2) = or(not(e1), e2). Automatically constant propagates when one of the expressions is True or False. */ + def implies(e1: Expression, e2: Expression): Expression = or(not(e1), e2) + + /** Builds a Mux expression with the correct type. */ + def mux(cond: Expression, tval: Expression, fval: Expression): Expression = { + require(tval.tpe == fval.tpe) + Mux(cond, tval, fval, tval.tpe) + } + + object True { + private val _True = UIntLiteral(1, IntWidth(1)) + + /** Matches `UInt<1>(1)` */ + def unapply(e: UIntLiteral): Boolean = e.value == 1 && e.width == _True.width + + /** Returns `UInt<1>(1)` */ + def apply(): UIntLiteral = _True + } + object False { + private val _False = UIntLiteral(0, IntWidth(1)) + + /** Matches `UInt<1>(0)` */ + def unapply(e: UIntLiteral): Boolean = e.value == 0 && e.width == _False.width + + /** Returns `UInt<1>(0)` */ + def apply(): UIntLiteral = _False + } } object MemoizedHash { diff --git a/src/main/scala/firrtl/Visitor.scala b/src/main/scala/firrtl/Visitor.scala index 7ba8a0bf79..f1b3a5c201 100644 --- a/src/main/scala/firrtl/Visitor.scala +++ b/src/main/scala/firrtl/Visitor.scala @@ -164,12 +164,12 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w } IntervalType(UnknownBound, UnknownBound, point) case 2 => - val lower = (ctx.lowerBound.getText, ctx.boundValue(0).getText) match { + val lower = ((ctx.lowerBound.getText, ctx.boundValue(0).getText): @unchecked) match { case (_, "?") => UnknownBound case ("(", v) => Open(string2BigDecimal(v)) case ("[", v) => Closed(string2BigDecimal(v)) } - val upper = (ctx.upperBound.getText, ctx.boundValue(1).getText) match { + val upper = ((ctx.upperBound.getText, ctx.boundValue(1).getText): @unchecked) match { case (_, "?") => UnknownBound case (")", v) => Open(string2BigDecimal(v)) case ("]", v) => Closed(string2BigDecimal(v)) @@ -318,6 +318,7 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w private def visitStmt(ctx: StmtContext): Statement = { val ctx_exp = ctx.exp.asScala val info = visitInfo(Option(ctx.info), ctx) + def stmtName = Option(ctx.stmtName).map(_.id.getText).getOrElse("") ctx.getChild(0) match { case when: WhenContext => visitWhen(when) case term: TerminalNode => @@ -346,7 +347,8 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w case "inst" => DefInstance(info, ctx.id(0).getText, ctx.id(1).getText) case "node" => DefNode(info, ctx.id(0).getText, visitExp(ctx_exp(0))) - case "stop(" => Stop(info, string2Int(ctx.intLit().getText), visitExp(ctx_exp(0)), visitExp(ctx_exp(1))) + case "stop(" => + Stop(info, string2Int(ctx.intLit().getText), visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), name = stmtName) case "attach" => Attach(info, ctx_exp.map(visitExp).toSeq) case "printf(" => Print( @@ -354,7 +356,8 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w visitStringLit(ctx.StringLit), ctx_exp.drop(2).map(visitExp).toSeq, visitExp(ctx_exp(0)), - visitExp(ctx_exp(1)) + visitExp(ctx_exp(1)), + name = stmtName ) // formal case "assert" => @@ -364,7 +367,8 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), visitExp(ctx_exp(2)), - visitStringLit(ctx.StringLit) + visitStringLit(ctx.StringLit), + name = stmtName ) case "assume" => Verification( @@ -373,7 +377,8 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), visitExp(ctx_exp(2)), - visitStringLit(ctx.StringLit) + visitStringLit(ctx.StringLit), + name = stmtName ) case "cover" => Verification( @@ -382,7 +387,8 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), visitExp(ctx_exp(2)), - visitStringLit(ctx.StringLit) + visitStringLit(ctx.StringLit), + name = stmtName ) // end formal case "skip" => EmptyStmt diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index a0b85007ac..e9dd95bc31 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -13,6 +13,7 @@ trait Kind case object WireKind extends Kind case object PoisonKind extends Kind case object RegKind extends Kind +case object RandomKind extends Kind case object InstanceKind extends Kind case object PortKind extends Kind case object NodeKind extends Kind @@ -119,6 +120,7 @@ case class WDefInstanceConnector( portCons: Seq[(Expression, Expression)]) extends Statement with IsDeclaration + with CanBeReferenced with UseSerializer { def mapExpr(f: Expression => Expression): Statement = this.copy(portCons = portCons.map { case (e1, e2) => (f(e1), f(e2)) }) @@ -346,6 +348,7 @@ case class CDefMemory( readUnderWrite: ReadUnderWrite.Value = ReadUnderWrite.Undefined) extends Statement with HasInfo + with CanBeReferenced with UseSerializer { def mapExpr(f: Expression => Expression): Statement = this def mapStmt(f: Statement => Statement): Statement = this @@ -361,6 +364,7 @@ case class CDefMemory( case class CDefMPort(info: Info, name: String, tpe: Type, mem: String, exps: Seq[Expression], direction: MPortDir) extends Statement with HasInfo + with CanBeReferenced with UseSerializer { def mapExpr(f: Expression => Expression): Statement = this.copy(exps = exps.map(f)) def mapStmt(f: Statement => Statement): Statement = this diff --git a/src/main/scala/firrtl/analyses/ConnectionGraph.scala b/src/main/scala/firrtl/analyses/ConnectionGraph.scala index e5e3bde2dc..85cbe4df65 100644 --- a/src/main/scala/firrtl/analyses/ConnectionGraph.scala +++ b/src/main/scala/firrtl/analyses/ConnectionGraph.scala @@ -147,7 +147,7 @@ class ConnectionGraph protected (val circuit: Circuit, val digraph: DiGraph[Refe val bfsQueue = new mutable.PriorityQueue[ReferenceTarget]()(ordering) bfsQueue.enqueue(root) while (bfsQueue.nonEmpty) { - val u = bfsQueue.dequeue + val u = bfsQueue.dequeue() for (v <- getEdges(u)) { if (!prev.contains(v) && !blacklist.contains(v)) { prev(v) = u @@ -416,6 +416,7 @@ object ConnectionGraph { case firrtl.ir.Field(name, Default, tpe) => Utils.create_exps(Reference(name, tpe, PortKind, SourceFlow)) // Module input case firrtl.ir.Field(name, Flip, tpe) => Utils.create_exps(Reference(name, tpe, PortKind, SinkFlow)) + case x => Utils.error(s"Unexpected flip: ${x.flip}") } assert(instPorts.size == modulePorts.size) val o = m.circuitTarget.module(ofModule) diff --git a/src/main/scala/firrtl/analyses/InstanceGraph.scala b/src/main/scala/firrtl/analyses/InstanceGraph.scala index 0017ff8ba1..83a04ca389 100644 --- a/src/main/scala/firrtl/analyses/InstanceGraph.scala +++ b/src/main/scala/firrtl/analyses/InstanceGraph.scala @@ -45,7 +45,7 @@ class InstanceGraph(c: Circuit) { val topInstance = DefInstance(subTop, subTop) instanceQueue.enqueue(topInstance) while (instanceQueue.nonEmpty) { - val current = instanceQueue.dequeue + val current = instanceQueue.dequeue() instanceGraph.addVertex(current) for (child <- childInstances(current.module)) { if (!instanceGraph.contains(child)) { @@ -185,7 +185,6 @@ object InstanceGraph { @deprecated("Use InstanceKeyGraph.collectInstances instead.", "FIRRTL 1.4") def collectInstances(insts: mutable.Set[DefInstance])(s: Statement): Unit = s match { case i: DefInstance => insts += i - case i: DefInstance => throwInternalError("Expecting DefInstance, found a DefInstance!") case i: WDefInstanceConnector => throwInternalError("Expecting DefInstance, found a DefInstanceConnector!") case _ => s.foreach(collectInstances(insts)) } diff --git a/src/main/scala/firrtl/analyses/InstanceKeyGraph.scala b/src/main/scala/firrtl/analyses/InstanceKeyGraph.scala index 06c0b3774f..7584e3c84c 100644 --- a/src/main/scala/firrtl/analyses/InstanceKeyGraph.scala +++ b/src/main/scala/firrtl/analyses/InstanceKeyGraph.scala @@ -158,7 +158,7 @@ object InstanceKeyGraph { val instanceQueue = new mutable.Queue[InstanceKey] instanceQueue.enqueue(topInstance) while (instanceQueue.nonEmpty) { - val current = instanceQueue.dequeue + val current = instanceQueue.dequeue() instanceGraph.addVertex(current) for (child <- childInstanceMap(current.module)) { if (!instanceGraph.contains(child)) { diff --git a/src/main/scala/firrtl/analyses/SymbolTable.scala b/src/main/scala/firrtl/analyses/SymbolTable.scala index 3b304bc1f3..e4a534444f 100644 --- a/src/main/scala/firrtl/analyses/SymbolTable.scala +++ b/src/main/scala/firrtl/analyses/SymbolTable.scala @@ -87,6 +87,10 @@ object SymbolTable { case d: DefNode => table.declare(d) case d: DefWire => table.declare(d) case d: DefRegister => table.declare(d) + // Matches named statements like printf, stop, assert, assume, cover if the name is not empty. + // Empty names are allowed for backwards compatibility reasons and + // indicate that the entity has essentially no name. + case s: IsDeclaration if s.name.nonEmpty => table.declare(s.name, UnknownType, firrtl.UnknownKind) case other => other.foreachStmt(scanStatement) } } diff --git a/src/main/scala/firrtl/annotations/Annotation.scala b/src/main/scala/firrtl/annotations/Annotation.scala index 5f7921279f..b5c9c7e0da 100644 --- a/src/main/scala/firrtl/annotations/Annotation.scala +++ b/src/main/scala/firrtl/annotations/Annotation.scala @@ -71,7 +71,7 @@ trait SingleTargetAnnotation[T <: Named] extends Annotation { case c: CircuitTarget => c.toNamed case other => throw Target.NamedException(s"Cannot convert $other to [[Named]]") } - Target.convertTarget2Named(result) match { + (Target.convertTarget2Named(result): @unchecked) match { case newTarget: T @unchecked => try { duplicate(newTarget) diff --git a/src/main/scala/firrtl/annotations/MemoryInitAnnotation.scala b/src/main/scala/firrtl/annotations/MemoryInitAnnotation.scala index 44656e0d08..1e81301d75 100644 --- a/src/main/scala/firrtl/annotations/MemoryInitAnnotation.scala +++ b/src/main/scala/firrtl/annotations/MemoryInitAnnotation.scala @@ -2,7 +2,14 @@ package firrtl.annotations -import firrtl.{MemoryArrayInit, MemoryEmissionOption, MemoryInitValue, MemoryRandomInit, MemoryScalarInit} +import firrtl.{ + MemoryArrayInit, + MemoryEmissionOption, + MemoryFileInlineInit, + MemoryInitValue, + MemoryRandomInit, + MemoryScalarInit +} /** * Represents the initial value of the annotated memory. @@ -33,3 +40,21 @@ case class MemoryArrayInitAnnotation(target: ReferenceTarget, values: Seq[BigInt override def initValue: MemoryInitValue = MemoryArrayInit(values) override def isRandomInit: Boolean = false } + +/** Initialize the `target` memory with inline readmem[hb] statement. */ +case class MemoryFileInlineAnnotation( + target: ReferenceTarget, + filename: String, + hexOrBinary: MemoryLoadFileType.FileType = MemoryLoadFileType.Hex) + extends MemoryInitAnnotation { + require(filename.trim.nonEmpty, "empty filename not allowed in MemoryFileInlineAnnotation") + override def duplicate(n: ReferenceTarget): Annotation = copy(n) + override def initValue: MemoryInitValue = MemoryFileInlineInit(filename, hexOrBinary) + override def isRandomInit: Boolean = false +} + +/** Initializes the memory inside the `ifndef SYNTHESIS` block (default) */ +case object MemoryNoSynthInit extends NoTargetAnnotation + +/** Initializes the memory outside the `ifndef SYNTHESIS` block */ +case object MemorySynthInit extends NoTargetAnnotation diff --git a/src/main/scala/firrtl/annotations/Target.scala b/src/main/scala/firrtl/annotations/Target.scala index 137d070e15..92339946b7 100644 --- a/src/main/scala/firrtl/annotations/Target.scala +++ b/src/main/scala/firrtl/annotations/Target.scala @@ -265,7 +265,7 @@ case class GenericTarget(circuitOpt: Option[String], moduleOpt: Option[String], case GenericTarget(Some(c), Some(m), Instance(i) +: OfModule(o) +: Vector()) => InstanceTarget(c, m, Nil, i, o) case GenericTarget(Some(c), Some(m), component) => val path = getPath.getOrElse(Nil) - (getRef, getInstanceOf) match { + ((getRef, getInstanceOf): @unchecked) match { case (Some((r, comps)), _) => ReferenceTarget(c, m, path, r, comps) case (None, Some((i, o))) => InstanceTarget(c, m, path, i, o) } @@ -516,6 +516,7 @@ trait IsComponent extends IsMember { case ("", Ref(name)) => name case (string, Field(value)) => s"$string.$value" case (string, Index(value)) => s"$string[$value]" + case (_, token) => Utils.error(s"Unexpected token: $token") } ComponentName(name, mn) case Seq(Instance(name), OfModule(o)) => ComponentName(name, mn) @@ -612,7 +613,7 @@ case class ModuleTarget(circuit: String, module: String) extends IsModule { } /** Target pointing to a declared named component in a [[firrtl.ir.DefModule]] - * This includes: [[firrtl.ir.Port]], [[firrtl.ir.DefWire]], [[firrtl.ir.DefRegister]], [[firrtl.ir.DefInstance]], + * This includes: [[firrtl.ir.Port]], [[firrtl.ir.DefWire]], [[firrtl.ir.DefRegister]], * [[firrtl.ir.DefMemory]], [[firrtl.ir.DefNode]] * @param circuit Name of the encapsulating circuit * @param module Name of the root module of this reference @@ -660,6 +661,7 @@ case class ReferenceTarget( case Index(idx) => sub_type(baseType) case Field(field) => field_type(baseType, field) case _: Ref => baseType + case token => Utils.error(s"Unexpected token $token") } componentType(headType, tokens.tail) } diff --git a/src/main/scala/firrtl/annotations/TargetUtils.scala b/src/main/scala/firrtl/annotations/TargetUtils.scala new file mode 100644 index 0000000000..164c430b55 --- /dev/null +++ b/src/main/scala/firrtl/annotations/TargetUtils.scala @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.annotations + +import firrtl._ +import firrtl.analyses.InstanceKeyGraph +import firrtl.analyses.InstanceKeyGraph.InstanceKey +import firrtl.annotations.TargetToken._ + +object TargetUtils { + + /** Turns an instance path into a corresponding [[IsModule]] + * + * @note First InstanceKey is treated as the [[CircuitTarget]] + * @param path Instance path + * @param start Module in instance path to be starting [[ModuleTarget]] + * @return [[IsModule]] corresponding to Instance path + */ + def instKeyPathToTarget(path: Seq[InstanceKey], start: Option[String] = None): IsModule = { + val head = path.head + val startx = start.getOrElse(head.module) + val top: IsModule = CircuitTarget(head.module).module(startx) // ~Top|Start + val pathx = path.dropWhile(_.module != startx) + if (pathx.isEmpty) top + else pathx.tail.foldLeft(top) { case (acc, key) => acc.instOf(key.name, key.module) } + } + + /** Calculates all [[InstanceTarget]]s that refer to the given [[IsModule]] + * + * {{{ + * ~Top|Top/a:A/b:B/c:C unfolds to: + * * ~Top|Top/a:A/b:B/c:C + * * ~Top|A/b:B/c:C + * * ~Top|B/c:C + * }}} + * @note [[ModuleTarget]] arguments return an empty Iterable + */ + def unfoldInstanceTargets(ismod: IsModule): Iterable[InstanceTarget] = { + // concretely use List which is fast in practice + def rec(im: IsModule): List[InstanceTarget] = im match { + case inst: InstanceTarget => inst :: rec(inst.stripHierarchy(1)) + case _ => Nil + } + rec(ismod) + } +} diff --git a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala index 4cd5c9f7a7..f96fd4e80f 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala @@ -3,12 +3,28 @@ package firrtl.backends.experimental.smt +import firrtl.backends.experimental.smt.Btor2Serializer.functionCallToArrayRead + import scala.collection.mutable private object Btor2Serializer { def serialize(sys: TransitionSystem, skipOutput: Boolean = false): Iterable[String] = { new Btor2Serializer().run(sys, skipOutput) } + + private def functionCallToArrayRead(call: BVFunctionCall): BVExpr = { + if (call.args.isEmpty) { + BVSymbol(call.name, call.width) + } else { + val index = concat(call.args) + val a = ArraySymbol(call.name, indexWidth = index.width, dataWidth = call.width) + ArrayRead(a, index) + } + } + private def concat(e: Iterable[BVExpr]): BVExpr = { + require(e.nonEmpty) + e.reduce((a, b) => BVConcat(a, b)) + } } private class Btor2Serializer private () { @@ -65,6 +81,7 @@ private class Btor2Serializer private () { case BVComparison(Compare.GreaterEqual, a, b, true) => binary("sgte", expr.width, a, b) case BVOp(op, a, b) => binary(s(op), expr.width, a, b) case BVConcat(a, b) => binary("concat", expr.width, a, b) + case call: BVFunctionCall => s(functionCallToArrayRead(call)) case ArrayRead(array, index) => line(s"read ${t(expr.width)} ${s(array)} ${s(index)}") case BVIte(cond, tru, fals) => @@ -164,6 +181,17 @@ private class Btor2Serializer private () { declare(ii.name, line(s"input ${t(ii.width)} ${ii.name}")) } + // declare uninterpreted functions a constant arrays + sys.ufs.foreach { foo => + val sym = if (foo.argWidths.isEmpty) { BVSymbol(foo.name, foo.width) } + else { + ArraySymbol(foo.name, foo.argWidths.sum, foo.width) + } + comment(foo.toString) + declare(sym.name, line(s"state ${t(sym)} ${sym.name}")) + line(s"next ${t(sym)} ${s(sym)} ${s(sym)}") + } + // define state init sys.states.foreach { st => // calculate init expression before declaring the state diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala index d85fbfe5fb..099b671218 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala @@ -8,21 +8,10 @@ import firrtl.PrimOps import firrtl.passes.CheckWidths.WidthTooBig private trait TranslationContext { - def getReference(name: String, tpe: ir.Type): BVExpr = BVSymbol(name, FirrtlExpressionSemantics.getWidth(tpe)) - def getRandom(tpe: ir.Type): BVExpr = getRandom(FirrtlExpressionSemantics.getWidth(tpe)) - def getRandom(width: Int): BVExpr + def getReference(name: String, tpe: ir.Type): BVExpr = BVSymbol(name, firrtl.bitWidth(tpe).toInt) } private object FirrtlExpressionSemantics { - def getWidth(tpe: ir.Type): Int = tpe match { - case ir.UIntType(ir.IntWidth(w)) => w.toInt - case ir.SIntType(ir.IntWidth(w)) => w.toInt - case ir.ClockType => 1 - case ir.ResetType => 1 - case ir.AnalogType(ir.IntWidth(w)) => w.toInt - case other => throw new RuntimeException(s"Cannot handle type $other") - } - def toSMT(e: ir.Expression)(implicit ctx: TranslationContext): BVExpr = { val eSMT = e match { case ir.DoPrim(op, args, consts, _) => onPrim(op, args, consts) @@ -34,9 +23,8 @@ private object FirrtlExpressionSemantics { case ir.Mux(cond, tval, fval, _) => val width = List(tval, fval).map(getWidth).max BVIte(toSMT(cond), toSMT(tval, width), toSMT(fval, width)) - case ir.ValidIf(cond, value, tpe) => - val tru = toSMT(value) - BVIte(toSMT(cond), tru, ctx.getRandom(tpe)) + case v: ir.ValidIf => + throw new RuntimeException(s"Unsupported expression: ValidIf ${v.serialize}") } assert( eSMT.width == getWidth(e), @@ -81,15 +69,7 @@ private object FirrtlExpressionSemantics { val (width, op) = if (isSigned(num)) { (getWidth(num) + 1, Op.SignedDiv) } else { (getWidth(num), Op.UnsignedDiv) } - // "The result of a division where den is zero is undefined." - val undef = ctx.getRandom(width) - val denSMT = toSMT(den) - val denIsZero = BVEqual(denSMT, BVLiteral(0, denSMT.width)) - val numByDen = BVOp(op, toSMT(num, width), forceWidth(denSMT, isSigned(den), width)) - BVIte(denIsZero, undef, numByDen) - case (PrimOps.Div, Seq(num, den), _) if isSigned(num) => - val width = getWidth(num) + 1 - BVOp(Op.SignedDiv, toSMT(num, width), toSMT(den, width)) + BVOp(op, toSMT(num, width), forceWidth(toSMT(den), isSigned(den), width)) case (PrimOps.Rem, Seq(num, den), _) => val op = if (isSigned(num)) Op.SignedRem else Op.UnsignedRem val width = args.map(getWidth).max @@ -194,5 +174,7 @@ private object FirrtlExpressionSemantics { case _: ir.SIntType => true case _ => false } - private def getWidth(e: ir.Expression): Int = getWidth(e.tpe) + + // Helper function + private def getWidth(e: ir.Expression): Int = firrtl.bitWidth(e.tpe).toInt } diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala index aed2011af5..cfab61b99c 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala @@ -4,13 +4,17 @@ package firrtl.backends.experimental.smt import firrtl.annotations.{MemoryInitAnnotation, NoTargetAnnotation, PresetRegAnnotation} +import firrtl.bitWidth import FirrtlExpressionSemantics.getWidth +import firrtl.backends.experimental.smt.random._ import firrtl.graph.MutableDiGraph import firrtl.options.Dependency +import firrtl.passes.MemPortUtils.memPortField import firrtl.passes.PassException +import firrtl.passes.memlib.VerilogMemDelays import firrtl.stage.Forms import firrtl.stage.TransformManager.TransformDependency -import firrtl.transforms.PropagatePresetAnnotations +import firrtl.transforms.{DeadCodeElimination, PropagatePresetAnnotations} import firrtl.{ ir, CircuitState, @@ -40,10 +44,12 @@ private case class TransitionSystem( assumes: Set[String], asserts: Set[String], fair: Set[String], + ufs: List[BVFunctionSymbol] = List(), comments: Map[String, String] = Map(), header: Array[String] = Array()) { def serialize: String = { (Iterator(name) ++ + ufs.map(u => u.toString) ++ inputs.map(i => s"input ${i.name} : ${SMTExpr.serializeType(i)}") ++ signals.map(s => s"${s.name} : ${SMTExpr.serializeType(s.e)} = ${s.e}") ++ states.map(s => s"state ${s.sym} = [init] ${s.init} [next] ${s.next}")).mkString("\n") @@ -56,7 +62,8 @@ object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration { // TODO: We only really need [[Forms.MidForm]] + LowerTypes, but we also want to fail if there are CombLoops // TODO: We also would like to run some optimization passes, but RemoveValidIf won't allow us to model DontCare // precisely and PadWidths emits ill-typed firrtl. - override def prerequisites: Seq[Dependency[Transform]] = Forms.LowForm + override def prerequisites: Seq[Dependency[Transform]] = Forms.LowForm ++ + Seq(Dependency(InvalidToRandomPass), Dependency(UndefinedMemoryBehaviorPass), Dependency(VerilogMemDelays)) override def invalidates(a: Transform): Boolean = false // since this pass only runs on the main module, inlining needs to happen before override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency[firrtl.passes.InlineInstances]) @@ -65,9 +72,12 @@ object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration { // Verilog emission passes. // Ideally we would go in and enable the [[PropagatePresetAnnotations]] to only depend on LowForm. private val presetPass = new PropagatePresetAnnotations + // We also need to run the DeadCodeElimination since PropagatePresets does not remove possible remaining + // AsyncReset nodes. + private val deadCodeElimination = new DeadCodeElimination override protected def execute(state: CircuitState): CircuitState = { // run the preset pass to extract all preset registers and remove preset reset signals - val afterPreset = presetPass.execute(state) + val afterPreset = deadCodeElimination.execute(presetPass.execute(state)) val circuit = afterPreset.circuit val presetRegs = afterPreset.annotations.collect { case PresetRegAnnotation(target) if target.module == circuit.main => target.ref @@ -79,15 +89,25 @@ object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration { .map(a => a.target.ref -> a.initValue) .toMap + // module look up table + val modules = circuit.modules.map(m => m.name -> m).toMap + + // collect uninterpreted module annotations + val uninterpreted = afterPreset.annotations.collect { + case a: UninterpretedModuleAnnotation => + UninterpretedModuleAnnotation.checkModule(modules(a.target.module), a) + a.target.module -> a + }.toMap + // convert the main module - val main = circuit.modules.find(_.name == circuit.main).get + val main = modules(circuit.main) val sys = main match { case x: ir.ExtModule => throw new ExtModuleException( "External modules are not supported by the SMT backend. Use yosys if you need to convert Verilog." ) case m: ir.Module => - new ModuleToTransitionSystem().run(m, presetRegs = presetRegs, memInit = memInit) + new ModuleToTransitionSystem().run(m, presetRegs = presetRegs, memInit = memInit, uninterpreted = uninterpreted) } val sortedSys = TopologicalSort.run(sys) @@ -122,12 +142,13 @@ private class MissingFeatureException(s: String) private class ModuleToTransitionSystem extends LazyLogging { def run( - m: ir.Module, - presetRegs: Set[String] = Set(), - memInit: Map[String, MemoryInitValue] = Map() + m: ir.Module, + presetRegs: Set[String] = Set(), + memInit: Map[String, MemoryInitValue] = Map(), + uninterpreted: Map[String, UninterpretedModuleAnnotation] = Map() ): TransitionSystem = { // first pass over the module to convert expressions; discover state and I/O - val scan = new ModuleScanner(makeRandom) + val scan = new ModuleScanner(uninterpreted) m.foreachPort(scan.onPort) m.foreachStmt(scan.onStatement) @@ -153,8 +174,7 @@ private class ModuleToTransitionSystem extends LazyLogging { onRegister(name, width, resetExpr, initExpr, nextExpr, presetRegs) } // turn memories into state - val memoryEncoding = new MemoryEncoding(makeRandom, scan.namespace) - val memoryStatesAndOutputs = scan.memories.map(m => memoryEncoding.onMemory(m, scan.connects, memInit.get(m.name))) + val memoryStatesAndOutputs = scan.memories.map(m => onMemory(m, scan.connects, memInit.get(m.name))) // replace pseudo assigns for memory outputs val memOutputs = memoryStatesAndOutputs.flatMap(_._2).toMap val signalsWithMem = signals.map { s => @@ -169,7 +189,7 @@ private class ModuleToTransitionSystem extends LazyLogging { case _ => true } ) - val states = regStates.toArray ++ memoryStatesAndOutputs.flatMap(_._1) + val states = regStates.toArray ++ memoryStatesAndOutputs.map(_._1) // generate comments from infos val comments = mutable.HashMap[String, String]() @@ -181,13 +201,17 @@ private class ModuleToTransitionSystem extends LazyLogging { } } - // inputs are original module inputs and any "random" signal we need for modelling - val inputs = scan.inputs ++ randoms.values + // inputs are original module inputs and any DefRandom signal + val inputs = scan.inputs // module info to the comment header val header = serializeInfo(m.info).map(InfoPrefix + _).toArray val fair = Set[String]() // as of firrtl 1.4 we do not support fairness constraints + + // collect unique functions + val ufs = scan.functionCalls.groupBy(_.name).map(_._2.head).toList + TransitionSystem( m.name, inputs.toArray, @@ -197,6 +221,7 @@ private class ModuleToTransitionSystem extends LazyLogging { constraints, bad, fair, + ufs, comments.toMap, header ) @@ -226,237 +251,112 @@ private class ModuleToTransitionSystem extends LazyLogging { } } - private val InfoSeparator = ", " - private val InfoPrefix = "@ " - private def serializeInfo(info: ir.Info): Option[String] = info match { - case ir.NoInfo => None - case f: ir.FileInfo => Some(f.escaped) - case m: ir.MultiInfo => - val infos = m.flatten - if (infos.isEmpty) { None } - else { Some(infos.map(_.escaped).mkString(InfoSeparator)) } - } - - private[firrtl] val randoms = mutable.LinkedHashMap[String, BVSymbol]() - private def makeRandom(baseName: String, width: Int): BVExpr = { - // TODO: actually ensure that there cannot be any name clashes with other identifiers - val suffixes = Iterator(baseName) ++ (0 until 200).map(ii => baseName + "_" + ii) - val name = suffixes.map(s => "RANDOM." + s).find(!randoms.contains(_)).get - val sym = BVSymbol(name, width) - randoms(name) = sym - sym - } -} - -private class MemoryEncoding(makeRandom: (String, Int) => BVExpr, namespace: Namespace) extends LazyLogging { type Connects = Iterable[(String, BVExpr)] - def onMemory( - defMem: ir.DefMemory, - connects: Connects, - initValue: Option[MemoryInitValue] - ): (Iterable[State], Connects) = { - // we can only work on appropriately lowered memories - assert( - defMem.dataType.isInstanceOf[ir.GroundType], - s"Memory $defMem is of type ${defMem.dataType} which is not a ground type!" - ) - assert(defMem.readwriters.isEmpty, "Combined read/write ports are not supported! Please split them up.") + private def onMemory(m: ir.DefMemory, connects: Connects, initValue: Option[MemoryInitValue]): (State, Connects) = { + checkMem(m) - // collect all memory meta-data in a custom class - val m = new MemInfo(defMem) + // map of inputs to the memory + val inputs = connects.filter(_._1.startsWith(m.name)).toMap - // find all connections related to this memory - val inputs = connects.filter(_._1.startsWith(m.prefix)).toMap + // derive the type of the memory from the dataType and depth + val dataWidth = bitWidth(m.dataType).toInt + val indexWidth = Utils.getUIntWidth(m.depth - 1).max(1) + val memSymbol = ArraySymbol(m.name, indexWidth, dataWidth) // there could be a constant init - val init = initValue.map(getInit(m, _)) - - // parse and check read and write ports - val writers = defMem.writers.map(w => new WritePort(m, w, inputs)) - val readers = defMem.readers.map(r => new ReadPort(m, r, inputs)) - - // derive next state from all write ports - assert(defMem.writeLatency == 1, "Only memories with write-latency of one are supported.") - val next: ArrayExpr = if (writers.isEmpty) { m.sym } - else { - if (writers.length > 2) { - throw new UnsupportedFeatureException(s"memories with 3+ write ports (${m.name})") - } - val validData = writers.foldLeft[ArrayExpr](m.sym) { case (sym, w) => w.writeTo(sym) } - if (writers.length == 1) { validData } - else { - assert(writers.length == 2) - val conflict = writers.head.doesConflict(writers.last) - val conflictData = writers.head.makeRandomData("_write_write_collision") - val conflictStore = ArrayStore(m.sym, writers.head.addr, conflictData) - ArrayIte(conflict, conflictStore, validData) - } - } - val state = State(m.sym, init, Some(next)) + val init = initValue.map(getInit(m, indexWidth, dataWidth, _)) + init.foreach(e => assert(e.dataWidth == memSymbol.dataWidth && e.indexWidth == memSymbol.indexWidth)) - // derive data signals from all read ports - assert(defMem.readLatency >= 0) - if (defMem.readLatency > 1) { - throw new UnsupportedFeatureException(s"memories with read latency 2+ (${m.name})") - } - val readPortSignals = if (defMem.readLatency == 0) { - readers.map { r => - // combinatorial read - if (defMem.readUnderWrite != ir.ReadUnderWrite.New) { - logger.warn( - s"WARN: Memory ${m.name} with combinatorial read port will always return the most recently written entry." + - s" The read-under-write => ${defMem.readUnderWrite} setting will be ignored." - ) - } - // since we do a combinatorial read, the "old" data is the current data - val data = r.read() - r.data.name -> data - } - } else { Seq() } - val readPortSignalsAndStates = if (defMem.readLatency == 1) { - readers.map { r => - defMem.readUnderWrite match { - case ir.ReadUnderWrite.New => - // create a state to save the address and the enable signal - val enPrev = BVSymbol(namespace.newName(r.en.name + "_prev"), r.en.width) - val addrPrev = BVSymbol(namespace.newName(r.addr.name + "_prev"), r.addr.width) - val signal = r.data.name -> r.read(addr = addrPrev, en = enPrev) - val states = Seq(State(enPrev, None, next = Some(r.en)), State(addrPrev, None, next = Some(r.addr))) - (Seq(signal), states) - case ir.ReadUnderWrite.Undefined => - // check for potential read/write conflicts in which case we need to return an arbitrary value - val anyWriteToTheSameAddress = any(writers.map(_.doesConflict(r))) - val next = if (anyWriteToTheSameAddress == False) { r.read() } - else { - val readUnderWriteData = r.makeRandomData("_read_under_write_undefined") - BVIte(anyWriteToTheSameAddress, readUnderWriteData, r.read()) - } - (Seq(), Seq(State(r.data, init = None, next = Some(next)))) - case ir.ReadUnderWrite.Old => - // we create a register for the read port data - (Seq(), Seq(State(r.data, init = None, next = Some(r.read())))) - } + // derive next state expression + val next = if (m.writers.isEmpty) { + memSymbol + } else { + m.writers.foldLeft[ArrayExpr](memSymbol) { + case (prev, write) => + // update + val addr = BVSymbol(memPortField(m, write, "addr").serialize, indexWidth) + val data = BVSymbol(memPortField(m, write, "data").serialize, dataWidth) + val update = ArrayStore(prev, index = addr, data = data) + + // update guard + val en = BVSymbol(memPortField(m, write, "en").serialize, 1) + val mask = BVSymbol(memPortField(m, write, "mask").serialize, 1) + val alwaysEnabled = Seq(en, mask).forall(s => inputs(s.name) == True) + if (alwaysEnabled) { update } + else { + ArrayIte(and(en, mask), update, prev) + } } - } else { Seq() } + } - val allReadPortSignals = readPortSignals ++ readPortSignalsAndStates.flatMap(_._1) - val readPortStates = readPortSignalsAndStates.flatMap(_._2) + val state = State(memSymbol, init, Some(next)) - (state +: readPortStates, allReadPortSignals) - } + // derive read expressions + val readSignals = m.readers.map { read => + val addr = BVSymbol(memPortField(m, read, "addr").serialize, indexWidth) + memPortField(m, read, "data").serialize -> ArrayRead(memSymbol, addr) + } - private def getInit(m: MemInfo, initValue: MemoryInitValue): ArrayExpr = initValue match { - case MemoryScalarInit(value) => ArrayConstant(BVLiteral(value, m.dataWidth), m.indexWidth) - case MemoryArrayInit(values) => - assert( - values.length == m.depth, - s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!" - ) - // in order to get a more compact encoding try to find the most common values - val histogram = mutable.LinkedHashMap[BigInt, Int]() - values.foreach(v => histogram(v) = 1 + histogram.getOrElse(v, 0)) - val baseValue = histogram.maxBy(_._2)._1 - val base = ArrayConstant(BVLiteral(baseValue, m.dataWidth), m.indexWidth) - values.zipWithIndex - .filterNot(_._1 == baseValue) - .foldLeft[ArrayExpr](base) { - case (array, (value, index)) => - ArrayStore(array, BVLiteral(index, m.indexWidth), BVLiteral(value, m.dataWidth)) - } - case other => throw new RuntimeException(s"Unsupported memory init option: $other") + (state, readSignals) } - private class MemInfo(m: ir.DefMemory) { - val name = m.name - val depth = m.depth - // derrive the type of the memory from the dataType and depth - val dataWidth = getWidth(m.dataType) - val indexWidth = Utils.getUIntWidth(m.depth - 1).max(1) - val sym = ArraySymbol(m.name, indexWidth, dataWidth) - val prefix = m.name + "." - val fullAddressRange = (BigInt(1) << indexWidth) == m.depth - lazy val depthBV = BVLiteral(m.depth, indexWidth) - def isValidAddress(addr: BVExpr): BVExpr = { - if (fullAddressRange) { True } - else { - BVComparison(Compare.Greater, depthBV, addr, signed = false) - } - } - } - private abstract class MemPort(memory: MemInfo, val name: String, inputs: String => BVExpr) { - val en: BVSymbol = makeField("en", 1) - val data: BVSymbol = makeField("data", memory.dataWidth) - val addr: BVSymbol = makeField("addr", memory.indexWidth) - protected def makeField(field: String, width: Int): BVSymbol = BVSymbol(memory.prefix + name + "." + field, width) - // make sure that all widths are correct - assert(inputs(en.name).width == en.width) - assert(inputs(addr.name).width == addr.width) - val enIsTrue: Boolean = inputs(en.name) == True - def makeRandomData(suffix: String): BVExpr = - makeRandom(memory.name + "_" + name + suffix, memory.dataWidth) - def read(addr: BVSymbol = addr, en: BVSymbol = en): BVExpr = { - val canBeOutOfRange = !memory.fullAddressRange - val canBeDisabled = !enIsTrue - val data = ArrayRead(memory.sym, addr) - val dataWithRangeCheck = if (canBeOutOfRange) { - val outOfRangeData = makeRandomData("_addr_out_of_range") - BVIte(memory.isValidAddress(addr), data, outOfRangeData) - } else { data } - val dataWithEnabledCheck = if (canBeDisabled) { - val disabledData = makeRandomData("_not_enabled") - BVIte(en, dataWithRangeCheck, disabledData) - } else { dataWithRangeCheck } - dataWithEnabledCheck - } - } - private class WritePort(memory: MemInfo, name: String, inputs: String => BVExpr) - extends MemPort(memory, name, inputs) { - assert(inputs(data.name).width == data.width) - val mask: BVSymbol = makeField("mask", 1) - assert(inputs(mask.name).width == mask.width) - val maskIsTrue: Boolean = inputs(mask.name) == True - val doWrite: BVExpr = (enIsTrue, maskIsTrue) match { - case (true, true) => True - case (true, false) => mask - case (false, true) => en - case (false, false) => and(en, mask) - } - def doesConflict(r: ReadPort): BVExpr = { - val sameAddress = BVEqual(r.addr, addr) - if (doWrite == True) { sameAddress } - else { and(doWrite, sameAddress) } - } - def doesConflict(w: WritePort): BVExpr = { - val bothWrite = and(doWrite, w.doWrite) - val sameAddress = BVEqual(addr, w.addr) - if (bothWrite == True) { sameAddress } - else { and(bothWrite, sameAddress) } - } - def writeTo(array: ArrayExpr): ArrayExpr = { - val doUpdate = if (memory.fullAddressRange) doWrite else and(doWrite, memory.isValidAddress(addr)) - val update = ArrayStore(array, index = addr, data = data) - if (doUpdate == True) update else ArrayIte(doUpdate, update, array) + private def getInit(m: ir.DefMemory, indexWidth: Int, dataWidth: Int, initValue: MemoryInitValue): ArrayExpr = + initValue match { + case MemoryScalarInit(value) => ArrayConstant(BVLiteral(value, dataWidth), indexWidth) + case MemoryArrayInit(values) => + assert( + values.length == m.depth, + s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!" + ) + // in order to get a more compact encoding try to find the most common values + val histogram = mutable.LinkedHashMap[BigInt, Int]() + values.foreach(v => histogram(v) = 1 + histogram.getOrElse(v, 0)) + val baseValue = histogram.maxBy(_._2)._1 + val base = ArrayConstant(BVLiteral(baseValue, dataWidth), indexWidth) + values.zipWithIndex + .filterNot(_._1 == baseValue) + .foldLeft[ArrayExpr](base) { + case (array, (value, index)) => + ArrayStore(array, BVLiteral(index, indexWidth), BVLiteral(value, dataWidth)) + } + case other => throw new RuntimeException(s"Unsupported memory init option: $other") } - } - private class ReadPort(memory: MemInfo, name: String, inputs: String => BVExpr) - extends MemPort(memory, name, inputs) {} - + // TODO: add to BV expression library private def and(a: BVExpr, b: BVExpr): BVExpr = (a, b) match { case (True, True) => True case (True, x) => x case (x, True) => x case _ => BVOp(Op.And, a, b) } - private def or(a: BVExpr, b: BVExpr): BVExpr = BVOp(Op.Or, a, b) + private val True = BVLiteral(1, 1) - private val False = BVLiteral(0, 1) - private def all(b: Iterable[BVExpr]): BVExpr = if (b.isEmpty) False else b.reduce((a, b) => and(a, b)) - private def any(b: Iterable[BVExpr]): BVExpr = if (b.isEmpty) True else b.reduce((a, b) => or(a, b)) + private def checkMem(m: ir.DefMemory): Unit = { + assert(m.readLatency == 0, "Expected read latency to be 0. Did you run VerilogMemDelays?") + assert(m.writeLatency == 1, "Expected read latency to be 1. Did you run VerilogMemDelays?") + assert( + m.dataType.isInstanceOf[ir.GroundType], + s"Memory $m is of type ${m.dataType} which is not a ground type!" + ) + assert(m.readwriters.isEmpty, "Combined read/write ports are not supported! Please split them up.") + } + + private val InfoSeparator = ", " + private val InfoPrefix = "@ " + private def serializeInfo(info: ir.Info): Option[String] = info match { + case ir.NoInfo => None + case f: ir.FileInfo => Some(f.escaped) + case m: ir.MultiInfo => + val infos = m.flatten + if (infos.isEmpty) { None } + else { Some(infos.map(_.escaped).mkString(InfoSeparator)) } + } } // performas a first pass over the module collecting all connections, wires, registers, input and outputs -private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLogging { +private class ModuleScanner( + uninterpreted: Map[String, UninterpretedModuleAnnotation]) + extends LazyLogging { import FirrtlExpressionSemantics.getWidth private[firrtl] val inputs = mutable.ArrayBuffer[BVSymbol]() @@ -473,10 +373,13 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog private[firrtl] val assumes = mutable.ArrayBuffer[String]() // maps identifiers to their info private[firrtl] val infos = mutable.ArrayBuffer[(String, ir.Info)]() - // keeps track of unused memory (data) outputs so that we can see where they are first used - private val unusedMemOutputs = mutable.LinkedHashMap[String, Int]() + // Keeps track of (so far) unused memory (data) and uninterpreted module outputs. + // This is used in order to delay declaring them for as long as possible. + private val unusedOutputs = mutable.LinkedHashMap[String, BVExpr]() // ensure unique names for assert/assume signals private[firrtl] val namespace = Namespace() + // keep track of all uninterpreted functions called + private[firrtl] val functionCalls = mutable.ArrayBuffer[BVFunctionSymbol]() private[firrtl] def onPort(p: ir.Port): Unit = { if (isAsyncReset(p.tpe)) { @@ -489,7 +392,7 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog if (isClock(p.tpe)) { clocks.add(p.name) } else { - inputs.append(BVSymbol(p.name, getWidth(p.tpe))) + inputs.append(BVSymbol(p.name, bitWidth(p.tpe).toInt)) } case ir.Output => if (!isClock(p.tpe)) { // we ignore clock outputs @@ -499,6 +402,12 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog } private[firrtl] def onStatement(s: ir.Statement): Unit = s match { + case DefRandom(info, name, tpe, _, _) => + namespace.newName(name) + assert(!isClock(tpe), "rand should never be a clock!") + // we model random sources as inputs and ignore the enable signal + infos.append(name -> info) + inputs.append(BVSymbol(name, bitWidth(tpe).toInt)) case ir.DefWire(info, name, tpe) => namespace.newName(name) if (!isClock(tpe)) { @@ -508,79 +417,50 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog case ir.DefNode(info, name, expr) => namespace.newName(name) if (!isClock(expr.tpe)) { - insertDummyAssignsForMemoryOutputs(expr) + insertDummyAssignsForUnusedOutputs(expr) infos.append(name -> info) - val e = onExpression(expr, name) + val e = onExpression(expr) nodes.append(name) connects.append((name, e)) } case ir.DefRegister(info, name, tpe, _, reset, init) => namespace.newName(name) - insertDummyAssignsForMemoryOutputs(reset) - insertDummyAssignsForMemoryOutputs(init) + insertDummyAssignsForUnusedOutputs(reset) + insertDummyAssignsForUnusedOutputs(init) infos.append(name -> info) - val width = getWidth(tpe) - val resetExpr = onExpression(reset, 1, name + "_reset") - val initExpr = onExpression(init, width, name + "_init") + val width = bitWidth(tpe).toInt + val resetExpr = onExpression(reset, 1) + val initExpr = onExpression(init, width) registers.append((name, width, resetExpr, initExpr)) case m: ir.DefMemory => namespace.newName(m.name) infos.append(m.name -> m.info) val outputs = getMemOutputs(m) (getMemInputs(m) ++ outputs).foreach(memSignals.append(_)) - val dataWidth = getWidth(m.dataType) - outputs.foreach(name => unusedMemOutputs(name) = dataWidth) + val dataWidth = bitWidth(m.dataType).toInt + outputs.foreach(name => unusedOutputs(name) = BVSymbol(name, dataWidth)) memories.append(m) case ir.Connect(info, loc, expr) => if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!") if (!isClock(loc.tpe)) { // we ignore clock connections val name = loc.serialize - insertDummyAssignsForMemoryOutputs(expr) + insertDummyAssignsForUnusedOutputs(expr) infos.append(name -> info) - connects.append((name, onExpression(expr, getWidth(loc.tpe), name))) + connects.append((name, onExpression(expr, bitWidth(loc.tpe).toInt))) } - case ir.IsInvalid(info, loc) => + case i @ ir.IsInvalid(info, loc) => if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!") - val name = loc.serialize - infos.append(name -> info) - connects.append((name, makeRandom(name + "_INVALID", getWidth(loc.tpe)))) - case ir.DefInstance(info, name, module, tpe) => - namespace.newName(name) - if (!tpe.isInstanceOf[ir.BundleType]) error(s"Instance $name of $module has an invalid type: ${tpe.serialize}") - // we treat all instances as blackboxes - logger.warn( - s"WARN: treating instance $name of $module as blackbox. " + - "Please flatten your hierarchy if you want to include submodules in the formal model." - ) - val ports = tpe.asInstanceOf[ir.BundleType].fields - // skip async reset ports - ports.filterNot(p => isAsyncReset(p.tpe)).foreach { p => - if (!p.tpe.isInstanceOf[ir.GroundType]) error(s"Instance $name of $module has an invalid port type: $p") - val isOutput = p.flip == ir.Default - val pName = name + "." + p.name - infos.append(pName -> info) - // outputs of the submodule become inputs to our module - if (isOutput) { - if (isClock(p.tpe)) { - clocks.add(pName) - } else { - inputs.append(BVSymbol(pName, getWidth(p.tpe))) - } - } else { - if (!isClock(p.tpe)) { // we ignore clock outputs - outputs.append(pName) - } - } - } + throw new UnsupportedFeatureException(s"IsInvalid statements are not supported: ${i.serialize}") + case ir.DefInstance(info, name, module, tpe) => onInstance(info, name, module, tpe) case s @ ir.Verification(op, info, _, pred, en, msg) => if (op == ir.Formal.Cover) { logger.warn(s"WARN: Cover statement was ignored: ${s.serialize}") } else { - insertDummyAssignsForMemoryOutputs(pred) - insertDummyAssignsForMemoryOutputs(en) + insertDummyAssignsForUnusedOutputs(pred) + insertDummyAssignsForUnusedOutputs(en) val name = namespace.newName(msgToName(op.toString, msg.string)) - val predicate = onExpression(pred, name + "_predicate") - val enabled = onExpression(en, name + "_enabled") + val predicate = onExpression(pred) + val enabled = onExpression(en) val e = BVImplies(enabled, predicate) infos.append(name -> info) connects.append(name -> e) @@ -604,6 +484,70 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog case other => other.foreachStmt(onStatement) } + private def onInstance(info: ir.Info, name: String, module: String, tpe: ir.Type): Unit = { + namespace.newName(name) + if (!tpe.isInstanceOf[ir.BundleType]) error(s"Instance $name of $module has an invalid type: ${tpe.serialize}") + if (uninterpreted.contains(module)) { + onUninterpretedInstance(info: ir.Info, name: String, module: String, tpe: ir.Type) + } else { + // We treat all instances that aren't annotated as uninterpreted as blackboxes + // this means that their outputs could be any value, no matter what their inputs are. + logger.warn( + s"WARN: treating instance $name of $module as blackbox. " + + "Please flatten your hierarchy if you want to include submodules in the formal model." + ) + val ports = tpe.asInstanceOf[ir.BundleType].fields + // skip async reset ports + ports.filterNot(p => isAsyncReset(p.tpe)).foreach { p => + if (!p.tpe.isInstanceOf[ir.GroundType]) error(s"Instance $name of $module has an invalid port type: $p") + val isOutput = p.flip == ir.Default + val pName = name + "." + p.name + infos.append(pName -> info) + // outputs of the submodule become inputs to our module + if (isOutput) { + if (isClock(p.tpe)) { + clocks.add(pName) + } else { + inputs.append(BVSymbol(pName, bitWidth(p.tpe).toInt)) + } + } else { + if (!isClock(p.tpe)) { // we ignore clock outputs + outputs.append(pName) + } + } + } + } + } + + private def onUninterpretedInstance(info: ir.Info, instanceName: String, module: String, tpe: ir.Type): Unit = { + val anno = uninterpreted(module) + + // sanity checks for ports were done already using the UninterpretedModule.checkModule function + val ports = tpe.asInstanceOf[ir.BundleType].fields + + val outputs = ports.filter(_.flip == ir.Default).map(p => BVSymbol(p.name, bitWidth(p.tpe).toInt)) + val inputs = ports.filterNot(_.flip == ir.Default).map(p => BVSymbol(p.name, bitWidth(p.tpe).toInt)) + + assert(anno.stateBits == 0, "TODO: implement support for uninterpreted stateful modules!") + + // for state-less (i.e. combinatorial) circuits, the outputs only depend on the inputs + val args = inputs.map(i => BVSymbol(instanceName + "." + i.name, i.width)).toList + outputs.foreach { out => + val functionName = anno.prefix + "." + out.name + val call = BVFunctionCall(functionName, args, out.width) + val wireName = instanceName + "." + out.name + // remember which functions were called + functionCalls.append(call.toSymbol) + // insert the output definition right before its first use in an attempt to get SSA + unusedOutputs(wireName) = call + // treat these outputs as wires + wires.append(wireName) + } + + // we also treat the arguments as wires + wires ++= args.map(_.name) + } + private val readInputFields = List("en", "addr") private val writeInputFields = List("en", "mask", "addr", "data") private def getMemInputs(m: ir.DefMemory): Iterable[String] = { @@ -617,39 +561,37 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog val p = m.name + "." m.readers.map(r => p + r + ".data") } - // inserts a dummy assign right before a memory output is used for the first time + // inserts a dummy assign right before a memory/uninterpreted module output is used for the first time // example: // m.r.data <= m.r.data ; this is the dummy assign // test <= m.r.data ; this is the first use of m.r.data - private def insertDummyAssignsForMemoryOutputs(next: ir.Expression): Unit = if (unusedMemOutputs.nonEmpty) { - implicit val uses = mutable.ArrayBuffer[String]() - findUnusedMemoryOutputUse(next) + private def insertDummyAssignsForUnusedOutputs(next: ir.Expression): Unit = if (unusedOutputs.nonEmpty) { + val uses = mutable.ArrayBuffer[String]() + findUnusedOutputUse(next)(uses) if (uses.nonEmpty) { val useSet = uses.toSet - unusedMemOutputs.foreach { - case (name, width) => - if (useSet.contains(name)) connects.append(name -> BVSymbol(name, width)) + unusedOutputs.foreach { + case (name, value) => + if (useSet.contains(name)) connects.append(name -> value) } - useSet.foreach(name => unusedMemOutputs.remove(name)) + useSet.foreach(name => unusedOutputs.remove(name)) } } - private def findUnusedMemoryOutputUse(e: ir.Expression)(implicit uses: mutable.ArrayBuffer[String]): Unit = e match { + private def findUnusedOutputUse(e: ir.Expression)(implicit uses: mutable.ArrayBuffer[String]): Unit = e match { case s: ir.SubField => val name = s.serialize - if (unusedMemOutputs.contains(name)) uses.append(name) - case other => other.foreachExpr(findUnusedMemoryOutputUse) + if (unusedOutputs.contains(name)) uses.append(name) + case other => other.foreachExpr(findUnusedOutputUse) } - private case class Context(baseName: String) extends TranslationContext { - override def getRandom(width: Int): BVExpr = makeRandom(baseName, width) - } + private case class Context() extends TranslationContext {} - private def onExpression(e: ir.Expression, width: Int, randomPrefix: String): BVExpr = { - implicit val ctx: TranslationContext = Context(randomPrefix) + private def onExpression(e: ir.Expression, width: Int): BVExpr = { + implicit val ctx: TranslationContext = Context() FirrtlExpressionSemantics.toSMT(e, width, allowNarrow = false) } - private def onExpression(e: ir.Expression, randomPrefix: String): BVExpr = { - implicit val ctx: TranslationContext = Context(randomPrefix) + private def onExpression(e: ir.Expression): BVExpr = { + implicit val ctx: TranslationContext = Context() FirrtlExpressionSemantics.toSMT(e) } diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala index 6369200625..0fc507e6bb 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala @@ -138,6 +138,17 @@ private case class BVIte(cond: BVExpr, tru: BVExpr, fals: BVExpr) extends BVExpr override def children: List[BVExpr] = List(cond, tru, fals) } +/** apply bv arguments to a function which returns a result of bit vector type */ +private case class BVFunctionCall(name: String, args: List[BVExpr], width: Int) extends BVExpr { + override def children = args + def toSymbol: BVFunctionSymbol = BVFunctionSymbol(name, args.map(_.width), width) + override def toString: String = args.mkString(name + "(", ", ", ")") +} + +private case class BVFunctionSymbol(name: String, argWidths: List[Int], width: Int) { + override def toString: String = s"$name : " + (argWidths :+ width).map(w => s"bv<$w>").mkString(" -> ") +} + private sealed trait ArrayExpr extends SMTExpr { val indexWidth: Int; val dataWidth: Int } private case class ArraySymbol(name: String, indexWidth: Int, dataWidth: Int) extends ArrayExpr with SMTSymbol { assert(!name.contains("|"), s"Invalid id $name contains escape character `|`") diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala index 19f1de8483..13ed8bdd61 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala @@ -54,6 +54,11 @@ private object SMTExprVisitor { case old @ BVIte(a, b, c) => val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar)) bv(if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else BVIte(nA, nB, nC)) + // n-ary + case old @ BVFunctionCall(name, args, width) => + val nArgs = args.map(a => map(a, bv, ar)) + val noneNew = nArgs.zip(args).forall { case (n, o) => n.eq(o) } + bv(if (noneNew) old else BVFunctionCall(name, nArgs, width)) } private def map(e: ArrayExpr, bv: BVFun, ar: ArrayFun): ArrayExpr = e match { diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala index 7bc0a077e9..75bde09c65 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala @@ -24,6 +24,11 @@ private object SMTLibSerializer { case a: ArrayExpr => serializeArrayType(a.indexWidth, a.dataWidth) } + def declareFunction(foo: BVFunctionSymbol): SMTCommand = { + val args = foo.argWidths.map(serializeBitVectorType) + DeclareFunction(BVSymbol(foo.name, foo.width), args) + } + private def serialize(e: BVExpr): String = e match { case BVLiteral(value, width) => val mask = (BigInt(1) << width) - 1 @@ -74,6 +79,7 @@ private object SMTLibSerializer { case BVConcat(a, b) => s"(concat ${asBitVector(a)} ${asBitVector(b)})" case ArrayRead(array, index) => s"(select ${serialize(array)} ${asBitVector(index)})" case BVIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})" + case BVFunctionCall(name, args, _) => args.map(serialize).mkString(s"($name ", " ", ")") case BVRawExpr(serialized, _) => serialized } diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala index f6d9a26f32..d35fe139d8 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala @@ -20,6 +20,9 @@ private object SMTTransitionSystemEncoder { // emit header as comments cmds ++= sys.header.map(Comment) + // declare uninterpreted functions used in model + cmds ++= sys.ufs.map(SMTLibSerializer.declareFunction) + // declare state type val stateType = id(name + "_s") cmds += DeclareUninterpretedSort(stateType) diff --git a/src/main/scala/firrtl/backends/experimental/smt/UninterpretedModuleAnnotation.scala b/src/main/scala/firrtl/backends/experimental/smt/UninterpretedModuleAnnotation.scala new file mode 100644 index 0000000000..c7442f6941 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/UninterpretedModuleAnnotation.scala @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: Apache-2.0 +// Author: Kevin Laeufer + +package firrtl.backends.experimental.smt + +import firrtl.annotations._ +import firrtl.ir +import firrtl.passes.PassException + +/** ExtModules annotated as UninterpretedModule will be modelled as + * UninterpretedFunction (SMTLib) or constant arrays (btor2). + * This can be useful when trying to abstract over a function that the + * SMT solver or model checker is struggling with. + * + * E.g., one could declare an abstract 64bit multiplier like this: + * ``` + * extmodule Mul64 : + * input a : UInt<64> + * input b : UInt<64> + * output r : UInt<64> + * ``` + * Now instead of using Chisel to actually implement a multiplication circuit + * we can instantiate this Mul64 module twice: Once in our implementation + * and once for our correctness property that might specify how the + * multiply instruction is supposed to be executed on our CPU. + * Now instead of having to prove equivalence of multiplication circuits, the + * solver only has to make sure that the connections to the multiplier are correct, + * since if `a` and `b` are the same on both instances of `Mul64`, then the `r` output + * will also be the same. This is a much easier problem and will result in much faster + * solving due to manual abstraction. + * + * When [[stateBits]] is 0, we model the module as purely combinatorial circuit and + * thus expect there to be no clock wire going into the module. + * Every output is thus a function of all inputs of the module. + * + * When [[stateBits]] is an N greater than zero, we will model the module as having an abstract state of width N. + * Thus on every clock transition the abstract state is updated and all outputs will take the state + * as well as the current inputs as arguments. + * TODO: Support for stateful circuits is work in progress. + * + * All output functions well be prefixed with [[prefix]] and end in the name of the output pin. + * It is the users responsibility to ensure that all function names will be unique by choosing apropriate + * prefixes. + * + * The annotation is consumed by the [[FirrtlToTransitionSystem]] pass. + */ +case class UninterpretedModuleAnnotation(target: ModuleTarget, prefix: String, stateBits: Int = 0) + extends SingleTargetAnnotation[ModuleTarget] { + require(stateBits >= 0, "negative number of bits is forbidden") + if (stateBits > 0) throw new NotImplementedError("TODO: support for stateful circuits is not implemented yet!") + override def duplicate(n: ModuleTarget) = copy(n) +} + +object UninterpretedModuleAnnotation { + + /** checks to see whether the annotation module can actually be abstracted. Use *after* LowerTypes! */ + def checkModule(m: ir.DefModule, anno: UninterpretedModuleAnnotation): Unit = m match { + case _: ir.Module => + throw new UninterpretedModuleException(s"UninterpretedModuleAnnotation can only be used with extmodule! $anno") + case m: ir.ExtModule => + val clockInputs = m.ports.collect { case p @ ir.Port(_, _, ir.Input, ir.ClockType) => p.name } + val clockOutput = m.ports.collect { case p @ ir.Port(_, _, ir.Output, ir.ClockType) => p.name } + val asyncResets = m.ports.collect { case p @ ir.Port(_, _, _, ir.AsyncResetType) => p.name } + if (clockOutput.nonEmpty) { + throw new UninterpretedModuleException( + s"We do not support clock outputs for uninterpreted modules! $clockOutput" + ) + } + if (asyncResets.nonEmpty) { + throw new UninterpretedModuleException( + s"We do not support async reset I/O for uninterpreted modules! $asyncResets" + ) + } + if (anno.stateBits == 0) { + if (clockInputs.nonEmpty) { + throw new UninterpretedModuleException(s"A combinatorial module may not have any clock inputs! $clockInputs") + } + } else { + if (clockInputs.size != 1) { + throw new UninterpretedModuleException(s"A stateful module must have exactly one clock input! $clockInputs") + } + } + } +} + +private class UninterpretedModuleException(s: String) extends PassException(s) diff --git a/src/main/scala/firrtl/backends/experimental/smt/random/DefRandom.scala b/src/main/scala/firrtl/backends/experimental/smt/random/DefRandom.scala new file mode 100644 index 0000000000..7381056e5a --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/random/DefRandom.scala @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.backends.experimental.smt.random + +import firrtl.Utils +import firrtl.ir._ + +/** Named source of random values. If there is no clock expression, than it will be clocked by the global clock. */ +case class DefRandom( + info: Info, + name: String, + tpe: Type, + clock: Option[Expression], + en: Expression = Utils.True()) + extends Statement + with HasInfo + with IsDeclaration + with CanBeReferenced + with UseSerializer { + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = + DefRandom(info, name, tpe, clock.map(f), f(en)) + def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = () + def foreachExpr(f: Expression => Unit): Unit = { clock.foreach(f); f(en) } + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) +} diff --git a/src/main/scala/firrtl/backends/experimental/smt/random/InvalidToRandomPass.scala b/src/main/scala/firrtl/backends/experimental/smt/random/InvalidToRandomPass.scala new file mode 100644 index 0000000000..c7eaad74b0 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/random/InvalidToRandomPass.scala @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.backends.experimental.smt.random + +import firrtl._ +import firrtl.annotations.NoTargetAnnotation +import firrtl.ir._ +import firrtl.passes._ +import firrtl.options.Dependency +import firrtl.stage.Forms +import firrtl.transforms.RemoveWires + +import scala.collection.mutable + +/** Chooses how to model explicit and implicit invalid values in the circuit */ +case class InvalidToRandomOptions( + randomizeInvalidSignals: Boolean = true, + randomizeDivisionByZero: Boolean = true) + extends NoTargetAnnotation + +/** Replaces all explicit and implicit "invalid" values with random values. + * Explicit invalids are: + * - signal is invalid + * - signal <= valid(..., expr) + * Implicit invalids are: + * - a / b when eq(b, 0) + */ +object InvalidToRandomPass extends Transform with DependencyAPIMigration { + override def prerequisites = Forms.LowForm + // once ValidIf has been removed, we can no longer detect and randomize them + override def optionalPrerequisiteOf = Seq(Dependency(RemoveValidIf)) + override def invalidates(a: Transform) = a match { + // this pass might destroy SSA form, as we add a wire for the data field of every read port + case _: RemoveWires => true + // TODO: should we add some optimization passes here? we could be generating some dead code. + case _ => false + } + + override protected def execute(state: CircuitState): CircuitState = { + val opts = state.annotations.collect { case o: InvalidToRandomOptions => o } + require(opts.size < 2, s"Multiple options: $opts") + val opt = opts.headOption.getOrElse(InvalidToRandomOptions()) + + // quick exit if we just want to skip this pass + if (!opt.randomizeDivisionByZero && !opt.randomizeInvalidSignals) { + state + } else { + val c = state.circuit.mapModule(onModule(_, opt)) + state.copy(circuit = c) + } + } + + private def onModule(m: DefModule, opt: InvalidToRandomOptions): DefModule = m match { + case d: DescribedMod => + throw new RuntimeException(s"CompilerError: Unexpected internal node: ${d.serialize}") + case e: ExtModule => e + case mod: Module => + val namespace = Namespace(mod) + mod.mapStmt(onStmt(namespace, opt, _)) + } + + private def onStmt(namespace: Namespace, opt: InvalidToRandomOptions, s: Statement): Statement = s match { + case IsInvalid(info, loc: RefLikeExpression) if opt.randomizeInvalidSignals => + val name = namespace.newName(loc.serialize.replace('.', '_') + "_invalid") + val rand = DefRandom(info, name, loc.tpe, None) + Block(List(rand, Connect(info, loc, Reference(rand)))) + case other => + val info = other match { + case h: HasInfo => h.info + case _ => NoInfo + } + val prefix = other match { + case c: Connect => c.loc.serialize.replace('.', '_') + case h: HasName => h.name + case _ => "" + } + val ctx = ExprCtx(namespace, opt, prefix, info, mutable.ListBuffer[Statement]()) + val stmt = other.mapExpr(onExpr(ctx, _)).mapStmt(onStmt(namespace, opt, _)) + if (ctx.rands.isEmpty) { stmt } + else { Block(Block(ctx.rands.toList), stmt) } + } + + private case class ExprCtx( + namespace: Namespace, + opt: InvalidToRandomOptions, + prefix: String, + info: Info, + rands: mutable.ListBuffer[Statement]) + + private def onExpr(ctx: ExprCtx, e: Expression): Expression = + e.mapExpr(onExpr(ctx, _)) match { + case ValidIf(_, value, tpe) if tpe == ClockType => + // we currently assume that clocks are always valid + // TODO: is that a good assumption? + value + case ValidIf(cond, value, tpe) if ctx.opt.randomizeInvalidSignals => + makeRand(ctx, cond, tpe, value, invert = true) + case d @ DoPrim(PrimOps.Div, Seq(_, den), _, tpe) if ctx.opt.randomizeDivisionByZero => + val denIsZero = Utils.eq(den, Utils.getGroundZero(den.tpe.asInstanceOf[GroundType])) + makeRand(ctx, denIsZero, tpe, d, invert = false) + case other => other + } + + private def makeRand( + ctx: ExprCtx, + cond: Expression, + tpe: Type, + value: Expression, + invert: Boolean + ): Expression = { + val name = ctx.namespace.newName(if (ctx.prefix.isEmpty) "invalid" else ctx.prefix + "_invalid") + // create a condition node if the condition isn't a reference already + val condRef = cond match { + case r: RefLikeExpression => if (invert) Utils.not(r) else r + case other => + val cond = if (invert) Utils.not(other) else other + val condNode = DefNode(ctx.info, ctx.namespace.newName(name + "_cond"), cond) + ctx.rands.append(condNode) + Reference(condNode) + } + val rand = DefRandom(ctx.info, name, tpe, None, condRef) + ctx.rands.append(rand) + Utils.mux(condRef, Reference(rand), value) + } +} diff --git a/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala b/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala new file mode 100644 index 0000000000..5fd0e68091 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala @@ -0,0 +1,461 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.backends.experimental.smt.random + +import firrtl.Utils.{isLiteral, BoolType} +import firrtl._ +import firrtl.annotations.NoTargetAnnotation +import firrtl.backends.experimental.smt._ +import firrtl.ir._ +import firrtl.options.Dependency +import firrtl.passes.MemPortUtils.memPortField +import firrtl.passes.memlib.AnalysisUtils.Connects +import firrtl.passes.memlib.InferReadWritePass.checkComplement +import firrtl.passes.memlib.{AnalysisUtils, InferReadWritePass, VerilogMemDelays} +import firrtl.stage.Forms +import firrtl.transforms.RemoveWires + +import scala.collection.mutable + +/** Chooses which undefined memory behaviors should be instrumented. */ +case class UndefinedMemoryBehaviorOptions( + randomizeWriteWriteConflicts: Boolean = true, + assertNoOutOfBoundsWrites: Boolean = false, + randomizeOutOfBoundsRead: Boolean = true, + randomizeDisabledReads: Boolean = true, + randomizeReadWriteConflicts: Boolean = true) + extends NoTargetAnnotation + +/** Adds sources of randomness to model the various "undefined behaviors" of firrtl memory. + * - Write/Write conflict: leads to arbitrary value written to write address + * - Out-of-bounds write: assertion failure (disabled by default) + * - Out-Of-bounds read: leads to arbitrary value being read + * - Read w/ en=0: leads to arbitrary value being read + * - Read/Write conflict: leads to arbitrary value being read + */ +object UndefinedMemoryBehaviorPass extends Transform with DependencyAPIMigration { + override def prerequisites = Forms.LowForm + override def optionalPrerequisiteOf = Seq(Dependency(VerilogMemDelays)) + override def invalidates(a: Transform) = a match { + // this pass might destroy SSA form, as we add a wire for the data field of every read port + case _: RemoveWires => true + // TODO: should we add some optimization passes here? we could be generating some dead code. + case _ => false + } + + override protected def execute(state: CircuitState): CircuitState = { + val opts = state.annotations.collect { case o: UndefinedMemoryBehaviorOptions => o } + require(opts.size < 2, s"Multiple options: $opts") + val opt = opts.headOption.getOrElse(UndefinedMemoryBehaviorOptions()) + + val c = state.circuit.mapModule(onModule(_, opt)) + state.copy(circuit = c) + } + + private def onModule(m: DefModule, opt: UndefinedMemoryBehaviorOptions): DefModule = m match { + case mod: Module => + val mems = findMems(mod) + if (mems.isEmpty) { mod } + else { + val namespace = Namespace(mod) + val connects = AnalysisUtils.getConnects(mod) + new InstrumentMems(opt, mems, connects, namespace).run(mod) + } + case other => other + } + + /** finds all memory instantiations in a circuit */ + private def findMems(m: Module): List[DefMemory] = { + val mems = mutable.ListBuffer[DefMemory]() + m.foreachStmt(findMems(_, mems)) + mems.toList + } + private def findMems(s: Statement, mems: mutable.ListBuffer[DefMemory]): Unit = s match { + case mem: DefMemory => mems.append(mem) + case other => other.foreachStmt(findMems(_, mems)) + } +} + +private class InstrumentMems( + opt: UndefinedMemoryBehaviorOptions, + mems: List[DefMemory], + connects: Connects, + namespace: Namespace) { + def run(m: Module): DefModule = { + // ensure that all memories are the kind we can support + mems.foreach(checkSupported(m.name, _)) + + // transform circuit + val body = m.body.mapStmt(transform) + m.copy(body = Block(body +: newStmts.toList)) + } + + // used to replace memory signals like `m.r.data` in RHS expressions + private val exprReplacements = mutable.HashMap[String, Expression]() + // add new statements at the end of the circuit + private val newStmts = mutable.ListBuffer[Statement]() + // disconnect references so that they can be reassigned + private val doDisconnect = mutable.HashSet[String]() + + // generates new expression replacements and immediately uses them + private def transform(s: Statement): Statement = s.mapStmt(transform) match { + case mem: DefMemory => onMem(mem) + case sx: Connect if doDisconnect.contains(sx.loc.serialize) => EmptyStmt // Filter old mem connections + case sx => sx.mapExpr(swapMemRefs) + } + private def swapMemRefs(e: Expression): Expression = e.mapExpr(swapMemRefs) match { + case sf: RefLikeExpression => exprReplacements.getOrElse(sf.serialize, sf) + case ex => ex + } + + private def onMem(m: DefMemory): Statement = { + // collect wire and random statement defines + implicit val declarations: mutable.ListBuffer[Statement] = mutable.ListBuffer[Statement]() + + // cache for the expressions of memory inputs + implicit val cache: mutable.HashMap[String, Expression] = mutable.HashMap[String, Expression]() + + // only for non power of 2 memories do we have to worry about reading or writing out of bounds + val canBeOutOfBounds = !isPow2(m.depth) + + // only if we have at least two write ports, can there be conflicts + val canHaveWriteWriteConflicts = m.writers.size > 1 + + // only certain memory types exhibit undefined read/write conflicts + val readWriteUndefined = (m.readLatency == m.writeLatency) && (m.readUnderWrite == ReadUnderWrite.Undefined) + assert( + m.readLatency == 0 || m.readLatency == m.writeLatency, + "TODO: what happens if a sync read mem has asymmetrical latencies?" + ) + + // a write port is enabled iff mask & en + val writeEn = m.writers.map { write => + val enRef = memPortField(m, write, "en") + val maskRef = memPortField(m, write, "mask") + + val prods = getProductTerms(enRef) ++ getProductTerms(maskRef) + val expr = Utils.and(readInput(m.info, enRef), readInput(m.info, maskRef)) + + (expr, prods) + } + + // implement the three undefined read behaviors + m.readers.foreach { read => + // many memories have their read enable hard wired to true + val canBeDisabled = !isTrue(readInput(m, read, "en")) + val readEn = if (canBeDisabled) readInput(m, read, "en") else Utils.True() + + // collect signals that would lead to a randomization + var doRand = List[Expression]() + + // randomize the read value when the address is out of bounds + if (canBeOutOfBounds && opt.randomizeOutOfBoundsRead) { + val addr = readInput(m, read, "addr") + val cond = Utils.and(readEn, Utils.not(isInBounds(m.depth, addr))) + val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_oob"), cond) + declarations += node + doRand = Reference(node) +: doRand + } + + if (readWriteUndefined && opt.randomizeReadWriteConflicts) { + val cond = readWriteConflict(m, read, writeEn) + val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_rwc"), cond) + declarations += node + doRand = Reference(node) +: doRand + } + + // randomize the read value when the read is disabled + if (canBeDisabled && opt.randomizeDisabledReads) { + val cond = Utils.not(readEn) + val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_disabled"), cond) + declarations += node + doRand = Reference(node) +: doRand + } + + // if there are no signals that would require a randomization, there is nothing to do + if (doRand.isEmpty) { + // nothing to do + } else { + val doRandName = s"${m.name}_${read}_do_rand" + val doRandNode = if (doRand.size == 1) { doRand.head } + else { + val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_do_rand"), doRand.reduce(Utils.or)) + declarations += node + Reference(node) + } + val doRandSignal = if (m.readLatency == 0) { doRandNode } + else { + val clock = readInput(m, read, "clk") + val (signal, regDecls) = pipeline(m.info, clock, doRandName, doRandNode, m.readLatency) + declarations ++= regDecls + signal + } + + // all old rhs references to m.r.data need to replace with m_r_data which might be random + val dataRef = memPortField(m, read, "data") + val dataWire = DefWire(m.info, namespace.newName(s"${m.name}_${read}_data"), m.dataType) + declarations += dataWire + exprReplacements(dataRef.serialize) = Reference(dataWire) + + // create a source of randomness and connect the new wire either to the actual data port or to the random value + val randName = namespace.newName(s"${m.name}_${read}_rand_data") + val random = DefRandom(m.info, randName, m.dataType, Some(readInput(m, read, "clk")), doRandSignal) + declarations += random + val data = Utils.mux(doRandSignal, Reference(random), dataRef) + newStmts.append(Connect(m.info, Reference(dataWire), data)) + } + } + + // write + if (opt.randomizeWriteWriteConflicts) { + writeWriteConflicts(m, writeEn) + } + + // add an assertion that if the write is taking place, then the address must be in range + if (canBeOutOfBounds && opt.assertNoOutOfBoundsWrites) { + m.writers.zip(writeEn).foreach { + case (write, (combinedEn, _)) => + val addr = readInput(m, write, "addr") + val cond = Utils.implies(combinedEn, isInBounds(m.depth, addr)) + val clk = readInput(m, write, "clk") + val a = Verification(Formal.Assert, m.info, clk, cond, Utils.True(), StringLit("out of bounds read")) + newStmts.append(a) + } + } + + Block(m +: declarations.toList) + } + + private def pipeline( + info: Info, + clk: Expression, + prefix: String, + e: Expression, + latency: Int + ): (Expression, Seq[Statement]) = { + require(latency > 0) + val regs = (1 to latency).map { i => + val name = namespace.newName(prefix + s"_r$i") + DefRegister(info, name, e.tpe, clk, Utils.False(), Reference(name, e.tpe, RegKind, UnknownFlow)) + } + val expr = regs.foldLeft(e) { + case (prev, reg) => + newStmts.append(Connect(info, Reference(reg), prev)) + Reference(reg) + } + (expr, regs) + } + + private def readWriteConflict( + m: DefMemory, + read: String, + writeEn: Seq[(Expression, ProdTerms)] + )( + implicit cache: mutable.HashMap[String, Expression], + declarations: mutable.ListBuffer[Statement] + ): Expression = { + if (m.writers.isEmpty) return Utils.False() + + val readEn = readInput(m, read, "en") + val readProd = getProductTerms(readEn) + + // create all conflict signals + val conflicts = m.writers.zip(writeEn).map { + case (write, (writeEn, writeProd)) => + if (isMutuallyExclusive(readProd, writeProd)) { + Utils.False() + } else { + val name = namespace.newName(s"${m.name}_${read}_${write}_rwc") + val bothEn = Utils.and(readEn, writeEn) + val sameAddr = Utils.eq(readInput(m, read, "addr"), readInput(m, write, "addr")) + // we need a wire because this condition might be used in a random statement + val wire = DefWire(m.info, name, BoolType) + declarations += wire + newStmts.append(Connect(m.info, Reference(wire), Utils.and(bothEn, sameAddr))) + Reference(wire) + } + } + + conflicts.reduce(Utils.or) + } + + private type ProdTerms = Seq[Expression] + private def writeWriteConflicts( + m: DefMemory, + writeEn: Seq[(Expression, ProdTerms)] + )( + implicit cache: mutable.HashMap[String, Expression], + declarations: mutable.ListBuffer[Statement] + ): Unit = { + if (m.writers.size < 2) return + + // we first create all conflict signals: + val conflict = + m.writers + .zip(writeEn) + .zipWithIndex + .flatMap { + case ((w1, (en1, en1Prod)), i1) => + m.writers.zip(writeEn).drop(i1 + 1).map { + case (w2, (en2, en2Prod)) => + if (isMutuallyExclusive(en1Prod, en2Prod)) { + (w1, w2) -> Utils.False() + } else { + val name = namespace.newName(s"${m.name}_${w1}_${w2}_wwc") + val bothEn = Utils.and(en1, en2) + val sameAddr = Utils.eq(readInput(m, w1, "addr"), readInput(m, w2, "addr")) + // we need a wire because this condition might be used in a random statement + val wire = DefWire(m.info, name, BoolType) + declarations += wire + newStmts.append(Connect(m.info, Reference(wire), Utils.and(bothEn, sameAddr))) + (w1, w2) -> Reference(wire) + } + } + } + .toMap + + // now we calculate the new enable and data signals + m.writers.zip(writeEn).zipWithIndex.foreach { + case ((w1, (en1, _)), i1) => + val prev = m.writers.take(i1) + val next = m.writers.drop(i1 + 1) + + // the write is enabled if the original enable is true and there are no prior conflicts + val en = if (prev.isEmpty) { + en1 + } else { + val prevConflicts = prev.map(o => conflict(o, w1)).reduce(Utils.or) + Utils.and(en1, Utils.not(prevConflicts)) + } + + // we write random data if there is a conflict with any of the next ports + if (next.isEmpty) { + // nothing to do, leave data as is + } else { + val nextConflicts = next.map(n => conflict(w1, n)).reduce(Utils.or) + // if the conflict expression is more complex, create a node for the signal + val hasConflict = nextConflicts match { + case _: DoPrim | _: Mux => + val node = DefNode(m.info, namespace.newName(s"${m.name}_${w1}_wwc_active"), nextConflicts) + declarations += node + Reference(node) + case _ => nextConflicts + } + + // create the source of randomness + val name = namespace.newName(s"${m.name}_${w1}_wwc_data") + val random = DefRandom(m.info, name, m.dataType, Some(readInput(m, w1, "clk")), hasConflict) + declarations.append(random) + + // generate new data input + val data = Utils.mux(hasConflict, Reference(random), readInput(m, w1, "data")) + newStmts.append(Connect(m.info, memPortField(m, w1, "data"), data)) + doDisconnect.add(memPortField(m, w1, "data").serialize) + } + + // connect data enable signals + val maskIsOne = isTrue(readInput(m, w1, "mask")) + if (!maskIsOne) { + newStmts.append(Connect(m.info, memPortField(m, w1, "mask"), Utils.True())) + doDisconnect.add(memPortField(m, w1, "mask").serialize) + } + newStmts.append(Connect(m.info, memPortField(m, w1, "en"), en)) + doDisconnect.add(memPortField(m, w1, "en").serialize) + } + } + + /** check whether two signals can be proven to be mutually exclusive */ + private def isMutuallyExclusive(prodA: ProdTerms, prodB: ProdTerms): Boolean = { + // this uses the same approach as the InferReadWrite pass + val proofOfMutualExclusion = prodA.find(a => prodB.exists(b => checkComplement(a, b))) + proofOfMutualExclusion.nonEmpty + } + + /** memory inputs my not be read, only assigned to, thus we might need to add a wire to make them accessible */ + private def readInput( + info: Info, + signal: RefLikeExpression + )( + implicit cache: mutable.HashMap[String, Expression], + declarations: mutable.ListBuffer[Statement] + ): Expression = + cache.getOrElseUpdate( + signal.serialize, { + // if it is a literal, we just return it + val value = connects(signal.serialize) + if (isLiteral(value)) { + value + } else { + // otherwise we make a wire that refelect the value + val wire = DefWire(info, copyName(signal), signal.tpe) + declarations += wire + + // connect the old expression to the new wire + val con = Connect(info, Reference(wire), value) + newStmts.append(con) + + // use a reference to this new wire + Reference(wire) + } + } + ) + private def readInput( + m: DefMemory, + port: String, + field: String + )( + implicit cache: mutable.HashMap[String, Expression], + declarations: mutable.ListBuffer[Statement] + ): Expression = + readInput(m.info, memPortField(m, port, field)) + + private def copyName(ref: RefLikeExpression): String = + namespace.newName(ref.serialize.replace('.', '_')) + + private def isInBounds(depth: BigInt, addr: Expression): Expression = { + val width = getWidth(addr) + // depth >= addr + DoPrim(PrimOps.Geq, List(UIntLiteral(depth, width), addr), List(), BoolType) + } + + private def isPow2(v: BigInt): Boolean = ((v - 1) & v) == 0 + + private def checkSupported(modName: String, m: DefMemory): Unit = { + assert(m.readwriters.isEmpty, s"[$modName] Combined read/write ports are currently not supported!") + if (m.writeLatency != 1) { + throw new UnsupportedFeatureException(s"[$modName] memories with write latency > 1 (${m.name})") + } + if (m.readLatency > 1) { + throw new UnsupportedFeatureException(s"[$modName] memories with read latency > 1 (${m.name})") + } + } + + private def getProductTerms(e: Expression): ProdTerms = + InferReadWritePass.getProductTerms(connects)(e) + + /** tries to expand the expression based on the connects we collected */ + private def expandExpr(e: Expression, fuel: Int): Expression = { + e match { + case m @ Mux(cond, tval, fval, _) => + m.copy(cond = expandExpr(cond, fuel), tval = expandExpr(tval, fuel), fval = expandExpr(fval, fuel)) + case p @ DoPrim(_, args, _, _) => + p.copy(args = args.map(expandExpr(_, fuel))) + case r: RefLikeExpression => + if (fuel > 0) { + connects.get(r.serialize) match { + case None => r + case Some(expr) => expandExpr(expr, fuel - 1) + } + } else { + r + } + case other => other + } + } + + private def isTrue(e: Expression): Boolean = simplifyExpr(expandExpr(e, fuel = 2)) == Utils.True() + + private def simplifyExpr(e: Expression): Expression = { + e // TODO: better simplification could improve the resulting circuit size + } +} diff --git a/src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala b/src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala index 26e0363319..80aea9963d 100644 --- a/src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala +++ b/src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala @@ -20,8 +20,7 @@ sealed abstract class FirrtlEmitter(form: CircuitForm) extends Transform with Em // Use list instead of set to maintain order val modules = mutable.ArrayBuffer.empty[DefModule] def onStmt(stmt: Statement): Unit = stmt match { - case DefInstance(_, _, name, _) => modules += map(name) - case WDefInstance(_, _, name, _) => modules += map(name) + case DefInstance(_, _, name, _) => modules += map(name) case _: WDefInstanceConnector => throwInternalError(s"unrecognized statement: $stmt") case other => other.foreach(onStmt) } @@ -61,7 +60,7 @@ sealed abstract class FirrtlEmitter(form: CircuitForm) extends Transform with Em def emit(state: CircuitState, writer: Writer): Unit = writer.write(state.circuit.serialize) } -class ChirrtlEmitter extends FirrtlEmitter(ChirrtlForm) -class HighFirrtlEmitter extends FirrtlEmitter(HighForm) -class MiddleFirrtlEmitter extends FirrtlEmitter(MidForm) -class LowFirrtlEmitter extends FirrtlEmitter(LowForm) +class ChirrtlEmitter extends FirrtlEmitter(CircuitForm.ChirrtlForm) +class HighFirrtlEmitter extends FirrtlEmitter(CircuitForm.HighForm) +class MiddleFirrtlEmitter extends FirrtlEmitter(CircuitForm.MidForm) +class LowFirrtlEmitter extends FirrtlEmitter(CircuitForm.LowForm) diff --git a/src/main/scala/firrtl/backends/verilog/VerilogEmitter.scala b/src/main/scala/firrtl/backends/verilog/VerilogEmitter.scala index 3ecd127974..13f7b7933f 100644 --- a/src/main/scala/firrtl/backends/verilog/VerilogEmitter.scala +++ b/src/main/scala/firrtl/backends/verilog/VerilogEmitter.scala @@ -7,7 +7,14 @@ import firrtl.PrimOps._ import firrtl.Utils._ import firrtl.WrappedExpression._ import firrtl.traversals.Foreachers._ -import firrtl.annotations.{CircuitTarget, ReferenceTarget, SingleTargetAnnotation} +import firrtl.annotations.{ + CircuitTarget, + MemoryLoadFileType, + MemoryNoSynthInit, + MemorySynthInit, + ReferenceTarget, + SingleTargetAnnotation +} import firrtl.passes.LowerTypes import firrtl.passes.MemPortUtils._ import firrtl.stage.TransformManager @@ -57,6 +64,13 @@ object VerilogEmitter { private def precedenceGt(op1: PrimOp, op2: PrimOp): Boolean = { precedenceMap(op1) < precedenceMap(op2) } + + /** Identifies PrimOps that never need parentheses + * + * These PrimOps emit either {..., a0, ...} or a0 so they never need parentheses + */ + private val neverParens: PrimOp => Boolean = + Set(Shl, Cat, Cvt, AsUInt, AsSInt, AsClock, AsAsyncReset, Pad) } class VerilogEmitter extends SeqTransform with Emitter { @@ -229,8 +243,7 @@ class VerilogEmitter extends SeqTransform with Emitter { // to ensure Verilog operations are signed. def op_stream(doprim: DoPrim): Seq[Any] = { def parenthesize(e: Expression, isFirst: Boolean): Any = doprim.op match { - // these PrimOps emit either {..., a0, ...} or a0 so they never need parentheses - case Shl | Cat | Cvt | AsUInt | AsSInt | AsClock | AsAsyncReset => e + case op if neverParens(op) => e case _ => e match { case e: DoPrim => @@ -247,7 +260,8 @@ class VerilogEmitter extends SeqTransform with Emitter { */ case other => val noParens = - precedenceGt(e.op, doprim.op) || + neverParens(e.op) || + precedenceGt(e.op, doprim.op) || (isFirst && precedenceEq(e.op, doprim.op) && !isUnaryOp(e.op)) if (noParens) other else Seq("(", other, ")") } @@ -475,6 +489,21 @@ class VerilogEmitter extends SeqTransform with Emitter { def getConnectEmissionOption(target: ReferenceTarget): ConnectEmissionOption = connectEmissionOption(target) + // Defines the memory initialization based on the annotation + // Defaults to having the memories inside the `ifndef SYNTHESIS` block + def emitMemoryInitAsNoSynth: Boolean = { + val annos = annotations.collect { case a @ (MemoryNoSynthInit | MemorySynthInit) => a } + annos match { + case Seq() => true + case Seq(MemoryNoSynthInit) => true + case Seq(MemorySynthInit) => false + case other => + throw new FirrtlUserException( + "There should only be at most one memory initialization option annotation, got $other" + ) + } + } + private val emissionAnnos = annotations.collect { case m: SingleTargetAnnotation[ReferenceTarget] @unchecked with EmissionOption => m } @@ -756,7 +785,7 @@ class VerilogEmitter extends SeqTransform with Emitter { val lines = noResetAlwaysBlocks.getOrElseUpdate(clk, ArrayBuffer[Seq[Any]]()) if (weq(en, one)) lines += Seq(e, " <= ", value, ";") else { - lines += Seq("if(", en, ") begin") + lines += Seq("if (", en, ") begin") lines += Seq(tab, e, " <= ", value, ";", info) lines += Seq("end") } @@ -849,6 +878,19 @@ class VerilogEmitter extends SeqTransform with Emitter { rstring, ";" ) + case MemoryFileInlineInit(filename, hexOrBinary) => + val readmem = hexOrBinary match { + case MemoryLoadFileType.Binary => "$readmemb" + case MemoryLoadFileType.Hex => "$readmemh" + } + if (emissionOptions.emitMemoryInitAsNoSynth) { + memoryInitials += Seq(s"""$readmem("$filename", ${s.name});""") + } else { + val inlineLoad = s"""initial begin + | $readmem("$filename", ${s.name}); + | end""".stripMargin + memoryInitials += Seq(inlineLoad) + } } } @@ -1023,32 +1065,56 @@ class VerilogEmitter extends SeqTransform with Emitter { val decl = if (fullSize > (1 << 29)) "reg /* sparse */" else "reg" declareVectorType(decl, sx.name, sx.dataType, sx.depth, sx.info) initialize_mem(sx, options) - if (sx.readLatency != 0 || sx.writeLatency != 1) + // Currently, no idiomatic way to directly emit write-first RW ports + val hasComplexRW = (sx.readwriters.nonEmpty && + (sx.readLatency != 1 || sx.readUnderWrite == ReadUnderWrite.New)) + if (sx.readLatency > 1 || sx.writeLatency != 1 || hasComplexRW) throw EmitterException( - "All memories should be transformed into " + - "blackboxes or combinational by previous passses" + Seq( + s"Memory ${sx.name} is too complex to emit directly.", + "Consider running VerilogMemDelays to simplify complex memories.", + "Alternatively, add the --repl-seq-mem flag to replace memories with blackboxes." + ).mkString(" ") ) + def createMemWire(firrtlRef: Expression, rhs: InfoExpr): Unit = { + // Don't use declaration-assignment, since this assignment might be emitted earlier than the + // actual connection to the memory port field in the source FIRRTL + declare("wire", LowerTypes.loweredName(firrtlRef), firrtlRef.tpe, MultiInfo(sx.info, rhs.info)) + assign(firrtlRef, rhs) + } + for (r <- sx.readers) { val data = memPortField(sx, r, "data") val addr = memPortField(sx, r, "addr") - // Ports should share an always@posedge, so can't have intermediary wire - - declare("wire", LowerTypes.loweredName(data), data.tpe, sx.info) - declare("wire", LowerTypes.loweredName(addr), addr.tpe, sx.info) - // declare("wire", LowerTypes.loweredName(en), en.tpe) - - //; Read port - assign(addr, netlist(addr)) - // assign(en, netlist(en)) //;Connects value to m.r.en - val mem = WRef(sx.name, memType(sx), MemKind, UnknownFlow) - val memPort = WSubAccess(mem, addr, sx.dataType, UnknownFlow) + val en = memPortField(sx, r, "en") + val memPort = WSubAccess(WRef(sx), addr, sx.dataType, UnknownFlow) val depthValue = UIntLiteral(sx.depth, IntWidth(sx.depth.bitLength)) val garbageGuard = DoPrim(Geq, Seq(addr, depthValue), Seq(), UnknownType) - if ((sx.depth & (sx.depth - 1)) == 0) - assign(data, memPort, sx.info) - else - garbageAssign(data, memPort, garbageGuard, sx.info) + val clkSource = netlist(memPortField(sx, r, "clk")).expr + + createMemWire(en, netlist(en)) + + if (sx.readLatency == 1 && sx.readUnderWrite != ReadUnderWrite.Old) { + val InfoExpr(addrInfo, addrDriver) = netlist(addr) + declare("reg", LowerTypes.loweredName(addr), addr.tpe, sx.info) + initialize(WRef(LowerTypes.loweredName(addr), addr.tpe), zero, zero) + update(addr, addrDriver, clkSource, en, addrInfo) + } else { + createMemWire(addr, netlist(addr)) + } + + if (sx.readLatency == 1 && sx.readUnderWrite == ReadUnderWrite.Old) { + declare("reg", LowerTypes.loweredName(data), data.tpe, sx.info) + initialize(WRef(LowerTypes.loweredName(data), data.tpe), zero, zero) + update(data, memPort, clkSource, en, sx.info) + } else { + declare("wire", LowerTypes.loweredName(data), data.tpe, sx.info) + if ((sx.depth & (sx.depth - 1)) == 0) + assign(data, memPort, sx.info) + else + garbageAssign(data, memPort, garbageGuard, sx.info) + } } for (w <- sx.writers) { @@ -1056,31 +1122,41 @@ class VerilogEmitter extends SeqTransform with Emitter { val addr = memPortField(sx, w, "addr") val mask = memPortField(sx, w, "mask") val en = memPortField(sx, w, "en") - //Ports should share an always@posedge, so can't have intermediary wire - // TODO should we use the info here for anything? - val InfoExpr(_, clk) = netlist(memPortField(sx, w, "clk")) - - declare("wire", LowerTypes.loweredName(data), data.tpe, sx.info) - declare("wire", LowerTypes.loweredName(addr), addr.tpe, sx.info) - declare("wire", LowerTypes.loweredName(mask), mask.tpe, sx.info) - declare("wire", LowerTypes.loweredName(en), en.tpe, sx.info) - - // Write port - assign(data, netlist(data)) - assign(addr, netlist(addr)) - assign(mask, netlist(mask)) - assign(en, netlist(en)) - - val mem = WRef(sx.name, memType(sx), MemKind, UnknownFlow) - val memPort = WSubAccess(mem, addr, sx.dataType, UnknownFlow) - update(memPort, data, clk, AND(en, mask), sx.info) + + val clkSource = netlist(memPortField(sx, w, "clk")).expr + + createMemWire(data, netlist(data)) + createMemWire(addr, netlist(addr)) + createMemWire(mask, netlist(mask)) + createMemWire(en, netlist(en)) + + val memPort = WSubAccess(WRef(sx), addr, sx.dataType, UnknownFlow) + update(memPort, data, clkSource, AND(en, mask), sx.info) + } + + for (rw <- sx.readwriters) { + val rdata = memPortField(sx, rw, "rdata") + val wdata = memPortField(sx, rw, "wdata") + val addr = memPortField(sx, rw, "addr") + val en = memPortField(sx, rw, "en") + val wmode = memPortField(sx, rw, "wmode") + val wmask = memPortField(sx, rw, "wmask") + val memPort = WSubAccess(WRef(sx), addr, sx.dataType, UnknownFlow) + + val clkSource = netlist(memPortField(sx, rw, "clk")).expr + + createMemWire(wdata, netlist(wdata)) + createMemWire(addr, netlist(addr)) + createMemWire(wmode, netlist(wmode)) + createMemWire(wmask, netlist(wmask)) + createMemWire(en, netlist(en)) + + declare("reg", LowerTypes.loweredName(rdata), rdata.tpe, sx.info) + initialize(WRef(LowerTypes.loweredName(rdata), rdata.tpe), zero, zero) + update(rdata, memPort, clkSource, en, sx.info) + update(memPort, wdata, clkSource, AND(en, AND(wmode, wmask)), sx.info) } - if (sx.readwriters.nonEmpty) - throw EmitterException( - "All readwrite ports should be transformed into " + - "read & write ports by previous passes" - ) case _ => } } @@ -1187,13 +1263,19 @@ class VerilogEmitter extends SeqTransform with Emitter { for (x <- initials) emit(Seq(tab, x)) for (x <- asyncInitials) emit(Seq(tab, x)) emit(Seq(" `endif // RANDOMIZE")) - for (x <- memoryInitials) emit(Seq(tab, x)) + + if (emissionOptions.emitMemoryInitAsNoSynth) { + for (x <- memoryInitials) emit(Seq(tab, x)) + } emit(Seq("end // initial")) // User-defined macro of code to run after an initial block emit(Seq("`ifdef FIRRTL_AFTER_INITIAL")) emit(Seq("`FIRRTL_AFTER_INITIAL")) emit(Seq("`endif")) emit(Seq("`endif // SYNTHESIS")) + if (!emissionOptions.emitMemoryInitAsNoSynth) { + for (x <- memoryInitials) emit(Seq(tab, x)) + } } if (formals.keys.nonEmpty) { diff --git a/src/main/scala/firrtl/graph/DiGraph.scala b/src/main/scala/firrtl/graph/DiGraph.scala index b99c696115..99bf84038d 100644 --- a/src/main/scala/firrtl/graph/DiGraph.scala +++ b/src/main/scala/firrtl/graph/DiGraph.scala @@ -4,6 +4,7 @@ package firrtl.graph import scala.collection.{mutable, Map, Set} import scala.collection.mutable.{LinkedHashMap, LinkedHashSet} +import firrtl.options.DependencyManagerUtils.{CharSet, PrettyCharSet} /** An exception that is raised when an assumed DAG has a cycle */ class CyclicException(val node: Any) extends Exception(s"No valid linearization for cyclic graph, found at $node") @@ -31,6 +32,16 @@ object DiGraph { } new DiGraph(edgeDataCopy) } + + /** Create a DiGraph from edges */ + def apply[T](edges: (T, T)*): DiGraph[T] = { + val edgeMap = new LinkedHashMap[T, LinkedHashSet[T]] + for ((from, to) <- edges) { + val set = edgeMap.getOrElseUpdate(from, new LinkedHashSet[T]) + set += to + } + new DiGraph(edgeMap) + } } /** Represents common behavior of all directed graphs */ @@ -153,7 +164,7 @@ class DiGraph[T](private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) { val queue = new mutable.Queue[T] queue.enqueue(root) while (queue.nonEmpty) { - val u = queue.dequeue + val u = queue.dequeue() for (v <- getEdges(u)) { if (!prev.contains(v) && !blacklist.contains(v)) { prev(v) = u @@ -257,7 +268,7 @@ class DiGraph[T](private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) { } frame.childCall = None while (frame.edgeIter.hasNext && frame.childCall.isEmpty) { - val w = frame.edgeIter.next + val w = frame.edgeIter.next() if (!indices.contains(w)) { frame.childCall = Some(w) callStack.push(new StrongConnectFrame(w, getEdges(w).iterator)) @@ -269,13 +280,13 @@ class DiGraph[T](private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) { if (lowlinks(v) == indices(v)) { val scc = new mutable.ArrayBuffer[T] do { - val w = stack.pop + val w = stack.pop() onstack -= w scc += w } while (scc.last != v); sccs.append(scc.toSeq) } - callStack.pop + callStack.pop() } } } @@ -305,7 +316,7 @@ class DiGraph[T](private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) { queue += start queue ++= linearize.filter(reachable.contains(_)) while (!queue.isEmpty) { - val current = queue.dequeue + val current = queue.dequeue() for (v <- getEdges(current)) { for (p <- paths(current)) { addBinding(v, p :+ v) @@ -386,6 +397,41 @@ class DiGraph[T](private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) { that.edges.foreach({ case (k, v) => eprime.getOrElseUpdate(k, new LinkedHashSet[T]) ++= v }) new DiGraph(eprime) } + + /** Serializes a `DiGraph[String]` as a pretty tree + * + * Multiple roots are supported, but cycles are not. + */ + def prettyTree(charSet: CharSet = PrettyCharSet)(implicit ev: T =:= String): String = { + // Set up characters for building the tree + val (l, n, c) = (charSet.lastNode, charSet.notLastNode, charSet.continuation) + val ctab = " " * c.size + " " + + // Recursively adds each node of the DiGraph to accumulating List[String] + // Uses List because prepend is cheap and this prevents quadratic behavior of String + // concatenations or even flatMapping on Seqs + def rec(tab: String, node: T, mark: String, prev: List[String]): List[String] = { + val here = s"$mark$node" + val children = this.getEdges(node) + val last = children.size - 1 + children.toList // Convert LinkedHashSet to List to avoid determinism issues + .zipWithIndex // Find last + .foldLeft(here :: prev) { + case (acc, (nodex, idx)) => + val nextTab = if (idx == last) tab + ctab else tab + c + " " + val nextMark = if (idx == last) tab + l else tab + n + rec(nextTab, nodex, nextMark + " ", acc) + } + } + this.findSources.toList // Convert LinkedHashSet to List to avoid determinism issues + .sortBy(_.toString) // Make order deterministic + .foldLeft(Nil: List[String]) { + case (acc, root) => rec("", root, "", acc) + } + .reverse + .mkString("\n") + } + } class MutableDiGraph[T] extends DiGraph[T](new LinkedHashMap[T, LinkedHashSet[T]]) { diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index a26a2a944e..13ba3d46cb 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -4,6 +4,8 @@ package firrtl package ir import Utils.{dec2string, trim} +import firrtl.backends.experimental.smt.random.DefRandom +import dataclass.{data, since} import firrtl.constraint.{Constraint, IsKnown, IsVar} import org.apache.commons.text.translate.{AggregateTranslator, JavaUnicodeEscaper, LookupTranslator} @@ -226,6 +228,13 @@ abstract class Expression extends FirrtlNode { */ sealed trait RefLikeExpression extends Expression { def flow: Flow } +/** Represents a statement that can be referenced in a firrtl expression. + * This explicitly excludes named side-effecting statements like Print, Stop and Verification. + * Note: This trait cannot be sealed since the memory ports are declared in WIR.scala. + * Once we fully remove all WIR, this trait could be sealed. + */ +trait CanBeReferenced + object Reference { /** Creates a Reference from a Wire */ @@ -234,6 +243,9 @@ object Reference { /** Creates a Reference from a Register */ def apply(reg: DefRegister): Reference = Reference(reg.name, reg.tpe, RegKind, UnknownFlow) + /** Creates a Reference from a Random Source */ + def apply(rnd: DefRandom): Reference = Reference(rnd.name, rnd.tpe, RandomKind, UnknownFlow) + /** Creates a Reference from a Node */ def apply(node: DefNode): Reference = Reference(node.name, node.value.tpe, NodeKind, SourceFlow) @@ -386,7 +398,11 @@ abstract class Statement extends FirrtlNode { def foreachString(f: String => Unit): Unit def foreachInfo(f: Info => Unit): Unit } -case class DefWire(info: Info, name: String, tpe: Type) extends Statement with IsDeclaration with UseSerializer { +case class DefWire(info: Info, name: String, tpe: Type) + extends Statement + with IsDeclaration + with CanBeReferenced + with UseSerializer { def mapStmt(f: Statement => Statement): Statement = this def mapExpr(f: Expression => Expression): Statement = this def mapType(f: Type => Type): Statement = DefWire(info, name, f(tpe)) @@ -407,6 +423,7 @@ case class DefRegister( init: Expression) extends Statement with IsDeclaration + with CanBeReferenced with UseSerializer { def mapStmt(f: Statement => Statement): Statement = this def mapExpr(f: Expression => Expression): Statement = @@ -428,6 +445,7 @@ object DefInstance { case class DefInstance(info: Info, name: String, module: String, tpe: Type = UnknownType) extends Statement with IsDeclaration + with CanBeReferenced with UseSerializer { def mapExpr(f: Expression => Expression): Statement = this def mapStmt(f: Statement => Statement): Statement = this @@ -461,6 +479,7 @@ case class DefMemory( readUnderWrite: ReadUnderWrite.Value = ReadUnderWrite.Undefined) extends Statement with IsDeclaration + with CanBeReferenced with UseSerializer { def mapStmt(f: Statement => Statement): Statement = this def mapExpr(f: Expression => Expression): Statement = this @@ -476,6 +495,7 @@ case class DefMemory( case class DefNode(info: Info, name: String, value: Expression) extends Statement with IsDeclaration + with CanBeReferenced with UseSerializer { def mapStmt(f: Statement => Statement): Statement = this def mapExpr(f: Expression => Expression): Statement = DefNode(info, name, f(value)) @@ -593,40 +613,67 @@ case class Attach(info: Info, exprs: Seq[Expression]) extends Statement with Has def foreachString(f: String => Unit): Unit = () def foreachInfo(f: Info => Unit): Unit = f(info) } -case class Stop(info: Info, ret: Int, clk: Expression, en: Expression) + +@data class Stop(info: Info, ret: Int, clk: Expression, en: Expression, @since("FIRRTL 1.5") name: String = "") extends Statement with HasInfo + with IsDeclaration with UseSerializer { def mapStmt(f: Statement => Statement): Statement = this - def mapExpr(f: Expression => Expression): Statement = Stop(info, ret, f(clk), f(en)) + def mapExpr(f: Expression => Expression): Statement = Stop(info, ret, f(clk), f(en), name) def mapType(f: Type => Type): Statement = this - def mapString(f: String => String): Statement = this + def mapString(f: String => String): Statement = withName(f(name)) def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) def foreachStmt(f: Statement => Unit): Unit = () def foreachExpr(f: Expression => Unit): Unit = { f(clk); f(en) } def foreachType(f: Type => Unit): Unit = () - def foreachString(f: String => Unit): Unit = () + def foreachString(f: String => Unit): Unit = f(name) def foreachInfo(f: Info => Unit): Unit = f(info) + def copy(info: Info = info, ret: Int = ret, clk: Expression = clk, en: Expression = en): Stop = { + Stop(info, ret, clk, en, name) + } +} +object Stop { + def unapply(s: Stop): Some[(Info, Int, Expression, Expression)] = { + Some((s.info, s.ret, s.clk, s.en)) + } } -case class Print( +@data class Print( info: Info, string: StringLit, args: Seq[Expression], clk: Expression, - en: Expression) + en: Expression, + @since("FIRRTL 1.5") + name: String = "") extends Statement with HasInfo + with IsDeclaration with UseSerializer { def mapStmt(f: Statement => Statement): Statement = this - def mapExpr(f: Expression => Expression): Statement = Print(info, string, args.map(f), f(clk), f(en)) + def mapExpr(f: Expression => Expression): Statement = Print(info, string, args.map(f), f(clk), f(en), name) def mapType(f: Type => Type): Statement = this - def mapString(f: String => String): Statement = this + def mapString(f: String => String): Statement = withName(f(name)) def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) def foreachStmt(f: Statement => Unit): Unit = () def foreachExpr(f: Expression => Unit): Unit = { args.foreach(f); f(clk); f(en) } def foreachType(f: Type => Unit): Unit = () - def foreachString(f: String => Unit): Unit = () + def foreachString(f: String => Unit): Unit = f(name) def foreachInfo(f: Info => Unit): Unit = f(info) + def copy( + info: Info = info, + string: StringLit = string, + args: Seq[Expression] = args, + clk: Expression = clk, + en: Expression = en + ): Print = { + Print(info, string, args, clk, en, name) + } +} +object Print { + def unapply(s: Print): Some[(Info, StringLit, Seq[Expression], Expression, Expression)] = { + Some((s.info, s.string, s.args, s.clk, s.en)) + } } // formal @@ -636,27 +683,45 @@ object Formal extends Enumeration { val Cover = Value("cover") } -case class Verification( +@data class Verification( op: Formal.Value, info: Info, clk: Expression, pred: Expression, en: Expression, - msg: StringLit) + msg: StringLit, + @since("FIRRTL 1.5") + name: String = "") extends Statement with HasInfo + with IsDeclaration with UseSerializer { def mapStmt(f: Statement => Statement): Statement = this def mapExpr(f: Expression => Expression): Statement = copy(clk = f(clk), pred = f(pred), en = f(en)) def mapType(f: Type => Type): Statement = this - def mapString(f: String => String): Statement = this + def mapString(f: String => String): Statement = withName(f(name)) def mapInfo(f: Info => Info): Statement = copy(info = f(info)) def foreachStmt(f: Statement => Unit): Unit = () def foreachExpr(f: Expression => Unit): Unit = { f(clk); f(pred); f(en); } def foreachType(f: Type => Unit): Unit = () - def foreachString(f: String => Unit): Unit = () + def foreachString(f: String => Unit): Unit = f(name) def foreachInfo(f: Info => Unit): Unit = f(info) + def copy( + op: Formal.Value = op, + info: Info = info, + clk: Expression = clk, + pred: Expression = pred, + en: Expression = en, + msg: StringLit = msg + ): Verification = { + Verification(op, info, clk, pred, en, msg, name) + } +} +object Verification { + def unapply(s: Verification): Some[(Formal.Value, Info, Expression, Expression, Expression, StringLit)] = { + Some((s.op, s.info, s.clk, s.pred, s.en, s.msg)) + } } // end formal @@ -891,6 +956,7 @@ case class IntervalType(lower: Bound, upper: Bound, point: Width) extends Ground case x => Some(x.setScale(0, FLOOR) * prec) } case (Closed(a), Some(prec)) => Some((a / prec).setScale(0, FLOOR) * prec) + case _ => None } def minAdjusted: Option[BigInt] = min.map(_ * BigDecimal(BigInt(1) << bp) match { @@ -977,6 +1043,7 @@ case class Port( tpe: Type) extends FirrtlNode with IsDeclaration + with CanBeReferenced with UseSerializer { def mapType(f: Type => Type): Port = Port(info, name, direction, f(tpe)) def mapString(f: String => String): Port = Port(info, f(name), direction, tpe) diff --git a/src/main/scala/firrtl/ir/Serializer.scala b/src/main/scala/firrtl/ir/Serializer.scala index 4a1bfbedf3..1c2bfc82a6 100644 --- a/src/main/scala/firrtl/ir/Serializer.scala +++ b/src/main/scala/firrtl/ir/Serializer.scala @@ -2,16 +2,24 @@ package firrtl.ir +import firrtl.Utils +import firrtl.backends.experimental.smt.random.DefRandom import firrtl.constraint.Constraint object Serializer { val NewLine = '\n' val Indent = " " - /** Converts a `FirrtlNode` into its string representation. */ + /** Converts a `FirrtlNode` into its string representation with + * default indentation. + */ def serialize(node: FirrtlNode): String = { + serialize(node, 0) + } + + /** Converts a `FirrtlNode` into its string representation. */ + def serialize(node: FirrtlNode, indent: Int): String = { val builder = new StringBuilder() - val indent = 0 node match { case n: Info => s(n)(builder, indent) case n: StringLit => s(n)(builder, indent) @@ -99,19 +107,26 @@ object Serializer { case Block(stmts) => val it = stmts.iterator while (it.hasNext) { - s(it.next) + s(it.next()) if (it.hasNext) newLineAndIndent() } - case Stop(info, ret, clk, en) => - b ++= "stop("; s(clk); b ++= ", "; s(en); b ++= ", "; b ++= ret.toString; b += ')'; s(info) - case Print(info, string, args, clk, en) => + case stop @ Stop(info, ret, clk, en) => + b ++= "stop("; s(clk); b ++= ", "; s(en); b ++= ", "; b ++= ret.toString; b += ')' + sStmtName(stop.name); s(info) + case print @ Print(info, string, args, clk, en) => b ++= "printf("; s(clk); b ++= ", "; s(en); b ++= ", "; b ++= string.escape - if (args.nonEmpty) b ++= ", "; s(args, ", "); b += ')'; s(info) + if (args.nonEmpty) b ++= ", "; s(args, ", "); b += ')' + sStmtName(print.name); s(info) case IsInvalid(info, expr) => s(expr); b ++= " is invalid"; s(info) case DefWire(info, name, tpe) => b ++= "wire "; b ++= name; b ++= " : "; s(tpe); s(info) case DefRegister(info, name, tpe, clock, reset, init) => b ++= "reg "; b ++= name; b ++= " : "; s(tpe); b ++= ", "; s(clock); b ++= " with :"; newLineAndIndent(1) b ++= "reset => ("; s(reset); b ++= ", "; s(init); b += ')'; s(info) + case DefRandom(info, name, tpe, clock, en) => + b ++= "rand "; b ++= name; b ++= " : "; s(tpe); + if (clock.isDefined) { b ++= ", "; s(clock.get); } + en match { case Utils.True() => case _ => b ++= " when "; s(en) } + s(info) case DefInstance(info, name, module, _) => b ++= "inst "; b ++= name; b ++= " of "; b ++= module; s(info) case DefMemory( info, @@ -138,9 +153,9 @@ object Serializer { case Attach(info, exprs) => // exprs should never be empty since the attach statement takes *at least* two signals according to the spec b ++= "attach ("; s(exprs, ", "); b += ')'; s(info) - case Verification(op, info, clk, pred, en, msg) => + case veri @ Verification(op, info, clk, pred, en, msg) => b ++= op.toString; b += '('; s(List(clk, pred, en), ", ", false); b ++= msg.escape - b += ')'; s(info) + b += ')'; sStmtName(veri.name); s(info) // WIR case firrtl.CDefMemory(info, name, tpe, size, seq, readUnderWrite) => @@ -155,6 +170,10 @@ object Serializer { case other => b ++= other.serialize // Handle user-defined nodes } + private def sStmtName(lbl: String)(implicit b: StringBuilder): Unit = { + if (lbl.nonEmpty) { b ++= s" : $lbl" } + } + private def s(node: Width)(implicit b: StringBuilder, indent: Int): Unit = node match { case IntWidth(width) => b += '<'; b ++= width.toString(); b += '>' case UnknownWidth => // empty string @@ -240,6 +259,7 @@ object Serializer { newLineAndIndent(1); s(modules.head)(b, indent + 1) modules.drop(1).foreach { m => newLineNoIndent(); newLineAndIndent(1); s(m)(b, indent + 1) } } + newLineNoIndent() } // serialize constraints diff --git a/src/main/scala/firrtl/options/DependencyManager.scala b/src/main/scala/firrtl/options/DependencyManager.scala index 39998ed7fb..f8299bbb1b 100644 --- a/src/main/scala/firrtl/options/DependencyManager.scala +++ b/src/main/scala/firrtl/options/DependencyManager.scala @@ -91,7 +91,7 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends } while (queue.nonEmpty) { - val u: Dependency[B] = queue.dequeue + val u: Dependency[B] = queue.dequeue() for (v <- extractor(dependencyToObject(u))) { if (!blacklist.contains(v) && !edges.contains(v)) { queue.enqueue(v) @@ -193,13 +193,13 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends ) } - /** An ordering of [[firrtl.options.TransformLike TransformLike]]s that causes the requested [[DependencyManager.targets - * targets]] to be executed starting from the [[DependencyManager.currentState currentState]]. This ordering respects + /** An ordering of [[firrtl.options.TransformLike TransformLike]]s that causes the requested [[firrtl.options.DependencyManager.targets + * targets]] to be executed starting from the [[firrtl.options.DependencyManager.currentState currentState]]. This ordering respects * prerequisites, optionalPrerequisites, optionalPrerequisiteOf, and invalidates of all constituent * [[firrtl.options.TransformLike TransformLike]]s. This uses an algorithm that attempts to reduce the number of - * re-lowerings due to invalidations. Re-lowerings are implemented as new [[DependencyManager]]s. - * @throws DependencyManagerException if a cycle exists in either the [[DependencyManager.dependencyGraph - * dependencyGraph]] or the [[DependencyManager.invalidateGraph invalidateGraph]]. + * re-lowerings due to invalidations. Re-lowerings are implemented as new [[firrtl.options.DependencyManager]]s. + * @throws firrtl.options.DependencyManagerException if a cycle exists in either the [[firrtl.options.DependencyManager.dependencyGraph + * dependencyGraph]] or the [[firrtl.options.DependencyManager.invalidateGraph invalidateGraph]]. */ lazy val transformOrder: Seq[B] = { @@ -244,8 +244,8 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends l ++ postprocessing } - /** A version of the [[DependencyManager.transformOrder transformOrder]] that flattens the transforms of any internal - * [[DependencyManager]]s. + /** A version of the [[firrtl.options.DependencyManager.transformOrder transformOrder]] that flattens the transforms of any internal + * [[firrtl.options.DependencyManager DependencyManager]]s. */ lazy val flattenedTransformOrder: Seq[B] = transformOrder.flatMap { case p: DependencyManager[A, B] => p.flattenedTransformOrder @@ -275,7 +275,14 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends | prerequisites: ${prerequisites.mkString("\n -", "\n -", "")}""".stripMargin ) } - (t.transform(a), ((state + wrapperToClass(t)).map(dToO).filterNot(t.invalidates).map(oToD))) + val logger = t.getLogger + logger.info(s"======== Starting ${t.name} ========") + val (timeMillis, annosx) = firrtl.Utils.time { t.transform(a) } + logger.info(s"""----------------------------${"-" * t.name.size}---------\n""") + logger.info(f"Time: $timeMillis%.1f ms") + logger.info(s"======== Finished ${t.name} ========") + val statex = (state + wrapperToClass(t)).map(dToO).filterNot(t.invalidates).map(oToD) + (annosx, statex) }._1 } diff --git a/src/main/scala/firrtl/options/Stage.scala b/src/main/scala/firrtl/options/Stage.scala index 0783bfb5e7..cefdd95793 100644 --- a/src/main/scala/firrtl/options/Stage.scala +++ b/src/main/scala/firrtl/options/Stage.scala @@ -28,7 +28,7 @@ abstract class Stage extends Phase { /** Execute this stage on some input annotations. Annotations will be read from any input annotation files. * @param annotations input annotations * @return output annotations - * @throws OptionsException if command line or annotation validation fails + * @throws firrtl.options.OptionsException if command line or annotation validation fails */ final def transform(annotations: AnnotationSeq): AnnotationSeq = { val annotationsx = @@ -52,7 +52,7 @@ abstract class Stage extends Phase { * @param args command line arguments * @param initialAnnotations annotation * @return output annotations - * @throws OptionsException if command line or annotation validation fails + * @throws firrtl.options.OptionsException if command line or annotation validation fails */ final def execute(args: Array[String], annotations: AnnotationSeq): AnnotationSeq = transform(shell.parse(args, annotations)) diff --git a/src/main/scala/firrtl/passes/CInferMDir.scala b/src/main/scala/firrtl/passes/CInferMDir.scala index 90f1c73950..cca8fde4ff 100644 --- a/src/main/scala/firrtl/passes/CInferMDir.scala +++ b/src/main/scala/firrtl/passes/CInferMDir.scala @@ -22,22 +22,19 @@ object CInferMDir extends Pass { case None => case Some(p) => mports(e.name) = (p, dir) match { - case (MInfer, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir") case (MInfer, MWrite) => MWrite case (MInfer, MRead) => MRead case (MInfer, MReadWrite) => MReadWrite - case (MWrite, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir") case (MWrite, MWrite) => MWrite case (MWrite, MRead) => MReadWrite case (MWrite, MReadWrite) => MReadWrite - case (MRead, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir") case (MRead, MWrite) => MReadWrite case (MRead, MRead) => MRead case (MRead, MReadWrite) => MReadWrite - case (MReadWrite, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir") case (MReadWrite, MWrite) => MReadWrite case (MReadWrite, MRead) => MReadWrite case (MReadWrite, MReadWrite) => MReadWrite + case _ => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir") } } e diff --git a/src/main/scala/firrtl/passes/CheckFlows.scala b/src/main/scala/firrtl/passes/CheckFlows.scala index 2816887f55..f78a115a04 100644 --- a/src/main/scala/firrtl/passes/CheckFlows.scala +++ b/src/main/scala/firrtl/passes/CheckFlows.scala @@ -10,7 +10,7 @@ import firrtl.options.Dependency object CheckFlows extends Pass { - override def prerequisites = Dependency(passes.ResolveFlows) +: firrtl.stage.Forms.WorkingIR + override def prerequisites = Dependency(passes.ResolveFlows) +: firrtl.stage.Forms.MinimalHighForm override def optionalPrerequisiteOf = Seq( diff --git a/src/main/scala/firrtl/passes/CheckHighForm.scala b/src/main/scala/firrtl/passes/CheckHighForm.scala index 5514741ae8..05635d0054 100644 --- a/src/main/scala/firrtl/passes/CheckHighForm.scala +++ b/src/main/scala/firrtl/passes/CheckHighForm.scala @@ -22,6 +22,10 @@ trait CheckHighFormLike { this: Pass => moduleNS += name scopes.head += name } + // ensures that the name cannot be used again, but prevent references to this name + def addToNamespace(name: String): Unit = { + moduleNS += name + } def expandMPortVisibility(port: CDefMPort): Unit = { // Legacy CHIRRTL ports are visible in any scope where their parent memory is visible scopes.find(_.contains(port.mem)).getOrElse(scopes.head) += port.name @@ -243,18 +247,26 @@ trait CheckHighFormLike { this: Pass => errors.append(new NegUIntException(info, mname)) case ex: DoPrim => checkHighFormPrimop(info, mname, ex) case _: Reference | _: WRef | _: UIntLiteral | _: Mux | _: ValidIf => - case ex: SubAccess => validSubexp(info, mname)(ex.expr) - case ex: WSubAccess => validSubexp(info, mname)(ex.expr) + case ex: SubAccess => validSubexp(info, mname)(ex.expr) case ex => ex.foreach(validSubexp(info, mname)) } e.foreach(checkHighFormW(info, mname + "/" + e.serialize)) e.foreach(checkHighFormE(info, mname, names)) } - def checkName(info: Info, mname: String, names: ScopeView)(name: String): Unit = { - if (!names.legalDecl(name)) - errors.append(new NotUniqueException(info, mname, name)) - names.declare(name) + def checkName(info: Info, mname: String, names: ScopeView, canBeReference: Boolean)(name: String): Unit = { + // Empty names are allowed for backwards compatibility reasons and + // indicate that the entity has essentially no name. + if (name.isEmpty) { assert(!canBeReference, "A statement with an empty name cannot be used as a reference!") } + else { + if (!names.legalDecl(name)) + errors.append(new NotUniqueException(info, mname, name)) + if (canBeReference) { + names.declare(name) + } else { + names.addToNamespace(name) + } + } } def checkInstance(info: Info, child: String, parent: String): Unit = { @@ -271,7 +283,11 @@ trait CheckHighFormLike { this: Pass => case NoInfo => minfo case x => x } - s.foreach(checkName(info, mname, names)) + val canBeReference = s match { + case _: CanBeReferenced => true + case _ => false + } + s.foreach(checkName(info, mname, names, canBeReference)) s match { case DefRegister(info, name, tpe, _, reset, init) => if (hasFlip(tpe)) @@ -284,7 +300,6 @@ trait CheckHighFormLike { this: Pass => if (sx.depth <= 0) errors.append(new NegMemSizeException(info, mname)) case sx: DefInstance => checkInstance(info, mname, sx.module) - case sx: WDefInstance => checkInstance(info, mname, sx.module) case sx: Connect => checkValidLoc(info, mname, sx.loc) case sx: PartialConnect => checkValidLoc(info, mname, sx.loc) case sx: Print => checkFstring(info, mname, sx.string, sx.args.length) @@ -350,7 +365,7 @@ trait CheckHighFormLike { this: Pass => object CheckHighForm extends Pass with CheckHighFormLike { - override def prerequisites = firrtl.stage.Forms.WorkingIR + override def prerequisites = firrtl.stage.Forms.MinimalHighForm override def optionalPrerequisiteOf = Seq( diff --git a/src/main/scala/firrtl/passes/CheckTypes.scala b/src/main/scala/firrtl/passes/CheckTypes.scala index bbccb35395..50fbfc2e5b 100644 --- a/src/main/scala/firrtl/passes/CheckTypes.scala +++ b/src/main/scala/firrtl/passes/CheckTypes.scala @@ -13,7 +13,7 @@ import firrtl.options.Dependency object CheckTypes extends Pass { - override def prerequisites = Dependency(InferTypes) +: firrtl.stage.Forms.WorkingIR + override def prerequisites = Dependency(InferTypes) +: firrtl.stage.Forms.MinimalHighForm override def optionalPrerequisiteOf = Seq( @@ -55,9 +55,9 @@ object CheckTypes extends Pass { class RegReqClk(info: Info, mname: String, name: String) extends PassException(s"$info: [module $mname] Register $name requires a clock typed signal.") class EnNotUInt(info: Info, mname: String) - extends PassException(s"$info: [module $mname] Enable must be a UIntType typed signal.") + extends PassException(s"$info: [module $mname] Enable must be a 1-bit UIntType typed signal.") class PredNotUInt(info: Info, mname: String) - extends PassException(s"$info: [module $mname] Predicate not a UIntType.") + extends PassException(s"$info: [module $mname] Predicate not a 1-bit UIntType.") class OpNotGround(info: Info, mname: String, op: String) extends PassException(s"$info: [module $mname] Primop $op cannot operate on non-ground types.") class OpNotUInt(info: Info, mname: String, op: String, e: String) @@ -81,7 +81,7 @@ object CheckTypes extends Pass { class MuxPassiveTypes(info: Info, mname: String) extends PassException(s"$info: [module $mname] Must mux between passive types.") class MuxCondUInt(info: Info, mname: String) - extends PassException(s"$info: [module $mname] A mux condition must be of type UInt.") + extends PassException(s"$info: [module $mname] A mux condition must be of type 1-bit UInt.") class MuxClock(info: Info, mname: String) extends PassException(s"$info: [module $mname] Firrtl does not support muxing clocks.") class ValidIfPassiveTypes(info: Info, mname: String) @@ -120,6 +120,15 @@ object CheckTypes extends Pass { case _ => false } + private def legalCondType(tpe: Type): Boolean = tpe match { + // If width is known, must be 1 + case UIntType(IntWidth(w)) => w == 1 + // Unknown width or variable widths (for width inference) are acceptable (checked in later run) + case UIntType(_) => true + // Any other type is not okay + case _ => false + } + private def bulk_equals(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Boolean = { (t1, t2) match { case (ClockType, ClockType) => flip1 == flip2 @@ -165,7 +174,8 @@ object CheckTypes extends Pass { bulk_equals(con.loc.tpe, con.expr.tpe, Default, Default) //;---------------- Helper Functions -------------- - def ut: UIntType = UIntType(UnknownWidth) + private val UIntUnknown = UIntType(UnknownWidth) + def ut: UIntType = UIntUnknown def st: SIntType = SIntType(UnknownWidth) def run(c: Circuit): Circuit = { @@ -332,9 +342,8 @@ object CheckTypes extends Pass { errors.append(new MuxSameType(info, mname, e.tval.tpe.serialize, e.fval.tpe.serialize)) if (!passive(e.tpe)) errors.append(new MuxPassiveTypes(info, mname)) - e.cond.tpe match { - case _: UIntType => - case _ => errors.append(new MuxCondUInt(info, mname)) + if (!legalCondType(e.cond.tpe)) { + errors.append(new MuxCondUInt(info, mname)) } case (e: ValidIf) => if (!passive(e.tpe)) @@ -375,7 +384,7 @@ object CheckTypes extends Pass { if (sx.clock.tpe != ClockType) { errors.append(new RegReqClk(info, mname, sx.name)) } - case sx: Conditionally if wt(sx.pred.tpe) != wt(ut) => + case sx: Conditionally if !legalCondType(sx.pred.tpe) => errors.append(new PredNotUInt(info, mname)) case sx: DefNode => sx.value.tpe match { @@ -396,16 +405,16 @@ object CheckTypes extends Pass { } case sx: Stop => if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname)) - if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname)) + if (!legalCondType(sx.en.tpe)) errors.append(new EnNotUInt(info, mname)) case sx: Print => if (sx.args.exists(x => wt(x.tpe) != wt(ut) && wt(x.tpe) != wt(st))) errors.append(new PrintfArgNotGround(info, mname)) if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname)) - if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname)) + if (!legalCondType(sx.en.tpe)) errors.append(new EnNotUInt(info, mname)) case sx: Verification => if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname)) - if (wt(sx.pred.tpe) != wt(ut)) errors.append(new PredNotUInt(info, mname)) - if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname)) + if (!legalCondType(sx.pred.tpe)) errors.append(new PredNotUInt(info, mname)) + if (!legalCondType(sx.en.tpe)) errors.append(new EnNotUInt(info, mname)) case sx: DefMemory => sx.dataType match { case AnalogType(w) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name)) diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala index 3f30e1243d..a9f4442165 100644 --- a/src/main/scala/firrtl/passes/CheckWidths.scala +++ b/src/main/scala/firrtl/passes/CheckWidths.scala @@ -13,7 +13,7 @@ import firrtl.options.Dependency object CheckWidths extends Pass { - override def prerequisites = Dependency[passes.InferWidths] +: firrtl.stage.Forms.WorkingIR + override def prerequisites = Dependency[passes.InferWidths] +: firrtl.stage.Forms.MinimalHighForm override def optionalPrerequisiteOf = Seq(Dependency[transforms.InferResets]) @@ -23,7 +23,7 @@ object CheckWidths extends Pass { val MaxWidth = 1000000 val DshlMaxWidth = getUIntWidth(MaxWidth) class UninferredWidth(info: Info, target: String) - extends PassException(s"""|$info : Uninferred width for target below.serialize}. (Did you forget to assign to it?) + extends PassException(s"""|$info : Uninferred width for target below. (Did you forget to assign to it?) |$target""".stripMargin) class UninferredBound(info: Info, target: String, bound: String) extends PassException(s"""|$info : Uninferred $bound bound for target. (Did you forget to assign to it?) diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index 7456d2ab06..8fb4e5fbd6 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -125,13 +125,13 @@ object ExpandWhens extends Pass { EmptyStmt // For simulation constructs, update simlist with predicated statement and return EmptyStmt case sx: Print => - simlist += (if (weq(p, one)) sx else Print(sx.info, sx.string, sx.args, sx.clk, AND(p, sx.en))) + simlist += (if (weq(p, one)) sx else sx.withEn(AND(p, sx.en))) EmptyStmt case sx: Stop => - simlist += (if (weq(p, one)) sx else Stop(sx.info, sx.ret, sx.clk, AND(p, sx.en))) + simlist += (if (weq(p, one)) sx else sx.withEn(AND(p, sx.en))) EmptyStmt case sx: Verification => - simlist += (if (weq(p, one)) sx else sx.copy(en = AND(p, sx.en))) + simlist += (if (weq(p, one)) sx else sx.withEn(AND(p, sx.en))) EmptyStmt // Expand conditionally, see comments below case sx: Conditionally => diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala index 01f0b823bc..8ab78fee6b 100644 --- a/src/main/scala/firrtl/passes/InferTypes.scala +++ b/src/main/scala/firrtl/passes/InferTypes.scala @@ -10,7 +10,7 @@ import firrtl.options.Dependency object InferTypes extends Pass { - override def prerequisites = Dependency(ResolveKinds) +: firrtl.stage.Forms.WorkingIR + override def prerequisites = Dependency(ResolveKinds) +: firrtl.stage.Forms.MinimalHighForm override def invalidates(a: Transform) = false @deprecated("This should never have been public", "FIRRTL 1.3.2") diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index aa2095fabb..d0677fadd5 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -71,7 +71,7 @@ class InferWidths extends Transform with ResolvedAnnotationPaths with Dependency Dependency(passes.ResolveFlows), Dependency[passes.InferBinaryPoints], Dependency[passes.TrimIntervals] - ) ++ firrtl.stage.Forms.WorkingIR + ) ++ firrtl.stage.Forms.MinimalHighForm override def invalidates(a: Transform) = false val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation]) @@ -110,6 +110,7 @@ class InferWidths extends Transform with ResolvedAnnotationPaths with Dependency case (AsyncResetType, AsyncResetType) => Nil case (ResetType, _) => Nil case (_, ResetType) => Nil + case _ => throwInternalError("Shouldn't be here") } private def addExpConstraints(e: Expression)(implicit constraintSolver: ConstraintSolver): Expression = diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index 4eba5d5901..912acf8ef1 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -187,18 +187,23 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe renameMap: RenameMap )(s: Statement ): Statement = { - def onName(ofModuleOpt: Option[String])(name: String) = { - if (prefix.nonEmpty && !ns.tryName(prefix + name)) { - throw new Exception(s"Inlining failed. Inlined name '${prefix + name}' already exists") - } - ofModuleOpt match { - case None => - renameMap.record(currentModule.ref(name), nextModule.ref(prefix + name)) - case Some(ofModule) => - renameMap.record(currentModule.instOf(name, ofModule), nextModule.instOf(prefix + name, ofModule)) + def onName(ofModuleOpt: Option[String])(name: String): String = { + // Empty names are allowed for backwards compatibility reasons and + // indicate that the entity has essentially no name and thus cannot be prefixed. + if (name.isEmpty) { name } + else { + if (prefix.nonEmpty && !ns.tryName(prefix + name)) { + throw new Exception(s"Inlining failed. Inlined name '${prefix + name}' already exists") + } + ofModuleOpt match { + case None => + renameMap.record(currentModule.ref(name), nextModule.ref(prefix + name)) + case Some(ofModule) => + renameMap.record(currentModule.instOf(name, ofModule), nextModule.instOf(prefix + name, ofModule)) + } + renames(name) = prefix + name + prefix + name } - renames(name) = prefix + name - prefix + name } s match { diff --git a/src/main/scala/firrtl/passes/Legalize.scala b/src/main/scala/firrtl/passes/Legalize.scala index ef0e17060c..e1a39fbe15 100644 --- a/src/main/scala/firrtl/passes/Legalize.scala +++ b/src/main/scala/firrtl/passes/Legalize.scala @@ -3,11 +3,11 @@ package firrtl.passes import firrtl.PrimOps._ -import firrtl.Utils.{error, zero, BoolType} +import firrtl.Utils.{error, getGroundZero, zero, BoolType} import firrtl.ir._ import firrtl.options.Dependency import firrtl.transforms.ConstantPropagation -import firrtl.{bitWidth, Transform} +import firrtl.{bitWidth, getWidth, Transform} import firrtl.Mappers._ // Replace shr by amount >= arg width with 0 for UInts and MSB for SInts @@ -56,6 +56,19 @@ object Legalize extends Pass { SIntLiteral(value, IntWidth(expr.consts.head)) case _ => expr } + // Convert `-x` to `0 - x` + private def legalizeNeg(expr: DoPrim): Expression = { + val arg = expr.args.head + arg.tpe match { + case tpe: SIntType => + val zero = getGroundZero(tpe) + DoPrim(Sub, Seq(zero, arg), Nil, expr.tpe) + case tpe: UIntType => + val zero = getGroundZero(tpe) + val sub = DoPrim(Sub, Seq(zero, arg), Nil, UIntType(tpe.width + IntWidth(1))) + DoPrim(AsSInt, Seq(sub), Nil, expr.tpe) + } + } private def legalizeConnect(c: Connect): Statement = { val t = c.loc.tpe val w = bitWidth(t) @@ -78,6 +91,7 @@ object Legalize extends Pass { case Shr => legalizeShiftRight(prim) case Pad => legalizePad(prim) case Bits | Head | Tail => legalizeBitExtract(prim) + case Neg => legalizeNeg(prim) case _ => prim } case e => e // respect pre-order traversal diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index 592caf5d17..0bd44a8ca3 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -8,8 +8,10 @@ import firrtl.annotations.{ MemoryInitAnnotation, MemoryRandomInitAnnotation, ModuleTarget, - ReferenceTarget + ReferenceTarget, + TargetToken } +import TargetToken.{Instance, OfModule} import firrtl.{ CircuitForm, CircuitState, @@ -73,16 +75,18 @@ object LowerTypes extends Transform with DependencyAPIMigration { val memInitByModule = memInitAnnos.map(_.asInstanceOf[MemoryInitAnnotation]).groupBy(_.target.encapsulatingModule) val c = CircuitTarget(state.circuit.main) - val resultAndRenames = state.circuit.modules.map(m => onModule(c, m, memInitByModule.getOrElse(m.name, Seq()))) + val refRenameMap = RenameMap() + val resultAndRenames = + state.circuit.modules.map(m => onModule(c, m, memInitByModule.getOrElse(m.name, Seq()), refRenameMap)) val result = state.circuit.copy(modules = resultAndRenames.map(_._1)) // memory init annotations could have been modified val newAnnos = otherAnnos ++ resultAndRenames.flatMap(_._3) - // chain module renames in topological order - val moduleRenames = resultAndRenames.map { case (m, r, _) => m.name -> r }.toMap - val moduleOrderBottomUp = InstanceKeyGraph(result).moduleOrder.reverseIterator - val renames = moduleOrderBottomUp.map(m => moduleRenames(m.name)).reduce((a, b) => a.andThen(b)) + // Build RenameMap for instances + val moduleRenames = resultAndRenames.map { case (m, r, _) => OfModule(m.name) -> r }.toMap + val instRenameMap = RenameMap.fromInstanceRenames(InstanceKeyGraph(state.circuit), moduleRenames) + val renames = instRenameMap.andThen(refRenameMap) state.copy(circuit = result, renames = Some(renames), annotations = newAnnos) } @@ -90,9 +94,9 @@ object LowerTypes extends Transform with DependencyAPIMigration { private def onModule( c: CircuitTarget, m: DefModule, - memoryInit: Seq[MemoryInitAnnotation] - ): (DefModule, RenameMap, Seq[MemoryInitAnnotation]) = { - val renameMap = RenameMap() + memoryInit: Seq[MemoryInitAnnotation], + renameMap: RenameMap + ): (DefModule, Map[Instance, Instance], Seq[MemoryInitAnnotation]) = { val ref = c.module(m.name) // first we lower the ports in order to ensure that their names are independent of the module body @@ -105,7 +109,9 @@ object LowerTypes extends Transform with DependencyAPIMigration { implicit val memInit: Seq[MemoryInitAnnotation] = memoryInit val newMod = mLoweredPorts.mapStmt(onStatement) - (newMod, renameMap, memInit) + val instRenames = symbols.getInstanceRenames.toMap + + (newMod, instRenames, memInit) } // We lower ports in a separate pass in order to ensure that statements inside the module do not influence port names. @@ -221,6 +227,7 @@ private class LoweringTable( private val namespace = mutable.HashSet[String]() ++ table.getSymbolNames // Serialized old access string to new ground type reference. private val nameToExprs = mutable.HashMap[String, Seq[RefLikeExpression]]() ++ portNameToExprs + private val instRenames = mutable.ListBuffer[(Instance, Instance)]() def lower(mem: DefMemory): Seq[DefMemory] = { val (mems, refs) = DestructTypes.destructMemory(m, mem, namespace, renameMap, portNames) @@ -228,7 +235,7 @@ private class LoweringTable( mems } def lower(inst: DefInstance): DefInstance = { - val (newInst, refs) = DestructTypes.destructInstance(m, inst, namespace, renameMap, portNames) + val (newInst, refs) = DestructTypes.destructInstance(m, inst, namespace, instRenames, portNames) nameToExprs ++= refs.map { case (name, r) => name -> List(r) } newInst } @@ -245,6 +252,7 @@ private class LoweringTable( } def getReferences(expr: RefLikeExpression): Seq[RefLikeExpression] = nameToExprs(serialize(expr)) + def getInstanceRenames: List[(Instance, Instance)] = instRenames.toList // We could just use FirrtlNode.serialize here, but we want to make sure there are not SubAccess nodes left. private def serialize(expr: RefLikeExpression): String = expr match { @@ -296,11 +304,11 @@ private object DestructTypes { * instead of a flat Reference when turning them into access expressions. */ def destructInstance( - m: ModuleTarget, - instance: DefInstance, - namespace: Namespace, - renameMap: RenameMap, - reserved: Set[String] + m: ModuleTarget, + instance: DefInstance, + namespace: Namespace, + instRenames: mutable.ListBuffer[(Instance, Instance)], + reserved: Set[String] ): (DefInstance, Seq[(String, SubField)]) = { val (rename, _) = uniquify(Field(instance.name, Default, instance.tpe), namespace, reserved) val newName = rename.map(_.name).getOrElse(instance.name) @@ -314,7 +322,7 @@ private object DestructTypes { // rename all references to the instance if necessary if (newName != instance.name) { - renameMap.record(m.instOf(instance.name, instance.module), m.instOf(newName, instance.module)) + instRenames += Instance(instance.name) -> Instance(newName) } // The ports do not need to be explicitly renamed here. They are renamed when the module ports are lowered. diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala index 875e80ae59..1a430778ff 100644 --- a/src/main/scala/firrtl/passes/PadWidths.scala +++ b/src/main/scala/firrtl/passes/PadWidths.scala @@ -58,7 +58,7 @@ object PadWidths extends Pass { case ex: ValidIf => ex.copy(value = fixup(width(ex.tpe))(ex.value)) case ex: DoPrim => ex.op match { - case Lt | Leq | Gt | Geq | Eq | Neq | Not | And | Or | Xor | Add | Sub | Mul | Div | Rem | Shr => + case Lt | Leq | Gt | Geq | Eq | Neq | Not | And | Or | Xor | Add | Sub | Rem | Shr => // sensitive ops ex.map(fixup((ex.args.map(width).foldLeft(0))(math.max))) case Dshl => diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala index f49af9353e..073bf49d93 100644 --- a/src/main/scala/firrtl/passes/RemoveAccesses.scala +++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala @@ -47,8 +47,7 @@ object RemoveAccesses extends Pass { * Seq(Location(a[0], UIntLiteral(0)), Location(a[1], UIntLiteral(1))) */ private def getLocations(e: Expression): Seq[Location] = e match { - case e: WRef => create_exps(e).map(Location(_, one)) - case e: WSubIndex => + case e: SubIndex => val ls = getLocations(e.expr) val start = get_point(e) val end = start + get_size(e.tpe) @@ -57,7 +56,7 @@ object RemoveAccesses extends Pass { (l, i) <- ls.zipWithIndex if ((i % stride) >= start) & ((i % stride) < end) ) yield l - case e: WSubField => + case e: SubField => val ls = getLocations(e.expr) val start = get_point(e) val end = start + get_size(e.tpe) @@ -66,17 +65,27 @@ object RemoveAccesses extends Pass { (l, i) <- ls.zipWithIndex if ((i % stride) >= start) & ((i % stride) < end) ) yield l - case e: WSubAccess => - val ls = getLocations(e.expr) - val stride = get_size(e.tpe) - val wrap = e.expr.tpe.asInstanceOf[VectorType].size - ls.zipWithIndex.map { - case (l, i) => - val c = (i / stride) % wrap - val basex = l.base - val guardx = AND(l.guard, EQV(UIntLiteral(c), e.index)) - Location(basex, guardx) + case SubAccess(expr, index, tpe, _) => + getLocations(expr).zipWithIndex.flatMap { + case (Location(exprBase, exprGuard), exprIndex) => + getLocations(index).map { + case Location(indexBase, indexGuard) => + Location( + exprBase, + AND( + AND( + indexGuard, + exprGuard + ), + EQV( + UIntLiteral((exprIndex / get_size(tpe)) % expr.tpe.asInstanceOf[VectorType].size), + indexBase + ) + ) + ) + } } + case e => create_exps(e).map(Location(_, one)) } /** Returns true if e contains a [[firrtl.WSubAccess]] @@ -111,26 +120,26 @@ object RemoveAccesses extends Pass { /** Replaces a subaccess in a given source expression */ val stmts = mutable.ArrayBuffer[Statement]() - def removeSource(e: Expression): Expression = e match { - case (_: WSubAccess | _: WSubField | _: WSubIndex | _: WRef) if hasAccess(e) => - val rs = getLocations(e) - rs.find(x => x.guard != one) match { - case None => throwInternalError(s"removeSource: shouldn't be here - $e") - case Some(_) => - val (wire, temp) = create_temp(e) - val temps = create_exps(temp) - def getTemp(i: Int) = temps(i % temps.size) - stmts += wire - rs.zipWithIndex.foreach { - case (x, i) if i < temps.size => - stmts += IsInvalid(get_info(s), getTemp(i)) - stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt) - case (x, i) => - stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt) - } - temp - } - case _ => e + // Only called on RefLikes that definitely have a SubAccess + // Must accept Expression because that's the output type of fixIndices + def removeSource(e: Expression): Expression = { + val rs = getLocations(e) + rs.find(x => x.guard != one) match { + case None => throwInternalError(s"removeSource: shouldn't be here - $e") + case Some(_) => + val (wire, temp) = create_temp(e) + val temps = create_exps(temp) + def getTemp(i: Int) = temps(i % temps.size) + stmts += wire + rs.zipWithIndex.foreach { + case (x, i) if i < temps.size => + stmts += IsInvalid(get_info(s), getTemp(i)) + stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt) + case (x, i) => + stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt) + } + temp + } } /** Replaces a subaccess in a given sink expression @@ -151,14 +160,23 @@ object RemoveAccesses extends Pass { case _ => loc } + /** Recurse until find SubAccess and call fixSource on its index + * @note this only accepts [[RefLikeExpression]]s but we can't enforce it because map + * requires Expression => Expression + */ + def fixIndices(e: Expression): Expression = e match { + case e: SubAccess => e.copy(index = fixSource(e.index)) + case other => other.map(fixIndices) + } + /** Recursively walks a source expression and fixes all subaccesses - * If we see a sub-access, replace it. - * Otherwise, map to children. + * + * If we see a RefLikeExpression that contains a SubAccess, we recursively remove + * subaccesses from the indices of any SubAccesses, then process modified RefLikeExpression */ def fixSource(e: Expression): Expression = e match { - case w: WSubAccess => removeSource(WSubAccess(w.expr, fixSource(w.index), w.tpe, w.flow)) - //case w: WSubIndex => removeSource(w) - //case w: WSubField => removeSource(w) + case ref: RefLikeExpression => + if (hasAccess(ref)) removeSource(fixIndices(ref)) else ref case x => x.map(fixSource) } diff --git a/src/main/scala/firrtl/passes/ResolveFlows.scala b/src/main/scala/firrtl/passes/ResolveFlows.scala index 1af87c10dc..a8672f55b5 100644 --- a/src/main/scala/firrtl/passes/ResolveFlows.scala +++ b/src/main/scala/firrtl/passes/ResolveFlows.scala @@ -9,7 +9,7 @@ import firrtl.options.Dependency object ResolveFlows extends Pass { - override def prerequisites = Seq(Dependency(passes.InferTypes)) ++ firrtl.stage.Forms.WorkingIR + override def prerequisites = Seq(Dependency(passes.InferTypes)) ++ firrtl.stage.Forms.MinimalHighForm override def invalidates(a: Transform) = false diff --git a/src/main/scala/firrtl/passes/ResolveKinds.scala b/src/main/scala/firrtl/passes/ResolveKinds.scala index 25be003dc9..745be1e26c 100644 --- a/src/main/scala/firrtl/passes/ResolveKinds.scala +++ b/src/main/scala/firrtl/passes/ResolveKinds.scala @@ -5,11 +5,12 @@ package firrtl.passes import firrtl._ import firrtl.ir._ import firrtl.Mappers._ +import firrtl.backends.experimental.smt.random.DefRandom import firrtl.traversals.Foreachers._ object ResolveKinds extends Pass { - override def prerequisites = firrtl.stage.Forms.WorkingIR + override def prerequisites = firrtl.stage.Forms.MinimalHighForm override def invalidates(a: Transform) = false @@ -31,6 +32,7 @@ object ResolveKinds extends Pass { case sx: DefRegister => kinds(sx.name) = RegKind case sx: WDefInstance => kinds(sx.name) = InstanceKind case sx: DefMemory => kinds(sx.name) = MemKind + case sx: DefRandom => kinds(sx.name) = RandomKind case _ => } s.map(resolve_stmt(kinds)) diff --git a/src/main/scala/firrtl/passes/ToWorkingIR.scala b/src/main/scala/firrtl/passes/ToWorkingIR.scala index 2a81449762..12a9bdd762 100644 --- a/src/main/scala/firrtl/passes/ToWorkingIR.scala +++ b/src/main/scala/firrtl/passes/ToWorkingIR.scala @@ -5,8 +5,15 @@ package firrtl.passes import firrtl.ir._ import firrtl.Transform +@deprecated( + "This pass is an identity transform. For an equivalent dependency, use firrtl.stage.forms.MinimalHighForm", + "FIRRTL 1.4.2" +) object ToWorkingIR extends Pass { override def prerequisites = firrtl.stage.Forms.MinimalHighForm + override def optionalPrerequisites = Seq.empty + override def optionalPrerequisiteOf = + (firrtl.stage.Forms.LowFormOptimized.toSet -- firrtl.stage.Forms.MinimalHighForm).toSeq override def invalidates(a: Transform) = false def run(c: Circuit): Circuit = c } diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala index 8191628246..20ae8db538 100644 --- a/src/main/scala/firrtl/passes/Uniquify.scala +++ b/src/main/scala/firrtl/passes/Uniquify.scala @@ -35,7 +35,7 @@ import MemPortUtils.memType object Uniquify extends Transform with DependencyAPIMigration { override def prerequisites = - Seq(Dependency(ResolveKinds), Dependency(InferTypes)) ++ firrtl.stage.Forms.WorkingIR + Seq(Dependency(ResolveKinds), Dependency(InferTypes)) ++ firrtl.stage.Forms.MinimalHighForm override def invalidates(a: Transform): Boolean = a match { case ResolveKinds | InferTypes => true diff --git a/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala b/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala index baad2f4f02..03dcf0a397 100644 --- a/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala +++ b/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala @@ -32,7 +32,7 @@ object VerilogModulusCleanup extends Pass { Dependency[firrtl.transforms.FixAddingNegativeLiterals], Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], Dependency[firrtl.transforms.InlineBitExtractionsTransform], - Dependency[firrtl.transforms.InlineCastsTransform], + Dependency[firrtl.transforms.InlineAcrossCastsTransform], Dependency[firrtl.transforms.LegalizeClocksTransform], Dependency[firrtl.transforms.FlattenRegUpdate] ) diff --git a/src/main/scala/firrtl/passes/VerilogPrep.scala b/src/main/scala/firrtl/passes/VerilogPrep.scala index ed5db92efb..9499889a19 100644 --- a/src/main/scala/firrtl/passes/VerilogPrep.scala +++ b/src/main/scala/firrtl/passes/VerilogPrep.scala @@ -28,7 +28,7 @@ object VerilogPrep extends Pass { Dependency[firrtl.transforms.FixAddingNegativeLiterals], Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], Dependency[firrtl.transforms.InlineBitExtractionsTransform], - Dependency[firrtl.transforms.InlineCastsTransform], + Dependency[firrtl.transforms.InlineAcrossCastsTransform], Dependency[firrtl.transforms.LegalizeClocksTransform], Dependency[firrtl.transforms.FlattenRegUpdate], Dependency(passes.VerilogModulusCleanup), diff --git a/src/main/scala/firrtl/passes/ZeroWidth.scala b/src/main/scala/firrtl/passes/ZeroWidth.scala index 60439ec14c..003ab3c9a0 100644 --- a/src/main/scala/firrtl/passes/ZeroWidth.scala +++ b/src/main/scala/firrtl/passes/ZeroWidth.scala @@ -25,19 +25,12 @@ object ZeroWidth extends Transform with DependencyAPIMigration { case _ => false } - private def makeEmptyMemBundle(name: String): Field = - Field( - name, - Flip, - BundleType( - Seq( - Field("addr", Default, UIntType(IntWidth(0))), - Field("en", Default, UIntType(IntWidth(0))), - Field("clk", Default, UIntType(IntWidth(0))), - Field("data", Flip, UIntType(IntWidth(0))) - ) - ) - ) + private def makeZero(tpe: ir.Type): ir.Type = tpe match { + case ClockType => UIntType(IntWidth(0)) + case a: UIntType => a.copy(IntWidth(0)) + case a: SIntType => a.copy(IntWidth(0)) + case a: AggregateType => a.map(makeZero) + } private def onEmptyMemStmt(s: Statement): Statement = s match { case d @ DefMemory(info, name, tpe, _, _, _, rs, ws, rws, _) => @@ -46,11 +39,9 @@ object ZeroWidth extends Transform with DependencyAPIMigration { DefWire( info, name, - BundleType( - rs.map(r => makeEmptyMemBundle(r)) ++ - ws.map(w => makeEmptyMemBundle(w)) ++ - rws.map(rw => makeEmptyMemBundle(rw)) - ) + MemPortUtils + .memType(d) + .map(makeZero) ) case Some(_) => d } @@ -135,8 +126,9 @@ object ZeroWidth extends Transform with DependencyAPIMigration { } } nonZeros match { - case Nil => UIntLiteral(ZERO, IntWidth(BigInt(1))) - case Seq(x) => x + case Nil => UIntLiteral(ZERO, IntWidth(BigInt(1))) + // We may have an SInt, Cat has type UInt so cast + case Seq(x) => castRhs(tpe, x) case seq => DoPrim(Cat, seq, consts, tpe).map(onExp) } case DoPrim(Andr, Seq(x), _, _) if (bitWidth(x.tpe) == 0) => UIntLiteral(1) // nothing false diff --git a/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala b/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala index 9ad653cf40..671a08b950 100644 --- a/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala +++ b/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala @@ -16,7 +16,6 @@ object RemoveAllButClocks extends Pass { case DefWire(i, n, ClockType) => s case DefNode(i, n, value) if value.tpe == ClockType => s case Connect(i, l, r) if l.tpe == ClockType => s - case sx: WDefInstance => sx case sx: DefInstance => sx case sx: Block => sx case sx: Conditionally => sx diff --git a/src/main/scala/firrtl/passes/memlib/CreateMemoryAnnotations.scala b/src/main/scala/firrtl/passes/memlib/CreateMemoryAnnotations.scala new file mode 100644 index 0000000000..240c2c9ad6 --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/CreateMemoryAnnotations.scala @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl +package passes +package memlib + +import firrtl.Utils.error +import firrtl.stage.Forms + +import java.io.File + +class CreateMemoryAnnotations extends Transform with DependencyAPIMigration { + + override def prerequisites = Forms.MidForm + override def optionalPrerequisites = Seq.empty + override def optionalPrerequisiteOf = Forms.MidEmitters + override def invalidates(a: Transform) = false + + def execute(state: CircuitState): CircuitState = { + state.copy(annotations = state.annotations.flatMap { + case ReplSeqMemAnnotation(inputFileName, outputConfig) => + Seq(MemLibOutConfigFileAnnotation(outputConfig, Nil)) ++ { + if (inputFileName.isEmpty) None + else if (new File(inputFileName).exists) { + import CustomYAMLProtocol._ + Some(PinAnnotation(new YamlFileReader(inputFileName).parse[Config].map(_.pin.name))) + } else error("Input configuration file does not exist!") + } + case a => Seq(a) + }) + } +} diff --git a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala deleted file mode 100644 index 1cdecdfa42..0000000000 --- a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala +++ /dev/null @@ -1,28 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package firrtl -package passes -package memlib - -import firrtl.stage.Forms - -class CreateMemoryAnnotations(reader: Option[YamlFileReader]) extends Transform with DependencyAPIMigration { - - override def prerequisites = Forms.MidForm - override def optionalPrerequisites = Seq.empty - override def optionalPrerequisiteOf = Forms.MidEmitters - override def invalidates(a: Transform) = false - - def execute(state: CircuitState): CircuitState = reader match { - case None => state - case Some(r) => - import CustomYAMLProtocol._ - val configs = r.parse[Config] - val oldAnnos = state.annotations - val (as, pins) = configs.foldLeft((oldAnnos, Seq.empty[String])) { - case ((annos, pins), config) => - (annos, pins :+ config.pin.name) - } - state.copy(annotations = PinAnnotation(pins.toSeq) +: as) - } -} diff --git a/src/main/scala/firrtl/passes/memlib/DumpMemoryAnnotations.scala b/src/main/scala/firrtl/passes/memlib/DumpMemoryAnnotations.scala new file mode 100644 index 0000000000..5cc1e0bf0f --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/DumpMemoryAnnotations.scala @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl +package passes +package memlib + +import firrtl.stage.Forms + +class DumpMemoryAnnotations extends Transform with DependencyAPIMigration { + + override def prerequisites = Forms.MidForm + override def optionalPrerequisites = Seq.empty + override def optionalPrerequisiteOf = Forms.MidEmitters + override def invalidates(a: Transform) = false + + def execute(state: CircuitState): CircuitState = { + state.copy(annotations = state.annotations.flatMap { + // convert and remove AnnotatedMemoriesAnnotation to CustomFileEmission + case AnnotatedMemoriesAnnotation(annotatedMemories) => + state.annotations.collect { + case a: MemLibOutConfigFileAnnotation => + a.copy(annotatedMemories = annotatedMemories) + // todo convert xxx to verilogs here. + } + case MemLibOutConfigFileAnnotation(_, Nil) => Nil + case a => Seq(a) + }) + } +} diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala index 0bb9445206..39c79bc625 100644 --- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala +++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala @@ -78,6 +78,16 @@ object InferReadWritePass extends Pass { case sx => sx } + /* If the ports share the same address in an undefined-collision SyncReadMem, reads issued while the write + * is enabled are *always* undefined; we may treat the read as if it were gated by the complement of w.en. + * Though not a strict requirement, this currently applies only to single-cycle read/write memories. + * N.B. for aggregate-typed memories, the spec is conservative and 'undefined' is not a function of the + * write mask, allowing optimization regardless of mask value. This must be revisited if the spec changes. + */ + private def canOptimizeCollidingRW(mem: DefMemory): Boolean = { + mem.readUnderWrite == ReadUnderWrite.Undefined && mem.readLatency == 1 && mem.writeLatency == 1 + } + def inferReadWriteStmt(connects: Connects, repl: Netlist, stmts: Statements)(s: Statement): Statement = s match { // infer readwrite ports only for non combinational memories case mem: DefMemory if mem.readLatency > 0 => @@ -94,7 +104,10 @@ object InferReadWritePass extends Pass { val proofOfMutualExclusion = wenProductTerms.find(a => renProductTerms.exists(b => checkComplement(a, b))) val wclk = getOrigin(connects)(memPortField(mem, w, "clk")) val rclk = getOrigin(connects)(memPortField(mem, r, "clk")) - if (weq(wclk, rclk) && proofOfMutualExclusion.nonEmpty) { + val waddr = getOrigin(connects)(memPortField(mem, w, "addr")) + val raddr = getOrigin(connects)(memPortField(mem, r, "addr")) + val optimizeCollision = (weq(waddr, raddr) && canOptimizeCollidingRW(mem)) + if (weq(wclk, rclk) && (proofOfMutualExclusion.nonEmpty || optimizeCollision)) { val rw = namespace.newName("rw") val rwExp = WSubField(WRef(mem.name), rw) readwriters += rw @@ -104,28 +117,32 @@ object InferReadWritePass extends Pass { repl(memPortField(mem, r, "en")) = EmptyExpression repl(memPortField(mem, r, "addr")) = EmptyExpression repl(memPortField(mem, r, "data")) = WSubField(rwExp, "rdata") - repl(memPortField(mem, w, "clk")) = EmptyExpression - repl(memPortField(mem, w, "en")) = EmptyExpression - repl(memPortField(mem, w, "addr")) = EmptyExpression + repl(memPortField(mem, w, "clk")) = WSubField(rwExp, "clk") repl(memPortField(mem, w, "data")) = WSubField(rwExp, "wdata") repl(memPortField(mem, w, "mask")) = WSubField(rwExp, "wmask") - stmts += Connect(NoInfo, WSubField(rwExp, "wmode"), proofOfMutualExclusion.get) - stmts += Connect(NoInfo, WSubField(rwExp, "clk"), wclk) stmts += Connect( NoInfo, WSubField(rwExp, "en"), DoPrim(Or, Seq(connects(memPortField(mem, r, "en")), connects(memPortField(mem, w, "en"))), Nil, BoolType) ) - stmts += Connect( - NoInfo, - WSubField(rwExp, "addr"), - Mux( - connects(memPortField(mem, w, "en")), - connects(memPortField(mem, w, "addr")), - connects(memPortField(mem, r, "addr")), - UnknownType + if (optimizeCollision) { + repl(memPortField(mem, w, "en")) = WSubField(rwExp, "wmode") + repl(memPortField(mem, w, "addr")) = WSubField(rwExp, "addr") + } else { + repl(memPortField(mem, w, "en")) = EmptyExpression + repl(memPortField(mem, w, "addr")) = EmptyExpression + stmts += Connect(NoInfo, WSubField(rwExp, "wmode"), proofOfMutualExclusion.get) + stmts += Connect( + NoInfo, + WSubField(rwExp, "addr"), + Mux( + connects(memPortField(mem, w, "en")), + connects(memPortField(mem, w, "addr")), + connects(memPortField(mem, r, "addr")), + UnknownType + ) ) - ) + } } } if (readwriters.isEmpty) mem diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala index 63ae049032..4b43118ba0 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala @@ -3,15 +3,17 @@ package firrtl.passes package memlib -import firrtl._ -import firrtl.ir._ -import firrtl.Utils._ import firrtl.Mappers._ -import MemPortUtils.{MemPortMap, Modules} -import MemTransformUtils._ +import firrtl.Utils._ +import firrtl._ import firrtl.annotations._ +import firrtl.ir._ +import firrtl.passes.MemPortUtils.{MemPortMap, Modules} +import firrtl.passes.memlib.MemTransformUtils._ +import firrtl.passes.wiring._ import firrtl.stage.Forms -import wiring._ + +import scala.collection.mutable.ListBuffer /** Annotates the name of the pins to add for WiringTransform */ case class PinAnnotation(pins: Seq[String]) extends NoTargetAnnotation @@ -24,8 +26,7 @@ object ReplaceMemMacros { * This will not generate wmask ports if not needed. * Creates the minimum # of black boxes needed by the design. */ -class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIMigration { - +class ReplaceMemMacros extends Transform with DependencyAPIMigration { override def prerequisites = Forms.MidForm override def optionalPrerequisites = Seq.empty override def optionalPrerequisiteOf = Forms.MidEmitters @@ -121,12 +122,13 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM }) ) - def memToBundle(s: DefAnnotatedMemory) = BundleType( + private def memToBundle(s: DefAnnotatedMemory) = BundleType( s.readers.map(Field(_, Flip, rPortToBundle(s))) ++ s.writers.map(Field(_, Flip, wPortToBundle(s))) ++ s.readwriters.map(Field(_, Flip, rwPortToBundle(s))) ) - def memToFlattenBundle(s: DefAnnotatedMemory) = BundleType( + + private def memToFlattenBundle(s: DefAnnotatedMemory) = BundleType( s.readers.map(Field(_, Flip, rPortToFlattenBundle(s))) ++ s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++ s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s))) @@ -136,7 +138,11 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM * The wrapper module has the same type as the memory it replaces * The external module */ - def createMemModule(m: DefAnnotatedMemory, wrapperName: String): Seq[DefModule] = { + private def createMemModule( + m: DefAnnotatedMemory, + wrapperName: String, + annotatedMemoriesBuffer: ListBuffer[DefAnnotatedMemory] + ): Seq[DefModule] = { assert(m.dataType != UnknownType) val wrapperIoType = memToBundle(m) val wrapperIoPorts = wrapperIoType.fields.map(f => Port(NoInfo, f.name, Input, f.tpe)) @@ -156,25 +162,25 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM // TODO: Annotate? -- use actual annotation map // add to conf file - writer.append(m) + annotatedMemoriesBuffer += m Seq(bb, wrapper) } // TODO(shunshou): get rid of copy pasta // Connects the clk, en, and addr fields from the wrapperPort to the bbPort - def defaultConnects(wrapperPort: WRef, bbPort: WSubField): Seq[Connect] = + private def defaultConnects(wrapperPort: WRef, bbPort: WSubField): Seq[Connect] = Seq("clk", "en", "addr").map(f => connectFields(bbPort, f, wrapperPort, f)) // Generates mask bits (concatenates an aggregate to ground type) // depending on mask granularity (# bits = data width / mask granularity) - def maskBits(mask: WSubField, dataType: Type, fillMask: Boolean): Expression = + private def maskBits(mask: WSubField, dataType: Type, fillMask: Boolean): Expression = if (fillMask) toBitMask(mask, dataType) else toBits(mask) - def adaptReader(wrapperPort: WRef, bbPort: WSubField): Seq[Statement] = + private def adaptReader(wrapperPort: WRef, bbPort: WSubField): Seq[Statement] = defaultConnects(wrapperPort, bbPort) :+ fromBits(WSubField(wrapperPort, "data"), WSubField(bbPort, "data")) - def adaptWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean): Seq[Statement] = { + private def adaptWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean): Seq[Statement] = { val wrapperData = WSubField(wrapperPort, "data") val defaultSeq = defaultConnects(wrapperPort, bbPort) :+ Connect(NoInfo, WSubField(bbPort, "data"), toBits(wrapperData)) @@ -189,7 +195,12 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM } } - def adaptReadWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean): Seq[Statement] = { + private def adaptReadWriter( + wrapperPort: WRef, + bbPort: WSubField, + hasMask: Boolean, + fillMask: Boolean + ): Seq[Statement] = { val wrapperWData = WSubField(wrapperPort, "wdata") val defaultSeq = defaultConnects(wrapperPort, bbPort) ++ Seq( fromBits(WSubField(wrapperPort, "rdata"), WSubField(bbPort, "rdata")), @@ -211,7 +222,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM private type NameMap = collection.mutable.HashMap[(String, String), String] /** Construct NameMap by assigning unique names for each memory blackbox */ - def constructNameMap(namespace: Namespace, nameMap: NameMap, mname: String)(s: Statement): Statement = { + private def constructNameMap(namespace: Namespace, nameMap: NameMap, mname: String)(s: Statement): Statement = { s match { case m: DefAnnotatedMemory => m.memRef match { @@ -223,13 +234,14 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM s.map(constructNameMap(namespace, nameMap, mname)) } - def updateMemStmts( - namespace: Namespace, - nameMap: NameMap, - mname: String, - memPortMap: MemPortMap, - memMods: Modules - )(s: Statement + private def updateMemStmts( + namespace: Namespace, + nameMap: NameMap, + mname: String, + memPortMap: MemPortMap, + memMods: Modules, + annotatedMemoriesBuffer: ListBuffer[DefAnnotatedMemory] + )(s: Statement ): Statement = s match { case m: DefAnnotatedMemory => if (m.maskGran.isEmpty) { @@ -242,41 +254,49 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM val newWrapperName = nameMap(mname -> m.name) val newMemBBName = namespace.newName(s"${newWrapperName}_ext") val newMem = m.copy(name = newMemBBName) - memMods ++= createMemModule(newMem, newWrapperName) + memMods ++= createMemModule(newMem, newWrapperName, annotatedMemoriesBuffer) WDefInstance(m.info, m.name, newWrapperName, UnknownType) case Some((module, mem)) => WDefInstance(m.info, m.name, nameMap(module -> mem), UnknownType) } - case sx => sx.map(updateMemStmts(namespace, nameMap, mname, memPortMap, memMods)) + case sx => sx.map(updateMemStmts(namespace, nameMap, mname, memPortMap, memMods, annotatedMemoriesBuffer)) } - def updateMemMods(namespace: Namespace, nameMap: NameMap, memMods: Modules)(m: DefModule) = { + private def updateMemMods( + namespace: Namespace, + nameMap: NameMap, + memMods: Modules, + annotatedMemoriesBuffer: ListBuffer[DefAnnotatedMemory] + )(m: DefModule + ) = { val memPortMap = new MemPortMap - (m.map(updateMemStmts(namespace, nameMap, m.name, memPortMap, memMods)) + (m.map(updateMemStmts(namespace, nameMap, m.name, memPortMap, memMods, annotatedMemoriesBuffer)) .map(updateStmtRefs(memPortMap))) } def execute(state: CircuitState): CircuitState = { + val annotatedMemoriesBuffer: collection.mutable.ListBuffer[DefAnnotatedMemory] = ListBuffer[DefAnnotatedMemory]() val c = state.circuit val namespace = Namespace(c) val memMods = new Modules val nameMap = new NameMap c.modules.map(m => m.map(constructNameMap(namespace, nameMap, m.name))) - val modules = c.modules.map(updateMemMods(namespace, nameMap, memMods)) - // print conf - writer.serialize() - val pannos = state.annotations.collect { case a: PinAnnotation => a } - val pins = pannos match { - case Seq() => Nil - case Seq(PinAnnotation(pins)) => pins - case _ => throwInternalError("Something went wrong") - } - val annos = pins.foldLeft(Seq[Annotation]()) { (seq, pin) => - seq ++ memMods.collect { - case m: ExtModule => SinkAnnotation(ModuleName(m.name, CircuitName(c.main)), pin) - } - } ++ state.annotations - state.copy(circuit = c.copy(modules = modules ++ memMods), annotations = annos) + val modules = c.modules.map(updateMemMods(namespace, nameMap, memMods, annotatedMemoriesBuffer)) + state.copy( + circuit = c.copy(modules = modules ++ memMods), + annotations = + state.annotations ++ + (state.annotations.collectFirst { case a: PinAnnotation => a } match { + case None => Nil + case Some(PinAnnotation(pins)) => + pins.foldLeft(Seq[Annotation]()) { (seq, pin) => + seq ++ memMods.collect { + case m: ExtModule => SinkAnnotation(ModuleName(m.name, CircuitName(c.main)), pin) + } + } + }) :+ + AnnotatedMemoriesAnnotation(annotatedMemoriesBuffer.toList) + ) } } diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala index f0325e8ec7..f9df27a789 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala @@ -3,14 +3,15 @@ package firrtl.passes package memlib +import firrtl.Utils.error import firrtl._ import firrtl.annotations._ -import firrtl.options.{HasShellOptions, ShellOption} -import Utils.error -import java.io.{CharArrayWriter, File, PrintWriter} -import wiring._ +import firrtl.options.{CustomFileEmission, HasShellOptions, ShellOption} +import firrtl.passes.wiring._ import firrtl.stage.{Forms, RunFirrtlTransformAnnotation} +import java.io.{CharArrayWriter, PrintWriter} + sealed trait PassOption case object InputConfigFileName extends PassOption case object OutputConfigFileName extends PassOption @@ -43,14 +44,20 @@ object PassConfigUtil { } } -class ConfWriter(filename: String) { - val outputBuffer = new CharArrayWriter - def append(m: DefAnnotatedMemory) = { - // legacy - // assert that we don't overflow going from BigInt to Int conversion +case class ReplSeqMemAnnotation(inputFileName: String, outputConfig: String) extends NoTargetAnnotation + +/** Generate conf file for a sequence of [[DefAnnotatedMemory]] + * @note file already has its suffix adding by `--replSeqMem` + */ +case class MemLibOutConfigFileAnnotation(file: String, annotatedMemories: Seq[DefAnnotatedMemory]) + extends NoTargetAnnotation + with CustomFileEmission { + def baseFileName(annotations: AnnotationSeq) = file + def suffix = None + def getBytes = annotatedMemories.map { m => require(bitWidth(m.dataType) <= Int.MaxValue) - m.maskGran.foreach { case x => require(x <= Int.MaxValue) } - val conf = MemConf( + m.maskGran.foreach(x => require(x <= Int.MaxValue)) + MemConf( m.name, m.depth, bitWidth(m.dataType).toInt, @@ -58,17 +65,12 @@ class ConfWriter(filename: String) { m.writers.length, m.readwriters.length, m.maskGran.map(_.toInt) - ) - outputBuffer.append(conf.toString) - } - def serialize() = { - val outputFile = new PrintWriter(filename) - outputFile.write(outputBuffer.toString) - outputFile.close() - } + ).toString + }.mkString("\n").getBytes } -case class ReplSeqMemAnnotation(inputFileName: String, outputConfig: String) extends NoTargetAnnotation +private[memlib] case class AnnotatedMemoriesAnnotation(annotatedMemories: List[DefAnnotatedMemory]) + extends NoTargetAnnotation object ReplSeqMemAnnotation { def parse(t: String): ReplSeqMemAnnotation = { @@ -111,7 +113,7 @@ class SimpleTransform(p: Pass, form: CircuitForm) extends Transform { class SimpleMidTransform(p: Pass) extends SimpleTransform(p, MidForm) // SimpleRun instead of PassBased because of the arguments to passSeq -class ReplSeqMem extends Transform with HasShellOptions with DependencyAPIMigration { +class ReplSeqMem extends SeqTransform with HasShellOptions with DependencyAPIMigration { override def prerequisites = Forms.MidForm override def optionalPrerequisites = Seq.empty @@ -132,31 +134,16 @@ class ReplSeqMem extends Transform with HasShellOptions with DependencyAPIMigrat ) ) - def transforms(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter): Seq[Transform] = + val transforms: Seq[Transform] = Seq( new SimpleMidTransform(Legalize), new SimpleMidTransform(ToMemIR), new SimpleMidTransform(ResolveMaskGranularity), new SimpleMidTransform(RenameAnnotatedMemoryPorts), + new CreateMemoryAnnotations, new ResolveMemoryReference, - new CreateMemoryAnnotations(inConfigFile), - new ReplaceMemMacros(outConfigFile), - new WiringTransform + new ReplaceMemMacros, + new WiringTransform, + new DumpMemoryAnnotations ) - - def execute(state: CircuitState): CircuitState = { - val annos = state.annotations.collect { case a: ReplSeqMemAnnotation => a } - annos match { - case Nil => state // Do nothing if there are no annotations - case Seq(ReplSeqMemAnnotation(inputFileName, outputConfig)) => - val inConfigFile = { - if (inputFileName.isEmpty) None - else if (new File(inputFileName).exists) Some(new YamlFileReader(inputFileName)) - else error("Input configuration file does not exist!") - } - val outConfigFile = new ConfWriter(outputConfig) - transforms(inConfigFile, outConfigFile).foldLeft(state) { (in, xform) => xform.runTransform(in) } - case _ => error("Unexpected transform annotation") - } - } } diff --git a/src/main/scala/firrtl/passes/memlib/SeparateWriteClocks.scala b/src/main/scala/firrtl/passes/memlib/SeparateWriteClocks.scala new file mode 100644 index 0000000000..f526f64bb1 --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/SeparateWriteClocks.scala @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.passes +package memlib + +import firrtl._ +import firrtl.ir._ +import firrtl.passes.LowerTypes +import firrtl.options.{Dependency, OptionsException} + +/** + * This transform introduces an intermediate wire on the clock field of each write port of synchronous-read memories + * that have *multiple* write/readwrite ports and undefined read-under-write collision behavior. Ultimately, the + * introduction of these intermediate wires does not change which clock net clocks each port; therefore, the purpose of + * this transform is to help generate Verilog that is more amenable to inference of RAM macros with multiple write + * ports in FPGA synthesis flows. This change will cause each write and each readwrite port to be emitted in a separate + * clocked procedure, yielding multiple benefits: + * + * 1) Separate write procedures avoid implicitly constraining cross-port read-write and write-write collision behaviors + * 2) The preference for separate clocked procedures for each write port is explicitly specified by Intel and Xilinx + * + * While this feature is not intended to be vendor-specific, inference of *multiple-write* RAM macros from behavioral + * Verilog or VHDL requires both advanced underlying RAM primitives and advanced synthesis tools. Currently, mapping + * such memories to programmable devices beyond modern Intel and Xilinx architectures can be prohibitive for users. + * + * Though the emission of separate processes for write ports could be absorbed into the Verilog emitter, the use of a + * pure-FIRRTL transform reduces implementation complexity and enhances reliability. + */ +class SeparateWriteClocks extends Transform with DependencyAPIMigration { + override def prerequisites = Seq(Dependency(passes.RemoveCHIRRTL), Dependency(passes.ExpandConnects)) + override def optionalPrerequisites = Seq(Dependency[InferReadWrite]) + override def optionalPrerequisiteOf = Seq(Dependency[SetDefaultReadUnderWrite]) + override def invalidates(a: Transform): Boolean = a match { + case ResolveFlows => true + case _ => false + } + + private type ExprMap = collection.mutable.HashMap[WrappedExpression, Reference] + + private def onExpr(replaceExprs: ExprMap)(expr: Expression): Expression = expr match { + case wsf: WSubField if (replaceExprs.contains(WrappedExpression(wsf))) => + replaceExprs(WrappedExpression(wsf)) + case e => e.mapExpr(onExpr(replaceExprs)) + } + + private def isMultiWriteSyncReadUndefinedRUW(mem: DefMemory): Boolean = { + (mem.writers.size + mem.readwriters.size) > 1 && + mem.readLatency == 1 && mem.writeLatency == 1 && + mem.readUnderWrite == ReadUnderWrite.Undefined + } + + private def onStmt(replaceExprs: ExprMap, ns: Namespace)(stmt: Statement): Statement = stmt match { + case mem: DefMemory if isMultiWriteSyncReadUndefinedRUW(mem) => + val clockRefs = (mem.writers ++ mem.readwriters).map { p => MemPortUtils.memPortField(mem, p, "clk") } + val clockWireMap = clockRefs.map { pClk => + WrappedExpression(pClk) -> DefWire(mem.info, ns.newName(LowerTypes.loweredName(pClk)), ClockType) + } + val clockStmts = clockWireMap.flatMap { + case (pClk, clkWire) => Seq(clkWire, Connect(mem.info, pClk.e1, Reference(clkWire))) + } + replaceExprs ++= clockWireMap.map { case (pClk, clkWire) => pClk -> Reference(clkWire) } + Block(mem +: clockStmts) + case Connect(i, lhs, rhs) => Connect(i, onExpr(replaceExprs)(lhs), rhs) + case PartialConnect(i, lhs, rhs) => PartialConnect(i, onExpr(replaceExprs)(lhs), rhs) + case IsInvalid(i, invalidated) => IsInvalid(i, onExpr(replaceExprs)(invalidated)) + case s => s.mapStmt(onStmt(replaceExprs, ns)) + } + + override def execute(state: CircuitState): CircuitState = { + val c = state.circuit + val cPrime = c.copy(modules = c.modules.map(m => m.mapStmt(onStmt(new ExprMap, Namespace(m))))) + state.copy(circuit = cPrime) + } +} diff --git a/src/main/scala/firrtl/passes/memlib/SetDefaultReadUnderWrite.scala b/src/main/scala/firrtl/passes/memlib/SetDefaultReadUnderWrite.scala new file mode 100644 index 0000000000..d56460990a --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/SetDefaultReadUnderWrite.scala @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.passes +package memlib + +import firrtl._ +import firrtl.ir._ +import firrtl.options.{Dependency, OptionsException} +import firrtl.annotations.NoTargetAnnotation + +sealed trait DefaultReadUnderWriteAnnotation extends NoTargetAnnotation + +/** This annotation directs the [[SetDefaultReadUnderWrite]] transform to assign a default value of 'old' (read-first + * behavior) to all synchronous-read memories with 'undefined' read-under-write parameters. + */ +case object DefaultReadFirstAnnotation extends DefaultReadUnderWriteAnnotation + +/** This annotation directs the [[SetDefaultReadUnderWrite]] transform to assign a default value of 'new' (write-first + * behavior) to all synchronous-read memories with 'undefined' read-under-write parameters. + */ +case object DefaultWriteFirstAnnotation extends DefaultReadUnderWriteAnnotation + +/** + * Adding a [[DefaultReadUnderWriteAnnotation]] and running the [[SetDefaultReadUnderWrite]] transform will cause all + * synchronous-read memories with 'undefined' read-under-write parameters to be assigned a default parameter value, + * either 'old' (read-first behavior) or 'new' (write-first behavior). This can help generate Verilog that is amenable + * to RAM macro inference for various FPGA tools, or it can be used to satisfy other downstream design constraints. + */ +class SetDefaultReadUnderWrite extends Transform with DependencyAPIMigration { + override def prerequisites = firrtl.stage.Forms.HighForm + override def optionalPrerequisites = Seq(Dependency[InferReadWrite]) + override def optionalPrerequisiteOf = Seq(Dependency(VerilogMemDelays)) + override def invalidates(a: Transform): Boolean = false + + private def onStmt(defaultRUW: ReadUnderWrite.Value)(stmt: Statement): Statement = stmt match { + case mem: DefMemory if (mem.readLatency > 0 && mem.readUnderWrite == ReadUnderWrite.Undefined) => + mem.copy(readUnderWrite = defaultRUW) + case s => s.mapStmt(onStmt(defaultRUW)) + } + + override def execute(state: CircuitState): CircuitState = { + val c = state.circuit + val ruwDefaults = state.annotations + .collect({ + case DefaultReadFirstAnnotation => ReadUnderWrite.Old + case DefaultWriteFirstAnnotation => ReadUnderWrite.New + }) + .toSet + if (ruwDefaults.size == 0) { + state + } else if (ruwDefaults.size == 1) { + state.copy(circuit = c.copy(modules = c.modules.map(m => m.mapStmt(onStmt(ruwDefaults.head))))) + } else { + throw new OptionsException("Conflicting default read-under-write settings.") + } + } +} diff --git a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala index 8fb2dc889f..a9b42ebac0 100644 --- a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala +++ b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala @@ -10,12 +10,22 @@ import firrtl.Mappers._ import firrtl.traversals.Foreachers._ import firrtl.transforms import firrtl.options.Dependency +import firrtl.annotations.NoTargetAnnotation import MemPortUtils._ import WrappedExpression._ import collection.mutable +/** + * Adding this annotation will allow the [[VerilogMemDelays]] transform to let 'simple' synchronous-read memories to + * pass through without explicitly breaking them apart into combinational-read memories and pipeline registers. Here, + * 'simple' memories are defined as those that have one-cycle read and write latencies AND either no readwrite ports or + * read-under-write behavior that is either 'undefined' or 'old'. This second restriction avoids the particularly + * complex case of blending FIRRTL readwrite port semantics with cross-port 'bypassing' of new data on collisions. + */ +case object PassthroughSimpleSyncReadMemsAnnotation extends NoTargetAnnotation + object MemDelayAndReadwriteTransformer { // Representation of a group of signals and associated valid signals case class WithValid(valid: Expression, payload: Seq[Expression]) @@ -77,13 +87,14 @@ object MemDelayAndReadwriteTransformer { * * @note The final transformed module is found in the (sole public) field [[transformed]] */ -class MemDelayAndReadwriteTransformer(m: DefModule) { +class MemDelayAndReadwriteTransformer(m: DefModule, passthroughSimpleSyncReadMems: Boolean = false) { import MemDelayAndReadwriteTransformer._ private val ns = Namespace(m) private val netlist = new collection.mutable.HashMap[WrappedExpression, Expression] private val exprReplacements = new collection.mutable.HashMap[WrappedExpression, Expression] private val newConns = new mutable.ArrayBuffer[Connect] + private val passthroughMems = new collection.mutable.HashSet[WrappedExpression] private def findMemConns(s: Statement): Unit = s match { case Connect(_, loc, expr) if (kind(loc) == MemKind) => netlist(we(loc)) = expr @@ -95,7 +106,15 @@ class MemDelayAndReadwriteTransformer(m: DefModule) { case ex => ex } + def canPassthrough(mem: DefMemory): Boolean = { + (mem.readLatency <= 1 && mem.writeLatency == 1 && + (mem.readwriters.isEmpty || (mem.readLatency == 1 && mem.readUnderWrite != ReadUnderWrite.New))) + } + private def transform(s: Statement): Statement = s.map(transform) match { + case mem: DefMemory if passthroughSimpleSyncReadMems && canPassthrough(mem) => + passthroughMems += WRef(mem) + mem case mem: DefMemory => // Per-memory bookkeeping val portNS = Namespace(mem.readers ++ mem.writers) @@ -163,7 +182,13 @@ class MemDelayAndReadwriteTransformer(m: DefModule) { newConns ++= (readStmts ++ writeStmts).flatMap(_.conns) Block(newMem +: (readStmts ++ writeStmts).flatMap(_.decls)) - case sx: Connect if kind(sx.loc) == MemKind => EmptyStmt // Filter old mem connections + case sx: Connect if kind(sx.loc) == MemKind => + val (memRef, _) = Utils.splitRef(sx.loc) + // Filter old mem connections for *transformed* memories only + if (passthroughMems(WrappedExpression(memRef))) + sx + else + EmptyStmt case sx => sx.map(swapMemRefs) } @@ -177,7 +202,8 @@ class MemDelayAndReadwriteTransformer(m: DefModule) { object VerilogMemDelays extends Pass { - override def prerequisites = firrtl.stage.Forms.LowForm :+ Dependency(firrtl.passes.RemoveValidIf) + override def prerequisites = firrtl.stage.Forms.LowForm + override val optionalPrerequisites = Seq(Dependency(firrtl.passes.RemoveValidIf)) override val optionalPrerequisiteOf = Seq(Dependency[VerilogEmitter], Dependency[SystemVerilogEmitter]) @@ -187,6 +213,14 @@ object VerilogMemDelays extends Pass { case _ => false } - def transform(m: DefModule): DefModule = (new MemDelayAndReadwriteTransformer(m)).transformed - def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(transform)) + private def transform(m: DefModule): DefModule = (new MemDelayAndReadwriteTransformer(m)).transformed + + @deprecated("VerilogMemDelays will change from a Pass to a Transform in FIRRTL 1.6.", "FIRRTL 1.5") + def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(transform)) + + override def execute(state: CircuitState): CircuitState = { + val enablePassthrough = state.annotations.contains(PassthroughSimpleSyncReadMemsAnnotation) + def transform(m: DefModule) = (new MemDelayAndReadwriteTransformer(m, enablePassthrough)).transformed + state.copy(circuit = state.circuit.copy(modules = state.circuit.modules.map(transform))) + } } diff --git a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala index 4f6705ec3f..86afe520ad 100644 --- a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala +++ b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala @@ -47,8 +47,7 @@ class WiringTransform extends Transform with DependencyAPIMigration { /** Defines the sequence of Transform that should be applied */ private def transforms(w: Seq[WiringInfo]): Seq[Transform] = Seq( - new Wiring(w), - ToWorkingIR + new Wiring(w) ) def execute(state: CircuitState): CircuitState = { val annos = state.annotations.collect { diff --git a/src/main/scala/firrtl/passes/wiring/WiringUtils.scala b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala index d926f6a95d..6f9b4f83c4 100644 --- a/src/main/scala/firrtl/passes/wiring/WiringUtils.scala +++ b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala @@ -90,8 +90,6 @@ object WiringUtils { def getChildrenMap(c: Circuit): ChildrenMap = { val childrenMap = new ChildrenMap() def getChildren(mname: String)(s: Statement): Unit = s match { - case s: WDefInstance => - childrenMap(mname) = childrenMap(mname) :+ ((s.name, s.module)) case s: DefInstance => childrenMap(mname) = childrenMap(mname) :+ ((s.name, s.module)) case s => s.foreach(getChildren(mname)) @@ -152,11 +150,11 @@ object WiringUtils { * sources/sinks not under sinks/sources. */ if (queue.size == 1) { - val u = queue.dequeue + val u = queue.dequeue() sinkInsts.foreach { v => owners(v) = Vector(u) } } else { while (queue.nonEmpty) { - val u = queue.dequeue + val u = queue.dequeue() visited(u) = true val edges = (i.graph.getEdges(u.last).map(u :+ _).toVector :+ u.dropRight(1)) @@ -224,11 +222,11 @@ object WiringUtils { * sources/sinks not under sinks/sources. */ if (queue.size == 1) { - val u = queue.dequeue + val u = queue.dequeue() sinkInsts.foreach { v => owners(v) = Vector(u) } } else { while (queue.nonEmpty) { - val u = queue.dequeue + val u = queue.dequeue() visited(u) = true val edges = i.graph.getEdges(u.last).map(u :+ _).toVector :+ u.dropRight(1) diff --git a/src/main/scala/firrtl/proto/FromProto.scala b/src/main/scala/firrtl/proto/FromProto.scala index 663e91b3cb..ed641eb2ae 100644 --- a/src/main/scala/firrtl/proto/FromProto.scala +++ b/src/main/scala/firrtl/proto/FromProto.scala @@ -148,6 +148,9 @@ object FromProto { case ReadUnderWrite.UNDEFINED => ir.ReadUnderWrite.Undefined case ReadUnderWrite.OLD => ir.ReadUnderWrite.Old case ReadUnderWrite.NEW => ir.ReadUnderWrite.New + case ReadUnderWrite.UNRECOGNIZED => + val msg = s"Unrecognized ReadUnderWrite value '$ruw', perhaps this version of FIRRTL is too old?" + throw new FirrtlUserException(msg) } def convert(dt: Firrtl.Statement.CMemory.TypeAndDepth): (ir.Type, BigInt) = @@ -171,6 +174,10 @@ object FromProto { case MEMORY_PORT_DIRECTION_READ => MRead case MEMORY_PORT_DIRECTION_WRITE => MWrite case MEMORY_PORT_DIRECTION_READ_WRITE => MReadWrite + case MEMORY_PORT_DIRECTION_UNKNOWN => MInfer + case UNRECOGNIZED => + val msg = s"Unrecognized MemoryPort Direction value '$mportdir', perhaps this version of FIRRTL is too old?" + throw new FirrtlUserException(msg) } def convert(port: Firrtl.Statement.MemoryPort, info: Firrtl.SourceInfo): CDefMPort = { @@ -191,6 +198,9 @@ object FromProto { case Formal.ASSERT => ir.Formal.Assert case Formal.ASSUME => ir.Formal.Assume case Formal.COVER => ir.Formal.Cover + case Formal.UNRECOGNIZED => + val msg = s"Unrecognized Formal value '$formal', perhaps this version of FIRRTL is too old?" + throw new FirrtlUserException(msg) } def convert(ver: Firrtl.Statement.Verification, info: Firrtl.SourceInfo): ir.Verification = @@ -248,9 +258,10 @@ object FromProto { case MEMORY_FIELD_NUMBER => convert(stmt.getMemory, info) case IS_INVALID_FIELD_NUMBER => ir.IsInvalid(convert(info), convert(stmt.getIsInvalid.getExpression)) - case CMEMORY_FIELD_NUMBER => convert(stmt.getCmemory, info) - case MEMORY_PORT_FIELD_NUMBER => convert(stmt.getMemoryPort, info) - case ATTACH_FIELD_NUMBER => convert(stmt.getAttach, info) + case CMEMORY_FIELD_NUMBER => convert(stmt.getCmemory, info) + case MEMORY_PORT_FIELD_NUMBER => convert(stmt.getMemoryPort, info) + case ATTACH_FIELD_NUMBER => convert(stmt.getAttach, info) + case VERIFICATION_FIELD_NUMBER => convert(stmt.getVerification, info) } } @@ -308,9 +319,13 @@ object FromProto { } def convert(dir: Firrtl.Port.Direction): ir.Direction = { + import Firrtl.Port.Direction._ dir match { - case Firrtl.Port.Direction.PORT_DIRECTION_IN => ir.Input - case Firrtl.Port.Direction.PORT_DIRECTION_OUT => ir.Output + case PORT_DIRECTION_IN => ir.Input + case PORT_DIRECTION_OUT => ir.Output + case (PORT_DIRECTION_UNKNOWN | UNRECOGNIZED) => + val msg = s"Unrecognized Port Direction value '$dir', perhaps this version of FIRRTL is too old?" + throw new FirrtlUserException(msg) } } diff --git a/src/main/scala/firrtl/proto/ToProto.scala b/src/main/scala/firrtl/proto/ToProto.scala index 4cdf6b85ce..f5ade0e3e6 100644 --- a/src/main/scala/firrtl/proto/ToProto.scala +++ b/src/main/scala/firrtl/proto/ToProto.scala @@ -307,6 +307,7 @@ object ToProto { .setCond(convert(cond)) .setEn(convert(en)) .setMsg(msg.string) + sb.setVerification(vb) case ir.IsInvalid(_, expr) => val ib = Firrtl.Statement.IsInvalid .newBuilder() diff --git a/src/main/scala/firrtl/stage/FirrtlAnnotations.scala b/src/main/scala/firrtl/stage/FirrtlAnnotations.scala index 99a6e9c3fb..44c884188c 100644 --- a/src/main/scala/firrtl/stage/FirrtlAnnotations.scala +++ b/src/main/scala/firrtl/stage/FirrtlAnnotations.scala @@ -180,7 +180,7 @@ case class RunFirrtlTransformAnnotation(transform: Transform) extends NoTargetAn object RunFirrtlTransformAnnotation extends HasShellOptions { def apply(transform: TransformDependency): RunFirrtlTransformAnnotation = - RunFirrtlTransformAnnotation(transform.getObject) + RunFirrtlTransformAnnotation(transform.getObject()) private[firrtl] def stringToEmitter(a: String): RunFirrtlTransformAnnotation = { val emitter = a match { @@ -255,13 +255,18 @@ case class FirrtlCircuitAnnotation(circuit: Circuit) extends NoTargetAnnotation * * - set with `--warn:no-scala-version-deprecation` */ +@deprecated("Support for Scala 2.11 has been dropped, this object no longer does anything", "FIRRTL 1.5") case object WarnNoScalaVersionDeprecation extends NoTargetAnnotation with FirrtlOption with HasShellOptions { def longOption: String = "warn:no-scala-version-deprecation" val options = Seq( new ShellOption[Unit]( longOption = longOption, - toAnnotationSeq = { _ => Seq(this) }, - helpText = "Suppress Scala 2.11 deprecation warning (ignored in Scala 2.12+)" + toAnnotationSeq = { _ => + val msg = s"'$longOption' no longer does anything and will be removed in FIRRTL 1.6" + firrtl.options.StageUtils.dramaticWarning(msg) + Seq(this) + }, + helpText = "(deprecated, this option does nothing)" ) ) } @@ -280,3 +285,30 @@ case object PrettyNoExprInlining extends NoTargetAnnotation with FirrtlOption wi ) ) } + +/** Turn off folding a specific primitive operand + * @param op the op that should never be folded + */ +case class DisableFold(op: ir.PrimOp) extends NoTargetAnnotation with FirrtlOption + +@deprecated("will be removed and merged into ConstantPropagation in 1.5", "1.4") +object DisableFold extends HasShellOptions { + + private val mapping: Map[String, ir.PrimOp] = PrimOps.builtinPrimOps.map { case op => op.toString -> op }.toMap + + override val options = Seq( + new ShellOption[String]( + longOption = "dont-fold", + toAnnotationSeq = a => { + mapping + .get(a) + .orElse(throw new OptionsException(s"Unknown primop '$a'. (Did you misspell it?)")) + .map(DisableFold(_)) + .toSeq + }, + helpText = "Disable folding of specific primitive operations", + helpValueName = Some("") + ) + ) + +} diff --git a/src/main/scala/firrtl/stage/FirrtlCli.scala b/src/main/scala/firrtl/stage/FirrtlCli.scala index 18f14107c5..9cfa6be9e5 100644 --- a/src/main/scala/firrtl/stage/FirrtlCli.scala +++ b/src/main/scala/firrtl/stage/FirrtlCli.scala @@ -21,7 +21,9 @@ trait FirrtlCli { this: Shell => firrtl.EmitAllModulesAnnotation, NoCircuitDedupAnnotation, WarnNoScalaVersionDeprecation, - PrettyNoExprInlining + PrettyNoExprInlining, + DisableFold, + OptimizeForFPGA ) .map(_.addOptions(parser)) diff --git a/src/main/scala/firrtl/stage/FirrtlCompilerTargets.scala b/src/main/scala/firrtl/stage/FirrtlCompilerTargets.scala new file mode 100644 index 0000000000..662f3dc00e --- /dev/null +++ b/src/main/scala/firrtl/stage/FirrtlCompilerTargets.scala @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.stage + +import firrtl.transforms._ +import firrtl.passes.memlib._ +import firrtl.options.{HasShellOptions, ShellOption} + +/** + * This flag enables a set of options that guide the FIRRTL compilation flow to ultimately generate Verilog that is + * more amenable to using for synthesized FPGA designs. Currently, this flag affects only memories, as the need to emit + * memories that support downstream inference of hardened RAM macros. These options are not intended to be specialized + * to any particular vendor; instead, they aim to emit simple Verilog that more closely reflects traditional + * human-written definitions of synchronous-read memories. + * + * 1) Enable the [[firrtl.passes.memlib.InferReadWrite]] transform to reduce port count, where applicable. + * + * 2) Use the [[firrtl.transforms.SimplifyMems]] transform to Lower aggregate-typed memories with always-high masks to + * packed memories without splitting them into multiple independent ground-typed memories. + * + * 3) Use the [[firrtl.passes.memlib.SeparateWriteClocks]] transform to ensure that each write port of a + * multiple-write, synchronous-read memory with 'undefined' collision behavior ultimately maps to a separate clocked + * process in the emitted Verilog. This avoids the issue of implicitly constraining cross-port collision and write + * ordering behavior and helps simplify inference of true dual-port RAM macros. + * + * 4) Use the [[firrtl.passes.memlib.SetDefaultReadUnderWrite]] to specify that memories with undefined + * read-under-write behavior should map to emitted microarchitectures characteristic of "read-first" ports by + * default. This eliminates the difficulty of inferring a RAM macro that matches the strict semantics of + * "write-first" ports. + * + * 5) Add a [[firrtl.passes.memlib.PassthroughSimpleSyncReadMemsAnnotation]] to allow some synchronous-read memories + * and readwrite ports to pass through [[firrtl.passes.memlib.VerilogMemDelays]] without introducing explicit + * pipeline registers or splitting ports. + */ +object OptimizeForFPGA extends HasShellOptions { + private val fpgaAnnos = Seq( + InferReadWriteAnnotation, + RunFirrtlTransformAnnotation(new InferReadWrite), + RunFirrtlTransformAnnotation(new SeparateWriteClocks), + DefaultReadFirstAnnotation, + RunFirrtlTransformAnnotation(new SetDefaultReadUnderWrite), + RunFirrtlTransformAnnotation(new SimplifyMems), + PassthroughSimpleSyncReadMemsAnnotation + ) + val options = Seq( + new ShellOption[Unit]( + longOption = "target:fpga", + toAnnotationSeq = a => fpgaAnnos, + helpText = "Choose compilation strategies that generally favor FPGA targets" + ) + ) +} diff --git a/src/main/scala/firrtl/stage/Forms.scala b/src/main/scala/firrtl/stage/Forms.scala index a6077f525c..4132f75861 100644 --- a/src/main/scala/firrtl/stage/Forms.scala +++ b/src/main/scala/firrtl/stage/Forms.scala @@ -25,7 +25,8 @@ object Forms { Dependency[annotations.transforms.CleanupNamedTargets] ) - val WorkingIR: Seq[TransformDependency] = MinimalHighForm :+ Dependency(passes.ToWorkingIR) + @deprecated("Use firrtl.stage.forms.MinimalHighForm", "FIRRTL 1.4.2") + val WorkingIR: Seq[TransformDependency] = MinimalHighForm val Checks: Seq[TransformDependency] = Seq( @@ -35,7 +36,7 @@ object Forms { Dependency(passes.CheckWidths) ) - val Resolved: Seq[TransformDependency] = WorkingIR ++ Checks ++ + val Resolved: Seq[TransformDependency] = MinimalHighForm ++ Checks ++ Seq( Dependency(passes.ResolveKinds), Dependency(passes.InferTypes), @@ -50,7 +51,6 @@ object Forms { val HighForm: Seq[TransformDependency] = ChirrtlForm ++ MinimalHighForm ++ - WorkingIR ++ Resolved ++ Deduped @@ -101,7 +101,7 @@ object Forms { Dependency[firrtl.transforms.FixAddingNegativeLiterals], Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], Dependency[firrtl.transforms.InlineBitExtractionsTransform], - Dependency[firrtl.transforms.InlineCastsTransform], + Dependency[firrtl.transforms.InlineAcrossCastsTransform], Dependency[firrtl.transforms.LegalizeClocksTransform], Dependency[firrtl.transforms.FlattenRegUpdate], Dependency(passes.VerilogModulusCleanup), diff --git a/src/main/scala/firrtl/stage/package.scala b/src/main/scala/firrtl/stage/package.scala index c159f8522a..68e7a9c5d5 100644 --- a/src/main/scala/firrtl/stage/package.scala +++ b/src/main/scala/firrtl/stage/package.scala @@ -34,6 +34,7 @@ package object stage { case a: CompilerAnnotation => logger.warn(s"Use of CompilerAnnotation is deprecated. Ignoring $a"); c case WarnNoScalaVersionDeprecation => c case PrettyNoExprInlining => c + case _: DisableFold => c } } } diff --git a/src/main/scala/firrtl/stage/phases/DriverCompatibility.scala b/src/main/scala/firrtl/stage/phases/DriverCompatibility.scala index a37774e5ae..7ad77212dd 100644 --- a/src/main/scala/firrtl/stage/phases/DriverCompatibility.scala +++ b/src/main/scala/firrtl/stage/phases/DriverCompatibility.scala @@ -58,8 +58,8 @@ object DriverCompatibility { def addOptions(p: OptionParser[AnnotationSeq]): Unit = p .opt[Unit]("top-name") .abbr("tn") - .hidden - .unbounded + .hidden() + .unbounded() .action((_, _) => throw new OptionsException(optionRemoved("--top-name/-tn"))) } @@ -71,8 +71,8 @@ object DriverCompatibility { def addOptions(p: OptionParser[AnnotationSeq]): Unit = p .opt[Unit]("split-modules") .abbr("fsm") - .hidden - .unbounded + .hidden() + .unbounded() .action((_, _) => throw new OptionsException(optionRemoved("--split-modules/-fsm"))) } diff --git a/src/main/scala/firrtl/stage/transforms/CheckScalaVersion.scala b/src/main/scala/firrtl/stage/transforms/CheckScalaVersion.scala index 9d894905f3..6ed900f1f2 100644 --- a/src/main/scala/firrtl/stage/transforms/CheckScalaVersion.scala +++ b/src/main/scala/firrtl/stage/transforms/CheckScalaVersion.scala @@ -6,14 +6,10 @@ import firrtl.{BuildInfo, CircuitState, DependencyAPIMigration, Transform} import firrtl.stage.WarnNoScalaVersionDeprecation import firrtl.options.StageUtils.dramaticWarning +@deprecated("Support for 2.11 has been dropped, this logic no longer does anything", "FIRRTL 1.5") object CheckScalaVersion { def migrationDocumentLink: String = "https://www.chisel-lang.org/chisel3/upgrading-from-scala-2-11.html" - private def getScalaMajorVersion: Int = { - val "2" :: major :: _ :: Nil = BuildInfo.scalaVersion.split("\\.").toList - major.toInt - } - final def deprecationMessage(version: String, option: String) = s"""|FIRRTL support for Scala $version is deprecated, please upgrade to Scala 2.12. | Migration guide: $migrationDocumentLink @@ -21,17 +17,10 @@ object CheckScalaVersion { } +@deprecated("Support for 2.11 has been dropped, this transform no longer does anything", "FIRRTL 1.5") class CheckScalaVersion extends Transform with DependencyAPIMigration { - import CheckScalaVersion._ override def invalidates(a: Transform) = false - def execute(state: CircuitState): CircuitState = { - def suppress = state.annotations.contains(WarnNoScalaVersionDeprecation) - if (getScalaMajorVersion == 11 && !suppress) { - val option = s"--${WarnNoScalaVersionDeprecation.longOption}" - dramaticWarning(deprecationMessage("2.11", option)) - } - state - } + def execute(state: CircuitState): CircuitState = state } diff --git a/src/main/scala/firrtl/stage/transforms/UpdateAnnotations.scala b/src/main/scala/firrtl/stage/transforms/UpdateAnnotations.scala index 4fa7788d2c..8bd29b9c3c 100644 --- a/src/main/scala/firrtl/stage/transforms/UpdateAnnotations.scala +++ b/src/main/scala/firrtl/stage/transforms/UpdateAnnotations.scala @@ -19,7 +19,7 @@ class UpdateAnnotations(val underlying: Transform) } def internalTransform(b: (CircuitState, CircuitState)): (CircuitState, CircuitState) = { - val result = Transform.runTransform(name, underlying.transform(b._2), logger) + val result = underlying.transform(b._2) (b._1, result) } } diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 5d57de3a8d..bc1fc9af7e 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -13,7 +13,8 @@ import firrtl.PrimOps._ import firrtl.graph.DiGraph import firrtl.analyses.InstanceKeyGraph import firrtl.annotations.TargetToken.Ref -import firrtl.options.Dependency +import firrtl.options.{Dependency, RegisteredTransform, ShellOption} +import firrtl.stage.DisableFold import annotation.tailrec import collection.mutable @@ -28,7 +29,17 @@ object ConstantPropagation { /** Pads e to the width of t */ def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match { - case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) + case (we, wt) if we < wt => + DoPrim( + Pad, + Seq(e), + Seq(wt), + e.tpe match { + case UIntType(_) => UIntType(IntWidth(wt)) + case SIntType(_) => SIntType(IntWidth(wt)) + case _ => e.tpe + } + ) case (we, wt) if we == wt => e } @@ -100,16 +111,15 @@ object ConstantPropagation { } -class ConstantPropagation extends Transform with DependencyAPIMigration { +class ConstantPropagation extends Transform with RegisteredTransform with DependencyAPIMigration { import ConstantPropagation._ override def prerequisites = ((new mutable.LinkedHashSet()) ++ firrtl.stage.Forms.LowForm - - Dependency(firrtl.passes.Legalize) - + Dependency(firrtl.passes.RemoveValidIf)).toSeq + - Dependency(firrtl.passes.Legalize)).toSeq - override def optionalPrerequisites = Seq.empty + override def optionalPrerequisites = Seq(Dependency(firrtl.passes.RemoveValidIf)) override def optionalPrerequisiteOf = Seq( @@ -124,6 +134,14 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { case _ => false } + val options = Seq( + new ShellOption[Unit]( + longOption = "no-constant-propagation", + toAnnotationSeq = _ => Seq(NoConstantPropagationAnnotation), + helpText = "Disable constant propagation elimination" + ) + ) + sealed trait SimplifyBinaryOp { def matchingArgsValue(e: DoPrim, arg: Expression): Expression def apply(e: DoPrim): Expression = { @@ -230,7 +248,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { } def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, _) if v == BigInt(0) => rhs - case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) + case SIntLiteral(v, _) if v == BigInt(0) => asUInt(pad(rhs, e.tpe), e.tpe) case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => lhs case _ => e } @@ -244,7 +262,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { } def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, _) if v == BigInt(0) => rhs - case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) + case SIntLiteral(v, _) if v == BigInt(0) => asUInt(pad(rhs, e.tpe), e.tpe) case _ => e } def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(0, getWidth(arg.tpe)) @@ -401,7 +419,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { override def reduce = (a: Boolean, b: Boolean) => a ^ b } - private def constPropPrim(e: DoPrim): Expression = e.op match { + private def constPropPrim(e: DoPrim, disabledOps: Set[PrimOp]): Expression = e.op match { + case a if disabledOps(a) => e case Shl => foldShiftLeft(e) case Dshl => foldDynamicShiftLeft(e) case Shr => foldShiftRight(e) @@ -495,19 +514,25 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_') def optimize(e: Expression): Expression = - constPropExpression(new NodeMap(), Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e) + constPropExpression( + new NodeMap(), + Map.empty[Instance, OfModule], + Map.empty[OfModule, Map[String, Literal]], + Set.empty + )(e) def optimize(e: Expression, nodeMap: NodeMap): Expression = - constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e) + constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]], Set.empty)(e) private def constPropExpression( nodeMap: NodeMap, instMap: collection.Map[Instance, OfModule], - constSubOutputs: Map[OfModule, Map[String, Literal]] + constSubOutputs: Map[OfModule, Map[String, Literal]], + disabledOps: Set[PrimOp] )(e: Expression ): Expression = { - val old = e.map(constPropExpression(nodeMap, instMap, constSubOutputs)) + val old = e.map(constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)) val propagated = old match { - case p: DoPrim => constPropPrim(p) + case p: DoPrim => constPropPrim(p, disabledOps) case m: Mux => constPropMux(m) case ref @ WRef(rname, _, _, SourceFlow) if nodeMap.contains(rname) => constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2) @@ -519,7 +544,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { } // We're done when the Expression no longer changes if (propagated eq old) propagated - else constPropExpression(nodeMap, instMap, constSubOutputs)(propagated) + else constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(propagated) } /** Hacky way of propagating source locators across nodes and connections that have just a @@ -555,6 +580,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { * @param instMap map of instance names to Module name * @param constInputs map of names of m's input ports to literal driving it (if applicable) * @param constSubOutputs Map of Module name to Map of output port name to literal driving it + * @param disabledOps a Set of any PrimOps that should not be folded * @return (Constpropped Module, Map of output port names to literal value, * Map of submodule modulenames to Map of input port names to literal values) */ @@ -564,7 +590,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { dontTouches: Set[String], instMap: collection.Map[Instance, OfModule], constInputs: Map[String, Literal], - constSubOutputs: Map[OfModule, Map[String, Literal]] + constSubOutputs: Map[OfModule, Map[String, Literal]], + disabledOps: Set[PrimOp] ): (Module, Map[String, Literal], Map[OfModule, Map[String, Seq[Literal]]]) = { var nPropagated = 0L @@ -637,7 +664,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { case WRef(rname, _, kind, _) if betterName(lname, rname) && !swapMap.contains(rname) && kind != PortKind => assert(!swapMap.contains(lname)) // <- Shouldn't be possible because lname is either a // node declaration or the single connection to a wire or register - swapMap += (lname -> rname, rname -> lname) + swapMap += lname -> rname + swapMap += rname -> lname case _ => } nodeMap(lname) = InfoExpr.wrap(info, value) @@ -646,7 +674,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { def constPropStmt(s: Statement): Statement = { val s0 = s.map(constPropStmt) // Statement recurse val s1 = propagateDirectConnectionInfoOnly(nodeMap, dontTouches)(s0) // hacky source locator propagation - val stmtx = s1.map(constPropExpression(nodeMap, instMap, constSubOutputs)) // propagate sub-Expressions + // propagate sub-Expressions + val stmtx = s1.map(constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)) // Record things that should be propagated stmtx match { case DefNode(info, name, value) if !dontTouches.contains(name) => @@ -654,11 +683,12 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { case reg: DefRegister if reg.reset.tpe == AsyncResetType => asyncResetRegs(reg.name) = reg case Connect(info, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) => - val exprx = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(expr, wtpe)) + val exprx = constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(expr, wtpe)) propagateRef(wname, exprx, info) // Record constants driving outputs case Connect(_, WRef(pname, ptpe, PortKind, _), lit: Literal) if !dontTouches.contains(pname) => - val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal] + val paddedLit = + constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(lit, ptpe)).asInstanceOf[Literal] constOutputs(pname) = paddedLit // Const prop registers that are driven by a mux tree containing only instances of one constant or self-assigns // This requires that reset has been made explicit @@ -714,7 +744,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { case _ => } - def padCPExp(e: Expression) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(e, ltpe)) + def padCPExp(e: Expression) = + constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(e, ltpe)) asyncResetRegs.get(lname) match { // Normal Register @@ -725,7 +756,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { // Mark instance inputs connected to a constant case Connect(_, lref @ WSubField(WRef(inst, _, InstanceKind, _), port, ptpe, _), lit: Literal) => - val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal] + val paddedLit = + constPropExpression(nodeMap, instMap, constSubOutputs, disabledOps)(pad(lit, ptpe)).asInstanceOf[Literal] val module = instMap(inst.Instance) val portsMap = constSubInputs.getOrElseUpdate(module, mutable.HashMap.empty) portsMap(port) = paddedLit +: portsMap.getOrElse(port, List.empty) @@ -750,7 +782,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { // When we call this function again, constOutputs and constSubInputs are reconstructed and // strictly a superset of the versions here - if (nPropagated > 0) constPropModule(modx, dontTouches, instMap, constInputs, constSubOutputs) + if (nPropagated > 0) constPropModule(modx, dontTouches, instMap, constInputs, constSubOutputs, disabledOps) else (modx, constOutputs.toMap, constSubInputs.mapValues(_.toMap).toMap) } @@ -761,7 +793,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v)) } - private def run(c: Circuit, dontTouchMap: Map[OfModule, Set[String]]): Circuit = { + private def run(c: Circuit, dontTouchMap: Map[OfModule, Set[String]], disabledOps: Set[PrimOp]): Circuit = { val iGraph = InstanceKeyGraph(c) val moduleDeps = iGraph.getChildInstanceMap val instCount = iGraph.staticInstanceCount @@ -800,7 +832,8 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { dontTouches, moduleDeps(mname), constInputs.getOrElse(mname, Map.empty), - constOutputs + constOutputs, + disabledOps ) // Accumulate all Literals used to drive a particular Module port val constInputsx = unify(constInputsAcc, mci)((a, b) => unify(a, b)((c, d) => c ++ d)) @@ -852,6 +885,13 @@ class ConstantPropagation extends Transform with DependencyAPIMigration { val dontTouchMap: Map[OfModule, Set[String]] = dontTouches.groupBy(_._1).mapValues(_.map(_._2).toSet).toMap - state.copy(circuit = run(state.circuit, dontTouchMap)) + val disabledOps = state.annotations.collect { case DisableFold(op) => op }.toSet + + if (state.annotations.contains(NoConstantPropagationAnnotation)) { + logger.info("Skipping Constant Propagation") + state + } else { + state.copy(circuit = run(state.circuit, dontTouchMap, disabledOps)) + } } } diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala index 13173fddbc..f72585d112 100644 --- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala +++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala @@ -11,6 +11,7 @@ import firrtl.analyses.InstanceKeyGraph import firrtl.Mappers._ import firrtl.Utils.{kind, throwInternalError} import firrtl.MemoizedHash._ +import firrtl.backends.experimental.smt.random.DefRandom import firrtl.options.{Dependency, RegisteredTransform, ShellOption} import collection.mutable @@ -126,6 +127,11 @@ class DeadCodeElimination extends Transform with RegisteredTransform with Depend val node = LogicNode(mod.name, name) depGraph.addVertex(node) Seq(clock, reset, init).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(node, ref)) + case DefRandom(_, name, _, clock, en) => + val node = LogicNode(mod.name, name) + depGraph.addVertex(node) + val inputs = clock ++: en +: Nil + inputs.flatMap(getDeps).foreach(ref => depGraph.addPairWithEdge(node, ref)) case DefNode(_, name, value) => val node = LogicNode(mod.name, name) depGraph.addVertex(node) @@ -225,6 +231,7 @@ class DeadCodeElimination extends Transform with RegisteredTransform with Depend val tpe = decl match { case _: DefNode => "node" case _: DefRegister => "reg" + case _: DefRandom => "rand" case _: DefWire => "wire" case _: Port => "port" case _: DefMemory => "mem" @@ -258,6 +265,11 @@ class DeadCodeElimination extends Transform with RegisteredTransform with Depend renames.delete(inst.name) EmptyStmt } + case print: Print => deleteIfNotEnabled(print, print.en) + case stop: Stop => deleteIfNotEnabled(stop, stop.en) + case formal: Verification => deleteIfNotEnabled(formal, formal.en) + // Statements are also declarations and thus this case needs to come *after* checking the + // print, stop and verification statements. case decl: IsDeclaration => val node = LogicNode(mod.name, decl.name) if (deadNodes.contains(node)) { @@ -265,10 +277,7 @@ class DeadCodeElimination extends Transform with RegisteredTransform with Depend renames.delete(decl.name) EmptyStmt } else decl - case print: Print => deleteIfNotEnabled(print, print.en) - case stop: Stop => deleteIfNotEnabled(stop, stop.en) - case formal: Verification => deleteIfNotEnabled(formal, formal.en) - case con: Connect => + case con: Connect => val node = getDeps(con.loc) match { case Seq(elt) => elt } if (deadNodes.contains(node)) EmptyStmt else con case Attach(info, exprs) => // If any exprs are dead then all are diff --git a/src/main/scala/firrtl/transforms/EnsureNamedStatements.scala b/src/main/scala/firrtl/transforms/EnsureNamedStatements.scala new file mode 100644 index 0000000000..a40409f9ff --- /dev/null +++ b/src/main/scala/firrtl/transforms/EnsureNamedStatements.scala @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.transforms + +import firrtl._ +import firrtl.ir._ + +/** Adds default names to print, stop and verification statements if their name is empty. */ +object EnsureNamedStatements extends Transform with DependencyAPIMigration { + override def invalidates(a: Transform) = false + + override protected def execute(state: CircuitState): CircuitState = { + val c = state.circuit.mapModule(onModule) + state.copy(circuit = c) + } + + private def onModule(m: DefModule): DefModule = m match { + case e: ExtModule => e + case mod: Module => + val namespace = Namespace(mod) + // Ensure we always start with _0 suffix + val prefixes = Seq("cover", "assert", "assume", "print", "stop") + prefixes.filterNot(namespace.contains).foreach(namespace.newName) + mod.mapStmt(onStmt(namespace)) + } + + private def onStmt(namespace: Namespace)(stmt: Statement): Statement = stmt match { + case s: Print if s.name.isEmpty => s.withName(namespace.newName("print")) + case s: Stop if s.name.isEmpty => s.withName(namespace.newName("stop")) + case s: Verification if s.name.isEmpty => + val baseName = s.op match { + case Formal.Cover => "cover" + case Formal.Assert => "assert" + case Formal.Assume => "assume" + } + s.withName(namespace.newName(baseName)) + case other => other.mapStmt(onStmt(namespace)) + } +} diff --git a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala index 664ce1e645..3f497c91cc 100644 --- a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala +++ b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala @@ -170,7 +170,7 @@ class FlattenRegUpdate extends Transform with DependencyAPIMigration { Dependency[FixAddingNegativeLiterals], Dependency[ReplaceTruncatingArithmetic], Dependency[InlineBitExtractionsTransform], - Dependency[InlineCastsTransform], + Dependency[InlineAcrossCastsTransform], Dependency[LegalizeClocksTransform] ) diff --git a/src/main/scala/firrtl/transforms/InferResets.scala b/src/main/scala/firrtl/transforms/InferResets.scala index 4c6ffde2d4..8fb556fd8c 100644 --- a/src/main/scala/firrtl/transforms/InferResets.scala +++ b/src/main/scala/firrtl/transforms/InferResets.scala @@ -117,7 +117,7 @@ class InferResets extends Transform with DependencyAPIMigration { Dependency(passes.InferTypes), Dependency(passes.ResolveFlows), Dependency[passes.InferWidths] - ) ++ stage.Forms.WorkingIR + ) ++ stage.Forms.MinimalHighForm override def invalidates(a: Transform): Boolean = a match { case _: checks.CheckResets | passes.CheckTypes => true diff --git a/src/main/scala/firrtl/transforms/InlineCasts.scala b/src/main/scala/firrtl/transforms/InlineCasts.scala index 71318eee1a..de54a326af 100644 --- a/src/main/scala/firrtl/transforms/InlineCasts.scala +++ b/src/main/scala/firrtl/transforms/InlineCasts.scala @@ -10,6 +10,7 @@ import firrtl.options.Dependency import firrtl.Utils.{isBitExtract, isCast, NodeMap} +@deprecated("Replaced by InlineAcrossCastsTransform", "FIRRTL 1.4.3") object InlineCastsTransform { // Checks if an Expression is made up of only casts terminated by a Literal or Reference @@ -28,26 +29,33 @@ object InlineCastsTransform { * @param expr the Expression being transformed * @return Returns expr with [[WRef]]s replaced by values found in replace */ - def onExpr(replace: NodeMap)(expr: Expression): Expression = expr match { - // Anything that may generate a part-select should not be inlined! - case DoPrim(op, _, _, _) if (isBitExtract(op) || op == Pad) => expr - case e => - e.map(onExpr(replace)) match { - case e @ WRef(name, _, _, _) => - replace - .get(name) - .filter(isSimpleCast(castSeen = false)) - .getOrElse(e) - case e @ DoPrim(op, Seq(WRef(name, _, _, _)), _, _) if isCast(op) => - replace - .get(name) - .map(value => e.copy(args = Seq(value))) - .getOrElse(e) - case other => other // Not a candidate - } + def onExpr(replace: NodeMap)(expr: Expression): Expression = { + // Keep track if we've seen any non-cast expressions while recursing + def rec(hasNonCastParent: Boolean)(expr: Expression): Expression = expr match { + // Skip pads to avoid inlining literals into pads which results in invalid Verilog + case DoPrim(op, _, _, _) if (isBitExtract(op) || op == Pad) => expr + case e => + e.map(rec(hasNonCastParent || !isCast(e))) match { + case e @ WRef(name, _, _, _) => + replace + .get(name) + .filter(isSimpleCast(castSeen = false)) + .getOrElse(e) + case e @ DoPrim(op, Seq(WRef(name, _, _, _)), _, _) if isCast(op) => + replace + .get(name) + // Only inline the Expression if there is no non-cast parent in the expression tree OR + // if the subtree contains only casts and references. + .filter(x => !hasNonCastParent || isSimpleCast(castSeen = true)(x)) + .map(value => e.copy(args = Seq(value))) + .getOrElse(e) + case other => other // Not a candidate + } + } + rec(false)(expr) } - /** Inline casts in a Statement + /** Inline across casts in a statement * * @param netlist a '''mutable''' HashMap mapping references to [[firrtl.ir.DefNode DefNode]]s to their connected * [[firrtl.ir.Expression Expression]]s. This function '''will''' mutate @@ -64,11 +72,17 @@ object InlineCastsTransform { case other => other } - /** Replaces truncating arithmetic in a Module */ + /** Inline across casts in a module */ def onMod(mod: DefModule): DefModule = mod.map(onStmt(new NodeMap)) } -/** Inline nodes that are simple casts */ +/** Inline expressions into casts and inline casts into other expressions + * + * Because casts are no-ops in the emitted Verilog, this transform eliminates statements that + * simply contain a cast. It does so by greedily building larger expression trees that contain at + * most one expression that is neither a cast nor reference-like node. + */ +@deprecated("Replaced by InlineAcrossCastsTransform", "FIRRTL 1.4.3") class InlineCastsTransform extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ diff --git a/src/main/scala/firrtl/transforms/LegalizeClocksAndAsyncResets.scala b/src/main/scala/firrtl/transforms/LegalizeClocksAndAsyncResets.scala index 5e3d276d0d..0765a2b141 100644 --- a/src/main/scala/firrtl/transforms/LegalizeClocksAndAsyncResets.scala +++ b/src/main/scala/firrtl/transforms/LegalizeClocksAndAsyncResets.scala @@ -91,7 +91,7 @@ class LegalizeClocksAndAsyncResetsTransform extends Transform with DependencyAPI Dependency[FixAddingNegativeLiterals], Dependency[ReplaceTruncatingArithmetic], Dependency[InlineBitExtractionsTransform], - Dependency[InlineCastsTransform] + Dependency[InlineAcrossCastsTransform] ) override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized diff --git a/src/main/scala/firrtl/transforms/LegalizeReductions.scala b/src/main/scala/firrtl/transforms/LegalizeReductions.scala index b5751b14e9..94cef65bfb 100644 --- a/src/main/scala/firrtl/transforms/LegalizeReductions.scala +++ b/src/main/scala/firrtl/transforms/LegalizeReductions.scala @@ -35,7 +35,7 @@ object LegalizeAndReductionsTransform { class LegalizeAndReductionsTransform extends Transform with DependencyAPIMigration { override def prerequisites = - firrtl.stage.Forms.WorkingIR ++ + firrtl.stage.Forms.MinimalHighForm ++ Seq(Dependency(passes.CheckTypes), Dependency(passes.CheckWidths)) override def optionalPrerequisites = Nil diff --git a/src/main/scala/firrtl/transforms/ManipulateNames.scala b/src/main/scala/firrtl/transforms/ManipulateNames.scala index 7be876ef0c..4a796e58c3 100644 --- a/src/main/scala/firrtl/transforms/ManipulateNames.scala +++ b/src/main/scala/firrtl/transforms/ManipulateNames.scala @@ -464,7 +464,7 @@ abstract class ManipulateNames[A <: ManipulateNames[_]: ClassTag] extends Transf val block = state.annotations.collect { case ManipulateNamesBlocklistAnnotation(targetSeq, t) => - t.getObject match { + t.getObject() match { case _: A => targetSeq case _ => Nil } @@ -473,7 +473,7 @@ abstract class ManipulateNames[A <: ManipulateNames[_]: ClassTag] extends Transf val allow = { val allowx = state.annotations.collect { case ManipulateNamesAllowlistAnnotation(targetSeq, t) => - t.getObject match { + t.getObject() match { case _: A => targetSeq case _ => Nil } @@ -491,13 +491,13 @@ abstract class ManipulateNames[A <: ManipulateNames[_]: ClassTag] extends Transf val annotationsx = state.annotations.flatMap { /* Consume blocklist annotations */ case foo @ ManipulateNamesBlocklistAnnotation(_, t) => - t.getObject match { + t.getObject() match { case _: A => None case _ => Some(foo) } /* Convert allowlist annotations to result annotations */ case foo @ ManipulateNamesAllowlistAnnotation(a, t) => - t.getObject match { + t.getObject() match { case _: A => (a, a.map(_.map(renames(_)).flatten)) match { case (a, b) => Some(ManipulateNamesAllowlistResultAnnotation(b, t, a)) diff --git a/src/main/scala/firrtl/transforms/MustDedup.scala b/src/main/scala/firrtl/transforms/MustDedup.scala new file mode 100644 index 0000000000..3e7629cda7 --- /dev/null +++ b/src/main/scala/firrtl/transforms/MustDedup.scala @@ -0,0 +1,245 @@ +// See LICENSE for license details. + +package firrtl.transforms + +import firrtl._ +import firrtl.annotations._ +import firrtl.annotations.TargetToken.OfModule +import firrtl.analyses.InstanceKeyGraph +import firrtl.analyses.InstanceKeyGraph.InstanceKey +import firrtl.options.Dependency +import firrtl.stage.Forms +import firrtl.graph.DiGraph + +import java.io.{File, FileWriter} + +/** Marks modules as "must deduplicate" */ +case class MustDeduplicateAnnotation(modules: Seq[IsModule]) extends MultiTargetAnnotation { + def targets: Seq[Seq[IsModule]] = modules.map(Seq(_)) + + def duplicate(n: Seq[Seq[Target]]): MustDeduplicateAnnotation = { + val newModules = n.map { + case Seq(mod: IsModule) => mod + case _ => + val msg = "Something went wrong! This anno should only rename to single IsModules! " + + s"Got: $modules -> $n" + throw new Exception(msg) + } + MustDeduplicateAnnotation(newModules) + } +} + +/** Specifies the directory where errors for modules that "must deduplicate" will be reported */ +case class MustDeduplicateReportDirectory(directory: String) extends NoTargetAnnotation + +object MustDeduplicateTransform { + sealed trait DedupFailureCandidate { + def message: String + def modules: Seq[OfModule] + } + case class LikelyShouldMatch(a: OfModule, b: OfModule) extends DedupFailureCandidate { + def message: String = s"Modules '${a.value}' and '${b.value}' likely should dedup but do not." + def modules = Seq(a, b) + } + object DisjointChildren { + sealed trait Reason + case object Left extends Reason + case object Right extends Reason + case object Both extends Reason + } + import DisjointChildren._ + case class DisjointChildren(a: OfModule, b: OfModule, reason: Reason) extends DedupFailureCandidate { + def message: String = { + def helper(x: OfModule, y: OfModule): String = s"'${x.value}' contains instances not found in '${y.value}'" + val why = reason match { + case Left => helper(a, b) + case Right => helper(b, a) + case Both => s"${helper(a, b)} and ${helper(b, a)}" + } + s"Modules '${a.value}' and '${b.value}' cannot be deduplicated because $why." + } + def modules = Seq(a, b) + } + + final class DeduplicationFailureException(msg: String) extends FirrtlUserException(msg) + + case class DedupFailure( + shouldDedup: Seq[OfModule], + relevantMods: Set[OfModule], + candidates: Seq[DedupFailureCandidate]) + + /** Reports deduplication failures two Modules + * + * @return (Set of Modules that only appear in one hierarchy or the other, candidate pairs of Module names) + */ + def findDedupFailures(shouldDedup: Seq[OfModule], graph: InstanceKeyGraph): DedupFailure = { + val instLookup = graph.getChildInstances.toMap + def recurse(a: OfModule, b: OfModule): Seq[DedupFailureCandidate] = { + val as = instLookup(a.value) + val bs = instLookup(b.value) + if (as.length != bs.length) { + val aa = as.toSet + val bb = bs.toSet + val reason = (aa.diff(bb).nonEmpty, bb.diff(aa).nonEmpty) match { + case (true, true) => Both + case (true, false) => Left + case (false, true) => Right + case _ => Utils.error("Impossible!") + } + Seq(DisjointChildren(a, b, reason)) + } else { + val fromChildren = as.zip(bs).flatMap { + case (ax, bx) => recurse(ax.OfModule, bx.OfModule) + } + if (fromChildren.nonEmpty) fromChildren + else if (a != b) Seq(LikelyShouldMatch(a, b)) + else Nil + } + } + + val allMismatches = { + // Recalculating this every time is a little wasteful, but we're on a failure path anyway + val digraph = graph.graph.transformNodes(_.OfModule) + val froms = shouldDedup.map(x => digraph.reachableFrom(x) + x) + val union = froms.reduce(_ union _) + val intersection = froms.reduce(_ intersect _) + union.diff(intersection) + }.toSet + val pairs = shouldDedup.tail.map(n => (shouldDedup.head, n)) + val candidates = pairs.flatMap { case (a, b) => recurse(a, b) } + DedupFailure(shouldDedup, allMismatches, candidates) + } + + // Find the minimal number of vertices in the graph to show paths from "mustDedup" to failure + // candidates and their context (eg. children for DisjoinChildren) + private def findNodesToKeep(failure: DedupFailure, graph: DiGraph[String]): collection.Set[String] = { + val shouldDedup = failure.shouldDedup.map(_.value).toSet + val nodeOfInterest: Set[String] = + shouldDedup ++ failure.candidates.flatMap { + case LikelyShouldMatch(OfModule(a), OfModule(b)) => Seq(a, b) + case DisjointChildren(OfModule(a), OfModule(b), _) => + Seq(a, b) ++ graph.getEdges(a) ++ graph.getEdges(b) + } + // Depth-first search looking for relevant nodes + def dfs(node: String): collection.Set[String] = { + val deeper = graph.getEdges(node).flatMap(dfs) + if (deeper.nonEmpty || nodeOfInterest(node)) deeper + node else deeper + } + shouldDedup.flatMap(dfs) + } + + /** Turn a [[DedupFailure]] into a pretty graph for visualization + * + * @param failure Failure to visualize + * @param graph DiGraph of module names (no instance information) + */ + def makeDedupFailureDiGraph(failure: DedupFailure, graph: DiGraph[String]): DiGraph[String] = { + // Recalculating this every time is a little wasteful, but we're on a failure path anyway + // Lookup the parent Module name of any Module + val getParents: String => Seq[String] = + graph.reverse.getEdgeMap + .mapValues(_.toSeq) + + val candidates = failure.candidates + val shouldDedup = failure.shouldDedup.map(_.value) + val shouldDedupSet = shouldDedup.toSet + val mygraph = { + // Create a graph of paths from "shouldDedup" nodes to the candidates + // rooted at the "shouldDedup" nodes + val nodesToKeep = findNodesToKeep(failure, graph) + graph.subgraph(nodesToKeep) + + // Add fake nodes to represent parents of the "shouldDedup" nodes + DiGraph(shouldDedup.map(n => getParents(n).mkString(", ") -> n): _*) + } + // Gather candidate modules and assign indices for reference + val candidateIdx: Map[String, Int] = + candidates.zipWithIndex.flatMap { case (c, idx) => c.modules.map(_.value -> idx) }.toMap + // Now mark the graph for modules of interest + val markedGraph = mygraph.transformNodes { n => + val next = if (shouldDedupSet(n)) s"($n)" else n + candidateIdx + .get(n) + .map(i => s"$next [$i]") + .getOrElse(next) + } + markedGraph + } +} + +/** Checks for modules that have been marked as "must deduplicate" + * + * In cases where marked modules did not deduplicate, this transform attempts to provide context on + * what went wrong for debugging. + */ +class MustDeduplicateTransform extends Transform with DependencyAPIMigration { + import MustDeduplicateTransform._ + + override def prerequisites = Seq(Dependency[DedupModules]) + + // Make this run as soon after Dedup as possible + override def optionalPrerequisiteOf = (Forms.MidForm.toSet -- Forms.HighForm).toSeq + + override def invalidates(a: Transform) = false + + def execute(state: CircuitState): CircuitState = { + + lazy val igraph = InstanceKeyGraph(state.circuit) + + val dedupFailures: Seq[DedupFailure] = + state.annotations.flatMap { + case MustDeduplicateAnnotation(mods) => + val moduleNames = mods.map(_.leafModule).distinct + if (moduleNames.size <= 1) None + else { + val modNames = moduleNames.map(OfModule) + Some(findDedupFailures(modNames, igraph)) + } + case _ => None + } + if (dedupFailures.nonEmpty) { + val modgraph = igraph.graph.transformNodes(_.module) + // Create and log reports + val reports = dedupFailures.map { + case fail @ DedupFailure(shouldDedup, _, candidates) => + val graph = makeDedupFailureDiGraph(fail, modgraph).prettyTree() + val mods = shouldDedup.map("'" + _.value + "'").mkString(", ") + val msg = + s"""===== $mods are marked as "must deduplicate", but did not deduplicate. ===== + |$graph + |Failure candidates: + |${candidates.zipWithIndex.map { case (c, i) => s" - [$i] " + c.message }.mkString("\n")} + |""".stripMargin + logger.error(msg) + msg + } + + // Write reports and modules to disk + val dirName = state.annotations.collectFirst { case MustDeduplicateReportDirectory(dir) => dir } + .getOrElse("dedup_failures") + val dir = new File(dirName) + logger.error(s"Writing error report(s) to ${dir}...") + FileUtils.makeDirectory(dir.toString) + for ((report, idx) <- reports.zipWithIndex) { + val f = new File(dir, s"report_$idx.rpt") + logger.error(s"Writing $f...") + val fw = new FileWriter(f) + fw.write(report) + fw.close() + } + + val modsDir = new File(dir, "modules") + FileUtils.makeDirectory(modsDir.toString) + logger.error(s"Writing relevant modules to $modsDir...") + val relevantModule = dedupFailures.flatMap(_.relevantMods.map(_.value)).toSet + for (mod <- state.circuit.modules if relevantModule(mod.name)) { + val fw = new FileWriter(new File(modsDir, s"${mod.name}.fir")) + fw.write(mod.serialize) + fw.close() + } + + val msg = s"Modules marked 'must deduplicate' failed to deduplicate! See error reports in $dirName" + throw new DeduplicationFailureException(msg) + } + state + } +} diff --git a/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala b/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala index 10163b7259..e817571aab 100644 --- a/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala +++ b/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala @@ -9,6 +9,9 @@ import firrtl.passes.PassException /** Indicate that DCE should not be run */ case object NoDCEAnnotation extends NoTargetAnnotation +/** Indicate that ConstantPropagation should not be run */ +case object NoConstantPropagationAnnotation extends NoTargetAnnotation + /** Lets an annotation mark its ReferenceTarget members as DontTouch * * This permits a transform to run and remove its associated annotations, diff --git a/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala b/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala index 0bf6419fb5..69d4aa8ddc 100644 --- a/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala +++ b/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala @@ -37,7 +37,7 @@ class VerilogRename extends RemoveKeywordCollisions(v_keywords) { Dependency[FixAddingNegativeLiterals], Dependency[ReplaceTruncatingArithmetic], Dependency[InlineBitExtractionsTransform], - Dependency[InlineCastsTransform], + Dependency[InlineAcrossCastsTransform], Dependency[LegalizeClocksTransform], Dependency[FlattenRegUpdate], Dependency(passes.VerilogModulusCleanup) diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala index ee03ad300c..7500b3860f 100644 --- a/src/main/scala/firrtl/transforms/RemoveWires.scala +++ b/src/main/scala/firrtl/transforms/RemoveWires.scala @@ -11,6 +11,7 @@ import firrtl.WrappedExpression._ import firrtl.graph.{CyclicException, MutableDiGraph} import firrtl.options.Dependency import firrtl.Utils.getGroundZero +import firrtl.backends.experimental.smt.random.DefRandom import scala.collection.mutable import scala.util.{Failure, Success, Try} @@ -41,13 +42,13 @@ class RemoveWires extends Transform with DependencyAPIMigration { case _ => false } - // Extract all expressions that are references to a Node, Wire, or Reg + // Extract all expressions that are references to a Node, Wire, Reg or Rand // Since we are operating on LowForm, they can only be WRefs private def extractNodeWireRegRefs(expr: Expression): Seq[WRef] = { val refs = mutable.ArrayBuffer.empty[WRef] def rec(e: Expression): Expression = { e match { - case ref @ WRef(_, _, WireKind | NodeKind | RegKind, _) => refs += ref + case ref @ WRef(_, _, WireKind | NodeKind | RegKind | RandomKind, _) => refs += ref case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested.foreach(rec) case _ => // Do nothing } @@ -59,8 +60,9 @@ class RemoveWires extends Transform with DependencyAPIMigration { // Transform netlist into DefNodes private def getOrderedNodes( - netlist: mutable.LinkedHashMap[WrappedExpression, (Seq[Expression], Info)], - regInfo: mutable.Map[WrappedExpression, DefRegister] + netlist: mutable.LinkedHashMap[WrappedExpression, (Seq[Expression], Info)], + regInfo: mutable.Map[WrappedExpression, DefRegister], + randInfo: mutable.Map[WrappedExpression, DefRandom] ): Try[Seq[Statement]] = { val digraph = new MutableDiGraph[WrappedExpression] for ((sink, (exprs, _)) <- netlist) { @@ -80,7 +82,8 @@ class RemoveWires extends Transform with DependencyAPIMigration { ordered.map { key => val WRef(name, _, kind, _) = key.e1 kind match { - case RegKind => regInfo(key) + case RegKind => regInfo(key) + case RandomKind => randInfo(key) case WireKind | NodeKind => val (Seq(rhs), info) = netlist(key) DefNode(info, name, rhs) @@ -100,6 +103,8 @@ class RemoveWires extends Transform with DependencyAPIMigration { val wireInfo = mutable.HashMap.empty[WrappedExpression, Info] // Additional info about registers val regInfo = mutable.HashMap.empty[WrappedExpression, DefRegister] + // Additional info about rand statements + val randInfo = mutable.HashMap.empty[WrappedExpression, DefRandom] def onStmt(stmt: Statement): Statement = { stmt match { @@ -115,7 +120,13 @@ class RemoveWires extends Transform with DependencyAPIMigration { val initDep = Some(reg.init).filter(we(WRef(reg)) != we(_)) // Dependency exists IF reg doesn't init itself regInfo(we(WRef(reg))) = reg netlist(we(WRef(reg))) = (Seq(reg.clock) ++ resetDep ++ initDep, reg.info) - case decl: IsDeclaration => // Keep all declarations except for nodes and non-Analog wires + case rand: DefRandom => + randInfo(we(Reference(rand))) = rand + netlist(we(Reference(rand))) = (rand.clock ++: rand.en +: List(), rand.info) + case decl: CanBeReferenced => + // Keep all declarations except for nodes and non-Analog wires and "other" statements. + // Thus this is expected to match DefInstance and DefMemory which both do not connect to + // any signals directly (instead a separate Connect is used). decls += decl case con @ Connect(cinfo, lhs, rhs) => kind(lhs) match { @@ -145,7 +156,7 @@ class RemoveWires extends Transform with DependencyAPIMigration { m match { case mod @ Module(info, name, ports, body) => onStmt(body) - getOrderedNodes(netlist, regInfo) match { + getOrderedNodes(netlist, regInfo, randInfo) match { case Success(logic) => Module(info, name, ports, Block(List() ++ decls ++ logic ++ otherStmts)) // If we hit a CyclicException, just abort removing wires diff --git a/src/main/scala/firrtl/transforms/SimplifyMems.scala b/src/main/scala/firrtl/transforms/SimplifyMems.scala index 8ecc484ae2..92e19f7e0a 100644 --- a/src/main/scala/firrtl/transforms/SimplifyMems.scala +++ b/src/main/scala/firrtl/transforms/SimplifyMems.scala @@ -6,6 +6,7 @@ package transforms import firrtl.ir._ import firrtl.Mappers._ import firrtl.annotations._ +import firrtl.options.Dependency import firrtl.passes._ import firrtl.passes.memlib._ import firrtl.stage.Forms @@ -21,9 +22,12 @@ import ResolveMaskGranularity._ class SimplifyMems extends Transform with DependencyAPIMigration { override def prerequisites = Forms.MidForm - override def optionalPrerequisites = Seq.empty + override def optionalPrerequisites = Seq(Dependency[InferReadWrite]) override def optionalPrerequisiteOf = Forms.MidEmitters - override def invalidates(a: Transform) = false + override def invalidates(a: Transform) = a match { + case InferTypes => true + case _ => false + } def onModule(c: Circuit, renames: RenameMap)(m: DefModule): DefModule = { val moduleNS = Namespace(m) diff --git a/src/main/scala/firrtl/transforms/SortModules.scala b/src/main/scala/firrtl/transforms/SortModules.scala new file mode 100644 index 0000000000..ffac16614d --- /dev/null +++ b/src/main/scala/firrtl/transforms/SortModules.scala @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.transforms + +import firrtl.{CircuitState, DependencyAPIMigration, Transform} +import firrtl.analyses.InstanceKeyGraph +import firrtl.options.Dependency +import firrtl.stage.Forms + +/** Return a circuit where all modules (and external modules) are defined before use. */ +class SortModules extends Transform with DependencyAPIMigration { + + override def prerequisites = Seq(Dependency(firrtl.passes.CheckChirrtl)) + override def optionalPrerequisites = Seq.empty + override def optionalPrerequisiteOf = Forms.ChirrtlEmitters + override def invalidates(a: Transform) = false + + override def execute(state: CircuitState): CircuitState = { + val modulesx = InstanceKeyGraph(state.circuit).moduleOrder.reverse + state.copy(circuit = state.circuit.copy(modules = modulesx)) + } + +} diff --git a/src/main/scala/firrtl/transforms/package.scala b/src/main/scala/firrtl/transforms/package.scala index d758fa0a6b..5455690e27 100644 --- a/src/main/scala/firrtl/transforms/package.scala +++ b/src/main/scala/firrtl/transforms/package.scala @@ -3,6 +3,9 @@ package firrtl package object transforms { + type InlineAcrossCastsTransform = InlineCastsTransform + val InlineAcrossCastsTransform = InlineCastsTransform + @deprecated("Replaced by LegalizeClocksAndAsyncResetsTransform", "FIRRTL 1.4.0") type LegalizeClocksTransform = LegalizeClocksAndAsyncResetsTransform @deprecated("Replaced by LegalizeClocksAndAsyncResetsTransform", "FIRRTL 1.4.0") diff --git a/src/main/scala/logger/Logger.scala b/src/main/scala/logger/Logger.scala index 20c1338ec1..09fc0924ef 100644 --- a/src/main/scala/logger/Logger.scala +++ b/src/main/scala/logger/Logger.scala @@ -47,6 +47,7 @@ object LogLevel extends Enumeration { */ trait LazyLogging { protected val logger = new Logger(this.getClass.getName) + def getLogger: Logger = logger } /** diff --git a/src/main/scala/logger/LoggerOptions.scala b/src/main/scala/logger/LoggerOptions.scala index 683d6741f7..bfd072df24 100644 --- a/src/main/scala/logger/LoggerOptions.scala +++ b/src/main/scala/logger/LoggerOptions.scala @@ -32,7 +32,7 @@ class LoggerOptions private[logger] ( } /** Return the name of the log file, defaults to `a.log` if unspecified */ - def getLogFileName(): Option[String] = if (!logToFile) None else logFileName.orElse(Some("a.log")) + def getLogFileName(): Option[String] = if (!logToFile()) None else logFileName.orElse(Some("a.log")) /** True if a [[Logger]] should be writing to a file */ @deprecated("logToFile was removed, use logFileName.nonEmpty", "FIRRTL 1.2") diff --git a/src/main/scala/tutorial/lesson1-circuit-traversal/AnalyzeCircuit.scala b/src/main/scala/tutorial/lesson1-circuit-traversal/AnalyzeCircuit.scala index 23ecb114ca..bab67a74e3 100644 --- a/src/main/scala/tutorial/lesson1-circuit-traversal/AnalyzeCircuit.scala +++ b/src/main/scala/tutorial/lesson1-circuit-traversal/AnalyzeCircuit.scala @@ -142,7 +142,7 @@ class AnalyzeCircuit extends Transform { visited match { // If e is a [[firrtl.ir.Mux Mux]], increment our ledger and return e. case Mux(cond, tval, fval, tpe) => - ledger.foundMux + ledger.foundMux() e // If e is not a [[firrtl.ir.Mux Mux]], return e. case notmux => notmux diff --git a/src/test/resources/passes/Legalize/Legalize.fir b/src/test/resources/passes/Legalize/Legalize.fir index 7e5386951a..a0a3984561 100644 --- a/src/test/resources/passes/Legalize/Legalize.fir +++ b/src/test/resources/passes/Legalize/Legalize.fir @@ -39,3 +39,25 @@ circuit Legalize : when neq(bar_15, UInt(1)) : printf(clock, UInt(1), "Assertion failed!\n bar_15 != 0\n") stop(clock, UInt(1), 1) + + ; Check neg of literals + node negUInt0 = neg(UInt(123)) + when neq(negUInt0, SInt(-123)) : + printf(clock, UInt(1), "Assertion failed!\n negUInt0 != -123\n") + stop(clock, UInt(1), 1) + node negUInt1 = neg(UInt<8>(0)) + when neq(negUInt1, SInt<8>(0)) : + printf(clock, UInt(1), "Assertion failed!\n negUInt1 != 0\n") + stop(clock, UInt(1), 1) + node negSInt0 = neg(SInt(123)) + when neq(negSInt0, SInt(-123)) : + printf(clock, UInt(1), "Assertion failed!\n negSInt0 != -123\n") + stop(clock, UInt(1), 1) + node negSInt1 = neg(SInt(-123)) + when neq(negSInt1, SInt(123)) : + printf(clock, UInt(1), "Assertion failed!\n negSInt1 != 123\n") + stop(clock, UInt(1), 1) + node negSInt2 = neg(SInt(0)) + when neq(negSInt2, SInt(0)) : + printf(clock, UInt(1), "Assertion failed!\n negSInt2 != 0\n") + stop(clock, UInt(1), 1) diff --git a/src/test/scala/firrtl/RenameMapPrivateSpec.scala b/src/test/scala/firrtl/RenameMapPrivateSpec.scala new file mode 100644 index 0000000000..d735e6c876 --- /dev/null +++ b/src/test/scala/firrtl/RenameMapPrivateSpec.scala @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl + +import firrtl.annotations.Target +import firrtl.annotations.TargetToken.{Instance, OfModule} +import firrtl.analyses.InstanceKeyGraph +import firrtl.testutils.FirrtlFlatSpec + +class RenameMapPrivateSpec extends FirrtlFlatSpec { + "RenameMap.fromInstanceRenames" should "handle instance renames" in { + def tar(str: String): Target = Target.deserialize(str) + val circuit = parse( + """circuit Top : + | module Bar : + | skip + | module Foo : + | inst bar of Bar + | module Top : + | inst foo1 of Foo + | inst foo2 of Foo + | inst bar of Bar + |""".stripMargin + ) + val graph = InstanceKeyGraph(circuit) + val renames = Map( + OfModule("Foo") -> Map(Instance("bar") -> Instance("bbb")), + OfModule("Top") -> Map(Instance("foo1") -> Instance("ffff")) + ) + val rm = RenameMap.fromInstanceRenames(graph, renames) + rm.get(tar("~Top|Top/foo1:Foo")) should be(Some(Seq(tar("~Top|Top/ffff:Foo")))) + rm.get(tar("~Top|Top/foo2:Foo")) should be(None) + // Check of nesting + rm.get(tar("~Top|Top/foo1:Foo/bar:Bar")) should be(Some(Seq(tar("~Top|Top/ffff:Foo/bbb:Bar")))) + rm.get(tar("~Top|Top/foo2:Foo/bar:Bar")) should be(Some(Seq(tar("~Top|Top/foo2:Foo/bbb:Bar")))) + rm.get(tar("~Top|Foo/bar:Bar")) should be(Some(Seq(tar("~Top|Foo/bbb:Bar")))) + rm.get(tar("~Top|Top/bar:Bar")) should be(None) + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystemPassSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystemPassSpec.scala new file mode 100644 index 0000000000..e9254d7fdb --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystemPassSpec.scala @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.backends.experimental.smt + +import firrtl.annotations.{CircuitTarget, PresetAnnotation} +import firrtl.options.Dependency +import firrtl.testutils.LeanTransformSpec + +class FirrtlToTransitionSystemPassSpec + extends LeanTransformSpec(Seq(Dependency(firrtl.backends.experimental.smt.FirrtlToTransitionSystem))) { + behavior.of("FirrtlToTransitionSystem") + + it should "support preset wires" in { + // In order to give registers an initial wire, we use preset annotated resets. + // When using a wire instead of an input (which has the advantage of working regardless of the + // module hierarchy), we need to initialize it in order to get through the wire initialization check. + // In Chisel this generates a node which needs to be removed. + + val src = """circuit ModuleAB : + | module ModuleAB : + | input clock : Clock + | node _T = asAsyncReset(UInt<1>("h0")) + | node preset = _T + | reg REG : UInt<1>, clock with : + | reset => (preset, UInt<1>("h0")) + | assert(clock, UInt(1), not(REG), "REG == 0") + |""".stripMargin + val anno = PresetAnnotation(CircuitTarget("ModuleAB").module("ModuleAB").ref("preset")) + + val result = compile(src, List(anno)) + val sys = result.annotations.collectFirst { case TransitionSystemAnnotation(sys) => sys }.get + assert(sys.states.head.init.isDefined) + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala index e489db7d5e..2a0276e1a9 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala @@ -195,4 +195,46 @@ class MemorySpec extends EndToEndSMTBaseSpec { "memory with two write ports" should "can have collisions when enables are unconstrained" taggedAs (RequiresZ3) in { test(collisionTest("UInt(1)"), MCFail(1), kmax = 1) } + + private def readEnableSrc(pred: String, num: Int) = + s""" + |circuit ReadEnableTest$num: + | module ReadEnableTest$num: + | input c : Clock + | input preset: AsyncReset + | + | reg first: UInt<1>, c with: (reset => (preset, UInt(1))) + | first <= UInt(0) + | + | reg even: UInt<1>, c with: (reset => (preset, UInt(0))) + | node odd = not(even) + | even <= not(even) + | + | mem m: + | data-type => UInt<8> + | depth => 4 + | reader => r + | read-latency => 1 + | write-latency => 1 + | read-under-write => undefined + | + | m.r.clk <= c + | m.r.addr <= UInt(0) + | ; the read port is enabled in even cycles + | m.r.en <= even + | + | assert(c, $pred, not(first), "") + |""".stripMargin + + "a memory with read enable" should "supply valid data one cycle after en=1" in { + val init = Seq(MemoryScalarInitAnnotation(CircuitTarget(s"ReadEnableTest1").module(s"ReadEnableTest1").ref("m"), 0)) + // the read port is enabled on even cycles, so on odd cycles we should reliably get zeros + test(readEnableSrc("or(not(odd), eq(m.r.data, UInt(0)))", 1), MCSuccess, kmax = 3, annos = init) + } + + "a memory with read enable" should "supply invalid data one cycle after en=0" in { + val init = Seq(MemoryScalarInitAnnotation(CircuitTarget(s"ReadEnableTest2").module(s"ReadEnableTest2").ref("m"), 0)) + // the read port is disabled on odd cycles, so on even cycles we should *NOT* reliably get zeros + test(readEnableSrc("or(not(even), eq(m.r.data, UInt(0)))", 2), MCFail(1), kmax = 1, annos = init) + } } diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala index 729498301a..f846e19cc0 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala @@ -15,8 +15,7 @@ import scala.sys.process.{Process, ProcessLogger} class SMTCompilationTest extends AnyFlatSpec with LazyLogging { it should "generate valid SMTLib for AddNot" taggedAs (RequiresZ3) in { compileAndParse("AddNot") } it should "generate valid SMTLib for FPU" taggedAs (RequiresZ3) in { compileAndParse("FPU") } - // we get a stack overflow in Scala 2.11 because of a deeply nested and(...) expression in the sequencer - it should "generate valid SMTLib for HwachaSequencer" taggedAs (RequiresZ3) ignore { + it should "generate valid SMTLib for HwachaSequencer" taggedAs (RequiresZ3) in { compileAndParse("HwachaSequencer") } it should "generate valid SMTLib for ICache" taggedAs (RequiresZ3) in { compileAndParse("ICache") } diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/UninterpretedModulesSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/UninterpretedModulesSpec.scala new file mode 100644 index 0000000000..e4404d1007 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/UninterpretedModulesSpec.scala @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.backends.experimental.smt.end2end + +import firrtl.annotations.CircuitTarget +import firrtl.backends.experimental.smt.UninterpretedModuleAnnotation + +class UninterpretedModulesSpec extends EndToEndSMTBaseSpec { + + private def testCircuit(assumption: String = ""): String = { + s"""circuit UF00: + | module UF00: + | input clk: Clock + | input a: UInt<128> + | input b: UInt<128> + | input c: UInt<128> + | + | inst m0 of Magic + | m0.a <= a + | m0.b <= b + | + | inst m1 of Magic + | m1.a <= a + | m1.b <= c + | + | assert(clk, eq(m0.r, m1.r), UInt(1), "m0.r == m1.r") + | $assumption + | extmodule Magic: + | input a: UInt<128> + | input b: UInt<128> + | output r: UInt<128> + |""".stripMargin + } + private val magicAnno = UninterpretedModuleAnnotation(CircuitTarget("UF00").module("Magic"), "magic", 0) + + "two instances of the same uninterpreted module" should "give the same result when given the same inputs" taggedAs (RequiresZ3) in { + val assumeTheSame = """assume(clk, eq(b,c), UInt(1), "b == c")""" + test(testCircuit(assumeTheSame), MCSuccess, 1, "inputs are the same ==> outputs are the same", Seq(magicAnno)) + } + "two instances of the same uninterpreted module" should "not always give the same result when given potentially different inputs" taggedAs (RequiresZ3) in { + test( + testCircuit(), + MCFail(0), + 1, + "inputs are not necessarily the same ==> outputs can be different", + Seq(magicAnno) + ) + } +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala new file mode 100644 index 0000000000..8f17a847ed --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/random/InvalidToRandomSpec.scala @@ -0,0 +1,56 @@ +package firrtl.backends.experimental.smt.random + +import firrtl.options.Dependency +import firrtl.testutils.LeanTransformSpec + +class InvalidToRandomSpec extends LeanTransformSpec(Seq(Dependency(InvalidToRandomPass))) { + behavior.of("InvalidToRandomPass") + + val src1 = + s""" + |circuit Test: + | module Test: + | input a : UInt<2> + | output o : UInt<8> + | output o2 : UInt<8> + | output o3 : UInt<8> + | + | o is invalid + | + | when eq(a, UInt(3)): + | o <= UInt(5) + | + | o2 is invalid + | node o2_valid = eq(a, UInt(2)) + | when o2_valid: + | o2 <= UInt(7) + | + | o3 is invalid + | o3 <= UInt(3) + |""".stripMargin + + it should "model invalid signals as random" in { + + val circuit = compile(src1, List()).circuit + //println(circuit.serialize) + val result = circuit.serialize.split('\n').map(_.trim) + + // the condition should end up as a new node if it wasn't a reference already + assert(result.contains("node _GEN_0_invalid_cond = not(eq(a, UInt<2>(\"h3\")))")) + assert(result.contains("node o2_valid = eq(a, UInt<2>(\"h2\"))")) + + // every invalid results in a random statement + assert(result.contains("rand _GEN_0_invalid : UInt<3> when _GEN_0_invalid_cond")) + assert(result.contains("rand _GEN_1_invalid : UInt<3> when not(o2_valid)")) + + // the random value is conditionally assigned + assert(result.contains("node _GEN_0 = mux(_GEN_0_invalid_cond, _GEN_0_invalid, UInt<3>(\"h5\"))")) + assert(result.contains("node _GEN_1 = mux(not(o2_valid), _GEN_1_invalid, UInt<3>(\"h7\"))")) + + // expressions that are trivially valid do not get randomized + assert(result.contains("o3 <= UInt<2>(\"h3\")")) + val defRandCount = result.count(_.contains("rand ")) + assert(defRandCount == 2) + } + +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorSpec.scala new file mode 100644 index 0000000000..f8f889ac33 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorSpec.scala @@ -0,0 +1,360 @@ +package firrtl.backends.experimental.smt.random + +import firrtl.options.Dependency +import firrtl.testutils.LeanTransformSpec + +class UndefinedMemoryBehaviorSpec extends LeanTransformSpec(Seq(Dependency(UndefinedMemoryBehaviorPass))) { + behavior.of("UndefinedMemoryBehaviorPass") + + it should "model write-write conflicts between 2 ports" in { + + val circuit = compile(UBMSources.writeWriteConflict, List()).circuit + // println(circuit.serialize) + val result = circuit.serialize.split('\n').map(_.trim) + + // a random value should be declared for the data written on a write-write conflict + assert(result.contains("rand m_a_wwc_data : UInt<32>, m_a_clk when m_a_b_wwc")) + + // a write-write conflict occurs when both ports are enabled and the addresses match + assert(result.contains("m_a_b_wwc <= and(and(m_a_en, m_b_en), eq(m_a_addr, m_b_addr))")) + + // the data of read port a depends on whether there is a write-write conflict + assert(result.contains("m.a.data <= mux(m_a_b_wwc, m_a_wwc_data, m_a_data)")) + + // the enable of read port b depends on whether there is a write-write conflict + assert(result.contains("m.b.en <= and(m_b_en, not(m_a_b_wwc))")) + } + + it should "model write-write conflicts between 3 ports" in { + + val circuit = compile(UBMSources.writeWriteConflict3, List()).circuit + //println(circuit.serialize) + val result = circuit.serialize.split('\n').map(_.trim) + + // when there is more than one next write port, a "active" node is created + assert(result.contains("node m_a_wwc_active = or(m_a_b_wwc, m_a_c_wwc)")) + + // a random value should be declared for the data written on a write-write conflict + assert(result.contains("rand m_a_wwc_data : UInt<32>, m_a_clk when m_a_wwc_active")) + assert(result.contains("rand m_b_wwc_data : UInt<32>, m_b_clk when m_b_c_wwc")) + + // a write-write conflict occurs when both ports are enabled and the addresses match + Seq(("a", "b"), ("a", "c"), ("b", "c")).foreach { + case (w1, w2) => + assert( + result.contains(s"m_${w1}_${w2}_wwc <= and(and(m_${w1}_en, m_${w2}_en), eq(m_${w1}_addr, m_${w2}_addr))") + ) + } + + // the data of read port a depends on whether there is a write-write conflict + assert(result.contains("m.a.data <= mux(m_a_wwc_active, m_a_wwc_data, m_a_data)")) + + // the data of read port b depends on whether there is a write-write conflict + assert(result.contains("m.b.data <= mux(m_b_c_wwc, m_b_wwc_data, m_b_data)")) + + // the enable of read port b depends on whether there is a write-write conflict + assert(result.contains("m.b.en <= and(m_b_en, not(m_a_b_wwc))")) + + // the enable of read port c depends on whether there is a write-write conflict + // note that in this case we do not add an extra node since the disjunction is only used once + assert(result.contains("m.c.en <= and(m_c_en, not(or(m_a_c_wwc, m_b_c_wwc)))")) + } + + it should "model write-write conflicts more efficiently when ports are mutually exclusive" in { + + val circuit = compile(UBMSources.writeWriteConflict3Exclusive, List()).circuit + // println(circuit.serialize) + val result = circuit.serialize.split('\n').map(_.trim) + + // we should not compute the conflict between a and c since it is impossible + assert(!result.contains("node m_a_c_wwc = and(and(m_a_en, m_c_en), eq(m_a_addr, m_c_addr))")) + + // the enable of port b depends on whether there is a conflict with a + assert(result.contains("m.b.en <= and(m_b_en, not(m_a_b_wwc))")) + + // the data of port b depends on whether these is a conflict with c + assert(result.contains("m.b.data <= mux(m_b_c_wwc, m_b_wwc_data, m_b_data)")) + + // the enable of port c only depend on whether there is a conflict with b since c and a cannot conflict + assert(result.contains("m.c.en <= and(m_c_en, not(m_b_c_wwc))")) + + // the data of port a only depends on whether there is a conflict with b since a and c cannot conflict + assert(result.contains("m.a.data <= mux(m_a_b_wwc, m_a_wwc_data, m_a_data)")) + } + + it should "assert out-of-bounds writes when told to" in { + val anno = List(UndefinedMemoryBehaviorOptions(assertNoOutOfBoundsWrites = true)) + + val circuit = compile(UBMSources.readWrite(30, 0), anno).circuit + // println(circuit.serialize) + val result = circuit.serialize.split('\n').map(_.trim) + + assert( + result.contains( + """assert(m_a_clk, geq(UInt<5>("h1e"), m_a_addr), UInt<1>("h1"), "out of bounds read")""" + ) + ) + } + + it should "model out-of-bounds reads" in { + val circuit = compile(UBMSources.readWrite(30, 0), List()).circuit + //println(circuit.serialize) + val result = circuit.serialize.split('\n').map(_.trim) + + // an out of bounds read happens if the depth is not greater or equal to the address + assert(result.contains("node m_r_oob = not(geq(UInt<5>(\"h1e\"), m_r_addr))")) + + // the source of randomness needs to be triggered when there is an out of bounds read + assert(result.contains("rand m_r_rand_data : UInt<32>, m_r_clk when m_r_oob")) + + // the data is random when there is an oob + assert(result.contains("m_r_data <= mux(m_r_oob, m_r_rand_data, m.r.data)")) + } + + it should "model un-enabled reads w/o out-of-bounds" in { + // without possible out-of-bounds + val circuit = compile(UBMSources.readEnable(32), List()).circuit + //println(circuit.serialize) + val result = circuit.serialize.split('\n').map(_.trim) + + // the memory is disabled when it is not enabled + assert(result.contains("node m_r_disabled = not(m_r_en)")) + + // the source of randomness needs to be triggered when there is an read while the port is disabled + assert(result.contains("rand m_r_rand_data : UInt<32>, m_r_clk when m_r_disabled")) + + // the data is random when there is an un-enabled read + assert(result.contains("m_r_data <= mux(m_r_disabled, m_r_rand_data, m.r.data)")) + } + + it should "model un-enabled reads with out-of-bounds" in { + // with possible out-of-bounds + val circuit = compile(UBMSources.readEnable(30), List()).circuit + //println(circuit.serialize) + val result = circuit.serialize.split('\n').map(_.trim) + + // the memory is disabled when it is not enabled + assert(result.contains("node m_r_disabled = not(m_r_en)")) + + // an out of bounds read happens if the depth is not greater or equal to the address and the memory is enabled + assert(result.contains("node m_r_oob = and(m_r_en, not(geq(UInt<5>(\"h1e\"), m_r_addr)))")) + + // the two possible issues are combined into a single signal + assert(result.contains("node m_r_do_rand = or(m_r_disabled, m_r_oob)")) + + // the source of randomness needs to be triggered when either issue occurs + assert(result.contains("rand m_r_rand_data : UInt<32>, m_r_clk when m_r_do_rand")) + + // the data is random when either issue occurs + assert(result.contains("m_r_data <= mux(m_r_do_rand, m_r_rand_data, m.r.data)")) + } + + it should "model un-enabled reads with out-of-bounds with read pipelining" in { + // with read latency one, we need to pipeline the `do_rand` signal + val circuit = compile(UBMSources.readEnable(30, 1), List()).circuit + //println(circuit.serialize) + val result = circuit.serialize.split('\n').map(_.trim) + + // pipeline register + assert(result.contains("m_r_do_rand_r1 <= m_r_do_rand")) + + // the source of randomness needs to be triggered by the pipeline register + assert(result.contains("rand m_r_rand_data : UInt<32>, m_r_clk when m_r_do_rand_r1")) + + // the data is random when the pipeline register is 1 + assert(result.contains("m_r_data <= mux(m_r_do_rand_r1, m_r_rand_data, m.r.data)")) + } + + it should "model read/write conflicts when they are undefined" in { + val circuit = compile(UBMSources.readWrite(32, 1), List()).circuit + //println(circuit.serialize) + val result = circuit.serialize.split('\n').map(_.trim) + + // detect read/write conflicts + assert(result.contains("m_r_a_rwc <= eq(m_r_addr, m_a_addr)")) + + // delay the signal + assert(result.contains("m_r_do_rand_r1 <= m_r_rwc")) + + // randomize the data + assert(result.contains("rand m_r_rand_data : UInt<32>, m_r_clk when m_r_do_rand_r1")) + assert(result.contains("m_r_data <= mux(m_r_do_rand_r1, m_r_rand_data, m.r.data)")) + } +} + +private object UBMSources { + + val writeWriteConflict = + s""" + |circuit Test: + | module Test: + | input c : Clock + | input preset: AsyncReset + | input addr : UInt<8> + | input data : UInt<32> + | input aEn : UInt<1> + | input bEn : UInt<1> + | + | mem m: + | data-type => UInt<32> + | depth => 32 + | reader => r + | writer => a, b + | read-latency => 0 + | write-latency => 1 + | read-under-write => undefined + | + | m.r.clk <= c + | m.r.en <= UInt(1) + | m.r.addr <= addr + | + | ; both read ports write to the same address and the same data + | m.a.clk <= c + | m.a.en <= aEn + | m.a.addr <= addr + | m.a.data <= data + | m.a.mask <= UInt(1) + | m.b.clk <= c + | m.b.en <= bEn + | m.b.addr <= addr + | m.b.data <= data + | m.b.mask <= UInt(1) + """.stripMargin + + val writeWriteConflict3 = + s""" + |circuit Test: + | module Test: + | input c : Clock + | input preset: AsyncReset + | input addr : UInt<8> + | input data : UInt<32> + | input aEn : UInt<1> + | input bEn : UInt<1> + | input cEn : UInt<1> + | + | mem m: + | data-type => UInt<32> + | depth => 32 + | reader => r + | writer => a, b, c + | read-latency => 0 + | write-latency => 1 + | read-under-write => undefined + | + | m.r.clk <= c + | m.r.en <= UInt(1) + | m.r.addr <= addr + | + | ; both read ports write to the same address and the same data + | m.a.clk <= c + | m.a.en <= aEn + | m.a.addr <= addr + | m.a.data <= data + | m.a.mask <= UInt(1) + | m.b.clk <= c + | m.b.en <= bEn + | m.b.addr <= addr + | m.b.data <= data + | m.b.mask <= UInt(1) + | m.c.clk <= c + | m.c.en <= cEn + | m.c.addr <= addr + | m.c.data <= data + | m.c.mask <= UInt(1) + """.stripMargin + + val writeWriteConflict3Exclusive = + s""" + |circuit Test: + | module Test: + | input c : Clock + | input preset: AsyncReset + | input addr : UInt<8> + | input data : UInt<32> + | input aEn : UInt<1> + | input bEn : UInt<1> + | + | mem m: + | data-type => UInt<32> + | depth => 32 + | reader => r + | writer => a, b, c + | read-latency => 0 + | write-latency => 1 + | read-under-write => undefined + | + | m.r.clk <= c + | m.r.en <= UInt(1) + | m.r.addr <= addr + | + | ; both read ports write to the same address and the same data + | m.a.clk <= c + | m.a.en <= aEn + | m.a.addr <= addr + | m.a.data <= data + | m.a.mask <= UInt(1) + | m.b.clk <= c + | m.b.en <= bEn + | m.b.addr <= addr + | m.b.data <= data + | m.b.mask <= UInt(1) + | m.c.clk <= c + | m.c.en <= not(aEn) + | m.c.addr <= addr + | m.c.data <= data + | m.c.mask <= UInt(1) + """.stripMargin + + def readWrite(depth: Int, readLatency: Int) = + s"""circuit CollisionTest: + | module CollisionTest: + | input c : Clock + | input preset: AsyncReset + | input addr : UInt<8> + | input data : UInt<32> + | output dataOut : UInt<32> + | + | mem m: + | data-type => UInt<32> + | depth => $depth + | reader => r + | writer => a + | read-latency => $readLatency + | write-latency => 1 + | read-under-write => undefined + | + | m.r.clk <= c + | m.r.en <= UInt(1) + | m.r.addr <= addr + | dataOut <= m.r.data + | + | m.a.clk <= c + | m.a.mask <= UInt(1) + | m.a.en <= UInt(1) + | m.a.addr <= addr + | m.a.data <= data + |""".stripMargin + + def readEnable(depth: Int, latency: Int = 0) = + s"""circuit Test: + | module Test: + | input c : Clock + | input addr : UInt<8> + | input en : UInt<1> + | output data : UInt<32> + | + | mem m: + | data-type => UInt<32> + | depth => $depth + | reader => r + | read-latency => $latency + | write-latency => 1 + | read-under-write => old + | + | m.r.clk <= c + | m.r.en <= en + | m.r.addr <= addr + | data <= m.r.data + |""".stripMargin +} diff --git a/src/test/scala/firrtl/passes/LowerTypesSpec.scala b/src/test/scala/firrtl/passes/LowerTypesSpec.scala index 70fa51fdc1..7ca9854496 100644 --- a/src/test/scala/firrtl/passes/LowerTypesSpec.scala +++ b/src/test/scala/firrtl/passes/LowerTypesSpec.scala @@ -2,10 +2,13 @@ package firrtl.passes import firrtl.annotations.{CircuitTarget, IsMember} +import firrtl.annotations.TargetToken.{Instance, OfModule} +import firrtl.analyses.InstanceKeyGraph import firrtl.{CircuitState, RenameMap, Utils} import firrtl.options.Dependency import firrtl.stage.TransformManager import firrtl.stage.TransformManager.TransformDependency +import firrtl.testutils.FirrtlMatchers import org.scalatest.flatspec.AnyFlatSpec /** Unit test style tests for [[LowerTypes]]. @@ -228,22 +231,35 @@ class LowerTypesRenamingSpec extends AnyFlatSpec { } /** Instances are a special case since they do not get completely destructed but instead become a 1-deep bundle. */ -class LowerTypesOfInstancesSpec extends AnyFlatSpec { +class LowerTypesOfInstancesSpec extends AnyFlatSpec with FirrtlMatchers { import LowerTypesSpecUtils._ private case class Lower(inst: firrtl.ir.DefInstance, fields: Seq[String], renameMap: RenameMap) private val m = CircuitTarget("m").module("m") + private val igraph = InstanceKeyGraph( + parse( + """circuit m: + | module c: + | skip + | module m: + | inst i of c + |""".stripMargin + ) + ) def resultToFieldSeq(res: Seq[(String, firrtl.ir.SubField)]): Seq[String] = res.map(_._2).map(r => s"${r.name} : ${r.tpe.serialize}") private def lower( - n: String, - tpe: String, - module: String, - namespace: Set[String], - renames: RenameMap = RenameMap() + n: String, + tpe: String, + module: String, + namespace: Set[String], + otherRenames: RenameMap = RenameMap() ): Lower = { val ref = firrtl.ir.DefInstance(firrtl.ir.NoInfo, n, module, parseType(tpe)) val mutableSet = scala.collection.mutable.HashSet[String]() ++ namespace - val (newInstance, res) = DestructTypes.destructInstance(m, ref, mutableSet, renames, Set()) + val instRenames = scala.collection.mutable.ListBuffer[(Instance, Instance)]() + val (newInstance, res) = DestructTypes.destructInstance(m, ref, mutableSet, instRenames, Set()) + val instMap = Map(OfModule("m") -> instRenames.toMap) + val renames = RenameMap.fromInstanceRenames(igraph, instMap).andThen(otherRenames) Lower(newInstance, resultToFieldSeq(res), renames) } private def get(l: Lower, m: IsMember): Set[IsMember] = l.renameMap.get(m).get.toSet @@ -305,7 +321,7 @@ class LowerTypesOfInstancesSpec extends AnyFlatSpec { assert(get(l, i) == Set(i_)) // the ports renaming is also noted - val r = portRenames.andThen(otherRenames) + val r = portRenames.andThen(l.renameMap) assert(r.get(i.ref("b")).get == Seq(i_.ref("b__c"))) assert(r.get(i.ref("b").field("c")).get == Seq(i_.ref("b__c"))) assert(r.get(i.ref("b_c")).get == Seq(i_.ref("b_c"))) diff --git a/src/test/scala/firrtl/testutils/FirrtlSpec.scala b/src/test/scala/firrtl/testutils/FirrtlSpec.scala index 3a6f937237..4dc2d642da 100644 --- a/src/test/scala/firrtl/testutils/FirrtlSpec.scala +++ b/src/test/scala/firrtl/testutils/FirrtlSpec.scala @@ -4,6 +4,7 @@ package firrtl.testutils import java.io._ import java.security.Permission +import scala.sys.process._ import logger.{LazyLogging, LogLevel, LogLevelAnnotation} @@ -122,12 +123,73 @@ trait FirrtlRunners extends BackendCompilationUtilities { assert(BackendCompilationUtilities.yosysExpectSuccess(customName, refName, testDir, timesteps)) } + /** Check equivalence of Firrtl with reference Verilog + * + * @note the name of the reference Verilog module is grabbed via regex + * @param inputFirrtl string containing Firrtl source + * @param referenceVerilog Verilog that will be used as reference for LEC + * @param timesteps the maximum number of timesteps to consider + */ + def firrtlEquivalenceWithVerilog( + inputFirrtl: String, + referenceVerilog: String, + timesteps: Int = 1 + ): Unit = { + val VerilogModule = """(?s).*module\s(\w+).*""".r + val refName = referenceVerilog match { + case VerilogModule(name) => name + case _ => throw new Exception(s"Reference Verilog must match simple regex! $VerilogModule") + } + val circuit = Parser.parse(inputFirrtl.split("\n").toIterator) + val inputName = circuit.main + require(refName != inputName, s"Name of reference Verilog must not match name of input FIRRTL: $refName") + + val testDir = createTestDirectory(inputName + "_equivalence_test") + + val annos = List( + TargetDirAnnotation(testDir.toString), + InfoModeAnnotation("ignore"), + stage.FirrtlCircuitAnnotation(circuit), + stage.RunFirrtlTransformAnnotation.stringToEmitter("verilog"), + stage.OutputFileAnnotation(inputName) + ) + + (new firrtl.stage.FirrtlStage).execute(Array(), annos) + + // Write reference + val w = new FileWriter(new File(testDir, s"$refName.v")) + w.write(referenceVerilog) + w.close() + + assert(BackendCompilationUtilities.yosysExpectSuccess(inputName, refName, testDir, timesteps)) + } + /** Compiles input Firrtl to Verilog */ def compileToVerilog(input: String, annotations: AnnotationSeq = Seq.empty): String = { + compileToVerilogCircuitState(input, annotations).getEmittedCircuit.value + } + + /** Compiles input Firrtl to Verilog */ + def compileToVerilogCircuitState(input: String, annotations: AnnotationSeq = Seq.empty): CircuitState = { val circuit = Parser.parse(input.split("\n").toIterator) val compiler = new VerilogCompiler - val res = compiler.compileAndEmit(CircuitState(circuit, HighForm, annotations), extraCheckTransforms) - res.getEmittedCircuit.value + compiler.compileAndEmit(CircuitState(circuit, HighForm, annotations), extraCheckTransforms) + } + + /** Run Verilator lint on some Verilog text + * + * @param inputVerilog Verilog to pass to `verilator --lint-only` + * @return Verilator return 0 + */ + def lintVerilog(inputVerilog: String): Unit = { + val testDir = createTestDirectory(s"${this.getClass.getSimpleName}_lint") + val filename = new File(testDir, "test.v") + val w = new FileWriter(filename) + w.write(inputVerilog) + w.close() + + val cmd = Seq("verilator", "--lint-only", filename.toString) + assert(cmd.!(loggingProcessLogger) == 0, "Lint must pass") } /** Compile a Firrtl file @@ -368,6 +430,14 @@ abstract class ExecutionTest( } } +/** Super class for execution driven Firrtl tests compiled without optimizations */ +abstract class ExecutionTestNoOpt( + name: String, + dir: String, + vFiles: Seq[String] = Seq.empty, + annotations: AnnotationSeq = Seq.empty) + extends ExecutionTest(name, dir, vFiles, RunFirrtlTransformAnnotation(new MinimumVerilogEmitter) +: annotations) + /** Super class for compilation driven Firrtl tests */ abstract class CompilationTest(name: String, dir: String) extends FirrtlPropSpec { property(s"$name should compile correctly") { diff --git a/src/test/scala/firrtl/transforms/EnsureNamedStatementsSpec.scala b/src/test/scala/firrtl/transforms/EnsureNamedStatementsSpec.scala new file mode 100644 index 0000000000..4c99399482 --- /dev/null +++ b/src/test/scala/firrtl/transforms/EnsureNamedStatementsSpec.scala @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.transforms + +import firrtl.options.Dependency +import firrtl.testutils.LeanTransformSpec + +class EnsureNamedStatementsSpec extends LeanTransformSpec(Seq(Dependency(EnsureNamedStatements))) { + behavior.of("EnsureNamedStatements") + + it should "automatically name statements that do not have a name yet" in { + val src = """circuit test : + | module test : + | input clock : Clock + | input stop_ : UInt<1> + | assert(clock, UInt(1), not(UInt(0)), "") + | stop(clock, UInt(1), 1) : stop_123 + | stop(clock, UInt(1), 1) + | assert(clock, UInt(0), UInt(0), "") + | assume(clock, UInt(0), UInt(0), "") + | cover(clock, UInt(0), UInt(0), "") + | cover(clock, UInt(0), UInt(0), "") + | + |""".stripMargin + + val result = compile(src, List()).circuit.serialize.split('\n').map(_.trim) + + val expected = List( + """assert(clock, UInt<1>("h1"), not(UInt<1>("h0")), "") : assert_0""", + """stop(clock, UInt<1>("h1"), 1) : stop_123""", + """stop(clock, UInt<1>("h1"), 1) : stop_0""", + """assert(clock, UInt<1>("h0"), UInt<1>("h0"), "") : assert_1""", + """assume(clock, UInt<1>("h0"), UInt<1>("h0"), "") : assume_0""", + """cover(clock, UInt<1>("h0"), UInt<1>("h0"), "") : cover_0""", + """cover(clock, UInt<1>("h0"), UInt<1>("h0"), "") : cover_1""" + ) + expected.foreach(e => assert(result.contains(e))) + } +} diff --git a/src/test/scala/firrtlTests/CheckSpec.scala b/src/test/scala/firrtlTests/CheckSpec.scala index 1137f8cd3b..547639d609 100644 --- a/src/test/scala/firrtlTests/CheckSpec.scala +++ b/src/test/scala/firrtlTests/CheckSpec.scala @@ -86,6 +86,106 @@ class CheckSpec extends AnyFlatSpec with Matchers { } } + behavior.of("Check Types") + + def runCheckTypes(input: String) = { + val passes = List(InferTypes, CheckTypes) + val wrapped = "circuit test:\n module test:\n " + input.replaceAll("\n", "\n ") + passes.foldLeft(Parser.parse(wrapped)) { case (c, p) => p.run(c) } + } + + it should "disallow mux enable conditions that are not 1-bit UInts (or unknown width)" in { + def mk(tpe: String) = + s"""|input en : $tpe + |input foo : UInt<8> + |input bar : UInt<8> + |node x = mux(en, foo, bar)""".stripMargin + a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) } + a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) } + a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) } + a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) } + a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) } + runCheckTypes(mk("UInt")) + runCheckTypes(mk("UInt<1>")) + } + + it should "disallow when predicates that are not 1-bit UInts (or unknown width)" in { + def mk(tpe: String) = + s"""|input en : $tpe + |input foo : UInt<8> + |input bar : UInt<8> + |output out : UInt<8> + |when en : + | out <= foo + |else: + | out <= bar""".stripMargin + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) } + runCheckTypes(mk("UInt")) + runCheckTypes(mk("UInt<1>")) + } + + it should "disallow print enables that are not 1-bit UInts (or unknown width)" in { + def mk(tpe: String) = + s"""|input en : $tpe + |input clock : Clock + |printf(clock, en, "Hello World!\\n")""".stripMargin + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) } + runCheckTypes(mk("UInt")) + runCheckTypes(mk("UInt<1>")) + } + + it should "disallow stop enables that are not 1-bit UInts (or unknown width)" in { + def mk(tpe: String) = + s"""|input en : $tpe + |input clock : Clock + |stop(clock, en, 0)""".stripMargin + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) } + runCheckTypes(mk("UInt")) + runCheckTypes(mk("UInt<1>")) + } + + it should "disallow verif node predicates that are not 1-bit UInts (or unknown width)" in { + def mk(tpe: String) = + s"""|input en : $tpe + |input cond : UInt<1> + |input clock : Clock + |assert(clock, en, cond, "Howdy!")""".stripMargin + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) } + runCheckTypes(mk("UInt")) + runCheckTypes(mk("UInt<1>")) + } + + it should "disallow verif node enables that are not 1-bit UInts (or unknown width)" in { + def mk(tpe: String) = + s"""|input en : UInt<1> + |input cond : $tpe + |input clock : Clock + |assert(clock, en, cond, "Howdy!")""".stripMargin + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) } + runCheckTypes(mk("UInt")) + runCheckTypes(mk("UInt<1>")) + } + "Instance loops a -> b -> a" should "be detected" in { val input = """ @@ -384,6 +484,36 @@ class CheckSpec extends AnyFlatSpec with Matchers { } } + "Attempting to shadow a statement name" should "throw an error" in { + val input = + s"""|circuit scopes: + | module scopes: + | input c: Clock + | input i: UInt<1> + | output o: UInt<1> + | wire x: UInt<1> + | when i: + | stop(c, UInt(1), 1) : x + | o <= and(x, i) + |""".stripMargin + assertThrows[CheckHighForm.NotUniqueException] { + checkHighInput(input) + } + } + + "Colliding statement names" should "throw an error" in { + val input = + s"""|circuit test: + | module test: + | input c: Clock + | stop(c, UInt(1), 1) : x + | stop(c, UInt(1), 1) : x + |""".stripMargin + assertThrows[CheckHighForm.NotUniqueException] { + checkHighInput(input) + } + } + "Conditionally statements" should "create separate consequent and alternate scopes" in { val input = s"""|circuit scopes: @@ -536,6 +666,20 @@ class CheckSpec extends AnyFlatSpec with Matchers { } } + it should "throw an exception if a statement name is used as a reference" in { + val src = """ + |circuit test: + | module test: + | input clock: Clock + | output a: UInt<2> + | stop(clock, UInt(1), 1) : hello + | a <= hello + |""".stripMargin + assertThrows[CheckHighForm.UndeclaredReferenceException] { + checkHighInput(src) + } + } + } object CheckSpec { diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index 9497304298..ababb95b9c 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -7,6 +7,7 @@ import firrtl.passes._ import firrtl.transforms._ import firrtl.testutils._ import firrtl.annotations.Annotation +import firrtl.stage.DisableFold class ConstantPropagationSpec extends FirrtlFlatSpec { val transforms: Seq[Transform] = @@ -798,6 +799,17 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { castCheck("Clock", "asClock") castCheck("AsyncReset", "asAsyncReset") } + + /* */ + "The rule a / a -> 1" should "be ignored if division folds are disabled" in { + val input = + """circuit foo: + | module foo: + | input a: UInt<8> + | output b: UInt<8> + | b <= div(a, a)""".stripMargin + (parse(exec(input, Seq(DisableFold(PrimOps.Div))))) should be(parse(input)) + } } // More sophisticated tests of the full compiler @@ -896,6 +908,17 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { execute(input, check, Seq(dontTouch("Child.in1"))) } + it should "NOT optimize if no-constant-propagation is enabled" in { + val input = + """circuit Foo: + | module Foo: + | input a: UInt<1> + | output b: UInt<1> + | b <= and(UInt<1>(0), a)""".stripMargin + val check = parse(input).serialize + execute(input, check, Seq(NoConstantPropagationAnnotation)) + } + it should "still propagate constants even when there is name swapping" in { val input = """circuit Top : @@ -1508,22 +1531,29 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { val input = s"""|circuit Foo: | module Foo: + | input in1: SInt<3> | output out1: UInt<2> | output out2: UInt<2> | output out3: UInt<2> + | output out4: UInt<4> | out1 <= xor(SInt<2>(-1), SInt<2>(1)) | out2 <= or(SInt<2>(-1), SInt<2>(1)) | out3 <= and(SInt<2>(-1), SInt<2>(-2)) + | out4 <= xor(in1, SInt<4>(0)) |""".stripMargin val check = s"""|circuit Foo: | module Foo: + | input in1: SInt<3> | output out1: UInt<2> | output out2: UInt<2> | output out3: UInt<2> + | output out4: UInt<4> | out1 <= UInt<2>(2) | out2 <= UInt<2>(3) | out3 <= UInt<2>(2) + | node _GEN_0 = pad(in1, 4) + | out4 <= asUInt(_GEN_0) |""".stripMargin execute(input, check, Seq.empty) } @@ -1626,6 +1656,20 @@ class ConstantPropagationEquivalenceSpec extends FirrtlFlatSpec { firrtlEquivalenceTest(input, transforms) } + // https://github.com/chipsalliance/firrtl/issues/2034 + "SInt OR with constant zero" should "have the correct widths" in { + val input = + s"""circuit WidthsOrSInt : + | module WidthsOrSInt : + | input in : SInt<1> + | input in2 : SInt<4> + | output out : UInt<8> + | output out2 : UInt<8> + | out <= or(in, SInt<8>(0)) + | out2 <= or(in2, SInt<8>(0))""".stripMargin + firrtlEquivalenceTest(input, transforms) + } + "addition by zero width wires" should "have the correct widths" in { val input = s"""circuit ZeroWidthAdd: diff --git a/src/test/scala/firrtlTests/ExpandWhensSpec.scala b/src/test/scala/firrtlTests/ExpandWhensSpec.scala index c186b5168c..52c87ffbc6 100644 --- a/src/test/scala/firrtlTests/ExpandWhensSpec.scala +++ b/src/test/scala/firrtlTests/ExpandWhensSpec.scala @@ -142,10 +142,25 @@ class ExpandWhensSpec extends FirrtlFlatSpec { | input in : UInt<32> | input p : UInt<1> | when p : - | assert(clock, eq(in, UInt<1>("h1")), UInt<1>("h1"), "assert0") + | assert(clock, eq(in, UInt<1>("h1")), UInt<1>("h1"), "assert0") : test_assert | else : | skip""".stripMargin - val check = "assert(clock, eq(in, UInt<1>(\"h1\")), and(and(UInt<1>(\"h1\"), p), UInt<1>(\"h1\")), \"assert0\")" + val check = + "assert(clock, eq(in, UInt<1>(\"h1\")), and(and(UInt<1>(\"h1\"), p), UInt<1>(\"h1\")), \"assert0\") : test_assert" + executeTest(input, check, true) + } + it should "handle stops" in { + val input = + """circuit Test : + | module Test : + | input clock : Clock + | input in : UInt<32> + | input p : UInt<1> + | when p : + | stop(clock, UInt(1), 1) : test_stop + | else : + | skip""".stripMargin + val check = """stop(clock, and(and(UInt<1>("h1"), p), UInt<1>("h1")), 1) : test_stop""" executeTest(input, check, true) } } diff --git a/src/test/scala/firrtlTests/InferReadWriteSpec.scala b/src/test/scala/firrtlTests/InferReadWriteSpec.scala index 1fb242972f..62969df5f2 100644 --- a/src/test/scala/firrtlTests/InferReadWriteSpec.scala +++ b/src/test/scala/firrtlTests/InferReadWriteSpec.scala @@ -177,4 +177,55 @@ circuit sram6t : // Check correctness of firrtl res should containLine(s"mem.rw.wmode <= wen") } + + def sameAddr(ruw: String): String = { + s""" + |circuit sram6t : + | module sram6t : + | input clock : Clock + | output io : { flip addr : UInt<11>, flip valid : UInt<1>, flip write : UInt<1>, flip dataIn : UInt<32>, dataOut : UInt<32>} + | + | mem mem: + | data-type => UInt<4> + | depth => 64 + | reader => r + | writer => w + | read-latency => 1 + | write-latency => 1 + | read-under-write => ${ruw} + | + | mem.r.clk <= clock + | mem.r.addr <= io.addr + | mem.r.en <= io.valid + | io.dataOut <= mem.r.data + | + | node wen = and(io.valid, io.write) + | mem.w.clk <= clock + | mem.w.addr <= io.addr + | mem.w.en <= wen + | mem.w.mask <= UInt(1) + | mem.w.data <= io.dataIn""".stripMargin + } + + "Infer ReadWrite Ports" should "infer readwrite ports from shared addresses with undefined readUnderWrite" in { + val input = sameAddr("undefined") + val annos = Seq(memlib.InferReadWriteAnnotation) + val res = compileAndEmit(CircuitState(parse(input), HighForm, annos)) + // Check correctness of firrtl + res should containLine(s"mem.rw.wmode <= wen") + } + + Seq("old", "new").foreach { ruw => + "Infer ReadWrite Ports" should s"not infer readwrite ports from shared addresses with '${ruw}' readUnderWrite" in { + val input = sameAddr(ruw) + val annos = Seq(memlib.InferReadWriteAnnotation) + intercept[Exception] { + compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) + } match { + case CustomTransformException(_: InferReadWriteCheckException) => // success + case _ => fail() + } + } + } + } diff --git a/src/test/scala/firrtlTests/InlineAcrossCastsSpec.scala b/src/test/scala/firrtlTests/InlineAcrossCastsSpec.scala new file mode 100644 index 0000000000..669ae077ff --- /dev/null +++ b/src/test/scala/firrtlTests/InlineAcrossCastsSpec.scala @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtlTests + +import firrtl.transforms.InlineAcrossCastsTransform +import firrtl.testutils.FirrtlFlatSpec +import firrtl.testutils.FirrtlCheckers._ + +class InlineAcrossCastsEquivalenceSpec extends FirrtlFlatSpec { + /* + * Note: InlineCasts is still part of mverilog, so this test must both: + * - Test that the InlineCasts fix is effective given the current mverilog + * - Provide a test that will be robust if and when InlineCasts is no longer run in mverilog + * + * This is why the test passes InlineCasts as a custom transform: to future-proof it so that + * it can do real LEC against no-InlineCasts. It currently is just a sanity check that the + * emitted Verilog is legal, but it will automatically become a more meaningful test when + * InlineCasts is not run in mverilog. + */ + "InlineCastsTransform" should "not produce broken Verilog" in { + val input = + s"""circuit literalsel_fir: + | module literalsel_fir: + | input i: UInt<4> + | output o: SInt<8> + | o <= pad(asSInt(UInt<2>("h1")), 8) + |""".stripMargin + firrtlEquivalenceTest(input, Seq(new InlineAcrossCastsTransform)) + } + + it should "not inline complex expressions into other complex expressions" in { + val input = + """circuit NeverInlineComplexIntoComplex : + | module NeverInlineComplexIntoComplex : + | input a : SInt<3> + | input b : UInt<2> + | input c : UInt<2> + | input sel : UInt<1> + | output out : SInt<3> + | node diff = sub(b, c) + | out <= mux(sel, a, asSInt(diff)) + |""".stripMargin + val expected = + """module NeverInlineComplexIntoComplexRef( + | input [2:0] a, + | input [1:0] b, + | input [1:0] c, + | input sel, + | output [2:0] out + |); + | wire [2:0] diff = b - c; + | assign out = sel ? $signed(a) : $signed(diff); + |endmodule + |""".stripMargin + firrtlEquivalenceWithVerilog(input, expected) + } + + it should "inline casts on both sides of a more complex expression" in { + val input = + """circuit test : + | module test : + | input clock : Clock + | input in : UInt<8> + | output out : UInt<8> + | + | node _T_1 = asUInt(clock) + | node _T_2 = not(_T_1) + | node clock_n = asClock(_T_2) + | reg r : UInt<8>, clock_n + | r <= in + | out <= r + |""".stripMargin + val verilog = compileToVerilogCircuitState(input) + verilog should containLine("always @(posedge clock_n) begin") + + } +} diff --git a/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala b/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala index 02ac3cd0d1..e11c428161 100644 --- a/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala +++ b/src/test/scala/firrtlTests/InlineBooleanExpressionsSpec.scala @@ -392,6 +392,22 @@ class InlineBooleanExpressionsSpec extends FirrtlFlatSpec { firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions)) } + // https://github.com/chipsalliance/firrtl/issues/2035 + // This is interesting because other ways of trying to express this get split out by + // SplitExpressions and don't get inlined again + // If we were to inline more expressions (ie. not just boolean ones) the issue this represents + // would come up more often + it should "handle cvt nested inside of a dshl" in { + val input = + """circuit DshlCvt: + | module DshlCvt: + | input a: UInt<4> + | input b: SInt<1> + | output o: UInt + | o <= dshl(a, asUInt(cvt(b)))""".stripMargin + firrtlEquivalenceTest(input, Seq(new InlineBooleanExpressions)) + } + it should s"respect --${PrettyNoExprInlining.longOption}" in { val input = """circuit Top : diff --git a/src/test/scala/firrtlTests/InlineCastsSpec.scala b/src/test/scala/firrtlTests/InlineCastsSpec.scala deleted file mode 100644 index e27020e5ed..0000000000 --- a/src/test/scala/firrtlTests/InlineCastsSpec.scala +++ /dev/null @@ -1,29 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package firrtlTests - -import firrtl.transforms.InlineCastsTransform -import firrtl.testutils.FirrtlFlatSpec - -/* - * Note: InlineCasts is still part of mverilog, so this test must both: - * - Test that the InlineCasts fix is effective given the current mverilog - * - Provide a test that will be robust if and when InlineCasts is no longer run in mverilog - * - * This is why the test passes InlineCasts as a custom transform: to future-proof it so that - * it can do real LEC against no-InlineCasts. It currently is just a sanity check that the - * emitted Verilog is legal, but it will automatically become a more meaningful test when - * InlineCasts is not run in mverilog. - */ -class InlineCastsEquivalenceSpec extends FirrtlFlatSpec { - "InlineCastsTransform" should "not produce broken Verilog" in { - val input = - s"""circuit literalsel_fir: - | module literalsel_fir: - | input i: UInt<4> - | output o: SInt<8> - | o <= pad(asSInt(UInt<2>("h1")), 8) - |""".stripMargin - firrtlEquivalenceTest(input, Seq(new InlineCastsTransform)) - } -} diff --git a/src/test/scala/firrtlTests/InlineInstancesTests.scala b/src/test/scala/firrtlTests/InlineInstancesTests.scala index 6bee2b7759..cc7257d242 100644 --- a/src/test/scala/firrtlTests/InlineInstancesTests.scala +++ b/src/test/scala/firrtlTests/InlineInstancesTests.scala @@ -460,6 +460,60 @@ class InlineInstancesTests extends LowTransformSpec { ) } + "inlining named statements" should "work" in { + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline + | i.clock <= clock + | i.a <= a + | b <= i.b + | module Inline : + | input clock : Clock + | input a : UInt<32> + | output b : UInt<32> + | b <= a + | assert(clock, UInt(1), eq(a,b), "a == b") : assert1 + | assert(clock, UInt(1), not(eq(a,b)), "a != b") + | stop(clock, UInt(0), 0) + |""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input a : UInt<32> + | output b : UInt<32> + | wire i_clock : Clock + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | i_b <= i_a + | assert(i_clock, UInt(1), eq(i_a, i_b), "a == b") : i_assert1 + | assert(i_clock, UInt(1), not(eq(i_a, i_b)), "a != b") + | stop(i_clock, UInt(0), 0) + | b <= i_b + | i_clock <= clock + | i_a <= a + |""".stripMargin + val top = CircuitTarget("Top").module("Top") + val inlined = top.instOf("i", "Inline") + + executeWithAnnos( + input, + check, + Seq( + inline("Inline"), + NoCircuitDedupAnnotation, + DummyAnno(inlined.ref("assert1")) + ), + Seq( + DummyAnno(top.ref("i_assert1")) + ) + ) + } + "inlining both grandparent and grandchild" should "should work" in { val input = """circuit Top : diff --git a/src/test/scala/firrtlTests/LegalizeSpec.scala b/src/test/scala/firrtlTests/LegalizeSpec.scala index 905d578e4b..ad85668ef8 100644 --- a/src/test/scala/firrtlTests/LegalizeSpec.scala +++ b/src/test/scala/firrtlTests/LegalizeSpec.scala @@ -2,6 +2,8 @@ package firrtlTests -import firrtl.testutils.ExecutionTest +import firrtl.testutils.{ExecutionTest, ExecutionTestNoOpt} class LegalizeExecutionTest extends ExecutionTest("Legalize", "/passes/Legalize") +// Legalize also needs to work when optimizations are turned off +class LegalizeExecutionTestNoOpt extends ExecutionTestNoOpt("Legalize", "/passes/Legalize") diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala index 9e58b74c59..6e774d181c 100644 --- a/src/test/scala/firrtlTests/LowerTypesSpec.scala +++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala @@ -321,6 +321,25 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { executeTest(input, expected) } + it should "rename nodes colliding with labled statements" in { + val input = + """circuit Test : + | module Test : + | input clock : Clock + | reg x : { b : UInt<1>, c : { d : UInt<2>, e : UInt<3>}[2], c_1_e : UInt<4>}[2], clock + | node a = x + | printf(clock, UInt(1), "") : a_0_c_ + | assert(clock, UInt(1), UInt(1), "") : a__0 + """.stripMargin + val expected = Seq( + "node a___0_b = x_0_b", + "node a___1_c__1_e = x_1_c__1_e", + "node a___1_c_1_e = x_1_c_1_e" + ) + + executeTest(input, expected) + } + it should "rename DefRegister expressions: clock, reset, and init" in { val input = """circuit Test : @@ -447,10 +466,27 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { | input a : { b : UInt<1>, flip c : { d : UInt<2>, e : UInt<3>}[2], c_1_e : UInt<4>}[2] | output a_0_b : UInt<1> | input a__0_c_ : { d : UInt<2>, e : UInt<3>}[2] - | a_0_b <= mux(a[UInt(0)].c_1_e, or(a[or(a[0].b, a[1].b)].b, xorr(a[0].c_1_e)), orr(cat(a__0_c_[0].e, a[1].c_1_e))) + | a_0_b <= mux(bits(a[UInt(0)].c_1_e, 0, 0), or(a[or(a[0].b, a[1].b)].b, xorr(a[0].c_1_e)), orr(cat(a__0_c_[0].e, a[1].c_1_e))) """.stripMargin val expected = Seq( - "a_0_b <= mux(a___0_c_1_e, or(_a_or_b, xorr(a___0_c_1_e)), orr(cat(a__0_c__0_e, a___1_c_1_e)))" + "a_0_b <= mux(bits(a___0_c_1_e, 0, 0), or(_a_or_b, xorr(a___0_c_1_e)), orr(cat(a__0_c__0_e, a___1_c_1_e)))" + ) + + executeTest(input, expected) + } + + it should "remove index express in SubAccess" in { + val input = + s"""circuit Bug : + | module Bug : + | input in0 : UInt<1> [2][2] + | input in1 : UInt<1> [2] + | input in2 : UInt<1> [2] + | output out : UInt<1> + | out <= in0[in1[in2[0]]][in1[in2[1]]] + |""".stripMargin + val expected = Seq( + "out <= _in0_in1_in1_in2_1" ) executeTest(input, expected) diff --git a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala index 54f0af8e42..d56ca65743 100644 --- a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala +++ b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala @@ -155,7 +155,10 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { it should "replicate the old order" in { val tm = new TransformManager(Forms.WorkingIR, Forms.MinimalHighForm) - compare(legacyTransforms(new firrtl.IRToWorkingIR), tm) + val patches = Seq( + Del(1) + ) + compare(legacyTransforms(new firrtl.IRToWorkingIR), tm, patches) } behavior.of("ResolveAndCheck") @@ -244,7 +247,7 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { new firrtl.transforms.ReplaceTruncatingArithmetic, new firrtl.transforms.InlineBitExtractionsTransform, new firrtl.transforms.PropagatePresetAnnotations, - new firrtl.transforms.InlineCastsTransform, + new firrtl.transforms.InlineAcrossCastsTransform, new firrtl.transforms.LegalizeClocksTransform, new firrtl.transforms.FlattenRegUpdate, firrtl.passes.VerilogModulusCleanup, @@ -268,7 +271,7 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { new firrtl.transforms.ReplaceTruncatingArithmetic, new firrtl.transforms.InlineBitExtractionsTransform, new firrtl.transforms.PropagatePresetAnnotations, - new firrtl.transforms.InlineCastsTransform, + new firrtl.transforms.InlineAcrossCastsTransform, new firrtl.transforms.LegalizeClocksTransform, new firrtl.transforms.FlattenRegUpdate, new firrtl.transforms.DeadCodeElimination, diff --git a/src/test/scala/firrtlTests/MemoryInitSpec.scala b/src/test/scala/firrtlTests/MemoryInitSpec.scala index 18f2b7ffda..44f0162ef5 100644 --- a/src/test/scala/firrtlTests/MemoryInitSpec.scala +++ b/src/test/scala/firrtlTests/MemoryInitSpec.scala @@ -4,7 +4,7 @@ package firrtlTests import firrtl._ import firrtl.annotations._ -import firrtl.testutils.FirrtlCheckers.containLine +import firrtl.testutils.FirrtlCheckers.{containLine, containLines} import firrtl.testutils.FirrtlFlatSpec import firrtlTests.execution._ @@ -165,6 +165,61 @@ class MemInitSpec extends FirrtlFlatSpec { assert(annos == Seq(MemoryArrayInitAnnotation(mRef, largeSeq))) } + "MemoryFileInlineAnnotation" should "emit $readmemh for text.hex" in { + val annos = Seq(MemoryFileInlineAnnotation(mRef, filename = "text.hex")) + val result = compile(annos) + result should containLine("""$readmemh("text.hex", """ + mRef.name + """);""") + } + + "MemoryFileInlineAnnotation" should "emit $readmemb for text.bin" in { + val annos = Seq(MemoryFileInlineAnnotation(mRef, filename = "text.bin", hexOrBinary = MemoryLoadFileType.Binary)) + val result = compile(annos) + result should containLine("""$readmemb("text.bin", """ + mRef.name + """);""") + } + + "MemoryFileInlineAnnotation" should "fail with blank filename" in { + assertThrows[Exception] { + compile(Seq(MemoryFileInlineAnnotation(mRef, filename = ""))) + } + } + + "MemoryInitialization" should "emit readmem in `ifndef SYNTHESIS` block by default" in { + val annos = Seq( + MemoryFileInlineAnnotation(mRef, filename = "text.hex", hexOrBinary = MemoryLoadFileType.Hex) + ) + val result = compile(annos) + result should containLines( + """`endif // RANDOMIZE""", + """$readmemh("text.hex", """ + mRef.name + """);""", + """end // initial""" + ) + } + + "MemoryInitialization" should "emit readmem outside `ifndef SYNTHESIS` block with MemorySynthInit annotation" in { + val annos = Seq( + MemoryFileInlineAnnotation(mRef, filename = "text.hex", hexOrBinary = MemoryLoadFileType.Hex) + ) ++ Seq(MemorySynthInit) + val result = compile(annos) + result should containLines( + """`endif // SYNTHESIS""", + """initial begin""", + """$readmemh("text.hex", """ + mRef.name + """);""", + """end""" + ) + } + + "MemoryInitialization" should "emit readmem outside `ifndef SYNTHESIS` block with MemoryNoSynthInit annotation" in { + val annos = Seq( + MemoryFileInlineAnnotation(mRef, filename = "text.hex", hexOrBinary = MemoryLoadFileType.Hex) + ) ++ Seq(MemoryNoSynthInit) + + val result = compile(annos) + result should containLines( + """`endif // RANDOMIZE""", + """$readmemh("text.hex", """ + mRef.name + """);""", + """end // initial""" + ) + } } abstract class MemInitExecutionSpec(values: Seq[Int], init: ReferenceTarget => Annotation) diff --git a/src/test/scala/firrtlTests/NegSpec.scala b/src/test/scala/firrtlTests/NegSpec.scala new file mode 100644 index 0000000000..c60294e3bc --- /dev/null +++ b/src/test/scala/firrtlTests/NegSpec.scala @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtlTests + +import firrtl.testutils._ + +class NegSpec extends FirrtlFlatSpec { + "unsigned neg" should "be correct and lint-clean" in { + val input = + """|circuit UnsignedNeg : + | module UnsignedNeg : + | input in : UInt<8> + | output out : SInt + | out <= neg(in) + |""".stripMargin + val expected = + """|module UnsignedNegRef( + | input [7:0] in, + | output [8:0] out + |); + | assign out = 8'd0 - in; + |endmodule""".stripMargin + firrtlEquivalenceWithVerilog(input, expected) + lintVerilog(compileToVerilog(input)) + } + + "signed neg" should "be correct and lint-clean" in { + val input = + """|circuit SignedNeg : + | module SignedNeg : + | input in : SInt<8> + | output out : SInt + | out <= neg(in) + |""".stripMargin + // -$signed(in) is a lint warning in Verilator but is functionally correct + val expected = + """|module SignedNegRef( + | input [7:0] in, + | output [8:0] out + |); + | assign out = -$signed(in); + |endmodule""".stripMargin + firrtlEquivalenceWithVerilog(input, expected) + lintVerilog(compileToVerilog(input)) + } +} diff --git a/src/test/scala/firrtlTests/ParserSpec.scala b/src/test/scala/firrtlTests/ParserSpec.scala index 373b960c65..ba61b134e4 100644 --- a/src/test/scala/firrtlTests/ParserSpec.scala +++ b/src/test/scala/firrtlTests/ParserSpec.scala @@ -147,6 +147,27 @@ class ParserSpec extends FirrtlFlatSpec { } } + // ********** Statement labels ********** + it should "allow certain statement to have a label" in { + val prelude = Seq("circuit top :", " module top :", " input c : Clock") + val statements = Seq("stop(c, UInt(1), 0)", "printf(c, UInt(1), \"\")") ++ + Seq("assert", "assume", "cover").map(_ + "(c, UInt(1), UInt(1), \"\")") + val validLabels = Seq(":test" -> "test", " :test" -> "test", " : test" -> "test", " : test01" -> "test01") + statements.foreach { stmt => + validLabels.foreach { + case (lbl, expected) => + val line = " " + stmt + lbl + val src = (prelude :+ line).mkString("\n") + "\n" + val res = firrtl.Parser.parse(src) + CircuitState(res, Nil) should containTree { + case s: Stop => s.name == expected + case s: Print => s.name == expected + case s: Verification => s.name == expected + } + } + } + } + // ********** Keywords ********** "Keywords" should "be allowed as Ids" in { import KeywordTests._ diff --git a/src/test/scala/firrtlTests/ProtoBufSpec.scala b/src/test/scala/firrtlTests/ProtoBufSpec.scala index d56ef7b1c1..e590994492 100644 --- a/src/test/scala/firrtlTests/ProtoBufSpec.scala +++ b/src/test/scala/firrtlTests/ProtoBufSpec.scala @@ -218,6 +218,19 @@ class ProtoBufSpec extends FirrtlFlatSpec { FromProto.convert(ToProto.convert(vi).build) should equal(expected) } + it should "support Verification" in { + val clk = ir.Reference("clk", UnknownType) + val pred = ir.Reference("pred", UnknownType) + val en = ir.Reference("en", UnknownType) + val assert = ir.Verification(ir.Formal.Assert, ir.NoInfo, clk, pred, en, ir.StringLit("my assert message")) + val assume = ir.Verification(ir.Formal.Assume, ir.NoInfo, clk, pred, en, ir.StringLit("my assume message")) + val cover = ir.Verification(ir.Formal.Cover, ir.NoInfo, clk, pred, en, ir.StringLit("my cover message")) + + FromProto.convert(ToProto.convert(assert).head.build) should equal(assert) + FromProto.convert(ToProto.convert(assume).head.build) should equal(assume) + FromProto.convert(ToProto.convert(cover).head.build) should equal(cover) + } + it should "appropriately escape and unescape FileInfo strings" in { val pairs = Seq( "test\\ntest" -> "test\ntest", diff --git a/src/test/scala/firrtlTests/RenameMapSpec.scala b/src/test/scala/firrtlTests/RenameMapSpec.scala index 29466c72ed..bebeb0bf8e 100644 --- a/src/test/scala/firrtlTests/RenameMapSpec.scala +++ b/src/test/scala/firrtlTests/RenameMapSpec.scala @@ -5,6 +5,8 @@ package firrtlTests import firrtl.RenameMap import firrtl.RenameMap.IllegalRenameException import firrtl.annotations._ +import firrtl.annotations.TargetToken.{Instance, OfModule} +import firrtl.analyses.InstanceKeyGraph import firrtl.testutils._ class RenameMapSpec extends FirrtlFlatSpec { diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index d21f80c8c6..4e00cb3ad8 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -3,14 +3,13 @@ package firrtlTests import firrtl._ +import firrtl.annotations._ import firrtl.ir._ import firrtl.passes._ -import firrtl.transforms._ import firrtl.passes.memlib._ -import firrtl.FileUtils +import firrtl.testutils.FirrtlCheckers._ import firrtl.testutils._ -import annotations._ -import FirrtlCheckers._ +import firrtl.transforms._ class ReplSeqMemSpec extends SimpleTransformSpec { def emitter = new LowFirrtlEmitter @@ -30,9 +29,11 @@ class ReplSeqMemSpec extends SimpleTransformSpec { } ) - def checkMemConf(filename: String, mems: Set[MemConf]) { + def checkMemConf(circuitState: CircuitState, mems: Set[MemConf]) { // Read the mem conf - val text = FileUtils.getText(filename) + val text = circuitState.annotations.collectFirst { + case a: MemLibOutConfigFileAnnotation => a.getBytes.map(_.toChar).mkString + }.get // Verify that this does not throw an exception val fromConf = MemConf.fromString(text) // Verify the mems in the conf are the same as the expected ones @@ -74,7 +75,7 @@ circuit Top : // Check correctness of firrtl parse(res.getEmittedCircuit.value) // Check the emitted conf - checkMemConf(confLoc, mems) + checkMemConf(res, mems) (new java.io.File(confLoc)).delete() } @@ -99,7 +100,7 @@ circuit Top : // Check correctness of firrtl parse(res.getEmittedCircuit.value) // Check the emitted conf - checkMemConf(confLoc, mems) + checkMemConf(res, mems) (new java.io.File(confLoc)).delete() } @@ -127,7 +128,7 @@ circuit CustomMemory : // Check correctness of firrtl parse(res.getEmittedCircuit.value) // Check the emitted conf - checkMemConf(confLoc, mems) + checkMemConf(res, mems) (new java.io.File(confLoc)).delete() } @@ -155,7 +156,7 @@ circuit CustomMemory : // Check correctness of firrtl parse(res.getEmittedCircuit.value) // Check the emitted conf - checkMemConf(confLoc, mems) + checkMemConf(res, mems) (new java.io.File(confLoc)).delete() } @@ -244,7 +245,7 @@ circuit CustomMemory : } numExtMods should be(2) // Check the emitted conf - checkMemConf(confLoc, mems) + checkMemConf(res, mems) (new java.io.File(confLoc)).delete() } @@ -291,7 +292,7 @@ circuit CustomMemory : } numExtMods should be(2) // Check the emitted conf - checkMemConf(confLoc, mems) + checkMemConf(res, mems) (new java.io.File(confLoc)).delete() } @@ -353,7 +354,7 @@ circuit CustomMemory : // would be 3 ExtModules numExtMods should be(2) // Check the emitted conf - checkMemConf(confLoc, mems) + checkMemConf(res, mems) (new java.io.File(confLoc)).delete() } @@ -413,7 +414,7 @@ circuit CustomMemory : val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) res.getEmittedCircuit.value shouldNot include("mask") // Check the emitted conf - checkMemConf(confLoc, mems) + checkMemConf(res, mems) (new java.io.File(confLoc)).delete() } @@ -422,7 +423,7 @@ circuit CustomMemory : circuit CustomMemory : module CustomMemory : input clock : Clock - output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<8>[2] } + output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<1>[2] } smem mem : UInt<8>[2][1024] read mport r = mem[io.raddr], clock @@ -443,7 +444,7 @@ circuit CustomMemory : res should containLine("mem.W0_mask_0 <= validif(io_en, io_mask_0)") res should containLine("mem.W0_mask_1 <= validif(io_en, io_mask_1)") // Check the emitted conf - checkMemConf(confLoc, mems) + checkMemConf(res, mems) (new java.io.File(confLoc)).delete() } @@ -452,7 +453,7 @@ circuit CustomMemory : circuit CustomMemory : module CustomMemory : input clock : Clock - output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<8>[2] } + output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<1>[2] } io.out is invalid @@ -477,7 +478,7 @@ circuit CustomMemory : res should containLine("mem.RW0_wmask_0 <= validif(io_en, io_mask_0)") res should containLine("mem.RW0_wmask_1 <= validif(io_en, io_mask_1)") // Check the emitted conf - checkMemConf(confLoc, mems) + checkMemConf(res, mems) (new java.io.File(confLoc)).delete() } @@ -498,7 +499,7 @@ circuit NoMemsHere : val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:" + confLoc), InferReadWriteAnnotation) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // Check the emitted conf - checkMemConf(confLoc, mems) + checkMemConf(res, mems) (new java.io.File(confLoc)).delete() } diff --git a/src/test/scala/firrtlTests/SeparateWriteClocksSpec.scala b/src/test/scala/firrtlTests/SeparateWriteClocksSpec.scala new file mode 100644 index 0000000000..476a3ae2de --- /dev/null +++ b/src/test/scala/firrtlTests/SeparateWriteClocksSpec.scala @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtlTests + +import firrtl._ +import firrtl.ir._ +import firrtl.passes.memlib.SeparateWriteClocks +import firrtl.testutils._ +import firrtl.testutils.FirrtlCheckers._ + +class SeparateWriteClocksSpec extends FirrtlFlatSpec { + def transform(input: String): CircuitState = { + val csx = (new SeparateWriteClocks).execute(CircuitState(parse(input), MidForm)) + val emittedCirc = EmittedFirrtlCircuit("top", csx.circuit.serialize, ".fir") + csx.copy(annotations = Seq(EmittedFirrtlCircuitAnnotation(emittedCirc))) + } + + behavior.of("SeparateWriteClocks") + + it should "add intermediate wires to clocks of multi-write sync-read memories" in { + val result = transform(s""" + |circuit top: + | module top: + | input clk: Clock + | input raddr: UInt<10> + | output rdata: UInt<8>[4] + | input waddr_a: UInt<10> + | input we_a: UInt<1> + | input wdata_a: UInt<8>[4] + | input waddr_a: UInt<10> + | input we_a: UInt<1> + | input wdata_a: UInt<8>[4] + | + | mem m: + | data-type => UInt<8> + | depth => 1024 + | reader => r + | writer => w_a + | writer => w_b + | read-latency => 1 + | write-latency => 1 + | read-under-write => undefined + | + | m.r.clk <= clk + | m.r.addr <= raddr + | m.r.en <= UInt(1) + | rdata <= m.r.data + | + | m.w_a.clk <= clk + | m.w_a.addr <= waddr_a + | m.w_a.en <= we_a + | m.w_a.mask <= UInt(1) + | m.w_a.data <= wdata_a + | + | m.w_b.clk <= clk + | m.w_b.addr <= waddr_b + | m.w_b.en <= we_b + | m.w_b.mask <= UInt(1) + | m.w_b.data <= wdata_b""".stripMargin) + + println(result.circuit.serialize) + result should containLine("m.r.clk <= clk") + result should containLine("m.w_a.clk <= m_w_a_clk") + result should containLine("m.w_b.clk <= m_w_b_clk") + result shouldNot containLine("m.w_a.clk <= clk") + result shouldNot containLine("m.w_b.clk <= clk") + } +} diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index 0a0df355ba..061837d796 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -189,14 +189,14 @@ class UnitTests extends FirrtlFlatSpec { //TODO(azidar): I realize this is brittle, but unfortunately there // isn't a better way to test this pass val check = Seq( - """wire _table_1 : { a : UInt<8>}""", - """_table_1.a is invalid""", + """wire _table_1_a : UInt<8>""", + """_table_1_a is invalid""", """when UInt<1>("h1") :""", - """_table_1.a <= table[1].a""", + """_table_1_a <= table[1].a""", """wire _otherTable_table_1_a_a : UInt<8>""", - """when eq(UInt<1>("h0"), _table_1.a) :""", + """when eq(UInt<1>("h0"), _table_1_a) :""", """otherTable[0].a <= _otherTable_table_1_a_a""", - """when eq(UInt<1>("h1"), _table_1.a) :""", + """when eq(UInt<1>("h1"), _table_1_a) :""", """otherTable[1].a <= _otherTable_table_1_a_a""", """_otherTable_table_1_a_a <= UInt<1>("h0")""" ) diff --git a/src/test/scala/firrtlTests/VerilogEmitterTests.scala b/src/test/scala/firrtlTests/VerilogEmitterTests.scala index 7704a0a20c..ec30e55c87 100644 --- a/src/test/scala/firrtlTests/VerilogEmitterTests.scala +++ b/src/test/scala/firrtlTests/VerilogEmitterTests.scala @@ -754,6 +754,30 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { result should containLine("assign z = _GEN_0[1:0];") } + it should "not pad multiplication" in { + val compiler = new VerilogCompiler + val result = compileBody( + """input x : UInt<2> + |input y: UInt<4> + |output z : UInt<6> + |z <= mul(x, y) + |""".stripMargin + ) + result should containLine("assign z = x * y;") + } + + it should "not pad division" in { + val compiler = new VerilogCompiler + val result = compileBody( + """input x : UInt<4> + |input y: UInt<2> + |output z : UInt<4> + |z <= div(x, y) + |""".stripMargin + ) + result should containLine("assign z = x / y;") + } + it should "correctly emit addition with a negative literal with width > 32" in { val result = compileBody( """input x : SInt<34> diff --git a/src/test/scala/firrtlTests/VerilogEquivalenceSpec.scala b/src/test/scala/firrtlTests/VerilogEquivalenceSpec.scala new file mode 100644 index 0000000000..747f668990 --- /dev/null +++ b/src/test/scala/firrtlTests/VerilogEquivalenceSpec.scala @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtlTests + +import firrtl.testutils._ + +class VerilogEquivalenceSpec extends FirrtlFlatSpec { + "mul followed by cat" should "be correct" in { + val header = s""" + |circuit Multiply : + | module Multiply : + | input x : UInt<4> + | input y : UInt<2> + | input z : UInt<2> + | output out : UInt<8> + |""".stripMargin + val input1 = header + """ + | out <= cat(z, mul(x, y))""".stripMargin + val input2 = header + """ + | node n = mul(x, y) + | node m = cat(z, n) + | out <= m""".stripMargin + val expected = s""" + |module MultiplyRef( + | input [3:0] x, + | input [1:0] y, + | input [1:0] z, + | output [7:0] out + |); + | wire [5:0] w = x * y; + | assign out = {z, w}; + |endmodule""".stripMargin + firrtlEquivalenceWithVerilog(input1, expected) + firrtlEquivalenceWithVerilog(input2, expected) + } + + "div followed by cat" should "be correct" in { + val header = s""" + |circuit Divide : + | module Divide : + | input x : UInt<4> + | input y : UInt<2> + | input z : UInt<2> + | output out : UInt<6> + |""".stripMargin + val input1 = header + """ + | out <= cat(z, div(x, y))""".stripMargin + val input2 = header + """ + | node n = div(x, y) + | node m = cat(z, n) + | out <= m""".stripMargin + val expected = s""" + |module DivideRef( + | input [3:0] x, + | input [1:0] y, + | input [1:0] z, + | output [5:0] out + |); + | wire [3:0] w = x / y; + | assign out = {z, w}; + |endmodule""".stripMargin + firrtlEquivalenceWithVerilog(input1, expected) + firrtlEquivalenceWithVerilog(input2, expected) + } + + "signed mul followed by cat" should "be correct" in { + val header = s""" + |circuit SignedMultiply : + | module SignedMultiply : + | input x : SInt<4> + | input y : SInt<2> + | input z : SInt<2> + | output out : UInt<8> + |""".stripMargin + val input1 = header + """ + | out <= cat(z, mul(x, y))""".stripMargin + val input2 = header + """ + | node n = mul(x, y) + | node m = cat(z, n) + | out <= m""".stripMargin + val expected = s""" + |module SignedMultiplyRef( + | input signed [3:0] x, + | input signed [1:0] y, + | input signed [1:0] z, + | output [7:0] out + |); + | wire [5:0] w = x * y; + | assign out = {z, w}; + |endmodule""".stripMargin + firrtlEquivalenceWithVerilog(input1, expected) + firrtlEquivalenceWithVerilog(input2, expected) + } + + "signed div followed by cat" should "be correct" in { + val header = s""" + |circuit SignedDivide : + | module SignedDivide : + | input x : SInt<4> + | input y : SInt<2> + | input z : SInt<2> + | output out : UInt<7> + |""".stripMargin + val input1 = header + """ + | out <= cat(z, div(x, y))""".stripMargin + val input2 = header + """ + | node n = div(x, y) + | node m = cat(z, n) + | out <= m""".stripMargin + val expected = s""" + |module SignedDivideRef( + | input signed [3:0] x, + | input signed [1:0] y, + | input signed [1:0] z, + | output [6:0] out + |); + | wire [4:0] w = x / y; + | assign out = {z, w}; + |endmodule""".stripMargin + firrtlEquivalenceWithVerilog(input1, expected) + firrtlEquivalenceWithVerilog(input2, expected) + } +} diff --git a/src/test/scala/firrtlTests/ZeroWidthTests.scala b/src/test/scala/firrtlTests/ZeroWidthTests.scala index 99ebbdd350..654c6f4294 100644 --- a/src/test/scala/firrtlTests/ZeroWidthTests.scala +++ b/src/test/scala/firrtlTests/ZeroWidthTests.scala @@ -220,6 +220,86 @@ class ZeroWidthTests extends FirrtlFlatSpec { | x <= UInt<1>(1)""".stripMargin (parse(exec(input))) should be(parse(check)) } + + "Cat of SInt with zero-width" should "keep type correctly" in { + val input = + """circuit Top : + | module Top : + | input x : SInt<0> + | input y : SInt<1> + | output z : UInt<1> + | z <= cat(y, x)""".stripMargin + val check = + """circuit Top : + | module Top : + | input y : SInt<1> + | output z : UInt<1> + | z <= asUInt(y)""".stripMargin + (parse(exec(input))) should be(parse(check)) + } + + "Memories with zero-width data-type" should "be fully removed" in { + val input = + """circuit Foo: + | module Foo: + | input clock: Clock + | input rAddr: UInt<4> + | input rEn: UInt<1> + | output rData: UInt<0> + | input wAddr: UInt<4> + | input wEn: UInt<1> + | input wMask: UInt<1> + | input wData: UInt<0> + | input rwEn: UInt<1> + | input rwMode: UInt<1> + | input rwAddr: UInt<1> + | input rwMask: UInt<1> + | input rwDataIn: UInt<0> + | output rwDataOut: UInt<0> + | + | mem memory: + | data-type => UInt<0> + | depth => 16 + | reader => r + | writer => w + | readwriter => rw + | read-latency => 0 + | write-latency => 1 + | read-under-write => undefined + | + | memory.r.clk <= clock + | memory.r.en <= rEn + | memory.r.addr <= rAddr + | rData <= memory.r.data + | memory.w.clk <= clock + | memory.w.en <= wEn + | memory.w.addr <= wAddr + | memory.w.mask <= wMask + | memory.w.data <= wData + | memory.rw.clk <= clock + | memory.rw.en <= rwEn + | memory.rw.addr <= rwAddr + | memory.rw.wmode <= rwMode + | memory.rw.wmask <= rwMask + | memory.rw.wdata <= rwDataIn + | rwDataOut <= memory.rw.rdata""".stripMargin + val check = + s"""circuit Foo: + | module Foo: + | input clock: Clock + | input rAddr: UInt<4> + | input rEn: UInt<1> + | input wAddr: UInt<4> + | input wEn: UInt<1> + | input wMask: UInt<1> + | input rwEn: UInt<1> + | input rwMode: UInt<1> + | input rwAddr: UInt<1> + | input rwMask: UInt<1> + | + |${Seq.tabulate(17)(_ => " skip").mkString("\n")}""".stripMargin + parse(exec(input)) should be(parse(check)) + } } class ZeroWidthVerilog extends FirrtlFlatSpec { diff --git a/src/test/scala/firrtlTests/annotationTests/TargetUtilsSpec.scala b/src/test/scala/firrtlTests/annotationTests/TargetUtilsSpec.scala new file mode 100644 index 0000000000..38266efeb1 --- /dev/null +++ b/src/test/scala/firrtlTests/annotationTests/TargetUtilsSpec.scala @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtlTests.annotationTests + +import firrtl.analyses.InstanceKeyGraph.InstanceKey +import firrtl.annotations._ +import firrtl.annotations.TargetToken._ +import firrtl.annotations.TargetUtils._ +import firrtl.testutils.FirrtlFlatSpec + +class TargetUtilsSpec extends FirrtlFlatSpec { + + behavior.of("instKeyPathToTarget") + + it should "create a ModuleTarget for the top module" in { + val input = InstanceKey("Top", "Top") :: Nil + val expected = ModuleTarget("Top", "Top") + instKeyPathToTarget(input) should be(expected) + } + + it should "create absolute InstanceTargets" in { + val input = InstanceKey("Top", "Top") :: + InstanceKey("foo", "Foo") :: + InstanceKey("bar", "Bar") :: + Nil + val expected = InstanceTarget("Top", "Top", Seq((Instance("foo"), OfModule("Foo"))), "bar", "Bar") + instKeyPathToTarget(input) should be(expected) + } + + it should "support starting somewhere down the path" in { + val input = InstanceKey("Top", "Top") :: + InstanceKey("foo", "Foo") :: + InstanceKey("bar", "Bar") :: + InstanceKey("fizz", "Fizz") :: + Nil + val expected = InstanceTarget("Top", "Bar", Seq(), "fizz", "Fizz") + instKeyPathToTarget(input, Some("Bar")) should be(expected) + } + + behavior.of("unfoldInstanceTargets") + + it should "return nothing for ModuleTargets" in { + val input = ModuleTarget("Top", "Foo") + unfoldInstanceTargets(input) should be(Iterable()) + } + + it should "return all other InstanceTargets to the same instance" in { + val input = ModuleTarget("Top", "Top").instOf("foo", "Foo").instOf("bar", "Bar").instOf("fizz", "Fizz") + val expected = + input :: + ModuleTarget("Top", "Foo").instOf("bar", "Bar").instOf("fizz", "Fizz") :: + ModuleTarget("Top", "Bar").instOf("fizz", "Fizz") :: + Nil + unfoldInstanceTargets(input) should be(expected) + } +} diff --git a/src/test/scala/firrtlTests/passes/RemoveAccessesSpec.scala b/src/test/scala/firrtlTests/passes/RemoveAccessesSpec.scala new file mode 100644 index 0000000000..1f1f19680f --- /dev/null +++ b/src/test/scala/firrtlTests/passes/RemoveAccessesSpec.scala @@ -0,0 +1,256 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtlTests +package passes + +import firrtl._ +import firrtl.testutils._ +import firrtl.stage.TransformManager +import firrtl.options.Dependency +import firrtl.passes._ + +class RemoveAccessesSpec extends FirrtlFlatSpec { + def compile(input: String): String = { + val manager = new TransformManager(Dependency(RemoveAccesses) :: Nil) + val result = manager.execute(CircuitState(parse(input), Nil)) + val checks = List( + CheckHighForm, + CheckTypes, + CheckFlows + ) + for (check <- checks) { check.run(result.circuit) } + result.circuit.serialize + } + def circuit(body: String): String = { + """|circuit Test : + | module Test : + |""".stripMargin + body.stripMargin.split("\n").mkString(" ", "\n ", "\n") + } + + behavior.of("RemoveAccesses") + + it should "handle a simple RHS subaccess" in { + val input = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |out <= in[idx]""" + ) + val expected = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |wire _in_idx : UInt<8> + |_in_idx is invalid + |when eq(UInt<1>("h0"), idx) : + | _in_idx <= in[0] + |when eq(UInt<1>("h1"), idx) : + | _in_idx <= in[1] + |when eq(UInt<2>("h2"), idx) : + | _in_idx <= in[2] + |when eq(UInt<2>("h3"), idx) : + | _in_idx <= in[3] + |out <= _in_idx""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "support complex expressions" in { + val input = circuit( + s"""|input clock : Clock + |input in : UInt<8>[4] + |input idx : UInt<2> + |input sel : UInt<1> + |output out : UInt<8> + |reg r : UInt<2>, clock + |out <= in[mux(sel, r, idx)] + |r <= not(idx)""" + ) + val expected = circuit( + s"""|input clock : Clock + |input in : UInt<8>[4] + |input idx : UInt<2> + |input sel : UInt<1> + |output out : UInt<8> + |reg r : UInt<2>, clock + |wire _in_mux : UInt<8> + |_in_mux is invalid + |when eq(UInt<1>("h0"), mux(sel, r, idx)) : + | _in_mux <= in[0] + |when eq(UInt<1>("h1"), mux(sel, r, idx)) : + | _in_mux <= in[1] + |when eq(UInt<2>("h2"), mux(sel, r, idx)) : + | _in_mux <= in[2] + |when eq(UInt<2>("h3"), mux(sel, r, idx)) : + | _in_mux <= in[3] + |out <= _in_mux + |r <= not(idx)""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "support nested subaccesses" in { + val input = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2>[4] + |input jdx : UInt<2> + |output out : UInt<8> + |out <= in[idx[jdx]]""" + ) + val expected = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2>[4] + |input jdx : UInt<2> + |output out : UInt<8> + |wire _idx_jdx : UInt<2> + |_idx_jdx is invalid + |when eq(UInt<1>("h0"), jdx) : + | _idx_jdx <= idx[0] + |when eq(UInt<1>("h1"), jdx) : + | _idx_jdx <= idx[1] + |when eq(UInt<2>("h2"), jdx) : + | _idx_jdx <= idx[2] + |when eq(UInt<2>("h3"), jdx) : + | _idx_jdx <= idx[3] + |wire _in_idx_jdx : UInt<8> + |_in_idx_jdx is invalid + |when eq(UInt<1>("h0"), _idx_jdx) : + | _in_idx_jdx <= in[0] + |when eq(UInt<1>("h1"), _idx_jdx) : + | _in_idx_jdx <= in[1] + |when eq(UInt<2>("h2"), _idx_jdx) : + | _in_idx_jdx <= in[2] + |when eq(UInt<2>("h3"), _idx_jdx) : + | _in_idx_jdx <= in[3] + |out <= _in_idx_jdx""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "avoid name collisions" in { + val input = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |out <= in[idx] + |node _in_idx = not(idx)""" + ) + val expected = circuit( + s"""|input in : UInt<8>[4] + |input idx : UInt<2> + |output out : UInt<8> + |wire _in_idx_0 : UInt<8> + |_in_idx_0 is invalid + |when eq(UInt<1>("h0"), idx) : + | _in_idx_0 <= in[0] + |when eq(UInt<1>("h1"), idx) : + | _in_idx_0 <= in[1] + |when eq(UInt<2>("h2"), idx) : + | _in_idx_0 <= in[2] + |when eq(UInt<2>("h3"), idx) : + | _in_idx_0 <= in[3] + |out <= _in_idx_0 + |node _in_idx = not(idx)""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "handle a simple LHS subaccess" in { + val input = circuit( + s"""|input in : UInt<8> + |input idx : UInt<2> + |output out : UInt<8>[4] + |out[idx] <= in""" + ) + val expected = circuit( + s"""|input in : UInt<8> + |input idx : UInt<2> + |output out : UInt<8>[4] + |wire _out_idx : UInt<8> + |when eq(UInt<1>("h0"), idx) : + | out[0] <= _out_idx + |when eq(UInt<1>("h1"), idx) : + | out[1] <= _out_idx + |when eq(UInt<2>("h2"), idx) : + | out[2] <= _out_idx + |when eq(UInt<2>("h3"), idx) : + | out[3] <= _out_idx + |_out_idx <= in""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "linearly expand RHS subaccesses of aggregate-typed vecs" in { + val input = circuit( + s"""|input in : { foo : UInt<8>, bar : UInt<8> }[4] + |input idx : UInt<2> + |output out : { foo : UInt<8>, bar : UInt<8> } + |out.foo <= in[idx].foo + |out.bar <= in[idx].bar""" + ) + val expected = circuit( + s"""|input in : { foo : UInt<8>, bar : UInt<8>}[4] + |input idx : UInt<2> + |output out : { foo : UInt<8>, bar : UInt<8>} + |wire _in_idx_foo : UInt<8> + |_in_idx_foo is invalid + |when eq(UInt<1>("h0"), idx) : + | _in_idx_foo <= in[0].foo + |when eq(UInt<1>("h1"), idx) : + | _in_idx_foo <= in[1].foo + |when eq(UInt<2>("h2"), idx) : + | _in_idx_foo <= in[2].foo + |when eq(UInt<2>("h3"), idx) : + | _in_idx_foo <= in[3].foo + |out.foo <= _in_idx_foo + |wire _in_idx_bar : UInt<8> + |_in_idx_bar is invalid + |when eq(UInt<1>("h0"), idx) : + | _in_idx_bar <= in[0].bar + |when eq(UInt<1>("h1"), idx) : + | _in_idx_bar <= in[1].bar + |when eq(UInt<2>("h2"), idx) : + | _in_idx_bar <= in[2].bar + |when eq(UInt<2>("h3"), idx) : + | _in_idx_bar <= in[3].bar + |out.bar <= _in_idx_bar""" + ) + compile(input) should be(parse(expected).serialize) + } + + it should "linearly expand LHS subaccesses of aggregate-typed vecs" in { + val input = circuit( + s"""|input in : { foo : UInt<8>, bar : UInt<8> } + |input idx : UInt<2> + |output out : { foo : UInt<8>, bar : UInt<8> }[4] + |out[idx].foo <= in.foo + |out[idx].bar <= in.bar""" + ) + val expected = circuit( + s"""|input in : { foo : UInt<8>, bar : UInt<8> } + |input idx : UInt<2> + |output out : { foo : UInt<8>, bar : UInt<8> }[4] + |wire _out_idx_foo : UInt<8> + |when eq(UInt<1>("h0"), idx) : + | out[0].foo <= _out_idx_foo + |when eq(UInt<1>("h1"), idx) : + | out[1].foo <= _out_idx_foo + |when eq(UInt<2>("h2"), idx) : + | out[2].foo <= _out_idx_foo + |when eq(UInt<2>("h3"), idx) : + | out[3].foo <= _out_idx_foo + |_out_idx_foo <= in.foo + |wire _out_idx_bar : UInt<8> + |when eq(UInt<1>("h0"), idx) : + | out[0].bar <= _out_idx_bar + |when eq(UInt<1>("h1"), idx) : + | out[1].bar <= _out_idx_bar + |when eq(UInt<2>("h2"), idx) : + | out[2].bar <= _out_idx_bar + |when eq(UInt<2>("h3"), idx) : + | out[3].bar <= _out_idx_bar + |_out_idx_bar <= in.bar""" + ) + compile(input) should be(parse(expected).serialize) + } +} diff --git a/src/test/scala/firrtlTests/stage/FirrtlMainSpec.scala b/src/test/scala/firrtlTests/stage/FirrtlMainSpec.scala index ea590d26fd..eaf48b4990 100644 --- a/src/test/scala/firrtlTests/stage/FirrtlMainSpec.scala +++ b/src/test/scala/firrtlTests/stage/FirrtlMainSpec.scala @@ -164,16 +164,6 @@ class FirrtlMainSpec |""".stripMargin } - /** This returns a string containing the default standard out string based on the Scala version. E.g., if there are - * version-specific deprecation warnings, those are available here and can be passed to tests that should have them. - */ - val defaultStdOut: Option[String] = BuildInfo.scalaVersion.split("\\.").toList match { - case "2" :: v :: _ :: Nil if v.toInt <= 11 => - Some(CheckScalaVersion.deprecationMessage("2.11", s"--${WarnNoScalaVersionDeprecation.longOption}")) - case x => - None - } - info("As a FIRRTL command line user") info("I want to compile some FIRRTL") Feature("FirrtlMain command line interface") { @@ -205,58 +195,48 @@ class FirrtlMainSpec Seq( /* Test all standard emitters with and without annotation file outputs */ FirrtlMainTest(args = Array("-X", "none", "-E", "chirrtl"), files = Seq("Top.fir")), - FirrtlMainTest(args = Array("-X", "high", "-E", "high"), stdout = defaultStdOut, files = Seq("Top.hi.fir")), + FirrtlMainTest(args = Array("-X", "high", "-E", "high"), files = Seq("Top.hi.fir")), FirrtlMainTest( args = Array("-X", "middle", "-E", "middle", "-foaf", "Top"), - stdout = defaultStdOut, files = Seq("Top.mid.fir", "Top.anno.json") ), FirrtlMainTest( args = Array("-X", "low", "-E", "low", "-foaf", "annotations.anno.json"), - stdout = defaultStdOut, files = Seq("Top.lo.fir", "annotations.anno.json") ), FirrtlMainTest( args = Array("-X", "verilog", "-E", "verilog", "-foaf", "foo.anno"), - stdout = defaultStdOut, files = Seq("Top.v", "foo.anno.anno.json") ), FirrtlMainTest( args = Array("-X", "sverilog", "-E", "sverilog", "-foaf", "foo.json"), - stdout = defaultStdOut, files = Seq("Top.sv", "foo.json.anno.json") ), /* Test all one file per module emitters */ FirrtlMainTest(args = Array("-X", "none", "-e", "chirrtl"), files = Seq("Top.fir", "Child.fir")), FirrtlMainTest( args = Array("-X", "high", "-e", "high"), - stdout = defaultStdOut, files = Seq("Top.hi.fir", "Child.hi.fir") ), FirrtlMainTest( args = Array("-X", "middle", "-e", "middle"), - stdout = defaultStdOut, files = Seq("Top.mid.fir", "Child.mid.fir") ), FirrtlMainTest( args = Array("-X", "low", "-e", "low"), - stdout = defaultStdOut, files = Seq("Top.lo.fir", "Child.lo.fir") ), FirrtlMainTest( args = Array("-X", "verilog", "-e", "verilog"), - stdout = defaultStdOut, files = Seq("Top.v", "Child.v") ), FirrtlMainTest( args = Array("-X", "sverilog", "-e", "sverilog"), - stdout = defaultStdOut, files = Seq("Top.sv", "Child.sv") ), /* Test mixing of -E with -e */ FirrtlMainTest( args = Array("-X", "middle", "-E", "high", "-e", "middle"), - stdout = defaultStdOut, files = Seq("Top.hi.fir", "Top.mid.fir", "Child.mid.fir"), notFiles = Seq("Child.hi.fir") ), @@ -264,33 +244,27 @@ class FirrtlMainSpec FirrtlMainTest(args = Array("-X", "none", "-E", "chirrtl", "-o", "foo"), files = Seq("foo.fir")), FirrtlMainTest( args = Array("-X", "high", "-E", "high", "-o", "foo"), - stdout = defaultStdOut, files = Seq("foo.hi.fir") ), FirrtlMainTest( args = Array("-X", "middle", "-E", "middle", "-o", "foo.middle"), - stdout = defaultStdOut, files = Seq("foo.middle.mid.fir") ), FirrtlMainTest( args = Array("-X", "low", "-E", "low", "-o", "foo.lo.fir"), - stdout = defaultStdOut, files = Seq("foo.lo.fir") ), FirrtlMainTest( args = Array("-X", "verilog", "-E", "verilog", "-o", "foo.sv"), - stdout = defaultStdOut, files = Seq("foo.sv.v") ), FirrtlMainTest( args = Array("-X", "sverilog", "-E", "sverilog", "-o", "Foo"), - stdout = defaultStdOut, files = Seq("Foo.sv") ), /* Test that an output is generated if no emitter is specified */ FirrtlMainTest( args = Array("-X", "verilog", "-o", "Foo"), - stdout = defaultStdOut, files = Seq("Foo.v") ) ) diff --git a/src/test/scala/firrtlTests/transforms/MustDedupSpec.scala b/src/test/scala/firrtlTests/transforms/MustDedupSpec.scala new file mode 100644 index 0000000000..2f633e0e52 --- /dev/null +++ b/src/test/scala/firrtlTests/transforms/MustDedupSpec.scala @@ -0,0 +1,267 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtlTests.transforms + +import org.scalatest.featurespec.AnyFeatureSpec +import org.scalatest.GivenWhenThen +import firrtl.testutils.FirrtlMatchers +import java.io.File + +import firrtl.graph.DiGraph +import firrtl.analyses.InstanceKeyGraph +import firrtl.annotations.CircuitTarget +import firrtl.annotations.TargetToken.OfModule +import firrtl.transforms._ +import firrtl.transforms.MustDeduplicateTransform._ +import firrtl.transforms.MustDeduplicateTransform.DisjointChildren._ +import firrtl.util.BackendCompilationUtilities.createTestDirectory +import firrtl.stage.{FirrtlSourceAnnotation, RunFirrtlTransformAnnotation} +import firrtl.options.{TargetDirAnnotation} +import logger.{LogLevel, LogLevelAnnotation, Logger} + +class MustDedupSpec extends AnyFeatureSpec with FirrtlMatchers with GivenWhenThen { + + Feature("When you have a simple non-deduping hierarcy") { + val text = """ + |circuit A : + | module C : + | output io : { flip in : UInt<8>, out : UInt<8> } + | io.out <= io.in + | module C_1 : + | output io : { flip in : UInt<8>, out : UInt<8> } + | io.out <= and(io.in, UInt("hff")) + | module B : + | output io : { flip in : UInt<8>, out : UInt<8> } + | inst c of C + | io <= c.io + | module B_1 : + | output io : { flip in : UInt<8>, out : UInt<8> } + | inst c of C_1 + | io <= c.io + | module A : + | output io : { flip in : UInt<8>, out : UInt<8> } + | inst b of B + | inst b_1 of B_1 + | io.out <= and(b.io.out, b_1.io.out) + | b.io.in <= io.in + | b_1.io.in <= io.in + """.stripMargin + val top = CircuitTarget("A") + val bdedup = MustDeduplicateAnnotation(Seq(top.module("B"), top.module("B_1"))) + val igraph = InstanceKeyGraph(parse(text)) + + Scenario("Full compilation should fail and dump reports to disk") { + val testDir = createTestDirectory("must_dedup") + val reportDir = new File(testDir, "reports") + val annos = Seq( + TargetDirAnnotation(testDir.toString), + FirrtlSourceAnnotation(text), + RunFirrtlTransformAnnotation(new MustDeduplicateTransform), + MustDeduplicateReportDirectory(reportDir.toString), + bdedup + ) + + a[DeduplicationFailureException] shouldBe thrownBy { + (new firrtl.stage.FirrtlPhase).transform(annos) + } + + reportDir should exist + + val report0 = new File(reportDir, "report_0.rpt") + report0 should exist + + val expectedModules = Seq("B", "B_1", "C", "C_1") + for (mod <- expectedModules) { + new File(reportDir, s"modules/$mod.fir") should exist + } + } + + Scenario("Non-deduping children should give actionable debug information") { + When("Finding dedup failures") + val failure = findDedupFailures(Seq(OfModule("B"), OfModule("B_1")), igraph) + + Then("The children should appear as a failure candidate") + failure.candidates should be(Seq(LikelyShouldMatch(OfModule("C"), OfModule("C_1")))) + + And("There should be a pretty DiGraph showing context") + val got = makeDedupFailureDiGraph(failure, igraph.graph.transformNodes(_.module)) + val expected = DiGraph("A" -> "(B)", "A" -> "(B_1)", "(B)" -> "C [0]", "(B_1)" -> "C_1 [0]") + // DiGraph uses referential equality so compare serialized form + got.prettyTree() should be(expected.prettyTree()) + } + + Scenario("Unrelated hierarchies should give actionable debug information") { + When("Finding dedup failures") + val failure = findDedupFailures(Seq(OfModule("B"), OfModule("C_1")), igraph) + + Then("The failure should note the hierarchies don't match") + failure.candidates should be(Seq(DisjointChildren(OfModule("B"), OfModule("C_1"), Left))) + + And("There should be a pretty DiGraph showing context") + val got = makeDedupFailureDiGraph(failure, igraph.graph.transformNodes(_.module)) + val expected = DiGraph("A" -> "(B) [0]", "(B) [0]" -> "C", "B_1" -> "(C_1) [0]") + // DiGraph uses referential equality so compare serialized form + got.prettyTree() should be(expected.prettyTree()) + } + } + + Feature("When you have a deep, non-deduping hierarchy") { + // Shadow hierarchy just to get an InstanceKeyGraph which can only be made from a circuit + val text = parse(""" + |circuit A : + | module E: + | skip + | module F : + | skip + | module F_1 : + | inst e of E + | module D : + | skip + | module D_1 : + | skip + | module C : + | inst d of D + | inst f of F + | module C_1 : + | inst d of D_1 + | inst f of F_1 + | module B : + | inst c of C + | inst e of E + | module B_1 : + | inst c of C_1 + | inst e of E + | module A : + | inst b of B + | inst b_1 of B_1 + |""".stripMargin) + val igraph = InstanceKeyGraph(text) + + Scenario("Non-deduping children should give actionable debug information") { + When("Finding dedup failures") + val failure = findDedupFailures(Seq(OfModule("B"), OfModule("B_1")), igraph) + + Then("The children should appear as a failure candidate") + failure.candidates should be( + Seq(LikelyShouldMatch(OfModule("D"), OfModule("D_1")), DisjointChildren(OfModule("F"), OfModule("F_1"), Right)) + ) + + And("There should be a pretty DiGraph showing context") + val got = makeDedupFailureDiGraph(failure, igraph.graph.transformNodes(_.module)) + val expected = DiGraph( + "A" -> "(B)", + "A" -> "(B_1)", + "(B)" -> "C", + "C" -> "D [0]", + "C" -> "F [1]", + "(B_1)" -> "C_1", + "C_1" -> "D_1 [0]", + "C_1" -> "F_1 [1]", + "F_1 [1]" -> "E", + // These last 2 are undesirable but E is included because it's a submodule of disjoint F and F_1 + "(B)" -> "E", + "(B_1)" -> "E" + ) + // DiGraph uses referential equality so compare serialized form + got.prettyTree() should be(expected.prettyTree()) + } + } + + Feature("When you have multiple modules that should dedup, but don't") { + // Shadow hierarchy just to get an InstanceKeyGraph which can only be made from a circuit + val text = parse(""" + |circuit A : + | module D : + | skip + | module D_1 : + | skip + | module C : + | skip + | module C_1 : + | skip + | module B : + | inst c of C + | inst d of D + | module B_1 : + | inst c of C_1 + | inst d of D + | module B_2 : + | inst c of C + | inst d of D_1 + | module A : + | inst b of B + | inst b_1 of B_1 + | inst b_2 of B_2 + |""".stripMargin) + val igraph = InstanceKeyGraph(text) + + Scenario("Non-deduping children should give actionable debug information") { + When("Finding dedup failures") + val failure = findDedupFailures(Seq(OfModule("B"), OfModule("B_1"), OfModule("B_2")), igraph) + + Then("The children should appear as a failure candidate") + failure.candidates should be( + Seq(LikelyShouldMatch(OfModule("C"), OfModule("C_1")), LikelyShouldMatch(OfModule("D"), OfModule("D_1"))) + ) + + And("There should be a pretty DiGraph showing context") + val got = makeDedupFailureDiGraph(failure, igraph.graph.transformNodes(_.module)) + val expected = DiGraph( + "A" -> "(B)", + "A" -> "(B_1)", + "A" -> "(B_2)", + "(B)" -> "C [0]", + "(B)" -> "D [1]", + "(B_1)" -> "C_1 [0]", + "(B_1)" -> "D [1]", + "(B_2)" -> "C [0]", + "(B_2)" -> "D_1 [1]" + ) + // DiGraph uses referential equality so compare serialized form + got.prettyTree() should be(expected.prettyTree()) + } + } + + Feature("When you have modules that should dedup, and they do") { + val text = """ + |circuit A : + | module C : + | output io : { flip in : UInt<8>, out : UInt<8> } + | io.out <= io.in + | module C_1 : + | output io : { flip in : UInt<8>, out : UInt<8> } + | io.out <= io.in + | module B : + | output io : { flip in : UInt<8>, out : UInt<8> } + | inst c of C + | io <= c.io + | module B_1 : + | output io : { flip in : UInt<8>, out : UInt<8> } + | inst c of C_1 + | io <= c.io + | module A : + | output io : { flip in : UInt<8>, out : UInt<8> } + | inst b of B + | inst b_1 of B_1 + | io.out <= and(b.io.out, b_1.io.out) + | b.io.in <= io.in + | b_1.io.in <= io.in + """.stripMargin + val top = CircuitTarget("A") + val bdedup = MustDeduplicateAnnotation(Seq(top.module("B"), top.module("B_1"))) + + Scenario("Full compilation should succeed") { + val testDir = createTestDirectory("must_dedup") + val reportDir = new File(testDir, "reports") + val annos = Seq( + TargetDirAnnotation(testDir.toString), + FirrtlSourceAnnotation(text), + RunFirrtlTransformAnnotation(new MustDeduplicateTransform), + MustDeduplicateReportDirectory(reportDir.toString), + bdedup + ) + + (new firrtl.stage.FirrtlPhase).transform(annos) + } + } +} diff --git a/src/test/scala/firrtlTests/transforms/SortModulesSpec.scala b/src/test/scala/firrtlTests/transforms/SortModulesSpec.scala new file mode 100644 index 0000000000..4aabeeeaf1 --- /dev/null +++ b/src/test/scala/firrtlTests/transforms/SortModulesSpec.scala @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtlTests.transforms + +import firrtl.{ir, CircuitState, Parser} +import firrtl.transforms.SortModules +import firrtl.traversals.Foreachers._ + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import scala.collection.mutable + +class SortModulesSpec extends AnyFlatSpec with Matchers { + + private def collectModules(names: mutable.ArrayBuffer[String], module: ir.DefModule): Unit = names += module.name + + behavior.of("SortModules") + + it should "enforce define before use of modules" in { + + val input = + """|circuit Foo: + | module Foo: + | inst bar of Bar + | module Bar: + | inst baz of Baz + | extmodule Baz: + | input a: UInt<1> + |""".stripMargin + + val state = CircuitState(Parser.parse(input), Seq.empty) + val moduleNames = mutable.ArrayBuffer.empty[String] + + (new SortModules) + .execute(state) + .circuit + .foreach(collectModules(moduleNames, _: ir.DefModule)) + + (moduleNames should contain).inOrderOnly("Baz", "Bar", "Foo") + } + +}