From 90c8b6b5c630efd8a2f5d18518a0a2093c9338af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patrick=20Gel=C3=9F?= <38036185+PGelss@users.noreply.github.com> Date: Thu, 23 May 2024 18:26:51 +0200 Subject: [PATCH] build_core can now also be used for MPS --- scikit_tt/tensor_train.py | 35 +++++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/scikit_tt/tensor_train.py b/scikit_tt/tensor_train.py index 55380f3..88d75be 100644 --- a/scikit_tt/tensor_train.py +++ b/scikit_tt/tensor_train.py @@ -2081,14 +2081,24 @@ def build_core_vector(matrix_list: List[Union[np.ndarray, int]], field_type: str m, n = matrix_list[i].shape # Determine matrix order - except ValueError: # Check that we do not consider other n-dimensional arrays. + except: - print("List has a non-two-dimensional array. Lists must contain matrices.\n") + try: - raise + m = matrix_list[i].shape + + except ValueError: + + print("List contains elements which are neither vectors nor matrices.\n") + + raise break # Stop once order of matrix (mxn) is determined. + if len(matrix_list[i].shape) == 1: + m = matrix_list[i].shape[0] + n = 1 + core = np.zeros((r1, m, n, 1), dtype = field_type) for i in range(r1): @@ -2107,7 +2117,7 @@ def build_core_vector(matrix_list: List[Union[np.ndarray, int]], field_type: str try: - core[i, :, :, 0] = matrix_list[i] + core[i, :, :, 0] = matrix_list[i].reshape([m, n]) except ValueError: @@ -2173,12 +2183,21 @@ def build_core(matrix_list: Union[ List[List[Union[np.ndarray, int]]], List[Unio m, n = list_element.shape - except ValueError: + except: - print("List has a non-two-dimensional array. Lists must contain matrices.\n") + try: - raise + m = list_element.shape + + except ValueError: + + print("List contains elements which are neither vectors nor matrices.\n") + + raise + if len(list_element.shape) == 1: + m = list_element.shape[0] + n = 1 core = np.zeros((r1, m, n, r2), dtype = field_type) @@ -2208,7 +2227,7 @@ def build_core(matrix_list: Union[ List[List[Union[np.ndarray, int]]], List[Unio try: - core[i, :, :, j] = matrix_list_element + core[i, :, :, j] = matrix_list_element.reshape([m, n]) except TypeError: