From 57729501c6cb987fb351dfc8c46cbbbf2ed8c455 Mon Sep 17 00:00:00 2001 From: hitonanode <32937551+hitonanode@users.noreply.github.com> Date: Sun, 22 Sep 2024 10:23:03 +0900 Subject: [PATCH] fix floor_sum --- utilities/floor_sum.hpp | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/utilities/floor_sum.hpp b/utilities/floor_sum.hpp index 451226af..01b4f8ea 100644 --- a/utilities/floor_sum.hpp +++ b/utilities/floor_sum.hpp @@ -1,21 +1,24 @@ #pragma once #include -// CUT begin // \sum_{i=0}^{n-1} floor((ai + b) / m) -// 0 <= n < 2e32 -// 1 <= m < 2e32 +// 0 <= n < 2e32 (if Int is long long) +// 1 <= m < 2e32 (if Int is long long) // 0 <= a, b < m // Complexity: O(lg(m)) -long long floor_sum(long long n, long long m, long long a, long long b) { - auto safe_mod = [](long long x, long long m) -> long long { +template +Int floor_sum(Int n, Int m, Int a, Int b) { + static_assert(-Int(1) < 0, "Int must be signed"); + static_assert(-Unsigned(1) > 0, "Unsigned must be unsigned"); + static_assert(sizeof(Unsigned) >= sizeof(Int), "Unsigned must be larger than Int"); + + auto safe_mod = [](Int x, Int m) -> Int { x %= m; if (x < 0) x += m; return x; }; - auto floor_sum_unsigned = [](unsigned long long n, unsigned long long m, unsigned long long a, - unsigned long long b) -> unsigned long long { - unsigned long long ans = 0; + auto floor_sum_unsigned = [](Unsigned n, Unsigned m, Unsigned a, Unsigned b) -> Unsigned { + Unsigned ans = 0; while (true) { if (a >= m) { ans += n * (n - 1) / 2 * (a / m); @@ -26,26 +29,26 @@ long long floor_sum(long long n, long long m, long long a, long long b) { b %= m; } - unsigned long long y_max = a * n + b; + Unsigned y_max = a * n + b; if (y_max < m) break; // y_max < m * (n + 1) // floor(y_max / m) <= n - n = (unsigned long long)(y_max / m); - b = (unsigned long long)(y_max % m); + n = (Unsigned)(y_max / m); + b = (Unsigned)(y_max % m); std::swap(m, a); } return ans; }; - unsigned long long ans = 0; + Unsigned ans = 0; if (a < 0) { - unsigned long long a2 = safe_mod(a, m); - ans -= 1ULL * n * (n - 1) / 2 * ((a2 - a) / m); + Unsigned a2 = safe_mod(a, m); + ans -= Unsigned(1) * n * (n - 1) / 2 * ((a2 - a) / m); a = a2; } if (b < 0) { - unsigned long long b2 = safe_mod(b, m); - ans -= 1ULL * n * ((b2 - b) / m); + Unsigned b2 = safe_mod(b, m); + ans -= Unsigned(1) * n * ((b2 - b) / m); b = b2; } return ans + floor_sum_unsigned(n, m, a, b);