Skip to content

Commit

Permalink
chore: Reduce unnecessary modulus in BitVec.
Browse files Browse the repository at this point in the history
This introduces proofs on natural number bitvector operations to
reduce modular arithmetic.
  • Loading branch information
joehendrix committed Nov 15, 2023
1 parent 8708106 commit e6c22fb
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Std/Data/BinomialHeap/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ theorem Heap.deleteMin_fst : ((s : Heap α).deleteMin le).map (·.1) = s.head? l
| .nil, _ => rfl
| .cons .., ⟨_, h₁, h₂⟩ => by
simp [size, Nat.shiftLeft, size_eq h₂, Nat.pow_succ, Nat.mul_succ]
simp [Nat.add_assoc, Nat.one_shiftLeft, h₁.realSize_eq, h₂.size_eq]
simp [Nat.add_assoc, Nat.shiftLeft_eq, h₁.realSize_eq, h₂.size_eq]

end Imp
40 changes: 34 additions & 6 deletions Std/Data/BitVec/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ institutional affiliations. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Joe Hendrix, Wojciech Nawrocki, Leonardo de Moura, Mario Carneiro, Alex Keizer
-/
import Std.Data.Nat.Init.Lemmas
import Std.Data.Fin.Basic
import Std.Data.Int.Basic
import Std.Data.Nat.Bitwise
import Std.Tactic.Alias

namespace Std
Expand Down Expand Up @@ -50,8 +50,18 @@ protected def ofNat (n : Nat) (i : Nat) : BitVec n where
(zero-cost) wrapper around a `Nat`. -/
protected def toNat (a : BitVec n) : Nat := a.toFin.val

/-- Return the bound in terms of toNat. -/
theorem isLt (x : BitVec w) : x.toNat < 2^w := x.toFin.isLt

/-- Prove equality of bitvectors in terms of nat operations. -/
theorem eq_of_toNat_eq {n} : ∀ {i j : BitVec n}, i.toNat = j.toNat → i = j
| ⟨_, _⟩, ⟨_, _⟩, rfl => rfl

theorem toNat_eq (x y : BitVec n) : x = y ↔ x.toNat = y.toNat :=
Iff.intro (congrArg BitVec.toNat) eq_of_toNat_eq

/-- Return the `i`-th least significant bit or `false` if `i ≥ w`. -/
@[inline] def getLsb (x : BitVec w) (i : Nat) : Bool := x.toNat &&& (1 <<< i) != 0
@[inline] def getLsb : BitVec w -> Nat -> Bool | ⟨x,_⟩, i => (x >>> i) % 2 == 1

/-- Return the `i`-th most significant bit or `false` if `i ≥ w`. -/
@[inline] def getMsb (x : BitVec w) (i : Nat) : Bool := i < w && getLsb x (w-1-i)
Expand Down Expand Up @@ -271,7 +281,8 @@ Bitwise AND for bit vectors.
SMT-Lib name: `bvand`.
-/
protected def and (x y : BitVec n) : BitVec n where toFin := x.toFin &&& y.toFin
protected def and (x y : BitVec n) : BitVec n where toFin :=
⟨x.toNat &&& y.toNat, Nat.land_lt_2_pow x.isLt y.isLt⟩
instance : AndOp (BitVec w) := ⟨.and⟩

/--
Expand All @@ -283,7 +294,8 @@ Bitwise OR for bit vectors.
SMT-Lib name: `bvor`.
-/
protected def or (x y : BitVec n) : BitVec n where toFin := x.toFin ||| y.toFin
protected def or (x y : BitVec n) : BitVec n where toFin :=
⟨x.toNat ||| y.toNat, Nat.lor_lt_2_pow x.isLt y.isLt⟩
instance : OrOp (BitVec w) := ⟨.or⟩

/--
Expand All @@ -295,7 +307,8 @@ instance : OrOp (BitVec w) := ⟨.or⟩
SMT-Lib name: `bvxor`.
-/
protected def xor (x y : BitVec n) : BitVec n where toFin := x.toFin ^^^ y.toFin
protected def xor (x y : BitVec n) : BitVec n where toFin :=
⟨x.toNat ^^^ y.toNat, Nat.xor_lt_2_pow x.isLt y.isLt⟩
instance : Xor (BitVec w) := ⟨.xor⟩

/--
Expand Down Expand Up @@ -361,14 +374,29 @@ SMT-Lib name: `rotate_right` except this operator uses a `Nat` shift amount.
-/
def rotateRight (x : BitVec w) (n : Nat) : BitVec w := x >>> n ||| x <<< (w - n)

/--
A version of `zeroExtend` that requires a proof, but is a noop.
-/
def zeroExtend' (w:Nat) (x : BitVec n) (le : n ≤ w) : BitVec w :=
⟨x.toNat, by
apply Nat.lt_of_lt_of_le x.isLt
exact Nat.pow_le_pow_of_le_right (by trivial) le⟩

/--
Concatenation of bitvectors. This uses the "big endian" convention that the more significant
input is on the left, so `0xab#8 ++ 0xcd#8 = 0xabcd#16`.
SMT-Lib name: `concat`.
-/
def append (msbs : BitVec n) (lsbs : BitVec m) : BitVec (n+m) :=
.ofNat (n + m) (msbs.toNat <<< m ||| lsbs.toNat)
⟨msbs.toNat <<< m, by
apply Nat.shiftLeft_lt_2_pow
simp only [Nat.add_sub_cancel]
exact msbs.isLt
⟩ ||| zeroExtend' (n+m) lsbs (by apply Nat.le_add_left)

-- (by apply Nat.le_add_right) <<< m)


instance : HAppend (BitVec w) (BitVec v) (BitVec (w + v)) := ⟨.append⟩

Expand Down
92 changes: 92 additions & 0 deletions Std/Data/Nat/Bitwise.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/-
Copyright (c) 2023 Lean FRO. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Joe Hendrix
-/

/-
This module defines properties of the bitwise operations on Natural numbers.
It is primarily intended to support the bitvector library.
-/
import Std.Data.Nat.Lemmas

namespace Nat

@[local simp]
private theorem eq_0_of_lt_one (x:Nat) : x < 1 ↔ x = 0 :=
Iff.intro
(fun p =>
match x with
| 0 => Eq.refl 0
| _+1 => False.elim (not_lt_zero _ (Nat.lt_of_succ_lt_succ p)))
(fun p => by simp [p, Nat.zero_lt_succ])

private theorem eq_0_of_lt (x:Nat) : x < 2^ 0 ↔ x = 0 := eq_0_of_lt_one x

@[local simp]
private theorem zero_lt_pow (n:Nat) : 0 < 2^n := by
induction n
case zero => simp [eq_0_of_lt]
case succ n hyp =>
simp [pow_succ]
exact (Nat.mul_lt_mul_of_pos_right hyp (by trivial : 2 > 0) : 0 < 2 ^ n * 2)

/-- This provides a bound on bitwise operations. -/
theorem bitwise_lt_2_pow (left : x < 2^n) (right : y < 2^n) : (Nat.bitwise f x y) < 2^n := by
induction n generalizing x y with
| zero =>
simp only [eq_0_of_lt] at left right
unfold bitwise
simp [left, right]
| succ n hyp =>
unfold bitwise
if x_zero : x = 0 then
simp only [x_zero, if_true]
by_cases p : f false true = true <;> simp [p, right]
else if y_zero : y = 0 then
simp only [x_zero, y_zero, if_false, if_true]
by_cases p : f true false = true <;> simp [p, left]
else
simp only [x_zero, y_zero, if_false]
have lt : 0 < 2 := by trivial
have xlb : x / 2 < 2^n := by simp [div_lt_iff_lt_mul lt]; exact left
have ylb : y / 2 < 2^n := by simp [div_lt_iff_lt_mul lt]; exact right
have hyp1 := hyp xlb ylb
by_cases p : f (decide (x % 2 = 1)) (decide (y % 2 = 1)) = true <;>
simp [p, pow_succ, mul_succ, Nat.add_assoc]
case pos =>
apply lt_of_succ_le
simp only [← Nat.succ_add]
apply Nat.add_le_add <;> exact hyp1
case neg =>
apply Nat.add_lt_add <;> exact hyp1

theorem lor_lt_2_pow {x y n : Nat} (left : x < 2^n) (right : y < 2^n) : (x ||| y) < 2^n :=
bitwise_lt_2_pow left right

theorem land_lt_2_pow {x y n : Nat} (left : x < 2^n) (right : y < 2^n) : (x &&& y) < 2^n :=
bitwise_lt_2_pow left right

theorem xor_lt_2_pow {x y n : Nat} (left : x < 2^n) (right : y < 2^n) : (x ^^^ y) < 2^n :=
bitwise_lt_2_pow left right

theorem shiftLeft_lt_2_pow {x m n : Nat} (bound : x < 2^(n-m)) : (x <<< m) < 2^n := by
induction m generalizing x n with
| zero => exact bound
| succ m hyp =>
simp [shiftLeft_succ_inside]
apply hyp
revert bound
rw [Nat.sub_succ]
match n - m with
| 0 =>
intro bound
simp [eq_0_of_lt_one] at bound
simp [bound]
| d + 1 =>
intro bound
simp [Nat.pow_succ, Nat.mul_comm _ 2]
exact Nat.mul_lt_mul_of_pos_left bound (by trivial : 0 < 2)

end Nat
46 changes: 33 additions & 13 deletions Std/Data/Nat/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -987,14 +987,6 @@ theorem pow_succ' {m n : Nat} : m ^ n.succ = m * m ^ n := by

@[simp] theorem pow_eq {m n : Nat} : m.pow n = m ^ n := rfl

@[simp] theorem shiftLeft_eq (a b : Nat) : a <<< b = a * 2 ^ b :=
match b with
| 0 => (Nat.mul_one _).symm
| b+1 => (shiftLeft_eq _ b).trans <| by
simp [pow_succ, Nat.mul_assoc, Nat.mul_left_comm, Nat.mul_comm]

theorem one_shiftLeft (n : Nat) : 1 <<< n = 2 ^ n := by rw [shiftLeft_eq, Nat.one_mul]

attribute [simp] Nat.pow_zero

protected theorem zero_pow {n : Nat} (H : 0 < n) : 0 ^ n = 0 := by
Expand Down Expand Up @@ -1171,22 +1163,50 @@ protected theorem dvd_of_mul_dvd_mul_right (kpos : 0 < k) (H : m * k ∣ n * k)
@[simp] theorem sum_append : Nat.sum (l₁ ++ l₂) = Nat.sum l₁ + Nat.sum l₂ := by
induction l₁ <;> simp [*, Nat.add_assoc]

/-! ### shiftRight -/
/-! ### shiftLeft and shiftRight -/

theorem shiftLeft_eq (a b : Nat) : a <<< b = a * 2 ^ b :=
match b with
| 0 => (Nat.mul_one _).symm
| b+1 => (shiftLeft_eq _ b).trans <| by
simp [pow_succ, Nat.mul_assoc, Nat.mul_left_comm, Nat.mul_comm]

@[deprecated]
theorem one_shiftLeft (n : Nat) : 1 <<< n = 2 ^ n := by rw [shiftLeft_eq, Nat.one_mul]

@[simp] theorem shiftLeft_zero : n <<< 0 = n := rfl

/-- Shiftleft on successor with multiple moved inside. -/
theorem shiftLeft_succ_inside (m n : Nat) : m <<< (n+1) = (2*m) <<< n := rfl

/-- Shiftleft on successor with multiple moved to outside. -/
theorem shiftLeft_succ : ∀(m n), m <<< (n + 1) = 2 * (m <<< n)
| m, 0 => rfl
| m, k + 1 => by
rw [shiftLeft_succ_inside _ (k+1)]
rw [shiftLeft_succ _ k, shiftLeft_succ_inside]

@[simp] theorem shiftRight_zero : n >>> 0 = n := rfl

@[simp] theorem shiftRight_succ (m n) : m >>> (n + 1) = (m >>> n) / 2 := rfl
theorem shiftRight_succ (m n) : m >>> (n + 1) = (m >>> n) / 2 := rfl

/-- Shiftleft on successor with division moved inside. -/
theorem shiftRight_succ_inside : ∀m n, m >>> (n+1) = (m/2) >>> n
| m, 0 => rfl
| m, k + 1 => by
rw [shiftRight_succ _ (k+1)]
rw [shiftRight_succ_inside _ k, shiftRight_succ]

@[simp] theorem zero_shiftRight : ∀ n, 0 >>> n = 0
| 0 => by simp [shiftRight]
| n + 1 => by simp [shiftRight, zero_shiftRight]
| n + 1 => by simp [shiftRight, zero_shiftRight, shiftRight_succ]

theorem shiftRight_add (m n : Nat) : ∀ k, m >>> (n + k) = (m >>> n) >>> k
| 0 => rfl
| k + 1 => by simp [add_succ, shiftRight_add]
| k + 1 => by simp [add_succ, shiftRight_add, shiftRight_succ]

theorem shiftRight_eq_div_pow (m : Nat) : ∀ n, m >>> n = m / 2 ^ n
| 0 => (Nat.div_one _).symm
| k + 1 => by
rw [shiftRight_add, shiftRight_eq_div_pow m k]
simp [Nat.div_div_eq_div_mul, ← Nat.pow_succ]
simp [Nat.div_div_eq_div_mul, ← Nat.pow_succ, shiftRight_succ]

0 comments on commit e6c22fb

Please sign in to comment.