Skip to content

Commit

Permalink
Use pytorch internal QR instead, it seems work.
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhangxyz committed Nov 21, 2023
1 parent 858a40c commit 16396f5
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions tat/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from multimethod import multimethod
import torch
from . import _utility
from ._qr import givens_qr
from ._qr import givens_qr, householder_qr # pylint: disable=unused-import
from ._svd import svd as manual_svd # pylint: disable=unused-import
from .edge import Edge

Expand Down Expand Up @@ -1753,12 +1753,13 @@ def qr(
names=("QR_Q", "QR_R"),
)

if self.fermion:
# Blocked tensor, use Givens rotation
data_q, data_r = givens_qr(tensor.data)
else:
# Non-blocked tensor, use Householder reflection
data_q, data_r = torch.linalg.qr(tensor.data, mode="reduced")
# if self.fermion:
# # Blocked tensor, use Givens rotation
# data_q, data_r = givens_qr(tensor.data)
# else:
# # Non-blocked tensor, use Householder reflection
# data_q, data_r = torch.linalg.qr(tensor.data, mode="reduced")
data_q, data_r = torch.linalg.qr(tensor.data, mode="reduced")

free_edge_q = tensor.edges[0]
common_edge_q = Tensor._guess_edge(torch.abs(data_q), free_edge_q, True)
Expand All @@ -1769,11 +1770,10 @@ def qr(
dtypes=self.dtypes,
data=data_q,
)
# tensor_q._ensure_mask()
tensor_q._ensure_mask() # pylint: disable=protected-access
free_edge_r = tensor.edges[1]
# common_edge_r = Tensor._guess_edge(torch.abs(data_r).transpose(0, 1), free_edge_r, False)
# Sometimes R matrix maybe singular, guess edge will return arbitary symmetry, so do not use guessed edge.
# In the other hand, QR based on Givens rotation always gives blocked result, which can be believed.
common_edge_r = common_edge_q.conjugate()
tensor_r = Tensor(
names=(common_name_r, "QR_R"),
Expand All @@ -1782,7 +1782,7 @@ def qr(
dtypes=self.dtypes,
data=data_r,
)
# tensor_r._ensure_mask()
tensor_r._ensure_mask() # pylint: disable=protected-access
assert common_edge_q.conjugate() == common_edge_r

tensor_q = tensor_q.split_edge({"QR_Q": ordered_free_edges_q}, False, set())
Expand Down

0 comments on commit 16396f5

Please sign in to comment.