diff --git a/bvm/Shaders/Math.h b/bvm/Shaders/Math.h index e3dc79126..3ad35cb71 100644 --- a/bvm/Shaders/Math.h +++ b/bvm/Shaders/Math.h @@ -724,6 +724,7 @@ namespace MultiPrecision } friend struct Float; + friend class FloatEx; template void SetDivResidNormalized(UInt& __restrict__ resid, const UInt& __restrict__ div) @@ -1373,9 +1374,500 @@ namespace MultiPrecision } }; + + class FloatEx + { + uint64_t m_Mantissa; + int32_t m_Order; + // value == 1.mantissa_without_hibit << order + // sign - hi bit (positive) + + static const int32_t s_Zero = TypeTraits::Min; + static const int32_t s_NaN = TypeTraits::Max; + + static const uint32_t s_Bits = sizeof(m_Mantissa) * 8; + static const uint64_t s_HiBit = 1ull << (s_Bits - 1); + + template + void NormalizeToPositive() + { + assert(IsNumberNnz()); + + auto nz = BitUtils::clz(m_Mantissa); + switch (nz) + { + case 0: + break; // ok + + case s_Bits: + assert(!m_Mantissa); + m_Order = s_Zero; + return; + + default: + m_Mantissa <<= nz; + OrderAddInternal(-(signed) nz); + } + } + + template + void OrderAddInternal(int32_t d) + { + assert(IsNumberNnz()); + + int32_t n0 = m_Order; + m_Order += d; + + if constexpr (bCareful) + { + // if order arithmetics reaches Zero or Nan - it's fine. We only need to take care of overflow + if (d < 0) + { + if (m_Order > n0) + m_Order = s_Zero; + } + else + { + if (m_Order < n0) + m_Order = s_NaN; + } + } + else + assert(IsNumberNnz()); + } + + template + void AssignUnsAsPositive(T val) + { + static_assert(sizeof(val) <= sizeof(m_Mantissa), ""); + + m_Order = s_Bits - 1; + m_Mantissa = val; + NormalizeToPositive(); + } + + bool HaveHiBit() const + { + return !!(s_HiBit & m_Mantissa); + } + + uint64_t get_WithHiBit() const { return s_HiBit | m_Mantissa; } + + void AddOrSub_Ord(const FloatEx& __restrict__ b, bool bAdd, bool bFlip) + { + assert(m_Order >= b.m_Order); + if (IsNaN()) + return; + assert(!b.IsNaN()); + + uint32_t d = m_Order - b.m_Order; + if ((d < s_Bits) && !b.IsZero()) + { + assert(IsNumberNnz()); + + bool bNeg = !HaveHiBit(); + if (bNeg) + { + m_Mantissa |= s_HiBit; + bFlip = !bFlip; + } + + auto bVal = b.get_WithHiBit() >> d; + + if (b.HaveHiBit() == bNeg) + bAdd = !bAdd; + + if (bAdd) + { + m_Mantissa += bVal; + if (m_Mantissa < bVal) + AddMsb(); // overflow + } + else + { + if (m_Mantissa >= bVal) + m_Mantissa -= bVal; + else + { + assert(!d); + bFlip = !bFlip; + m_Mantissa = bVal - m_Mantissa; + } + + NormalizeToPositive(); + } + } + + if (bFlip) + Negate(); + } + + void AddOrSub(const FloatEx a, FloatEx b, bool bAdd) + { + if (a.m_Order >= b.m_Order) + { + *this = a; + AddOrSub_Ord(b, bAdd, false); + } + else + { + *this = b; + AddOrSub_Ord(a, bAdd, !bAdd); + } + } + + int cmp_Uns(const FloatEx& x) const + { + assert(IsNumberNnz() && x.IsNumberNnz()); + assert(!(s_HiBit & (m_Mantissa ^ x.m_Mantissa))); + if (m_Order > x.m_Order) + return 1; + if (m_Order < x.m_Order) + return -1; + if (m_Mantissa > x.m_Mantissa) + return 1; + if (m_Mantissa < x.m_Mantissa) + return -1; + + return 0; + } + + void AddMsb() + { + m_Mantissa = (m_Mantissa >> 1) | s_HiBit; + OrderAddInternal(1); + } + + static uint64_t DivInternal(uint64_t a, uint64_t b) + { + assert((s_HiBit & b) && (a < b)); + + MultiPrecision::UInt<4> nom; + nom.Set<2>(a); + + MultiPrecision::UInt<2> res; + res.SetDivResidNormalized(nom, MultiPrecision::From(b)); + + return res.Get<0, uint64_t>(); + } + + int get_Class() const + { + switch (m_Order) + { + case s_Zero: + return 0; + case s_NaN: + return 2; + } + return HaveHiBit() ? 1 : -1; + } + + public: + + FloatEx() { Set0(); } + + template + FloatEx(const T& x) { Assign(x); } + + bool IsNaN() const { return s_NaN == m_Order; } + bool IsZero() const { return s_Zero == m_Order; } + bool IsNumber() const { return s_NaN != m_Order; } + + bool IsNumberNnz() const + { + return IsNumber() && !IsZero(); + } + + bool IsPositive() const + { + return IsNumberNnz() && HaveHiBit(); + } + + bool IsNegative() const + { + return IsNumberNnz() && !HaveHiBit(); + } + + void Set0() { m_Order = s_Zero; } + void SetNaN() { m_Order = s_NaN; } + + void Negate() { m_Mantissa ^= s_HiBit; } // safe even if 0/NaN + + void AddOrder(int32_t n) + { + if (IsNumberNnz()) + OrderAddInternal(n); + } + void Assign(const FloatEx& x) + { + m_Mantissa = x.m_Mantissa; + m_Order = x.m_Order; + } + + template + void Assign(T val) + { + if constexpr (TypeTraits::IsSigned) + { + if (val < 0) + { + AssignUnsAsPositive(-val); + Negate(); + return; + } + } + AssignUnsAsPositive(val); + } + + template + FloatEx& operator = (const T& x) + { + Assign(x); + return *this; + } + + static FloatEx get_0() + { + FloatEx x; + x.Set0(); + return x; + } + + static FloatEx get_NaN() + { + FloatEx x; + x.SetNaN(); + return x; + } + + static FloatEx get_1() + { + FloatEx x; + x.m_Mantissa = s_HiBit; + x.m_Order = 0; + return x; + } + + static FloatEx get_Half() + { + FloatEx x; + x.m_Mantissa = s_HiBit; + x.m_Order = -1; + return x; + } + + static FloatEx get_1_minus_eps() + { + FloatEx x; + x.m_Mantissa = static_cast(-1); + x.m_Order = -1; + return x; + } + + template + bool RoundDown(T& ret) const + { + if (m_Order < 0) + { + ret = 0; + return true; // also covers Zero + } + + if (IsNaN()) + { + ret = 0; + return false; + } + + typedef TypeTraits Type; + static_assert(sizeof(T) <= sizeof(m_Mantissa)); + + constexpr int32_t nOrderMax = Type::Bits - !!Type::IsSigned - 1; + + if (HaveHiBit()) + { + if (m_Order > nOrderMax) + { + ret = Type::Max; // overflow/inf + return false; // overflow + } + } + else + { + // negative + if constexpr (!Type::IsSigned) + { + ret = 0; + return false; + } + else + { + if (m_Order > nOrderMax) + { + ret = Type::Min; // underflow + return false; + } + } + } + + assert((m_Order > 0) && (m_Order <= nOrderMax)); + uint32_t rs = s_Bits - 1 - m_Order; + + ret = static_cast(get_WithHiBit() >> rs); + + if constexpr (Type::IsSigned) + { + if (!HaveHiBit()) + ret = -ret; + } + + return true; + } + + template + bool Round(T& ret) const { + return (*this + get_Half()).RoundDown(ret); + } + + template + bool RoundUp(T& ret) const { + return (*this + get_1_minus_eps()).RoundDown(ret); + } + + FloatEx operator + (const FloatEx& b) const + { + FloatEx res; + res.AddOrSub(*this, b, true); + return res; + } + + FloatEx operator - (const FloatEx& b) const + { + FloatEx res; + res.AddOrSub(*this, b, false); + return res; + } + + FloatEx operator * (FloatEx b) const + { + if (IsNaN()) + return *this; + if (!b.IsNumberNnz()) + return b; + if (IsZero()) + return *this; + + auto x = MultiPrecision::From(get_WithHiBit()) * MultiPrecision::From(b.get_WithHiBit()); + + // the result may loose at most 1 msb + FloatEx res; + res.m_Order = m_Order; + res.OrderAddInternal(b.m_Order); + x.Get<2>(res.m_Mantissa); + + if (res.HaveHiBit()) + res.AddOrder(1); + else + res.m_Mantissa = (res.m_Mantissa << 1) | (x.get_Val<2>() >> (MultiPrecision::nWordBits - 1)); + + res.m_Mantissa ^= ((m_Mantissa ^ b.m_Mantissa) & s_HiBit); + + return res; + } + + FloatEx operator / (FloatEx b) const + { + if (!b.IsNumberNnz()) + return get_NaN(); + if (!IsNumberNnz()) + return *this; + + FloatEx res; + res.m_Order = m_Order; + res.OrderAddInternal(-b.m_Order); + res.AddOrder(-1); + + auto aVal = get_WithHiBit(); + auto bVal = b.get_WithHiBit(); + // since both operands are normalized, the result must be within (1/2 .. 2) + if (aVal >= bVal) + { + res.m_Mantissa = DivInternal(aVal - bVal, bVal); + res.AddMsb(); + } + else + { + res.m_Mantissa = DivInternal(aVal, bVal); + assert(s_HiBit & res.m_Mantissa); + } + + res.m_Mantissa ^= ((m_Mantissa ^ b.m_Mantissa) & s_HiBit); + + return res; + } + + FloatEx operator << (int32_t n) + { + auto res = *this; + res.AddOrder(n); + return res; + } + + FloatEx operator >> (int32_t n) + { + auto res = *this; + res.AddOrder(-n); + return res; + } + + FloatEx& operator += (FloatEx b) { + return *this = *this + b; + } + + FloatEx& operator -= (FloatEx b) { + return *this = *this - b; + } + + FloatEx& operator *= (FloatEx b) { + return *this = *this * b; + } + + FloatEx& operator /= (FloatEx b) { + return *this = *this / b; + } + + int cmp(const FloatEx& x) const + { + // negative, zero, positive, NaN + auto c0 = get_Class(); + auto c1 = x.get_Class(); + if (c0 < c1) + return -1; + if (c0 > c1) + return 1; + + if (!IsNumberNnz()) + return 0; + + int res = cmp_Uns(x); + return HaveHiBit() ? res : (-res); + } + + bool operator < (const FloatEx& x) const { return cmp(x) < 0; } + bool operator > (const FloatEx& x) const { return cmp(x) > 0; } + bool operator <= (const FloatEx& x) const { return cmp(x) <= 0; } + bool operator >= (const FloatEx& x) const { return cmp(x) >= 0; } + bool operator == (const FloatEx& x) const { return cmp(x) == 0; } + bool operator != (const FloatEx& x) const { return cmp(x) != 0; } + + }; + + #pragma pack (pop) static_assert(sizeof(Float) == 12); + static_assert(sizeof(FloatEx) == 12); static_assert(sizeof(FloatLegacy) == 16); } // namespace MultiPrecision