Skip to content

Commit

Permalink
build_core can now also be used for MPS
Browse files Browse the repository at this point in the history
  • Loading branch information
PGelss authored May 23, 2024
1 parent a4d1a8e commit 90c8b6b
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions scikit_tt/tensor_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2081,14 +2081,24 @@ def build_core_vector(matrix_list: List[Union[np.ndarray, int]], field_type: str

m, n = matrix_list[i].shape # Determine matrix order

except ValueError: # Check that we do not consider other n-dimensional arrays.
except:

print("List has a non-two-dimensional array. Lists must contain matrices.\n")
try:

raise
m = matrix_list[i].shape

except ValueError:

print("List contains elements which are neither vectors nor matrices.\n")

raise

break # Stop once order of matrix (mxn) is determined.

if len(matrix_list[i].shape) == 1:
m = matrix_list[i].shape[0]
n = 1

core = np.zeros((r1, m, n, 1), dtype = field_type)

for i in range(r1):
Expand All @@ -2107,7 +2117,7 @@ def build_core_vector(matrix_list: List[Union[np.ndarray, int]], field_type: str

try:

core[i, :, :, 0] = matrix_list[i]
core[i, :, :, 0] = matrix_list[i].reshape([m, n])

except ValueError:

Expand Down Expand Up @@ -2173,12 +2183,21 @@ def build_core(matrix_list: Union[ List[List[Union[np.ndarray, int]]], List[Unio

m, n = list_element.shape

except ValueError:
except:

print("List has a non-two-dimensional array. Lists must contain matrices.\n")
try:

raise
m = list_element.shape

except ValueError:

print("List contains elements which are neither vectors nor matrices.\n")

raise

if len(list_element.shape) == 1:
m = list_element.shape[0]
n = 1

core = np.zeros((r1, m, n, r2), dtype = field_type)

Expand Down Expand Up @@ -2208,7 +2227,7 @@ def build_core(matrix_list: Union[ List[List[Union[np.ndarray, int]]], List[Unio

try:

core[i, :, :, j] = matrix_list_element
core[i, :, :, j] = matrix_list_element.reshape([m, n])

except TypeError:

Expand Down

0 comments on commit 90c8b6b

Please sign in to comment.