Skip to content

Commit

Permalink
Implement prototype for torch based fermionic library.
Browse files Browse the repository at this point in the history
Some function not implemented or defined
[ ] merge_edge
[ ] split_edge
[ ] contract
[ ] trace
[ ] identity
[ ] exponential
[ ] conjugate
[ ] svd
[ ] qr
  • Loading branch information
hzhangxyz committed Nov 3, 2023
0 parents commit e6e6781
Show file tree
Hide file tree
Showing 11 changed files with 1,576 additions and 0 deletions.
37 changes: 37 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
name: CI

on: [push, pull_request]

jobs:
CI:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Display Python version
run: python -c "import sys; print(sys.version)"
- name: Install CI tools
run: pip install pylint mypy pytest pytest-cov
- name: Show directory
run: pwd && ls
working-directory: ${{ runner.workspace }}
- name: Install requirements
run: pip install .
working-directory: ${{ runner.workspace }}
- name: Run pytest
run: pytest --cov=tat
working-directory: ${{ runner.workspace }}
- name: Run mypy
run: mypy tat
working-directory: ${{ runner.workspace }}
- name: Run pylint
run: pylint tat
working-directory: ${{ runner.workspace }}
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.coverage
__pycache__
3 changes: 3 additions & 0 deletions README.org
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
* TAT

A Fermionic tensor library based on pytorch.
25 changes: 25 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
[tool.pylint]
max-line-length = 120
generated-members = 'torch.*'

[tool.yapf]
based_on_style = 'google'
column_limit = 120

[tool.mypy]
check_untyped_defs = true
disallow_untyped_defs = true

[project]
name = 'tat'
version = '0.4.0'
description = "A Fermionic tensor library based on pytorch."
requires-python = ">=3.7"
authors = [
{email = "[email protected]"},
{name = "Hao Zhang"}
]
dependencies = [
'multimethod',
'torch',
]
6 changes: 6 additions & 0 deletions tat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
The tat is a Fermionic tensor library based on pytorch.
"""

from .edge import Edge
from .tensor import Tensor
191 changes: 191 additions & 0 deletions tat/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""
This file implement a compat layer for legacy TAT interface.
"""

from __future__ import annotations
import typing
from multimethod import multimethod
import torch
from .edge import Edge as E
from .tensor import Tensor as T

# pylint: disable=too-few-public-methods
# pylint: disable=too-many-instance-attributes


class CompatSymmetry:
"""
The common Symmetry namespace
"""

def __init__(self: CompatSymmetry, fermion: tuple[bool, ...], dtypes: tuple[torch.dtype, ...]) -> None:
self.fermion: tuple[bool, ...] = fermion
self.dtypes: tuple[torch.dtype, ...] = dtypes

# pylint: disable=invalid-name
self.S: CompatScalar
self.D: CompatScalar
self.C: CompatScalar
self.Z: CompatScalar
self.float32: CompatScalar
self.float64: CompatScalar
self.float: CompatScalar
self.complex64: CompatScalar
self.complex128: CompatScalar
self.complex: CompatScalar

self.S = self.float32 = CompatScalar(self, torch.float32)
self.D = self.float64 = self.float = CompatScalar(self, torch.float64)
self.C = self.complex64 = CompatScalar(self, torch.complex64)
self.Z = self.complex128 = self.complex = CompatScalar(self, torch.complex128)

def _parse_segments(self: CompatSymmetry, segments: list) -> tuple[tuple[torch.Tensor, ...], int]:
# Segments may be [Sym] or [(Sym, Size)]
try:
# try [(Sym, Size)] first
return self._parse_segments_kernel(segments)
except TypeError:
# Cannot unpack is a type error, value[index] is a type error, too
# convert [Sym] to [(Sym, Size)]
return self._parse_segments_kernel([(sym, 1) for sym in segments])

def _parse_segments_kernel(self: CompatSymmetry, segments: list) -> tuple[tuple[torch.Tensor, ...], int]:
# [(Sym, Size)] for every element
dimension = sum(dim for _, dim in segments)
symmetry = tuple(
torch.tensor(
sum(
([self._parse_segments_get_subsymmetry(sym, index)] * dim
for sym, dim in segments),
[],
), # Concat all segment for this subsymmetry
dtype=sub_symmetry,
) # Generate subsymmetry one by one
for index, sub_symmetry in enumerate(self.dtypes))
return symmetry, dimension

def _parse_segments_get_subsymmetry(self: CompatSymmetry, sym: object, index: int) -> object:
# Most of time, symmetry is a tuple of subsymmetry
# But if there is only ome subsymmetry, it could not be a tuple but subsymmetry itself.
# pylint: disable=no-else-return
if isinstance(sym, tuple):
return sym[index]
else:
if len(self.fermion) == 1:
return sym
else:
raise TypeError(f"{sym=} is not subscriptable")

@multimethod
def Edge(self: CompatSymmetry, dimension: int) -> E:
"""
Create edge with compat interface.
"""
# pylint: disable=invalid-name
symmetry = tuple(torch.zeros(dimension, dtype=sub_symmetry) for sub_symmetry in self.dtypes)
return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=False)

@Edge.register
def _(self: CompatSymmetry, segments: list, arrow: bool = False) -> E:
symmetry, dimension = self._parse_segments(segments)
return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=arrow)

@Edge.register
def _(self: CompatSymmetry, segments_and_bool: tuple[list, bool]) -> E:
segments, arrow = segments_and_bool
symmetry, dimension = self._parse_segments(segments)
return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=arrow)


class CompatScalar:
"""
The common Scalar namespace.
"""

def __init__(self: CompatScalar, symmetry: CompatSymmetry, dtype: torch.dtype) -> None:
self.symmetry: CompatSymmetry = symmetry
self.dtype: torch.dtype = dtype

@multimethod
def Tensor(self: CompatScalar, names: list[str], edges: list) -> T:
"""
Create tensor with compat names and edges.
"""
# pylint: disable=invalid-name
return T(
tuple(names),
tuple(self.symmetry.Edge(edge) for edge in edges),
fermion=self.symmetry.fermion,
dtypes=self.symmetry.dtypes,
dtype=self.dtype,
)

@Tensor.register
def _(self: CompatScalar) -> T:
result = T(
(),
(),
fermion=self.symmetry.fermion,
dtypes=self.symmetry.dtypes,
dtype=self.dtype,
)
result.data.reshape([-1])[0] = 1
return result

@Tensor.register
def _(
self: CompatScalar,
number: typing.Any,
names: list[str] | None = None,
edge_symmetry: list | None = None,
edge_arrow: list[bool] | None = None,
) -> T:
# Create high rank tensor with only one element
if names is None:
names = []
if edge_symmetry is None:
edge_symmetry = [None for _ in names]
if edge_arrow is None:
edge_arrow = [False for _ in names]
result = T(
tuple(names),
tuple(
E(
fermion=self.symmetry.fermion,
dtypes=self.symmetry.dtypes,
symmetry=tuple(
torch.tensor([self._create_size1_get_subsymmetry(symmetry, index)], dtype=dtype)
for index, dtype in enumerate(self.symmetry.dtypes)),
dimension=1,
arrow=arrow,
)
for symmetry, arrow in zip(edge_symmetry, edge_arrow)),
fermion=self.symmetry.fermion,
dtypes=self.symmetry.dtypes,
dtype=self.dtype,
)
result.data.reshape([-1])[0] = number
return result

def _create_size1_get_subsymmetry(self: CompatScalar, sym: object, index: int) -> object:
# pylint: disable=no-else-return
if sym is None:
return 0
elif isinstance(sym, tuple):
return sym[index]
else:
if len(self.symmetry.fermion) == 1:
return sym
else:
raise TypeError(f"{sym=} is not subscriptable")


No: CompatSymmetry = CompatSymmetry(fermion=(), dtypes=())
Z2: CompatSymmetry = CompatSymmetry(fermion=(False,), dtypes=(torch.bool,))
U1: CompatSymmetry = CompatSymmetry(fermion=(False,), dtypes=(torch.int,))
Fermi: CompatSymmetry = CompatSymmetry(fermion=(True,), dtypes=(torch.int,))
FermiZ2: CompatSymmetry = CompatSymmetry(fermion=(True, False), dtypes=(torch.int, torch.bool))
FermiU1: CompatSymmetry = CompatSymmetry(fermion=(True, False), dtypes=(torch.int, torch.int))
Parity: CompatSymmetry = CompatSymmetry(fermion=(True,), dtypes=(torch.bool,))
FermiFermi: CompatSymmetry = CompatSymmetry(fermion=(True, True), dtypes=(torch.int, torch.int))
Normal: CompatSymmetry = No
Loading

0 comments on commit e6e6781

Please sign in to comment.