Skip to content

Commit

Permalink
Add compat function index_by_point and point_by_index for edge.
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhangxyz committed Nov 19, 2023
1 parent 09399e9 commit 1e2a7ac
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions tat/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,17 +267,9 @@ def parity(int_parity: int) -> bool:
# Segment index


def _get_index_for_position(position: tuple[typing.Any, int], edge: E) -> int:
sym, index = position
if not isinstance(sym, tuple):
sym = (sym,)
return next(total_index for total_index in range(edge.dimension) if all(
sub_sym == sub_symmetry[total_index] for sub_sym, sub_symmetry in zip(sym, edge.symmetry))) + index


@T._prepare_position.register # pylint: disable=protected-access,no-member
def _(self: T, position: dict[str, tuple[typing.Any, int]]) -> tuple[int, ...]:
return tuple(_get_index_for_position(position[name], edge) for name, edge in zip(self.names, self.edges))
return tuple(index_by_point(edge, position[name]) for name, edge in zip(self.names, self.edges))


# Function renames
Expand Down Expand Up @@ -323,6 +315,28 @@ def exponential(self: T, pairs: set[tuple[str, str]], step: int | None = None) -
return origin_exponential(self, pairs)


# Edge point conversion


@_compat_function(E)
def index_by_point(self: E, point: tuple[typing.Any, int]) -> int:
"Get index by point on an edge"
sym, sub_index = point
if not isinstance(sym, tuple):
sym = (sym,)
return next(total_index for total_index in range(self.dimension) if all(
sub_sym == sub_symmetry[total_index] for sub_sym, sub_symmetry in zip(sym, self.symmetry))) + sub_index


@_compat_function(E)
def point_by_index(self: E, index: int) -> tuple[typing.Any, int]:
"Get point by index on an edge"
sym = tuple(sub_symmetry[index] for sub_symmetry in self.symmetry)
sub_index = sum(
1 for i in range(index) if all(sub_sym == sub_symmetry[i] for sub_sym, sub_symmetry in zip(sym, self.symmetry)))
return sym, sub_index


# Random utility


Expand Down

0 comments on commit 1e2a7ac

Please sign in to comment.