From 16396f5371b2df30e33959b2ad899f4030fce172 Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Tue, 21 Nov 2023 09:18:53 +0800 Subject: [PATCH] Use pytorch internal QR instead, it seems work. --- tat/tensor.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tat/tensor.py b/tat/tensor.py index c122db337..5e0abcf88 100644 --- a/tat/tensor.py +++ b/tat/tensor.py @@ -9,7 +9,7 @@ from multimethod import multimethod import torch from . import _utility -from ._qr import givens_qr +from ._qr import givens_qr, householder_qr # pylint: disable=unused-import from ._svd import svd as manual_svd # pylint: disable=unused-import from .edge import Edge @@ -1753,12 +1753,13 @@ def qr( names=("QR_Q", "QR_R"), ) - if self.fermion: - # Blocked tensor, use Givens rotation - data_q, data_r = givens_qr(tensor.data) - else: - # Non-blocked tensor, use Householder reflection - data_q, data_r = torch.linalg.qr(tensor.data, mode="reduced") + # if self.fermion: + # # Blocked tensor, use Givens rotation + # data_q, data_r = givens_qr(tensor.data) + # else: + # # Non-blocked tensor, use Householder reflection + # data_q, data_r = torch.linalg.qr(tensor.data, mode="reduced") + data_q, data_r = torch.linalg.qr(tensor.data, mode="reduced") free_edge_q = tensor.edges[0] common_edge_q = Tensor._guess_edge(torch.abs(data_q), free_edge_q, True) @@ -1769,11 +1770,10 @@ def qr( dtypes=self.dtypes, data=data_q, ) - # tensor_q._ensure_mask() + tensor_q._ensure_mask() # pylint: disable=protected-access free_edge_r = tensor.edges[1] # common_edge_r = Tensor._guess_edge(torch.abs(data_r).transpose(0, 1), free_edge_r, False) # Sometimes R matrix maybe singular, guess edge will return arbitary symmetry, so do not use guessed edge. - # In the other hand, QR based on Givens rotation always gives blocked result, which can be believed. common_edge_r = common_edge_q.conjugate() tensor_r = Tensor( names=(common_name_r, "QR_R"), @@ -1782,7 +1782,7 @@ def qr( dtypes=self.dtypes, data=data_r, ) - # tensor_r._ensure_mask() + tensor_r._ensure_mask() # pylint: disable=protected-access assert common_edge_q.conjugate() == common_edge_r tensor_q = tensor_q.split_edge({"QR_Q": ordered_free_edges_q}, False, set())