Skip to content

Commit

Permalink
Implement prototype for torch based fermionic library.
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhangxyz committed Nov 26, 2023
0 parents commit bdfdcd9
Show file tree
Hide file tree
Showing 17 changed files with 4,351 additions and 0 deletions.
48 changes: 48 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
name: CI

on: [push, pull_request]

jobs:
CI:

runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
include:
- python-version: "3.10"
pytorch-version: "1.12"
- python-version: "3.10"
pytorch-version: "1.13"
- python-version: "3.10"
pytorch-version: "2.0"
- python-version: "3.10"
pytorch-version: "2.1"

- python-version: "3.11"
pytorch-version: "1.13"
- python-version: "3.11"
pytorch-version: "2.0"
- python-version: "3.11"
pytorch-version: "2.1"

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install requirements
run: |
pip install pylint==2.17 mypy==1.6 pytest==7.4 pytest-cov==4.1
pip install torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu
pip install multimethod
- name: Run pylint
run: pylint tat tests
working-directory: ${{ github.workspace }}
- name: Run mypy
run: mypy tat tests
working-directory: ${{ github.workspace }}
- name: Run pytest
run: pytest
working-directory: ${{ github.workspace }}
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.coverage
.mypy_cache
__pycache__
env
675 changes: 675 additions & 0 deletions LICENSE.md

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# TAT

A Fermionic tensor library based on pytorch.
32 changes: 32 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
[project]
name = "tat"
version = "0.4.0"
authors = [
{email = "[email protected]", name = "Hao Zhang"}
]
description = "A Fermionic tensor library based on pytorch."
readme = "README.md"
requires-python = ">=3.10"
license = {text = "GPL-3.0-or-later"}
dependencies = [
"multimethod>=1.9",
"torch>=1.12",
]

[tool.pylint]
max-line-length = 120
generated-members = "torch.*"
init-hook="import sys; sys.path.append(\".\")"

[tool.yapf]
based_on_style = "google"
column_limit = 120

[tool.mypy]
check_untyped_defs = true
disallow_untyped_defs = true

[tool.pytest.ini_options]
pythonpath = "."
testpaths = ["tests",]
addopts = "--cov=tat"
6 changes: 6 additions & 0 deletions tat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
The tat is a Fermionic tensor library based on pytorch.
"""

from .edge import Edge
from .tensor import Tensor
233 changes: 233 additions & 0 deletions tat/_qr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
"""
This module implements QR decomposition based on Givens rotation and Householder reflection.
"""

import typing
import torch

# pylint: disable=invalid-name


@torch.jit.script
def _syminvadj(X: torch.Tensor) -> torch.Tensor:
ret = X + X.H
ret.diagonal().real[:] *= 1 / 2
return ret


@torch.jit.script
def _triliminvadjskew(X: torch.Tensor) -> torch.Tensor:
ret = torch.tril(X - X.H)
if torch.is_complex(X):
ret.diagonal().imag[:] *= 1 / 2
return ret


@torch.jit.script
def _qr_backward(
Q: torch.Tensor,
R: torch.Tensor,
Q_grad: typing.Optional[torch.Tensor],
R_grad: typing.Optional[torch.Tensor],
) -> typing.Optional[torch.Tensor]:
# see https://arxiv.org/pdf/2009.10071.pdf section 4.3 and 4.5
# see pytorch torch/csrc/autograd/FunctionsManual.cpp:linalg_qr_backward
m = Q.size(0)
n = R.size(1)

if Q_grad is not None:
if R_grad is not None:
MH = R_grad @ R.H - Q.H @ Q_grad
else:
MH = -Q.H @ Q_grad
else:
if R_grad is not None:
MH = R_grad @ R.H
else:
return None

# pylint: disable=no-else-return
if m >= n:
# Deep and square matrix
b = Q @ _syminvadj(torch.triu(MH))
if Q_grad is not None:
b = b + Q_grad
return torch.linalg.solve_triangular(R.H, b, upper=False, left=False)
else:
# Wide matrix
b = Q @ (_triliminvadjskew(-MH))
result = torch.linalg.solve_triangular(R[:, :m].H, b, upper=False, left=False)
result = torch.cat((result, torch.zeros([m, n - m], dtype=result.dtype, device=result.device)), dim=1)
if R_grad is not None:
result = result + Q @ R_grad
return result


class CommonQR(torch.autograd.Function):
"""
Implement the autograd function for QR.
"""

# pylint: disable=abstract-method

@staticmethod
def backward( # type: ignore[override]
ctx: typing.Any,
Q_grad: torch.Tensor | None,
R_grad: torch.Tensor | None,
) -> torch.Tensor | None:
# pylint: disable=arguments-differ
Q, R = ctx.saved_tensors
return _qr_backward(Q, R, Q_grad, R_grad)


@torch.jit.script
def _normalize_diagonal(a: torch.Tensor) -> torch.Tensor:
r = torch.sqrt(a.conj() * a)
return torch.where(
r == torch.zeros([], dtype=a.dtype, device=a.device),
torch.ones([], dtype=a.dtype, device=a.device),
a / r,
)


@torch.jit.script
def _givens_parameter(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
r = torch.sqrt(a.conj() * a + b.conj() * b)
return torch.where(
b == torch.zeros([], dtype=a.dtype, device=a.device),
torch.ones([], dtype=a.dtype, device=a.device),
a / r,
), torch.where(
b == torch.zeros([], dtype=a.dtype, device=a.device),
torch.zeros([], dtype=a.dtype, device=a.device),
b / r,
)


@torch.jit.script
def _givens_qr(A: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
m, n = A.shape
k = min(m, n)
Q = torch.eye(m, dtype=A.dtype, device=A.device)
R = A.clone(memory_format=torch.contiguous_format)

# Parallel strategy
# Every row rotated to the nearest row above
for g in range(m - 1, 0, -1):
# rotate R[g, 0], R[g+2, 1], R[g+4, 2], ...
for i, col in zip(range(g, m, 2), range(n)):
j = i - 1
# Rotate inside column col
# Rotate from row i to row j
c, s = _givens_parameter(R[j, col], R[i, col])
Q[i], Q[j] = -s * Q[j] + c * Q[i], c.conj() * Q[j] + s.conj() * Q[i]
R[i], R[j] = -s * R[j] + c * R[i], c.conj() * R[j] + s.conj() * R[i]
for g in range(1, k):
# rotate R[g+1, g], R[g+1+2, g+1], R[g+1+4, g+2], ...
for i, col in zip(range(g + 1, m, 2), range(g, n)):
j = i - 1
# Rotate inside column col
# Rotate from row i to row j
c, s = _givens_parameter(R[j, col], R[i, col])
Q[i], Q[j] = -s * Q[j] + c * Q[i], c.conj() * Q[j] + s.conj() * Q[i]
R[i], R[j] = -s * R[j] + c * R[i], c.conj() * R[j] + s.conj() * R[i]

# for j in range(n):
# for i in range(j + 1, m):
# col = j
# # Rotate inside column col
# # Rotate from row i to row j
# c, s = _givens_parameter(R[j, col], R[i, col])
# Q[i], Q[j] = -s * Q[j] + c * Q[i], c.conj() * Q[j] + s.conj() * Q[i]
# R[i], R[j] = -s * R[j] + c * R[i], c.conj() * R[j] + s.conj() * R[i]

# Make diagonal positive
c = _normalize_diagonal(R.diagonal()).conj()
Q[:k] *= torch.unsqueeze(c, 1)
R[:k] *= torch.unsqueeze(c, 1)

Q, R = Q[:k].H, R[:k]
return Q, R


class GivensQR(CommonQR):
"""
Compute the reduced QR decomposition using Givens rotation.
"""

# pylint: disable=abstract-method

@staticmethod
def forward( # type: ignore[override]
ctx: torch.autograd.function.FunctionCtx,
A: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# pylint: disable=arguments-differ
Q, R = _givens_qr(A)
ctx.save_for_backward(Q, R)
return Q, R


@torch.jit.script
def _normalize_delta(a: torch.Tensor) -> torch.Tensor:
norm = a.norm()
return torch.where(
norm == torch.zeros([], dtype=a.dtype, device=a.device),
torch.zeros([], dtype=a.dtype, device=a.device),
a / norm,
)


@torch.jit.script
def _reflect_target(x: torch.Tensor) -> torch.Tensor:
return torch.norm(x) * _normalize_diagonal(x[0])


@torch.jit.script
def _householder_qr(A: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
m, n = A.shape
k = min(m, n)
Q = torch.eye(m, dtype=A.dtype, device=A.device)
R = A.clone(memory_format=torch.contiguous_format)

for i in range(k):
x = R[i:, i]
v = torch.zeros_like(x)
# For complex matrix, it require <v|x> = <x|v>, i.e. v[0] and x[0] have opposite argument.
v[0] = _reflect_target(x)
# Reflect x to v
delta = _normalize_delta(v - x)
# H = 1 - 2 |Delta><Delta|
R[i:] -= 2 * torch.outer(delta, delta.conj() @ R[i:])
Q[i:] -= 2 * torch.outer(delta, delta.conj() @ Q[i:])

# Make diagonal positive
c = _normalize_diagonal(R.diagonal()).conj()
Q[:k] *= torch.unsqueeze(c, 1)
R[:k] *= torch.unsqueeze(c, 1)

Q, R = Q[:k].H, R[:k]
return Q, R


class HouseholderQR(CommonQR):
"""
Compute the reduced QR decomposition using Householder reflection.
"""

# pylint: disable=abstract-method

@staticmethod
def forward( # type: ignore[override]
ctx: torch.autograd.function.FunctionCtx,
A: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# pylint: disable=arguments-differ
Q, R = _householder_qr(A)
ctx.save_for_backward(Q, R)
return Q, R


givens_qr = GivensQR.apply
householder_qr = HouseholderQR.apply
Loading

0 comments on commit bdfdcd9

Please sign in to comment.