From 616ff8bbd5ab3eac2d748924138edb79ae881328 Mon Sep 17 00:00:00 2001 From: thbake Date: Thu, 11 Apr 2024 16:28:32 +0200 Subject: [PATCH] Updated ranks in tensorkrylov.py --- examples/tensorkrylov.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/examples/tensorkrylov.py b/examples/tensorkrylov.py index 80ce117..1214dd1 100644 --- a/examples/tensorkrylov.py +++ b/examples/tensorkrylov.py @@ -6,6 +6,7 @@ from numpy.random import seed, rand from scikit_tt.tensor_train import TT, build_core, residual_error, uniform, residual_error from scikit_tt.solvers.sle import als +from time import time class MatrixCollection(object): @@ -167,9 +168,13 @@ def _update_approximation(x_TT: "TT", V: "MatrixCollection", y_TT: "TT"): for s in range(x_TT.order): x_TT.cores[s] = np.sum(V[s][None, :, :, None, None] @ y_TT.cores[s][:, None, :, :, :], axis = 2) + x_TT.ranks[s] = x_TT.cores[s].shape[0] + x_TT.ranks[s + 1] = x_TT.cores[s].shape[3] + return +#def _residual_norm() def symmetric_tensorkrylov(A: "MatrixCollection", b: List[np.ndarray], rank: int, nmax: int, tol = 1e-9): @@ -213,11 +218,9 @@ def symmetric_tensorkrylov(A: "MatrixCollection", b: List[np.ndarray], rank: int y_TT = als(TT_operator, TT_guess, TT_rhs) _update_approximation(x_TT, V_minors, y_TT) - print(A_TT) - print(x_TT) - print(b_TT) - r_norm = residual_error(A_TT, x_TT, b_TT) + #r_norm = residual_error(A_TT, x_TT, b_TT) + #print(r_norm) if r_norm <= tol: return x_TT @@ -239,8 +242,23 @@ def random_rhs(n: int): A = MatrixCollection([ As for _ in range(d) ]) b = [ bs for _ in range(d) ] -rank = 5 +rank = 8 ranks = [1] + ([rank] * (d - 1)) + [1] +row_dims = [n for _ in range(d)] +col_dims = [1 for _ in range(d)] + +x_TT = scikit_tt.tensor_train.rand(row_dims, col_dims, ranks) +A_TT = _TT_operator(A, n - 1) +b_TT = _TT_rhs(b) + + +start = time() +x_TT = als(A_TT, x_TT, b_TT) +end = time() +print("Done", end - start) + +print(residual_error(A_TT, x_TT, b_TT)) + print(symmetric_tensorkrylov(A, b, rank, n, tol = 1e-9))