Skip to content

Commit

Permalink
Merge pull request #1239 from OpenXiangShan/dev-wrbypass
Browse files Browse the repository at this point in the history
bpu: extract wrbypass to be a module
  • Loading branch information
Lingrui98 authored Nov 17, 2021
2 parents 5551d32 + 569b279 commit 0bbc9ca
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 273 deletions.
59 changes: 18 additions & 41 deletions src/main/scala/xiangshan/frontend/Bim.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,55 +64,32 @@ class BIM(implicit p: Parameters) extends BasePredictor with BimParams with BPUU
// Update logic
val u_valid = RegNext(io.update.valid)
val update = RegNext(io.update.bits)

val u_idx = bimAddr.getIdx(update.pc)

val update_mask = LowerMask(PriorityEncoderOH(update.preds.br_taken_mask.asUInt))
val newCtrs = Wire(Vec(numBr, UInt(2.W)))
val need_to_update = VecInit((0 until numBr).map(i => u_valid && update.ftb_entry.brValids(i) && update_mask(i)))

// Bypass logic
val wrbypass_ctrs = RegInit(0.U.asTypeOf(Vec(bypassEntries, Vec(numBr, UInt(2.W)))))
val wrbypass_ctr_valids = RegInit(0.U.asTypeOf(Vec(bypassEntries, Vec(numBr, Bool()))))
val wrbypass_idx = RegInit(0.U.asTypeOf(Vec(bypassEntries, UInt(log2Up(bimSize).W))))
val wrbypass_enq_ptr = RegInit(0.U(log2Up(bypassEntries).W))

val wrbypass_hits = VecInit((0 until bypassEntries).map(i =>
!doing_reset && wrbypass_idx(i) === u_idx))
val wrbypass_hit = wrbypass_hits.reduce(_||_)
val wrbypass_hit_idx = PriorityEncoder(wrbypass_hits)

val oldCtrs = VecInit((0 until numBr).map(i =>
Mux(wrbypass_hit && wrbypass_ctr_valids(wrbypass_hit_idx)(i),
wrbypass_ctrs(wrbypass_hit_idx)(i), update.meta(2*i+1, 2*i))))
// Bypass logic
val wrbypass = Module(new WrBypass(UInt(2.W), bypassEntries, log2Up(bimSize), numWays = numBr))
wrbypass.io.wen := need_to_update.reduce(_||_)
wrbypass.io.write_idx := u_idx
wrbypass.io.write_data := newCtrs
wrbypass.io.write_way_mask.map(_ := need_to_update)

val oldCtrs =
VecInit((0 until numBr).map(i =>
Mux(wrbypass.io.hit && wrbypass.io.hit_data(i).valid,
wrbypass.io.hit_data(i).bits,
update.meta(2*i+1, 2*i))
))

val newTakens = update.preds.br_taken_mask
val newCtrs = VecInit((0 until numBr).map(i =>
newCtrs := VecInit((0 until numBr).map(i =>
satUpdate(oldCtrs(i), 2, newTakens(i))
))

val update_mask = LowerMask(PriorityEncoderOH(update.preds.br_taken_mask.asUInt))
val need_to_update = VecInit((0 until numBr).map(i => u_valid && update.ftb_entry.brValids(i) && update_mask(i)))

when (reset.asBool) { wrbypass_ctr_valids.foreach(_ := VecInit(Seq.fill(numBr)(false.B)))}

for (i <- 0 until numBr) {
when(need_to_update.reduce(_||_)) {
when(wrbypass_hit) {
when(need_to_update(i)) {
wrbypass_ctrs(wrbypass_hit_idx)(i) := newCtrs(i)
wrbypass_ctr_valids(wrbypass_hit_idx)(i) := true.B
}
}.otherwise {
wrbypass_ctr_valids(wrbypass_enq_ptr)(i) := false.B
when(need_to_update(i)) {
wrbypass_ctrs(wrbypass_enq_ptr)(i) := newCtrs(i)
wrbypass_ctr_valids(wrbypass_enq_ptr)(i) := true.B
}
}
}
}

when (need_to_update.reduce(_||_) && !wrbypass_hit) {
wrbypass_idx(wrbypass_enq_ptr) := u_idx
wrbypass_enq_ptr := (wrbypass_enq_ptr + 1.U)(log2Up(bypassEntries)-1, 0)
}

bim.io.w.apply(
valid = need_to_update.asUInt.orR || doing_reset,
Expand Down
78 changes: 6 additions & 72 deletions src/main/scala/xiangshan/frontend/ITTAGE.scala
Original file line number Diff line number Diff line change
Expand Up @@ -157,21 +157,6 @@ class ITTageTable
val wrBypassEntries = 4
val phistLen = if (PathHistoryLength > histLen) histLen else PathHistoryLength

// def compute_tag_and_hash(unhashed_idx: UInt, hist: UInt, phist: UInt) = {
// val idx_history = compute_folded_ghist(hist, log2Ceil(nRows))
// // val idx = (unhashed_idx ^ (unhashed_idx >> (log2Ceil(nRows)-tableIdx+1)) ^ idx_history ^ idx_phist)(log2Ceil(nRows) - 1, 0)
// val idx = (unhashed_idx ^ idx_history)(log2Ceil(nRows) - 1, 0)
// val tag_history = compute_folded_ghist(hist, tagLen)
// val alt_tag_history = compute_folded_ghist(hist, tagLen-1)
// // Use another part of pc to make tags
// val tag = (
// if (tagLen > 1)
// ((unhashed_idx >> log2Ceil(nRows)) ^ tag_history ^ (alt_tag_history << 1)) (tagLen - 1, 0)
// else 0.U
// )
// (idx, tag)
// }

require(histLen == 0 && tagLen == 0 || histLen != 0 && tagLen != 0)
val idxFhInfo = (histLen, min(log2Ceil(nRows), histLen))
val tagFhInfo = (histLen, min(histLen, tagLen))
Expand Down Expand Up @@ -312,37 +297,14 @@ class ITTageTable
waymask = true.B
)

val wrbypass_tags = RegInit(0.U.asTypeOf(Vec(wrBypassEntries, UInt(tagLen.W))))
val wrbypass_idxs = RegInit(0.U.asTypeOf(Vec(wrBypassEntries, UInt(log2Ceil(nRows).W))))
val wrbypass_ctrs = RegInit(0.U.asTypeOf(Vec(wrBypassEntries, UInt(ITTageCtrBits.W))))
val wrbypass_enq_idx = RegInit(0.U(log2Ceil(wrBypassEntries).W))


val wrbypass_hits = VecInit((0 until wrBypassEntries) map { i =>
wrbypass_tags(i) === update_tag &&
wrbypass_idxs(i) === update_idx
})


val wrbypass_hit = wrbypass_hits.reduce(_||_)
// val wrbypass_rhit = wrbypass_rhits.reduce(_||_)
val wrbypass_hit_idx = ParallelPriorityEncoder(wrbypass_hits)
// val wrbypass_rhit_idx = PriorityEncoder(wrbypass_rhits)

// val wrbypass_rctr_hits = VecInit((0 until TageBanks).map( b => wrbypass_ctr_valids(wrbypass_rhit_idx)(b)))

// val rhit_ctrs = RegEnable(wrbypass_ctrs(wrbypass_rhit_idx), wrbypass_rhit)

// when (RegNext(wrbypass_rhit)) {
// for (b <- 0 until TageBanks) {
// when (RegNext(wrbypass_rctr_hits(b.U + baseBank))) {
// io.resp(b).bits.ctr := rhit_ctrs(s2_bankIdxInOrder(b))
// }
// }
// }
val wrbypass = Module(new WrBypass(UInt(ITTageCtrBits.W), wrBypassEntries, log2Ceil(nRows), tagWidth=tagLen))

wrbypass.io.wen := io.update.valid
wrbypass.io.write_idx := update_idx
wrbypass.io.write_tag.map(_ := update_tag)
wrbypass.io.write_data.map(_ := update_wdata.ctr)

val old_ctr = Mux(wrbypass_hit, wrbypass_ctrs(wrbypass_hit_idx), io.update.oldCtr)
val old_ctr = Mux(wrbypass.io.hit, wrbypass.io.hit_data(0).bits, io.update.oldCtr)
update_wdata.ctr := Mux(io.update.alloc, 2.U, inc_ctr(old_ctr, io.update.correct))
update_wdata.valid := true.B
update_wdata.tag := update_tag
Expand All @@ -352,22 +314,6 @@ class ITTageTable
update_hi_wdata := io.update.u(1)
update_lo_wdata := io.update.u(0)

when (io.update.valid) {
when (wrbypass_hit) {
wrbypass_ctrs(wrbypass_hit_idx) := update_wdata.ctr
} .otherwise {
wrbypass_ctrs(wrbypass_enq_idx) := update_wdata.ctr
}
}

when (io.update.valid && !wrbypass_hit) {
wrbypass_tags(wrbypass_enq_idx) := update_tag
wrbypass_idxs(wrbypass_enq_idx) := update_idx
wrbypass_enq_idx := (wrbypass_enq_idx + 1.U)(log2Ceil(wrBypassEntries)-1,0)
}

XSPerfAccumulate("ittage_table_wrbypass_hit", io.update.valid && wrbypass_hit)
XSPerfAccumulate("ittage_table_wrbypass_enq", io.update.valid && !wrbypass_hit)
XSPerfAccumulate("ittage_table_hits", io.resp.valid)

if (BPUDebug && debug) {
Expand All @@ -388,20 +334,8 @@ class ITTageTable
p"update ITTAGE Table: writing tag:${update_tag}, " +
p"ctr: ${update_wdata.ctr}, target:${Hexadecimal(update_wdata.target)}" +
p" in idx $update_idx\n")
val hitCtr = wrbypass_ctrs(wrbypass_hit_idx)
XSDebug(wrbypass_hit && io.update.valid,
p"wrbypass hit wridx:$wrbypass_hit_idx, idx:$update_idx, tag: $update_tag, " +
p"ctr:$hitCtr, newCtr:${update_wdata.ctr}\n")

XSDebug(RegNext(io.req.valid) && !s1_req_rhit, "TageTableResp: no hits!\n")

// when (wrbypass_rhit && wrbypass_ctr_valids(wrbypass_rhit_idx).reduce(_||_)) {
// for (b <- 0 until TageBanks) {
// XSDebug(wrbypass_ctr_valids(wrbypass_rhit_idx)(b),
// "wrbypass rhits, wridx:%d, tag:%x, idx:%d, hitctr:%d, bank:%d\n",
// wrbypass_rhit_idx, tag, idx, wrbypass_ctrs(wrbypass_rhit_idx)(b), b.U)
// }
// }

// ------------------------------Debug-------------------------------------
val valids = RegInit(0.U.asTypeOf(Vec(nRows, Bool())))
Expand Down
64 changes: 5 additions & 59 deletions src/main/scala/xiangshan/frontend/SC.scala
Original file line number Diff line number Diff line change
Expand Up @@ -116,67 +116,19 @@ class SCTable(val nRows: Int, val ctrBits: Int, val histLen: Int)(implicit p: Pa

val wrBypassEntries = 4

class SCWrBypass extends XSModule {
val io = IO(new Bundle {
val wen = Input(Bool())
val update_idx = Input(UInt(log2Ceil(nRows).W))
val update_ctrs = Flipped(ValidIO(SInt(ctrBits.W)))
val update_ctrPos = Input(UInt(log2Ceil(2).W))
val update_altPos = Input(UInt(log2Ceil(2).W))

val hit = Output(Bool())
val ctrs = Vec(2, ValidIO(SInt(ctrBits.W)))
})

val idxes = RegInit(0.U.asTypeOf(Vec(wrBypassEntries, UInt(log2Ceil(nRows).W))))
val ctrs = RegInit(0.U.asTypeOf(Vec(wrBypassEntries, Vec(2, SInt(ctrBits.W)))))
val ctr_valids = RegInit(0.U.asTypeOf(Vec(wrBypassEntries, Vec(2, Bool()))))
val enq_idx = RegInit(0.U(log2Ceil(wrBypassEntries).W))

val hits = VecInit((0 until wrBypassEntries).map { i => idxes(i) === io.update_idx })

val hit = hits.reduce(_||_)
val hit_idx = ParallelPriorityEncoder(hits)

io.hit := hit

for (i <- 0 until 2) {
io.ctrs(i).valid := ctr_valids(hit_idx)(i)
io.ctrs(i).bits := ctrs(hit_idx)(i)
}

when (io.wen) {
when (hit) {
ctrs(hit_idx)(io.update_ctrPos) := io.update_ctrs.bits
ctr_valids(hit_idx)(io.update_ctrPos) := io.update_ctrs.valid
}.otherwise {
ctr_valids(enq_idx)(io.update_altPos) := false.B
ctr_valids(enq_idx)(io.update_ctrPos) := io.update_ctrs.valid
ctrs(enq_idx)(io.update_ctrPos) := io.update_ctrs.bits
}
}

when(io.wen && !hit) {
idxes(enq_idx) := io.update_idx
enq_idx := (enq_idx + 1.U)(log2Ceil(wrBypassEntries)-1, 0)
}
}

val wrbypass = Module(new SCWrBypass)
val wrbypass = Module(new WrBypass(SInt(ctrBits.W), wrBypassEntries, log2Ceil(nRows), numWays=2))

val ctrPos = io.update.tagePred
val altPos = !io.update.tagePred
val bypass_ctr = wrbypass.io.ctrs(ctrPos)
val bypass_ctr = wrbypass.io.hit_data(ctrPos)
val hit_and_valid = wrbypass.io.hit && bypass_ctr.valid
val oldCtr = Mux(hit_and_valid, bypass_ctr.bits, io.update.oldCtr)
update_wdata := ctrUpdate(oldCtr, io.update.taken)

wrbypass.io.wen := io.update.mask
wrbypass.io.update_ctrs.valid := io.update.mask
wrbypass.io.update_ctrs.bits := update_wdata
wrbypass.io.update_idx := update_idx
wrbypass.io.update_ctrPos := ctrPos
wrbypass.io.update_altPos := altPos
wrbypass.io.write_data.map(_ := update_wdata) // only one of them are used
wrbypass.io.write_idx := update_idx
wrbypass.io.write_way_mask.map(_ := UIntToOH(ctrPos).asTypeOf(Vec(2, Bool())))

val u = io.update
XSDebug(io.req.valid,
Expand All @@ -188,12 +140,6 @@ class SCTable(val nRows: Int, val ctrBits: Int, val histLen: Int)(implicit p: Pa
XSDebug(io.update.mask,
p"update Table: pc:${Hexadecimal(u.pc)}, " +
p"tageTaken:${u.tagePred}, taken:${u.taken}, oldCtr:${u.oldCtr}\n")
val updateCtrPos = io.update.tagePred
val hitCtr = wrbypass.io.ctrs(updateCtrPos).bits
XSDebug(wrbypass.io.hit && wrbypass.io.ctrs(updateCtrPos).valid && io.update.mask,
p"wrbypass hit idx:$update_idx, ctr:$hitCtr, " +
p"taken:${io.update.taken} newCtr:${update_wdata}\n")

}

class SCThreshold(val ctrBits: Int = 6)(implicit p: Parameters) extends SCBundle {
Expand Down
Loading

0 comments on commit 0bbc9ca

Please sign in to comment.