Skip to content

Commit

Permalink
some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
privefl committed Aug 13, 2024
1 parent 42ee7be commit c0f15be
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 3 deletions.
56 changes: 56 additions & 0 deletions tmp-tests/test-hybrid-gauss-seidel.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
A <- diag(4)
A[1, 2] <- A[2, 1] <- 0.8
A[3, 4] <- A[4, 3] <- -0.5
A2 <- A
A[2, 3] <- A[3, 2] <- 0.3
A
A2
ld <- colSums(A**2)

x0 <- rnorm(4)
b <- A %*% x0 + rnorm(1) / 1000

cbind(x0, solve(A, b), solve(A2, b))

gauss_seidel <- function(A, b, divide_diag = TRUE, niter = 20) {
K <- length(b)
x <- rep(0, K)
for (iter in 1:niter) {
# print(x)
for (j in 1:K) {
x[j] <- (b[j] - A[j, -j] %*% x[-j]) / `if`(divide_diag, A[j, j], 1)
}
# c(crossprod(x, A %*% x), crossprod(x, b))
}
x
}

cbind(x0, solve(A, b), solve(A2, b), gauss_seidel(A, b), gauss_seidel(A2, b))

solve(A[1:3, 1:3], b[1:3])
solve(A[2:4, 2:4], b[2:4])

K <- length(b)
x <- rep(0, K)
for (iter in 1:30) {
x1 <- x[1:3]; x2 <- x[2:4]
# print(x)
for (j in 1:3) x1[j] <- (b[1:3][j] - A[1:3, 1:3][j, -j] %*% x1[-j])
for (j in 1:3) x2[j] <- (b[2:4][j] - A[2:4, 2:4][j, -j] %*% x2[-j])
# w2 <- c(crossprod(A[c(1, 3), 2]), crossprod(A[c(3, 4), 2]))
w2 <- c(crossprod(A[c(1, 3), 2], solve(A[c(1, 3), c(1, 3)], A[c(1, 3), 2])),
crossprod(A[c(4, 3), 2], solve(A[c(4, 3), c(4, 3)], A[c(4, 3), 2])))
# w3 <- c(crossprod(A[c(1, 2), 3]), crossprod(A[c(2, 4), 3]))
w3 <- c(crossprod(A[c(1, 2), 3], solve(A[c(1, 2), c(1, 2)], A[c(1, 2), 3])),
crossprod(A[c(4, 2), 3], solve(A[c(4, 2), c(4, 2)], A[c(4, 2), 3])))
w2 <- w2 / sum(w2); w3 <- w3 / sum(w3)
x <- c(x1[1], x1[2] * w2[1] + x2[1] * w2[2], x1[3] * w3[1] + x2[2] * w3[2], x2[3])
print(cbind(c(x1, 0), c(0, x2), x))
}
(res <- cbind(x0, A = solve(A, b), A2 = solve(A2, b),
Al = solve(A + 0.01 * diag(4), b),
As = solve(0.9 * A + 0.1 * diag(4), b), x))

FCR <- 45
curve(FCR + (185 - FCR) * x, from = 0.5); curve(185 * x, add = TRUE, col = "blue")
62 + (185 - 62) * c(0.6, 0.7)
55 changes: 55 additions & 0 deletions tmp-tests/test-hybrid-gauss-seidel2.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
set.seed(1)
A <- diag(4)
A[1, 2] <- A[2, 1] <- 0.8
A[3, 4] <- A[4, 3] <- -0.5
A_block <- A
A[2, 3] <- A[3, 2] <- 0.3
A
A_block

x0 <- rnorm(4)
b <- A %*% x0 + rnorm(1) / 1000

# iterative algorithm to solve Ax=b
Gauss_Seidel <- function(A, b, niter = 20, verbose = FALSE) {
K <- length(b)
x <- rep(0, K)
for (iter in 1:niter) {
for (j in 1:K) x[j] <- (b[j] - A[j, -j] %*% x[-j]) / A[j, j]
if (verbose) print(x)
}
x
}

cbind.data.frame(x0, solve(A, b), Gauss_Seidel(A, b),
solve(A_block, b), Gauss_Seidel(A_block, b))



# Try overlapping block strategy
b1 <- b[1:3]; b2 <- b[2:4]; A1 <- A[1:3, 1:3]; A2 <- A[2:4, 2:4]

# w2 <- c(crossprod(A1[-2, 2]), crossprod(A2[-1, 1]))
w2 <- c(crossprod(A1[-2, 2], solve(A1[-2, -2], A1[-2, 2])),
crossprod(A2[-1, 1], solve(A2[-1, -1], A2[-1, 1])))
# w3 <- c(crossprod(A1[-3, 3]), crossprod(A2[-2, 2]))
w3 <- c(crossprod(A1[-3, 3], solve(A1[-3, -3], A1[-3, 3])),
crossprod(A2[-2, 2], solve(A2[-2, -2], A2[-2, 2])))
w2 <- w2 / sum(w2); w3 <- w3 / sum(w3)
list(w2, w3) # weights to merge overlapping results for several blocks

K <- length(b)
x <- rep(0, K)
for (iter in 1:30) {
x1 <- x[1:3]; x2 <- x[2:4]
for (j in 1:3) x1[j] <- (b1[j] - A1[j, -j] %*% x1[-j]) / A1[j, j]
for (j in 1:3) x2[j] <- (b2[j] - A2[j, -j] %*% x2[-j]) / A2[j, j]
x <- c(x1[1], x1[2] * w2[1] + x2[1] * w2[2], x1[3] * w3[1] + x2[2] * w3[2], x2[3])
print(cbind.data.frame(x1 = c(x1, 0), x2 = c(0, x2), x))
}
cbind.data.frame(x0, solve(A, b), solve(A_block, b), x)


cbind(A, bigutilsr::regul_glasso(A, 0.01))
bigutilsr::regul_glasso(A_block, 0.01)
A_block
10 changes: 7 additions & 3 deletions tmp-tests/test-ldpred2-auto-correct-h2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ List ldpred2_gibbs_auto(Environment corr,
const NumericVector& beta_hat,
const NumericVector& n_vec,
const IntegerVector& ind_sub,
const NumericVector& delta,
double p_init,
double h2_init,
int burn_in,
Expand Down Expand Up @@ -60,7 +61,9 @@ List ldpred2_gibbs_auto(Environment corr,

int j2 = ind_sub[j];
double dotprod = dotprods[j2];
double res_beta_hat_j = beta_hat[j] - shrink_corr * (dotprod - curr_beta[j]);
double delta_j = delta[j];
double res_beta_hat_j =
(beta_hat[j] - shrink_corr * (dotprod - curr_beta[j])) / (1 + delta_j);

double C1 = sigma2 * n_vec[j];
double C2 = 1 / (1 + 1 / C1);
Expand All @@ -71,7 +74,8 @@ List ldpred2_gibbs_auto(Environment corr,
(1 + inv_odd_p * ::sqrt(1 + C1) * ::exp(-C3 * C3 / C4 / 2));

double prev_beta = curr_beta[j];
double dotprod_shrunk = shrink_corr * dotprod + (1 - shrink_corr) * prev_beta;
double dotprod_shrunk =
shrink_corr * dotprod + (1 + delta_j - shrink_corr) * prev_beta;

if (k >= burn_in) {
avg_postp[j] += postp;
Expand All @@ -98,7 +102,7 @@ List ldpred2_gibbs_auto(Environment corr,
}

if (diff != 0) {
cur_h2_est += diff * (2 * dotprod_shrunk + diff);
cur_h2_est += diff * (2 * dotprod_shrunk + diff * (1 + delta_j));
dotprods = sfbm->incr_mult_col(j2, dotprods, diff);
}
}
Expand Down

0 comments on commit c0f15be

Please sign in to comment.