Skip to content

Commit

Permalink
fix floor_sum
Browse files Browse the repository at this point in the history
  • Loading branch information
hitonanode committed Sep 22, 2024
1 parent 1b07158 commit 5772950
Showing 1 changed file with 19 additions and 16 deletions.
35 changes: 19 additions & 16 deletions utilities/floor_sum.hpp
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
#pragma once
#include <utility>

// CUT begin
// \sum_{i=0}^{n-1} floor((ai + b) / m)
// 0 <= n < 2e32
// 1 <= m < 2e32
// 0 <= n < 2e32 (if Int is long long)
// 1 <= m < 2e32 (if Int is long long)
// 0 <= a, b < m
// Complexity: O(lg(m))
long long floor_sum(long long n, long long m, long long a, long long b) {
auto safe_mod = [](long long x, long long m) -> long long {
template <class Int = long long, class Unsigned = unsigned long long>
Int floor_sum(Int n, Int m, Int a, Int b) {
static_assert(-Int(1) < 0, "Int must be signed");
static_assert(-Unsigned(1) > 0, "Unsigned must be unsigned");
static_assert(sizeof(Unsigned) >= sizeof(Int), "Unsigned must be larger than Int");

auto safe_mod = [](Int x, Int m) -> Int {
x %= m;
if (x < 0) x += m;
return x;
};
auto floor_sum_unsigned = [](unsigned long long n, unsigned long long m, unsigned long long a,
unsigned long long b) -> unsigned long long {
unsigned long long ans = 0;
auto floor_sum_unsigned = [](Unsigned n, Unsigned m, Unsigned a, Unsigned b) -> Unsigned {
Unsigned ans = 0;
while (true) {
if (a >= m) {
ans += n * (n - 1) / 2 * (a / m);
Expand All @@ -26,26 +29,26 @@ long long floor_sum(long long n, long long m, long long a, long long b) {
b %= m;
}

unsigned long long y_max = a * n + b;
Unsigned y_max = a * n + b;
if (y_max < m) break;
// y_max < m * (n + 1)
// floor(y_max / m) <= n
n = (unsigned long long)(y_max / m);
b = (unsigned long long)(y_max % m);
n = (Unsigned)(y_max / m);
b = (Unsigned)(y_max % m);
std::swap(m, a);
}
return ans;
};

unsigned long long ans = 0;
Unsigned ans = 0;
if (a < 0) {
unsigned long long a2 = safe_mod(a, m);
ans -= 1ULL * n * (n - 1) / 2 * ((a2 - a) / m);
Unsigned a2 = safe_mod(a, m);
ans -= Unsigned(1) * n * (n - 1) / 2 * ((a2 - a) / m);
a = a2;
}
if (b < 0) {
unsigned long long b2 = safe_mod(b, m);
ans -= 1ULL * n * ((b2 - b) / m);
Unsigned b2 = safe_mod(b, m);
ans -= Unsigned(1) * n * ((b2 - b) / m);
b = b2;
}
return ans + floor_sum_unsigned(n, m, a, b);
Expand Down

0 comments on commit 5772950

Please sign in to comment.