Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Operations for bit representation of Nat and BitVec. #366

Merged
merged 31 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
2e7ce57
chore: Reduce unnecessary modulus in BitVec.
joehendrix Nov 15, 2023
0b439fe
Fix missing import
joehendrix Nov 15, 2023
09d9c2d
Fix copyright notice.
joehendrix Nov 15, 2023
ae2863d
Introduce Fin.hIterate
joehendrix Nov 15, 2023
08b1624
chore: Remove unintentional committed lines.
joehendrix Nov 15, 2023
5447a01
fix: Cleanup shift left per Mario's suggestions.
joehendrix Nov 15, 2023
bcbc7fd
Move hIterate.loop so I can docstring it.
joehendrix Nov 15, 2023
bbca4c2
feat: Introduce Nat.testBit
joehendrix Nov 15, 2023
97875a1
chore: BitVec definition changes for better efficiency
joehendrix Nov 16, 2023
74b64d8
Bitvector addition as bitblast
joehendrix Nov 22, 2023
a6fa19f
chore: Drop unused lemmas
joehendrix Nov 22, 2023
aded617
chore: Further cleanups
joehendrix Nov 22, 2023
2e9d627
chore: Fixes from rebase
joehendrix Nov 23, 2023
a86adb1
chore: Rename carry_clean to carry.
joehendrix Nov 23, 2023
5baaea8
chore: more cleanups
joehendrix Nov 23, 2023
a94eae9
chore: Minor tweaks
joehendrix Nov 23, 2023
c500f63
chore: More minor cleanups
joehendrix Nov 23, 2023
8443c52
chore: Fix lint
joehendrix Nov 23, 2023
fc83cc6
chore: Rename cons_getLsb_truncat
joehendrix Nov 23, 2023
9399bec
chore: rehome bitvec lemmas
joehendrix Nov 23, 2023
2d39c7e
Update Std/Data/BitVec/Bitblast.lean
joehendrix Nov 27, 2023
a5a3c36
Minor cleanups
joehendrix Nov 27, 2023
f88b46b
chore: More cleanups
joehendrix Nov 27, 2023
c4e15c0
Update Std/Data/Fin/Iterate.lean
joehendrix Nov 28, 2023
1a06e6d
Update Std/Data/Fin/Iterate.lean
joehendrix Nov 28, 2023
936b664
Update Std/Data/Nat/Basic.lean
joehendrix Nov 28, 2023
27c7db4
chore: Cleanups per review suggestions.
joehendrix Nov 28, 2023
78d79b0
chore: More PR review cleanups
joehendrix Nov 28, 2023
dffb6db
chore: More cleanups
joehendrix Nov 28, 2023
1264243
chore: More cleanups
joehendrix Nov 28, 2023
d61e9e7
chore: Whitespace
joehendrix Nov 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Std.lean
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@ import Std.Data.BinomialHeap.Basic
import Std.Data.BinomialHeap.Lemmas
import Std.Data.BitVec
import Std.Data.BitVec.Basic
import Std.Data.BitVec.Bitblast
import Std.Data.BitVec.Folds
import Std.Data.BitVec.Lemmas
import Std.Data.Bool
import Std.Data.Char
import Std.Data.DList
import Std.Data.Fin.Basic
import Std.Data.Fin.Init.Lemmas
import Std.Data.Fin.Iterate
import Std.Data.Fin.Lemmas
import Std.Data.HashMap
import Std.Data.HashMap.Basic
Expand All @@ -48,6 +52,7 @@ import Std.Data.List.Pairwise
import Std.Data.MLList.Basic
import Std.Data.MLList.Heartbeats
import Std.Data.Nat.Basic
import Std.Data.Nat.Bitwise
import Std.Data.Nat.Gcd
import Std.Data.Nat.Init.Lemmas
import Std.Data.Nat.Lemmas
Expand Down
71 changes: 51 additions & 20 deletions Std/Data/BitVec/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ institutional affiliations. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Joe Hendrix, Wojciech Nawrocki, Leonardo de Moura, Mario Carneiro, Alex Keizer
-/
import Std.Data.Nat.Init.Lemmas
import Std.Data.Fin.Basic
import Std.Data.Int.Basic
import Std.Data.Nat.Bitwise
import Std.Tactic.Alias

namespace Std
Expand Down Expand Up @@ -44,14 +44,21 @@ namespace BitVec
/-- The `BitVec` with value `i mod 2^n`. Treated as an operation on bitvectors,
this is truncation of the high bits when downcasting and zero-extension when upcasting. -/
protected def ofNat (n : Nat) (i : Nat) : BitVec n where
toFin := Fin.ofNat' i (Nat.pow_two_pos _)
toFin :=
let p : i &&& 2^n-1 < 2^n := by
apply Nat.and_lt_two_pow
exact Nat.sub_lt (Nat.pow_two_pos n) (Nat.le_refl 1)
⟨i &&& 2^n-1, p⟩

/-- Given a bitvector `a`, return the underlying `Nat`. This is O(1) because `BitVec` is a
(zero-cost) wrapper around a `Nat`. -/
protected def toNat (a : BitVec n) : Nat := a.toFin.val

/-- Return the bound in terms of toNat. -/
theorem isLt (x : BitVec w) : x.toNat < 2^w := x.toFin.isLt

/-- Return the `i`-th least significant bit or `false` if `i ≥ w`. -/
@[inline] def getLsb (x : BitVec w) (i : Nat) : Bool := x.toNat &&& (1 <<< i) != 0
@[inline] def getLsb (x : BitVec w) (i : Nat) : Bool := x.toNat.testBit i

/-- Return the `i`-th most significant bit or `false` if `i ≥ w`. -/
@[inline] def getMsb (x : BitVec w) (i : Nat) : Bool := i < w && getLsb x (w-1-i)
Expand All @@ -70,7 +77,7 @@ protected def toInt (a : BitVec n) : Int :=
if a.msb then Int.ofNat a.toNat - Int.ofNat (2^n) else a.toNat

/-- Return a bitvector `0` of size `n`. This is the bitvector with all zero bits. -/
protected def zero (n : Nat) : BitVec n := .ofNat n 0
protected def zero (n : Nat) : BitVec n := ⟨0, Nat.pow_two_pos n⟩

instance : Inhabited (BitVec n) where default := .zero n

Expand All @@ -89,7 +96,7 @@ attribute [match_pattern] BitVec.ofNat
| _ => throw ()

/-- Convert bitvector into a fixed-width hex number. -/
protected def toHex {n:Nat} (x:BitVec n) : String :=
protected def toHex {n : Nat} (x : BitVec n) : String :=
let s := (Nat.toDigits 16 x.toNat).asString
let t := (List.replicate ((n+3) / 4 - s.length) '0').asString
t ++ s
Expand Down Expand Up @@ -271,7 +278,8 @@ Bitwise AND for bit vectors.

SMT-Lib name: `bvand`.
-/
protected def and (x y : BitVec n) : BitVec n where toFin := x.toFin &&& y.toFin
protected def and (x y : BitVec n) : BitVec n where toFin :=
⟨x.toNat &&& y.toNat, Nat.and_lt_two_pow x.toNat y.isLt⟩
instance : AndOp (BitVec w) := ⟨.and⟩

/--
Expand All @@ -283,7 +291,8 @@ Bitwise OR for bit vectors.

SMT-Lib name: `bvor`.
-/
protected def or (x y : BitVec n) : BitVec n where toFin := x.toFin ||| y.toFin
protected def or (x y : BitVec n) : BitVec n where toFin :=
⟨x.toNat ||| y.toNat, Nat.or_lt_two_pow x.isLt y.isLt⟩
instance : OrOp (BitVec w) := ⟨.or⟩

/--
Expand All @@ -295,7 +304,8 @@ instance : OrOp (BitVec w) := ⟨.or⟩

SMT-Lib name: `bvxor`.
-/
protected def xor (x y : BitVec n) : BitVec n where toFin := x.toFin ^^^ y.toFin
protected def xor (x y : BitVec n) : BitVec n where toFin :=
⟨x.toNat ^^^ y.toNat, Nat.xor_lt_two_pow x.isLt y.isLt⟩
instance : Xor (BitVec w) := ⟨.xor⟩

/--
Expand Down Expand Up @@ -361,14 +371,33 @@ SMT-Lib name: `rotate_right` except this operator uses a `Nat` shift amount.
-/
def rotateRight (x : BitVec w) (n : Nat) : BitVec w := x >>> n ||| x <<< (w - n)

/--
A version of `zeroExtend` that requires a proof, but is a noop.
-/
def zeroExtend' {n w : Nat} (le : n ≤ w) (x : BitVec n) : BitVec w :=
⟨x.toNat, by
apply Nat.lt_of_lt_of_le x.isLt
exact Nat.pow_le_pow_of_le_right (by trivial) le⟩

/--
`shiftLeftZeroExtend x n` returns `zeroExtend (w+n) x <<< n` without
needing to compute `x % 2^(2+n)`.
-/
def shiftLeftZeroExtend (msbs : BitVec w) (m : Nat) : BitVec (w+m) :=
let shiftLeftLt {x : Nat} (p : x < 2^w) (m : Nat) : x <<< m < 2^(w+m) := by
simp [Nat.shiftLeft_eq, Nat.pow_add]
apply Nat.mul_lt_mul_of_pos_right p
exact (Nat.pow_two_pos m)
⟨msbs.toNat <<< m, shiftLeftLt msbs.isLt m⟩

/--
Concatenation of bitvectors. This uses the "big endian" convention that the more significant
input is on the left, so `0xab#8 ++ 0xcd#8 = 0xabcd#16`.
input is on the left, so `0xAB#8 ++ 0xCD#8 = 0xABCD#16`.

SMT-Lib name: `concat`.
-/
def append (msbs : BitVec n) (lsbs : BitVec m) : BitVec (n+m) :=
.ofNat (n + m) (msbs.toNat <<< m ||| lsbs.toNat)
shiftLeftZeroExtend msbs m ||| zeroExtend' (Nat.le_add_left m n) lsbs

instance : HAppend (BitVec w) (BitVec v) (BitVec (w + v)) := ⟨.append⟩

Expand Down Expand Up @@ -405,7 +434,11 @@ If `v < w` then it truncates the high bits instead.

SMT-Lib name: `zero_extend`.
-/
def zeroExtend (v : Nat) (x : BitVec w) : BitVec v := .ofNat v x.toNat
def zeroExtend (v : Nat) (x : BitVec w) : BitVec v :=
if h : w ≤ v then
zeroExtend' h x
else
.ofNat v x.toNat

/--
Truncate the high bits of bitvector `x` of length `w`, resulting in a vector of length `v`.
Expand Down Expand Up @@ -435,25 +468,23 @@ def signExtend (v : Nat) (x : BitVec w) : BitVec v := .ofInt v x.toInt
@[simp] theorem mul_eq (x y : BitVec w) : BitVec.mul x y = x * y := rfl
@[simp] theorem zero_eq : BitVec.zero n = 0#n := rfl

@[simp]
theorem cast_ofNat {n m : Nat} (h : n = m) (x : Nat) :
@[simp] theorem cast_ofNat {n m : Nat} (h : n = m) (x : Nat) :
cast h (BitVec.ofNat n x) = BitVec.ofNat m x := by
subst h; rfl

@[simp]
theorem cast_cast {n m k : Nat} (h₁ : n = m) (h₂ : m = k) (x : BitVec n) :
@[simp] theorem cast_cast {n m k : Nat} (h₁ : n = m) (h₂ : m = k) (x : BitVec n) :
cast h₂ (cast h₁ x) = cast (h₁ ▸ h₂) x :=
rfl

@[simp]
theorem cast_eq {n : Nat} (h : n = n) (x : BitVec n) :
@[simp] theorem cast_eq {n : Nat} (h : n = n) (x : BitVec n) :
cast h x = x :=
rfl

/-- Turn a `Bool` into a bitvector of length `1` -/
def ofBool : Bool → BitVec 1
| false => 0
| true => 1
def ofBool (b : Bool) : BitVec 1 := cond b 1 0

@[simp] theorem ofBool_false : ofBool false = 0 := by trivial
@[simp] theorem ofBool_true : ofBool true = 1 := by trivial

/-- The empty bitvector -/
abbrev nil : BitVec 0 := 0
Expand Down
159 changes: 159 additions & 0 deletions Std/Data/BitVec/Bitblast.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/-
Copyright (c) 2023 by the authors listed in the file AUTHORS and their
institutional affiliations. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Harun Khan, Abdalrhman M Mohamed, Joe Hendrix
-/
import Std.Data.BitVec.Folds

/-!
# Bitblasting of bitvectors
joehendrix marked this conversation as resolved.
Show resolved Hide resolved

This module provides theorems for showing the equivalence between BitVec operations using
the `Fin 2^n` representation and Boolean vectors. It is still under development, but
intended to provide a path for converting SAT and SMT solver proofs about BitVectors
as vectors of bits into proofs about Lean `BitVec` values.

The module is named for the bit-blasting operation in an SMT solver that converts bitvector
expressions into expressions about individual bits in each vector.

## Main results
* `x + y : BitVec w` is equivalent to `adc x y false`.

## Future work
All other operations are to be PR'ed later and are already proved in
https://github.com/mhk119/lean-smt/blob/bitvec/Smt/Data/BitVec.lean.
-/

open Nat Bool

/-! ### Preliminaries -/

namespace Std.BitVec

private theorem testBit_limit {x i : Nat} (x_lt_succ : x < 2^(i+1)) :
testBit x i = decide (x ≥ 2^i) := by
cases xi : testBit x i with
| true =>
simp [testBit_implies_ge xi]
| false =>
simp
cases Nat.lt_or_ge x (2^i) with
| inl x_lt =>
exact x_lt
| inr x_ge =>
have ⟨j, ⟨j_ge, jp⟩⟩ := ge_two_pow_implies_high_bit_true x_ge
joehendrix marked this conversation as resolved.
Show resolved Hide resolved
cases Nat.lt_or_eq_of_le j_ge with
| inr x_eq =>
simp [x_eq, jp] at xi
| inl x_lt =>
exfalso
apply Nat.lt_irrefl
calc x < 2^(i+1) := x_lt_succ
_ ≤ 2 ^ j := Nat.pow_le_pow_of_le_right Nat.zero_lt_two x_lt
_ ≤ x := testBit_implies_ge jp

private theorem mod_two_pow_succ (x i : Nat) :
x % 2^(i+1) = 2^i*(x.testBit i).toNat + x % (2 ^ i):= by
apply Nat.eq_of_testBit_eq
intro j
simp only [Nat.mul_add_lt_is_or, testBit_or, testBit_mod_two_pow, testBit_shiftLeft,
Nat.testBit_bool_to_nat, Nat.sub_eq_zero_iff_le, Nat.mod_lt, Nat.pow_two_pos,
testBit_mul_pow_two]
rcases Nat.lt_trichotomy i j with i_lt_j | i_eq_j | j_lt_i
· have i_le_j : i ≤ j := Nat.le_of_lt i_lt_j
have not_j_le_i : ¬(j ≤ i) := Nat.not_le_of_lt i_lt_j
have not_j_lt_i : ¬(j < i) := Nat.not_lt_of_le i_le_j
have not_j_lt_i_succ : ¬(j < i + 1) :=
Nat.not_le_of_lt (Nat.succ_lt_succ i_lt_j)
simp [i_le_j, not_j_le_i, not_j_lt_i, not_j_lt_i_succ]
· simp [i_eq_j]
· have j_le_i : j ≤ i := Nat.le_of_lt j_lt_i
have j_le_i_succ : j < i + 1 := Nat.succ_le_succ j_le_i
have not_j_ge_i : ¬(j ≥ i) := Nat.not_le_of_lt j_lt_i
simp [j_lt_i, j_le_i, not_j_ge_i, j_le_i_succ]

private theorem mod_two_pow_lt (x i : Nat) : x % 2 ^ i < 2^i := Nat.mod_lt _ (Nat.pow_two_pos _)

/-! ### Addition -/

/-- carry w x y c returns true if the `w` carry bit is true when computing `x + y + c`. -/
def carry (w x y : Nat) (c : Bool) : Bool := decide (x % 2^w + y % 2^w + c.toNat ≥ 2^w)

@[simp] theorem carry_zero : carry 0 x y c = c := by
cases c <;> simp [carry, mod_one]

/-- Carry function for bitwise addition. -/
def adcb (x y c : Bool) : Bool × Bool := (x && y || x && c || y && c, Bool.xor x (Bool.xor y c))

/-- Bitwise addition implemented via a ripple carry adder. -/
def adc (x y : BitVec w) : Bool → Bool × BitVec w :=
iunfoldr fun (i : Fin w) c => adcb (x.getLsb i) (y.getLsb i) c

theorem adc_overflow_limit (x y i : Nat) (c : Bool) : x % 2^i + (y % 2^i + c.toNat) < 2^(i+1) := by
apply Nat.lt_of_succ_le
simp only [←Nat.succ_add, Nat.pow_succ, Nat.mul_two]
apply Nat.add_le_add (mod_two_pow_lt _ _)
apply Nat.le_trans
exact (Nat.add_le_add_left (Bool.toNat_le_one c) _)
exact Nat.mod_lt _ (Nat.pow_two_pos i)

theorem carry_succ (w x y : Nat) (c : Bool) : carry (succ w) x y c =
decide ((x.testBit w).toNat + (y.testBit w).toNat + (carry w x y c).toNat ≥ 2) := by
simp only [carry, mod_two_pow_succ _ w, decide_eq_decide]
generalize testBit x w = xh
generalize testBit y w = yh
have sum_bnd : x%2^w + (y%2^w + c.toNat) < 2*2^w := by
simp [Nat.mul_comm 2 _, ←Nat.pow_succ ]
exact adc_overflow_limit x y w c
simp only [Nat.pow_succ]
cases xh <;> cases yh <;> cases Decidable.em (x%2^w + (y%2^w + toNat c) ≥ 2 ^ w) with | _ pred =>
simp [Nat.one_shiftLeft, Nat.add_assoc, Nat.add_left_comm _ (2^_) _, Nat.mul_comm (2^_) _,
Nat.not_le_of_lt, Nat.add_succ, Nat.succ_le_succ,
mul_le_add_right, le_add_right, sum_bnd, pred]

theorem adc_value_step {i : Nat} (i_lt : i < w) (x y : BitVec w) (c : Bool) :
getLsb (x + y + zeroExtend w (ofBool c)) i =
Bool.xor (getLsb x i) (Bool.xor (getLsb y i) (carry i x.toNat y.toNat c)) := by
let ⟨x, x_lt⟩ := x
let ⟨y, y_lt⟩ := y
simp only [getLsb, toNat_add, toNat_zeroExtend, i_lt, toNat_ofFin, toNat_ofBool,
Nat.mod_add_mod, Nat.add_mod_mod]
apply Eq.trans
rw [← Nat.div_add_mod x (2^i), ← Nat.div_add_mod y (2^i)]
simp only
[ Nat.testBit_mod_two_pow,
Nat.testBit_mul_two_pow_add_eq,
i_lt,
decide_True,
Bool.true_and,
Nat.add_assoc,
Nat.add_left_comm (_%_) (_ * _) _,
testBit_limit (adc_overflow_limit x y i c)
]
simp [testBit_to_div_mod, carry, Nat.add_assoc]

theorem adc_correct (x y : BitVec w) (c : Bool) :
adc x y c = (carry w x.toNat y.toNat c, x + y + zeroExtend w (ofBool c)) := by
simp only [adc]
apply iunfoldr_replace
(fun i => carry i x.toNat y.toNat c)
(x + y + zeroExtend w (ofBool c))
c
case init =>
simp [carry, Nat.mod_one]
cases c <;> rfl
case step =>
intro ⟨i, lt⟩
simp only [adcb, Prod.mk.injEq, carry_succ]
apply And.intro
case left =>
rw [testBit_toNat, testBit_toNat]
cases x.getLsb i <;>
cases y.getLsb i <;>
cases carry i x.toNat y.toNat c <;> simp [Nat.succ_le_succ_iff]
case right =>
simp [adc_value_step lt]

theorem add_as_adc (w : Nat) (x y : BitVec w) : x + y = (adc x y false).snd := by
simp [adc_correct]
Loading
Loading