diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml
new file mode 100644
index 000000000..bef185036
--- /dev/null
+++ b/.github/workflows/CI.yml
@@ -0,0 +1,48 @@
+name: CI
+
+on: [push, pull_request]
+
+jobs:
+ CI:
+
+ runs-on: ubuntu-latest
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ - python-version: "3.10"
+ pytorch-version: "1.12"
+ - python-version: "3.10"
+ pytorch-version: "1.13"
+ - python-version: "3.10"
+ pytorch-version: "2.0"
+ - python-version: "3.10"
+ pytorch-version: "2.1"
+
+ - python-version: "3.11"
+ pytorch-version: "1.13"
+ - python-version: "3.11"
+ pytorch-version: "2.0"
+ - python-version: "3.11"
+ pytorch-version: "2.1"
+
+ steps:
+ - uses: actions/checkout@v4
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install requirements
+ run: |
+ pip install pylint==2.17 mypy==1.6 pytest==7.4 pytest-cov==4.1
+ pip install torch==${{ matrix.pytorch-version }}+cpu --index-url https://download.pytorch.org/whl/cpu
+ pip install multimethod
+ - name: Run pylint
+ run: pylint tat tests
+ working-directory: ${{ github.workspace }}
+ - name: Run mypy
+ run: mypy tat tests
+ working-directory: ${{ github.workspace }}
+ - name: Run pytest
+ run: pytest
+ working-directory: ${{ github.workspace }}
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 000000000..a5d2ad6c8
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+.coverage
+.mypy_cache
+__pycache__
+env
\ No newline at end of file
diff --git a/LICENSE.md b/LICENSE.md
new file mode 100644
index 000000000..496acdb2a
--- /dev/null
+++ b/LICENSE.md
@@ -0,0 +1,675 @@
+# GNU GENERAL PUBLIC LICENSE
+
+Version 3, 29 June 2007
+
+Copyright (C) 2007 Free Software Foundation, Inc.
+
+
+Everyone is permitted to copy and distribute verbatim copies of this
+license document, but changing it is not allowed.
+
+## Preamble
+
+The GNU General Public License is a free, copyleft license for
+software and other kinds of works.
+
+The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+the GNU General Public License is intended to guarantee your freedom
+to share and change all versions of a program--to make sure it remains
+free software for all its users. We, the Free Software Foundation, use
+the GNU General Public License for most of our software; it applies
+also to any other work released this way by its authors. You can apply
+it to your programs, too.
+
+When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+To protect your rights, we need to prevent others from denying you
+these rights or asking you to surrender the rights. Therefore, you
+have certain responsibilities if you distribute copies of the
+software, or if you modify it: responsibilities to respect the freedom
+of others.
+
+For example, if you distribute copies of such a program, whether
+gratis or for a fee, you must pass on to the recipients the same
+freedoms that you received. You must make sure that they, too, receive
+or can get the source code. And you must show them these terms so they
+know their rights.
+
+Developers that use the GNU GPL protect your rights with two steps:
+(1) assert copyright on the software, and (2) offer you this License
+giving you legal permission to copy, distribute and/or modify it.
+
+For the developers' and authors' protection, the GPL clearly explains
+that there is no warranty for this free software. For both users' and
+authors' sake, the GPL requires that modified versions be marked as
+changed, so that their problems will not be attributed erroneously to
+authors of previous versions.
+
+Some devices are designed to deny users access to install or run
+modified versions of the software inside them, although the
+manufacturer can do so. This is fundamentally incompatible with the
+aim of protecting users' freedom to change the software. The
+systematic pattern of such abuse occurs in the area of products for
+individuals to use, which is precisely where it is most unacceptable.
+Therefore, we have designed this version of the GPL to prohibit the
+practice for those products. If such problems arise substantially in
+other domains, we stand ready to extend this provision to those
+domains in future versions of the GPL, as needed to protect the
+freedom of users.
+
+Finally, every program is threatened constantly by software patents.
+States should not allow patents to restrict development and use of
+software on general-purpose computers, but in those that do, we wish
+to avoid the special danger that patents applied to a free program
+could make it effectively proprietary. To prevent this, the GPL
+assures that patents cannot be used to render the program non-free.
+
+The precise terms and conditions for copying, distribution and
+modification follow.
+
+## TERMS AND CONDITIONS
+
+### 0. Definitions.
+
+"This License" refers to version 3 of the GNU General Public License.
+
+"Copyright" also means copyright-like laws that apply to other kinds
+of works, such as semiconductor masks.
+
+"The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of
+an exact copy. The resulting work is called a "modified version" of
+the earlier work or a work "based on" the earlier work.
+
+A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user
+through a computer network, with no transfer of a copy, is not
+conveying.
+
+An interactive user interface displays "Appropriate Legal Notices" to
+the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+### 1. Source Code.
+
+The "source code" for a work means the preferred form of the work for
+making modifications to it. "Object code" means any non-source form of
+a work.
+
+A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+The Corresponding Source need not include anything that users can
+regenerate automatically from other parts of the Corresponding Source.
+
+The Corresponding Source for a work in source code form is that same
+work.
+
+### 2. Basic Permissions.
+
+All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+You may make, run and propagate covered works that you do not convey,
+without conditions so long as your license otherwise remains in force.
+You may convey covered works to others for the sole purpose of having
+them make modifications exclusively for you, or provide you with
+facilities for running those works, provided that you comply with the
+terms of this License in conveying all material for which you do not
+control copyright. Those thus making or running the covered works for
+you must do so exclusively on your behalf, under your direction and
+control, on terms that prohibit them from making any copies of your
+copyrighted material outside their relationship with you.
+
+Conveying under any other circumstances is permitted solely under the
+conditions stated below. Sublicensing is not allowed; section 10 makes
+it unnecessary.
+
+### 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such
+circumvention is effected by exercising rights under this License with
+respect to the covered work, and you disclaim any intention to limit
+operation or modification of the work as a means of enforcing, against
+the work's users, your or third parties' legal rights to forbid
+circumvention of technological measures.
+
+### 4. Conveying Verbatim Copies.
+
+You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+### 5. Conveying Modified Source Versions.
+
+You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these
+conditions:
+
+- a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+- b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under
+ section 7. This requirement modifies the requirement in section 4
+ to "keep intact all notices".
+- c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+- d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+### 6. Conveying Non-Source Forms.
+
+You may convey a covered work in object code form under the terms of
+sections 4 and 5, provided that you also convey the machine-readable
+Corresponding Source under the terms of this License, in one of these
+ways:
+
+- a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+- b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the Corresponding
+ Source from a network server at no charge.
+- c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+- d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+- e) Convey the object code using peer-to-peer transmission,
+ provided you inform other peers where the object code and
+ Corresponding Source of the work are being offered to the general
+ public at no charge under subsection 6d.
+
+A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal,
+family, or household purposes, or (2) anything designed or sold for
+incorporation into a dwelling. In determining whether a product is a
+consumer product, doubtful cases shall be resolved in favor of
+coverage. For a particular product received by a particular user,
+"normally used" refers to a typical or common use of that class of
+product, regardless of the status of the particular user or of the way
+in which the particular user actually uses, or expects or is expected
+to use, the product. A product is a consumer product regardless of
+whether the product has substantial commercial, industrial or
+non-consumer uses, unless such uses represent the only significant
+mode of use of the product.
+
+"Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to
+install and execute modified versions of a covered work in that User
+Product from a modified version of its Corresponding Source. The
+information must suffice to ensure that the continued functioning of
+the modified object code is in no case prevented or interfered with
+solely because modification has been made.
+
+If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or
+updates for a work that has been modified or installed by the
+recipient, or for the User Product in which it has been modified or
+installed. Access to a network may be denied when the modification
+itself materially and adversely affects the operation of the network
+or violates the rules and protocols for communication across the
+network.
+
+Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+### 7. Additional Terms.
+
+"Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders
+of that material) supplement the terms of this License with terms:
+
+- a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+- b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+- c) Prohibiting misrepresentation of the origin of that material,
+ or requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+- d) Limiting the use for publicity purposes of names of licensors
+ or authors of the material; or
+- e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+- f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions
+ of it) with contractual assumptions of liability to the recipient,
+ for any liability that these contractual assumptions directly
+ impose on those licensors and authors.
+
+All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions; the
+above requirements apply either way.
+
+### 8. Termination.
+
+You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+However, if you cease all violation of this License, then your license
+from a particular copyright holder is reinstated (a) provisionally,
+unless and until the copyright holder explicitly and finally
+terminates your license, and (b) permanently, if the copyright holder
+fails to notify you of the violation by some reasonable means prior to
+60 days after the cessation.
+
+Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+### 9. Acceptance Not Required for Having Copies.
+
+You are not required to accept this License in order to receive or run
+a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+### 10. Automatic Licensing of Downstream Recipients.
+
+Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+### 11. Patents.
+
+A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+A contributor's "essential patent claims" are all patent claims owned
+or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+A patent license is "discriminatory" if it does not include within the
+scope of its coverage, prohibits the exercise of, or is conditioned on
+the non-exercise of one or more of the rights that are specifically
+granted under this License. You may not convey a covered work if you
+are a party to an arrangement with a third party that is in the
+business of distributing software, under which you make payment to the
+third party based on the extent of your activity of conveying the
+work, and under which the third party grants, to any of the parties
+who would receive the covered work from you, a discriminatory patent
+license (a) in connection with copies of the covered work conveyed by
+you (or copies made from those copies), or (b) primarily for and in
+connection with specific products or compilations that contain the
+covered work, unless you entered into that arrangement, or that patent
+license was granted, prior to 28 March 2007.
+
+Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+### 12. No Surrender of Others' Freedom.
+
+If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under
+this License and any other pertinent obligations, then as a
+consequence you may not convey it at all. For example, if you agree to
+terms that obligate you to collect a royalty for further conveying
+from those to whom you convey the Program, the only way you could
+satisfy both those terms and this License would be to refrain entirely
+from conveying the Program.
+
+### 13. Use with the GNU Affero General Public License.
+
+Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU Affero General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the special requirements of the GNU Affero General Public License,
+section 13, concerning interaction through a network will apply to the
+combination as such.
+
+### 14. Revised Versions of this License.
+
+The Free Software Foundation may publish revised and/or new versions
+of the GNU General Public License from time to time. Such new versions
+will be similar in spirit to the present version, but may differ in
+detail to address new problems or concerns.
+
+Each version is given a distinguishing version number. If the Program
+specifies that a certain numbered version of the GNU General Public
+License "or any later version" applies to it, you have the option of
+following the terms and conditions either of that numbered version or
+of any later version published by the Free Software Foundation. If the
+Program does not specify a version number of the GNU General Public
+License, you may choose any version ever published by the Free
+Software Foundation.
+
+If the Program specifies that a proxy can decide which future versions
+of the GNU General Public License can be used, that proxy's public
+statement of acceptance of a version permanently authorizes you to
+choose that version for the Program.
+
+Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+### 15. Disclaimer of Warranty.
+
+THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT
+WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND
+PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE
+DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR
+CORRECTION.
+
+### 16. Limitation of Liability.
+
+IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR
+CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES,
+INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES
+ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT
+NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR
+LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM
+TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER
+PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
+
+### 17. Interpretation of Sections 15 and 16.
+
+If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+END OF TERMS AND CONDITIONS
+
+## How to Apply These Terms to Your New Programs
+
+If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these
+terms.
+
+To do so, attach the following notices to the program. It is safest to
+attach them to the start of each source file to most effectively state
+the exclusion of warranty; and each file should have at least the
+"copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper
+mail.
+
+If the program does terminal interaction, make it output a short
+notice like this when it starts in an interactive mode:
+
+ Copyright (C)
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
+ This is free software, and you are welcome to redistribute it
+ under certain conditions; type `show c' for details.
+
+The hypothetical commands \`show w' and \`show c' should show the
+appropriate parts of the General Public License. Of course, your
+program's commands might be different; for a GUI interface, you would
+use an "about box".
+
+You should also get your employer (if you work as a programmer) or
+school, if any, to sign a "copyright disclaimer" for the program, if
+necessary. For more information on this, and how to apply and follow
+the GNU GPL, see .
+
+The GNU General Public License does not permit incorporating your
+program into proprietary programs. If your program is a subroutine
+library, you may consider it more useful to permit linking proprietary
+applications with the library. If this is what you want to do, use the
+GNU Lesser General Public License instead of this License. But first,
+please read .
diff --git a/README.md b/README.md
new file mode 100644
index 000000000..3370c23ae
--- /dev/null
+++ b/README.md
@@ -0,0 +1,3 @@
+# TAT
+
+A Fermionic tensor library based on pytorch.
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 000000000..66bd33d52
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,32 @@
+[project]
+name = "tat"
+version = "0.4.0"
+authors = [
+ {email = "zh970205@mail.ustc.edu.cn", name = "Hao Zhang"}
+]
+description = "A Fermionic tensor library based on pytorch."
+readme = "README.md"
+requires-python = ">=3.10"
+license = {text = "GPL-3.0-or-later"}
+dependencies = [
+ "multimethod>=1.9",
+ "torch>=1.12",
+]
+
+[tool.pylint]
+max-line-length = 120
+generated-members = "torch.*"
+init-hook="import sys; sys.path.append(\".\")"
+
+[tool.yapf]
+based_on_style = "google"
+column_limit = 120
+
+[tool.mypy]
+check_untyped_defs = true
+disallow_untyped_defs = true
+
+[tool.pytest.ini_options]
+pythonpath = "."
+testpaths = ["tests",]
+addopts = "--cov=tat"
diff --git a/tat/__init__.py b/tat/__init__.py
new file mode 100644
index 000000000..9f8cc4bcb
--- /dev/null
+++ b/tat/__init__.py
@@ -0,0 +1,6 @@
+"""
+The tat is a Fermionic tensor library based on pytorch.
+"""
+
+from .edge import Edge
+from .tensor import Tensor
diff --git a/tat/_qr.py b/tat/_qr.py
new file mode 100644
index 000000000..595279610
--- /dev/null
+++ b/tat/_qr.py
@@ -0,0 +1,233 @@
+"""
+This module implements QR decomposition based on Givens rotation and Householder reflection.
+"""
+
+import typing
+import torch
+
+# pylint: disable=invalid-name
+
+
+@torch.jit.script
+def _syminvadj(X: torch.Tensor) -> torch.Tensor:
+ ret = X + X.H
+ ret.diagonal().real[:] *= 1 / 2
+ return ret
+
+
+@torch.jit.script
+def _triliminvadjskew(X: torch.Tensor) -> torch.Tensor:
+ ret = torch.tril(X - X.H)
+ if torch.is_complex(X):
+ ret.diagonal().imag[:] *= 1 / 2
+ return ret
+
+
+@torch.jit.script
+def _qr_backward(
+ Q: torch.Tensor,
+ R: torch.Tensor,
+ Q_grad: typing.Optional[torch.Tensor],
+ R_grad: typing.Optional[torch.Tensor],
+) -> typing.Optional[torch.Tensor]:
+ # see https://arxiv.org/pdf/2009.10071.pdf section 4.3 and 4.5
+ # see pytorch torch/csrc/autograd/FunctionsManual.cpp:linalg_qr_backward
+ m = Q.size(0)
+ n = R.size(1)
+
+ if Q_grad is not None:
+ if R_grad is not None:
+ MH = R_grad @ R.H - Q.H @ Q_grad
+ else:
+ MH = -Q.H @ Q_grad
+ else:
+ if R_grad is not None:
+ MH = R_grad @ R.H
+ else:
+ return None
+
+ # pylint: disable=no-else-return
+ if m >= n:
+ # Deep and square matrix
+ b = Q @ _syminvadj(torch.triu(MH))
+ if Q_grad is not None:
+ b = b + Q_grad
+ return torch.linalg.solve_triangular(R.H, b, upper=False, left=False)
+ else:
+ # Wide matrix
+ b = Q @ (_triliminvadjskew(-MH))
+ result = torch.linalg.solve_triangular(R[:, :m].H, b, upper=False, left=False)
+ result = torch.cat((result, torch.zeros([m, n - m], dtype=result.dtype, device=result.device)), dim=1)
+ if R_grad is not None:
+ result = result + Q @ R_grad
+ return result
+
+
+class CommonQR(torch.autograd.Function):
+ """
+ Implement the autograd function for QR.
+ """
+
+ # pylint: disable=abstract-method
+
+ @staticmethod
+ def backward( # type: ignore[override]
+ ctx: typing.Any,
+ Q_grad: torch.Tensor | None,
+ R_grad: torch.Tensor | None,
+ ) -> torch.Tensor | None:
+ # pylint: disable=arguments-differ
+ Q, R = ctx.saved_tensors
+ return _qr_backward(Q, R, Q_grad, R_grad)
+
+
+@torch.jit.script
+def _normalize_diagonal(a: torch.Tensor) -> torch.Tensor:
+ r = torch.sqrt(a.conj() * a)
+ return torch.where(
+ r == torch.zeros([], dtype=a.dtype, device=a.device),
+ torch.ones([], dtype=a.dtype, device=a.device),
+ a / r,
+ )
+
+
+@torch.jit.script
+def _givens_parameter(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ r = torch.sqrt(a.conj() * a + b.conj() * b)
+ return torch.where(
+ b == torch.zeros([], dtype=a.dtype, device=a.device),
+ torch.ones([], dtype=a.dtype, device=a.device),
+ a / r,
+ ), torch.where(
+ b == torch.zeros([], dtype=a.dtype, device=a.device),
+ torch.zeros([], dtype=a.dtype, device=a.device),
+ b / r,
+ )
+
+
+@torch.jit.script
+def _givens_qr(A: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ m, n = A.shape
+ k = min(m, n)
+ Q = torch.eye(m, dtype=A.dtype, device=A.device)
+ R = A.clone(memory_format=torch.contiguous_format)
+
+ # Parallel strategy
+ # Every row rotated to the nearest row above
+ for g in range(m - 1, 0, -1):
+ # rotate R[g, 0], R[g+2, 1], R[g+4, 2], ...
+ for i, col in zip(range(g, m, 2), range(n)):
+ j = i - 1
+ # Rotate inside column col
+ # Rotate from row i to row j
+ c, s = _givens_parameter(R[j, col], R[i, col])
+ Q[i], Q[j] = -s * Q[j] + c * Q[i], c.conj() * Q[j] + s.conj() * Q[i]
+ R[i], R[j] = -s * R[j] + c * R[i], c.conj() * R[j] + s.conj() * R[i]
+ for g in range(1, k):
+ # rotate R[g+1, g], R[g+1+2, g+1], R[g+1+4, g+2], ...
+ for i, col in zip(range(g + 1, m, 2), range(g, n)):
+ j = i - 1
+ # Rotate inside column col
+ # Rotate from row i to row j
+ c, s = _givens_parameter(R[j, col], R[i, col])
+ Q[i], Q[j] = -s * Q[j] + c * Q[i], c.conj() * Q[j] + s.conj() * Q[i]
+ R[i], R[j] = -s * R[j] + c * R[i], c.conj() * R[j] + s.conj() * R[i]
+
+ # for j in range(n):
+ # for i in range(j + 1, m):
+ # col = j
+ # # Rotate inside column col
+ # # Rotate from row i to row j
+ # c, s = _givens_parameter(R[j, col], R[i, col])
+ # Q[i], Q[j] = -s * Q[j] + c * Q[i], c.conj() * Q[j] + s.conj() * Q[i]
+ # R[i], R[j] = -s * R[j] + c * R[i], c.conj() * R[j] + s.conj() * R[i]
+
+ # Make diagonal positive
+ c = _normalize_diagonal(R.diagonal()).conj()
+ Q[:k] *= torch.unsqueeze(c, 1)
+ R[:k] *= torch.unsqueeze(c, 1)
+
+ Q, R = Q[:k].H, R[:k]
+ return Q, R
+
+
+class GivensQR(CommonQR):
+ """
+ Compute the reduced QR decomposition using Givens rotation.
+ """
+
+ # pylint: disable=abstract-method
+
+ @staticmethod
+ def forward( # type: ignore[override]
+ ctx: torch.autograd.function.FunctionCtx,
+ A: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ # pylint: disable=arguments-differ
+ Q, R = _givens_qr(A)
+ ctx.save_for_backward(Q, R)
+ return Q, R
+
+
+@torch.jit.script
+def _normalize_delta(a: torch.Tensor) -> torch.Tensor:
+ norm = a.norm()
+ return torch.where(
+ norm == torch.zeros([], dtype=a.dtype, device=a.device),
+ torch.zeros([], dtype=a.dtype, device=a.device),
+ a / norm,
+ )
+
+
+@torch.jit.script
+def _reflect_target(x: torch.Tensor) -> torch.Tensor:
+ return torch.norm(x) * _normalize_diagonal(x[0])
+
+
+@torch.jit.script
+def _householder_qr(A: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ m, n = A.shape
+ k = min(m, n)
+ Q = torch.eye(m, dtype=A.dtype, device=A.device)
+ R = A.clone(memory_format=torch.contiguous_format)
+
+ for i in range(k):
+ x = R[i:, i]
+ v = torch.zeros_like(x)
+ # For complex matrix, it require = , i.e. v[0] and x[0] have opposite argument.
+ v[0] = _reflect_target(x)
+ # Reflect x to v
+ delta = _normalize_delta(v - x)
+ # H = 1 - 2 |Delta> tuple[torch.Tensor, torch.Tensor]:
+ # pylint: disable=arguments-differ
+ Q, R = _householder_qr(A)
+ ctx.save_for_backward(Q, R)
+ return Q, R
+
+
+givens_qr = GivensQR.apply
+householder_qr = HouseholderQR.apply
diff --git a/tat/_svd.py b/tat/_svd.py
new file mode 100644
index 000000000..bfbc9dcc4
--- /dev/null
+++ b/tat/_svd.py
@@ -0,0 +1,286 @@
+"""
+This module implements SVD decomposition without Householder reflection.
+"""
+
+import typing
+import torch
+from ._qr import _normalize_diagonal, _givens_parameter
+
+# pylint: disable=invalid-name
+
+
+@torch.jit.script
+def _svd(A: torch.Tensor, error: float) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ # pylint: disable=too-many-locals
+ # pylint: disable=too-many-branches
+ # pylint: disable=too-many-statements
+ # pylint: disable=too-many-nested-blocks
+
+ # see https://web.stanford.edu/class/cme335/lecture6.pdf
+ m, n = A.shape
+ trans = False
+ if m < n:
+ trans = True
+ A = A.transpose(0, 1)
+ m, n = n, m
+ U = torch.eye(m, dtype=A.dtype, device=A.device)
+ V = torch.eye(n, dtype=A.dtype, device=A.device)
+
+ # Make bidiagonal matrix
+ B = A.clone(memory_format=torch.contiguous_format)
+ for i in range(n):
+ # (i:, i)
+ for j in range(m - 1, i, -1):
+ col = i
+ # Rotate inside col i
+ # Rotate from row j to j-1
+ c, s = _givens_parameter(B[j - 1, col], B[j, col])
+ U[j], U[j - 1] = -s * U[j - 1] + c * U[j], c.conj() * U[j - 1] + s.conj() * U[j]
+ B[j], B[j - 1] = -s * B[j - 1] + c * B[j], c.conj() * B[j - 1] + s.conj() * B[j]
+ # x = B[i:, i]
+ # v = torch.zeros_like(x)
+ # v[0] = _reflect_target(x)
+ # delta = _normalize_delta(v - x)
+ # B[i:, :] -= 2 * torch.outer(delta, delta.conj() @ B[i:, :])
+ # U[i:, :] -= 2 * torch.outer(delta, delta.conj() @ U[i:, :])
+
+ # (i, i+1:)/H
+ if i == n - 1:
+ break
+ for j in range(n - 1, i + 1, -1):
+ row = i
+ # Rotate inside row i
+ # Rotate from col j to j-1
+ c, s = _givens_parameter(B[row, j - 1], B[row, j])
+ V[j], V[j - 1] = -s * V[j - 1] + c * V[j], c.conj() * V[j - 1] + s.conj() * V[j]
+ B[:, j], B[:, j - 1] = -s * B[:, j - 1] + c * B[:, j], c.conj() * B[:, j - 1] + s.conj() * B[:, j]
+ # x = B[i, i + 1:]
+ # v = torch.zeros_like(x)
+ # v[0] = _reflect_target(x)
+ # delta = _normalize_delta(v - x)
+ # B[:, i + 1:] -= 2 * torch.outer(B[:, i + 1:] @ delta.conj(), delta)
+ # V[i + 1:, :] -= 2 * torch.outer(delta, delta.conj() @ V[i + 1:, :])
+ B = B[:n]
+ U = U[:n]
+ # print(B)
+ # error_decomp = torch.max(torch.abs(U.H @ B @ V.H.T - A)).item()
+ # assert error_decomp < 1e-4
+
+ # QR iteration with implicit Q
+ S = torch.diagonal(B).clone(memory_format=torch.contiguous_format)
+ F = torch.diagonal(B, offset=1).clone(memory_format=torch.contiguous_format)
+ F.resize_(S.size(0))
+ F[-1] = 0
+ X = F[-1]
+ stack: list[tuple[int, int]] = [(0, n - 1)]
+ while stack:
+ # B.zero_()
+ # B.diagonal()[:] = S
+ # B.diagonal(offset = 1)[:] = F[:-1]
+ # error_decomp = torch.max(torch.abs(U.H @ B @ V.H.T - A)).item()
+ # assert error_decomp < 1e-4
+
+ low = stack[-1][0]
+ high = stack[-1][1]
+
+ if low == high:
+ stack.pop()
+ continue
+
+ max_diagonal = torch.abs(S[low])
+ for b in range(low, high + 1):
+ Sb = torch.abs(S[b])
+ if Sb < max_diagonal:
+ max_diagonal = Sb
+ # Check if S[b] is zero
+ if Sb < error:
+ # pylint: disable=no-else-continue
+ if b == low:
+ X = F[b].clone()
+ F[b] = 0
+ for i in range(b + 1, high + 1):
+ c, s = _givens_parameter(S[i], X)
+ U[b], U[i] = -s * U[i] + c * U[b], c.conj() * U[i] + s.conj() * U[b]
+
+ S[i] = c.conj() * S[i] + s.conj() * X
+ if i != high:
+ X, F[i] = -s * F[i] + c * X, c.conj() * F[i] + s.conj() * X
+ stack.pop()
+ stack.append((b + 1, high))
+ stack.append((low, b))
+ continue
+ else:
+ X = F[b - 1].clone()
+ F[b - 1] = 0
+ for i in range(b - 1, low - 1, -1):
+ c, s = _givens_parameter(S[i], X)
+ V[b], V[i] = -s * V[i] + c * V[b], c.conj() * V[i] + s.conj() * V[b]
+
+ S[i] = c.conj() * S[i] + s.conj() * X
+ if i != low:
+ X, F[i - 1] = -s * F[i - 1] + c * X, c.conj() * F[i - 1] + s.conj() * X
+ stack.pop()
+ stack.append((b, high))
+ stack.append((low, b - 1))
+ continue
+
+ b = int(torch.argmin(torch.abs(F[low:high]))) + low
+ if torch.abs(F[b]) < max_diagonal * error:
+ F[b] = 0
+ stack.pop()
+ stack.append((b + 1, high))
+ stack.append((low, b))
+ continue
+
+ tdn = (S[b + 1].conj() * S[b + 1] + F[b].conj() * F[b]).real
+ tdn_1 = (S[b].conj() * S[b] + F[b - 1].conj() * F[b - 1]).real
+ tsn_1 = F[b].conj() * S[b]
+ d = (tdn_1 - tdn) / 2
+ mu = tdn + d - torch.sign(d) * torch.sqrt(d**2 + tsn_1.conj() * tsn_1)
+ for i in range(low, high):
+ if i == low:
+ c, s = _givens_parameter(S[low].conj() * S[low] - mu, S[low].conj() * F[low])
+ else:
+ c, s = _givens_parameter(F[i - 1], X)
+ V[i + 1], V[i] = -s * V[i] + c * V[i + 1], c.conj() * V[i] + s.conj() * V[i + 1]
+ if i != low:
+ F[i - 1] = c.conj() * F[i - 1] + s.conj() * X
+ F[i], S[i] = -s * S[i] + c * F[i], c.conj() * S[i] + s.conj() * F[i]
+ S[i + 1], X = c * S[i + 1], s.conj() * S[i + 1]
+
+ c, s = _givens_parameter(S[i], X)
+ U[i + 1], U[i] = -s * U[i] + c * U[i + 1], c.conj() * U[i] + s.conj() * U[i + 1]
+
+ S[i] = c.conj() * S[i] + s.conj() * X
+ S[i + 1], F[i] = -s * F[i] + c * S[i + 1], c.conj() * F[i] + s.conj() * S[i + 1]
+ if i != high - 1:
+ F[i + 1], X = c * F[i + 1], s.conj() * F[i + 1]
+
+ # Make diagonal positive
+ c = _normalize_diagonal(S).conj()
+ V *= c.unsqueeze(1) # U is larger than V
+ S *= c
+ S = S.real
+
+ # Sort
+ S, order = torch.sort(S, descending=True)
+ U = U[order]
+ V = V[order]
+
+ # pylint: disable=no-else-return
+ if trans:
+ return V.H, S, U.H.T
+ else:
+ return U.H, S, V.H.T
+
+
+@torch.jit.script
+def _skew(A: torch.Tensor) -> torch.Tensor:
+ return A - A.H
+
+
+@torch.jit.script
+def _svd_backward(
+ U: torch.Tensor,
+ S: torch.Tensor,
+ Vh: torch.Tensor,
+ gU: typing.Optional[torch.Tensor],
+ gS: typing.Optional[torch.Tensor],
+ gVh: typing.Optional[torch.Tensor],
+) -> typing.Optional[torch.Tensor]:
+ # pylint: disable=too-many-locals
+ # pylint: disable=too-many-branches
+ # pylint: disable=too-many-arguments
+
+ # See pytorch torch/csrc/autograd/FunctionsManual.cpp:svd_backward
+ if gS is None and gU is None and gVh is None:
+ return None
+
+ m = U.size(0)
+ n = Vh.size(1)
+
+ if gU is None and gVh is None:
+ assert gS is not None
+ # pylint: disable=no-else-return
+ if m >= n:
+ return U @ (gS.unsqueeze(1) * Vh)
+ else:
+ return (U * gS.unsqueeze(0)) @ Vh
+
+ is_complex = torch.is_complex(U)
+
+ UhgU = _skew(U.H @ gU) if gU is not None else None
+ VhgV = _skew(Vh @ gVh.H) if gVh is not None else None
+
+ S2 = S * S
+ E = S2.unsqueeze(0) - S2.unsqueeze(1)
+ E.diagonal()[:] = 1
+
+ if gU is not None:
+ if gVh is not None:
+ assert UhgU is not None
+ assert VhgV is not None
+ gA = (UhgU * S.unsqueeze(0) + S.unsqueeze(1) * VhgV) / E
+ else:
+ assert UhgU is not None
+ gA = (UhgU / E) * S.unsqueeze(0)
+ else:
+ assert VhgV is not None
+ gA = S.unsqueeze(1) * (VhgV / E)
+
+ if gS is not None:
+ gA = gA + torch.diag(gS)
+
+ if is_complex and gU is not None and gVh is not None:
+ assert UhgU is not None
+ gA = gA + torch.diag(UhgU.diagonal() / (2 * S))
+
+ if m > n and gU is not None:
+ gA = U @ gA
+ gUSinv = gU / S.unsqueeze(0)
+ gA = gA + gUSinv - U @ (U.H @ gUSinv)
+ gA = gA @ Vh
+ elif m < n and gVh is not None:
+ gA = gA @ Vh
+ SinvgVh = gVh / S.unsqueeze(1)
+ gA = gA + SinvgVh - (SinvgVh @ Vh.H) @ Vh
+ gA = U @ gA
+ elif m >= n:
+ gA = U @ (gA @ Vh)
+ else:
+ gA = (U @ gA) @ Vh
+
+ return gA
+
+
+class SVD(torch.autograd.Function):
+ """
+ Compute SVD decomposition without Householder reflection.
+ """
+
+ # pylint: disable=abstract-method
+
+ @staticmethod
+ def forward( # type: ignore[override]
+ ctx: torch.autograd.function.FunctionCtx,
+ A: torch.Tensor,
+ error: float,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ # pylint: disable=arguments-differ
+ U, S, V = _svd(A, error)
+ ctx.save_for_backward(U, S, V)
+ return U, S, V
+
+ @staticmethod
+ def backward( # type: ignore[override]
+ ctx: typing.Any,
+ U_grad: torch.Tensor | None,
+ S_grad: torch.Tensor | None,
+ V_grad: torch.Tensor,
+ ) -> tuple[torch.Tensor | None, None]:
+ # pylint: disable=arguments-differ
+ U, S, V = ctx.saved_tensors
+ return _svd_backward(U, S, V, U_grad, S_grad, V_grad), None
+
+
+svd = SVD.apply
diff --git a/tat/_utility.py b/tat/_utility.py
new file mode 100644
index 000000000..05491dcc6
--- /dev/null
+++ b/tat/_utility.py
@@ -0,0 +1,41 @@
+"""
+Some internal utilities used by tat.
+"""
+
+import torch
+
+# pylint: disable=missing-function-docstring
+# pylint: disable=no-else-return
+
+
+def unsqueeze(tensor: torch.Tensor, index: int, rank: int) -> torch.Tensor:
+ return tensor.reshape([-1 if i == index else 1 for i in range(rank)])
+
+
+def neg_symmetry(tensor: torch.Tensor) -> torch.Tensor:
+ if tensor.dtype is torch.bool:
+ return tensor
+ else:
+ return -tensor
+
+
+def add_symmetry(tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor:
+ if tensor_1.dtype is torch.bool:
+ return tensor_1 ^ tensor_2
+ else:
+ return tensor_1 + tensor_2
+
+
+def zero_symmetry(tensor: torch.Tensor) -> torch.Tensor:
+ # pylint: disable=singleton-comparison
+ if tensor.dtype is torch.bool:
+ return tensor == False
+ else:
+ return tensor == 0
+
+
+def parity(tensor: torch.Tensor) -> torch.Tensor:
+ if tensor.dtype is torch.bool:
+ return tensor
+ else:
+ return tensor % 2 != 0
diff --git a/tat/compat.py b/tat/compat.py
new file mode 100644
index 000000000..f4d0c739b
--- /dev/null
+++ b/tat/compat.py
@@ -0,0 +1,430 @@
+"""
+This file implements a compat layer for legacy TAT interface.
+"""
+
+from __future__ import annotations
+import typing
+from multimethod import multimethod
+import torch
+from .edge import Edge as E
+from .tensor import Tensor as T
+
+# pylint: disable=too-few-public-methods
+# pylint: disable=too-many-instance-attributes
+# pylint: disable=redefined-outer-name
+
+
+class Symmetry(tuple):
+ """
+ The compat symmetry constructor, without detailed type check.
+ """
+
+ def __new__(cls: type[Symmetry], *sym: typing.Any) -> Symmetry:
+ if len(sym) == 1 and isinstance(sym[0], tuple):
+ sym = sym[0]
+ return tuple.__new__(Symmetry, sym)
+
+ def __neg__(self: Symmetry) -> Symmetry:
+ return Symmetry(tuple(sub_sym if isinstance(sub_sym, bool) else -sub_sym for sub_sym in self))
+
+
+class CompatSymmetry:
+ """
+ The common Symmetry namespace.
+ """
+
+ def __init__(self: CompatSymmetry, fermion: tuple[bool, ...], dtypes: tuple[torch.dtype, ...]) -> None:
+ # This create fake module like TAT.No, TAT.Z2 or similar things, it need to specify the symmetry attributes.
+ # symmetry is set by two attributes: fermion and dtypes.
+ self.fermion: tuple[bool, ...] = fermion
+ self.dtypes: tuple[torch.dtype, ...] = dtypes
+
+ # pylint: disable=invalid-name
+ self.S: CompatScalar
+ self.D: CompatScalar
+ self.C: CompatScalar
+ self.Z: CompatScalar
+ self.float32: CompatScalar
+ self.float64: CompatScalar
+ self.float: CompatScalar
+ self.complex64: CompatScalar
+ self.complex128: CompatScalar
+ self.complex: CompatScalar
+
+ # In old TAT, something like TAT.No.D is a sub module for tensor with specific scalar type.
+ # In this compat library, it is implemented by another fake module: CompatScalar.
+ self.S = self.float32 = CompatScalar(self, torch.float32)
+ self.D = self.float64 = self.float = CompatScalar(self, torch.float64)
+ self.C = self.complex64 = CompatScalar(self, torch.complex64)
+ self.Z = self.complex128 = self.complex = CompatScalar(self, torch.complex128)
+
+ self.Edge: CompatEdge = CompatEdge(self)
+ self.Symmetry: type[Symmetry] = Symmetry
+
+
+class CompatEdge:
+ """
+ The compat edge constructor.
+ """
+
+ def __init__(self: CompatEdge, owner: CompatSymmetry) -> None:
+ self.fermion: tuple[bool, ...] = owner.fermion
+ self.dtypes: tuple[torch.dtype, ...] = owner.dtypes
+
+ def _parse_segments(self: CompatEdge, segments: list) -> tuple[tuple[torch.Tensor, ...], int]:
+ # In TAT, user could use [Sym] or [(Sym, Size)] to set segments of a edge, where [(Sym, Size)] is nothing but
+ # the symmetry and size of every blocks. While [Sym] acts like [(Sym, 1)], so try to treat input as
+ # [(Sym, Size)] First, if error raised, convert it from [Sym] to [(Sym, 1)] and try again.
+ try:
+ # try [(Sym, Size)] first
+ return self._parse_segments_kernel(segments)
+ except TypeError:
+ # Cannot unpack is a type error, value[index] is a type error, too. So only catch TypeError here.
+ # convert [Sym] to [(Sym, Size)]
+ return self._parse_segments_kernel([(sym, 1) for sym in segments])
+ # This function return the symmetry tuple and dimension
+
+ def _parse_segments_kernel(
+ self: CompatEdge,
+ segments: list[tuple[typing.Any, int]],
+ ) -> tuple[tuple[torch.Tensor, ...], int]:
+ # [(Sym, Size)] for every element
+ dimension = sum(dim for _, dim in segments)
+ symmetry = tuple(
+ torch.tensor(
+ # tat.Edge need torch.Tensor as its symmetry, convert it to torch.Tensor with specific dtype.
+ sum(
+ # Concat all segment for this sub symmetry from an empty list
+ # Every segment is just sym[index] * dim, sometimes sym may be sub symmetry itself directly instead
+ # of tuple of sub symmetry, so call an utility function _parse_segments_get_subsymmetry here.
+ ([self._parse_segments_get_subsymmetry(sym, index)] * dim
+ for sym, dim in segments),
+ [],
+ ),
+ dtype=sub_symmetry,
+ )
+ # Generate sub symmetry one by one
+ for index, sub_symmetry in enumerate(self.dtypes))
+ return symmetry, dimension
+
+ def _parse_segments_get_subsymmetry(self: CompatEdge, sym: object, index: int) -> object:
+ # Most of time, symmetry is a tuple of sub symmetry
+ # But if there is only one sub symmetry in the symmetry, it could not be a tuple but sub symmetry itself.
+ # pylint: disable=no-else-return
+ if isinstance(sym, tuple):
+ # If it is tuple, there is no need to do any other check
+ return sym[index]
+ else:
+ # If it is not tuple, it should be sub symmetry directly, so this symmetry only should own single sub
+ # symmetry, check it.
+ if len(self.fermion) == 1:
+ return sym
+ else:
+ raise TypeError(f"{sym=} is not subscript-able")
+
+ @multimethod
+ def __call__(self: CompatEdge, edge: E) -> E:
+ """
+ Create edge with compat interface.
+
+ It may be created by
+ 1. Edge(dimension) create trivial symmetry with specified dimension.
+ 2. Edge(segments, arrow) create with given segments and arrow.
+ 3. Edge(segments_arrow_tuple) create with a tuple of segments and arrow.
+ """
+ # pylint: disable=invalid-name
+ return edge
+
+ @__call__.register
+ def _(self: CompatEdge, dimension: int) -> E:
+ # Generate a trivial symmetry tuple. In this tuple, every sub symmetry is a torch.zeros tensor with specific
+ # dtype and the same dimension.
+ symmetry = tuple(torch.zeros(dimension, dtype=sub_symmetry) for sub_symmetry in self.dtypes)
+ return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=False)
+
+ @__call__.register
+ def _(self: CompatEdge, segments: list, arrow: bool = False) -> E:
+ symmetry, dimension = self._parse_segments(segments)
+ return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=arrow)
+
+ @__call__.register
+ def _(self: CompatEdge, segments_and_bool: tuple[list, bool]) -> E:
+ segments, arrow = segments_and_bool
+ symmetry, dimension = self._parse_segments(segments)
+ return E(fermion=self.fermion, dtypes=self.dtypes, symmetry=symmetry, dimension=dimension, arrow=arrow)
+
+
+class CompatScalar:
+ """
+ The common Scalar namespace.
+ """
+
+ def __init__(self: CompatScalar, symmetry: CompatSymmetry, dtype: torch.dtype) -> None:
+ # This is fake module like TAT.No.D, TAT.Fermi.complex, so it records the parent symmetry information and its
+ # own dtype.
+ self.symmetry: CompatSymmetry = symmetry
+ self.dtype: torch.dtype = dtype
+ # pylint: disable=invalid-name
+ self.Tensor: CompatTensor = CompatTensor(self)
+
+
+class CompatTensor:
+ """
+ The compat tensor constructor.
+ """
+
+ def __init__(self: CompatTensor, owner: CompatScalar) -> None:
+ self.symmetry: CompatSymmetry = owner.symmetry
+ self.dtype: torch.dtype = owner.dtype
+ self.model: CompatSymmetry = owner.symmetry
+ self.is_complex: bool = self.dtype.is_complex
+ self.is_real: bool = self.dtype.is_floating_point
+
+ @multimethod
+ def __call__(self: CompatTensor, tensor: T) -> T:
+ """
+ Create tensor with compat names and edges.
+
+ It may be create by
+ 1. Tensor(names, edges) The most used interface.
+ 2. Tensor() Create a rank-0 tensor, fill with number 1.
+ 3. Tensor(number, names=[], edge_symmetry=[], edge_arrow=[]) Create a size-1 tensor, with specified edge, and
+ filled with specified number.
+ """
+ # pylint: disable=invalid-name
+ return tensor
+
+ @__call__.register
+ def _(self: CompatTensor, names: list[str], edges: list) -> T:
+ return T(
+ tuple(names),
+ tuple(self.symmetry.Edge(edge) for edge in edges),
+ fermion=self.symmetry.fermion,
+ dtypes=self.symmetry.dtypes,
+ dtype=self.dtype,
+ )
+
+ @__call__.register
+ def _(self: CompatTensor) -> T:
+ result = T(
+ (),
+ (),
+ fermion=self.symmetry.fermion,
+ dtypes=self.symmetry.dtypes,
+ data=torch.ones([], dtype=self.dtype),
+ )
+ return result
+
+ @__call__.register
+ def _(
+ self: CompatTensor,
+ number: typing.Any,
+ names: list[str] | None = None,
+ edge_symmetry: list | None = None,
+ edge_arrow: list[bool] | None = None,
+ ) -> T:
+ # Create high rank tensor with only one element
+ if names is None:
+ names = []
+ if edge_symmetry is None:
+ edge_symmetry = [None for _ in names]
+ if edge_arrow is None:
+ edge_arrow = [False for _ in names]
+ result = T(
+ tuple(names),
+ tuple(
+ # Create edge for every rank, given the only symmetry(maybe None) and arrow.
+ E(
+ fermion=self.symmetry.fermion,
+ dtypes=self.symmetry.dtypes,
+ # For every edge, its symmetry is a tuple of all sub symmetry.
+ symmetry=tuple(
+ # For every sub symmetry, get the only symmetry for it, since dimension of all edge is 1.
+ # It should be noticed that the symmetry may be None, tuple or sub symmetry directly.
+ torch.tensor([self._create_size1_get_subsymmetry(symmetry, index)], dtype=dtype)
+ for index, dtype in enumerate(self.symmetry.dtypes)),
+ dimension=1,
+ arrow=arrow,
+ )
+ for symmetry, arrow in zip(edge_symmetry, edge_arrow)),
+ fermion=self.symmetry.fermion,
+ dtypes=self.symmetry.dtypes,
+ data=torch.full([1 for _ in names], number, dtype=self.dtype),
+ )
+ return result
+
+ def _create_size1_get_subsymmetry(self: CompatTensor, sym: object, index: int) -> object:
+ # pylint: disable=no-else-return
+ # sym may be None, tuple or sub symmetry directly.
+ if sym is None:
+ # If is None, user may want to create symmetric edge with trivial symmetry, which should be 0 for int and
+ # False for bool, always return 0 here, since it will be converted to correct type by torch.tensor.
+ return 0
+ elif isinstance(sym, tuple):
+ # If it is tuple, there is no need to do any other check
+ return sym[index]
+ else:
+ # If it is not tuple, it should be sub symmetry directly, so this symmetry only should own single sub
+ # symmetry, check it.
+ if len(self.symmetry.fermion) == 1:
+ return sym
+ else:
+ raise TypeError(f"{sym=} is not subscript-able")
+
+
+# Create fake sub module for all symmetry compiled in old version TAT
+No: CompatSymmetry = CompatSymmetry(fermion=(), dtypes=())
+Z2: CompatSymmetry = CompatSymmetry(fermion=(False,), dtypes=(torch.bool,))
+U1: CompatSymmetry = CompatSymmetry(fermion=(False,), dtypes=(torch.int,))
+Fermi: CompatSymmetry = CompatSymmetry(fermion=(True,), dtypes=(torch.int,))
+FermiZ2: CompatSymmetry = CompatSymmetry(fermion=(True, False), dtypes=(torch.int, torch.bool))
+FermiU1: CompatSymmetry = CompatSymmetry(fermion=(True, False), dtypes=(torch.int, torch.int))
+Parity: CompatSymmetry = CompatSymmetry(fermion=(True,), dtypes=(torch.bool,))
+FermiFermi: CompatSymmetry = CompatSymmetry(fermion=(True, True), dtypes=(torch.int, torch.int))
+Normal: CompatSymmetry = No
+
+# SJ Dong's convention
+
+
+def arrow(int_arrow: int) -> bool:
+ "SJ Dong's convention of arrow"
+ # pylint: disable=no-else-return
+ if int_arrow == +1:
+ return False
+ elif int_arrow == -1:
+ return True
+ else:
+ raise ValueError("int arrow should be +1 or -1.")
+
+
+def parity(int_parity: int) -> bool:
+ "SJ Dong's convention of parity"
+ # pylint: disable=no-else-return
+ if int_parity == +1:
+ return False
+ elif int_parity == -1:
+ return True
+ else:
+ raise ValueError("int parity should be +1 or -1.")
+
+
+# Segment index
+
+
+@T._prepare_position.register # pylint: disable=protected-access,no-member
+def _(self: T, position: dict[str, tuple[typing.Any, int]]) -> tuple[int, ...]:
+ return tuple(index_by_point(edge, position[name]) for name, edge in zip(self.names, self.edges))
+
+
+# Add some compat interface
+
+
+def _compat_function(focus_type: type, name: str | None = None) -> typing.Callable[[typing.Callable], typing.Callable]:
+
+ def _result(function: typing.Callable) -> typing.Callable:
+ if name is None:
+ attr_name = function.__name__
+ else:
+ attr_name = name
+ setattr(focus_type, attr_name, function)
+ return function
+
+ return _result
+
+
+@property # type: ignore[misc]
+def storage(self: T) -> typing.Any:
+ "Get the storage of the tensor"
+ assert self.data.is_contiguous()
+ return self.data.reshape([-1])
+
+
+@_compat_function(T, name="storage") # type: ignore[misc]
+@storage.setter
+def storage(self: T, value: typing.Any) -> None:
+ "Set the storage of the tensor"
+ assert self.data.is_contiguous()
+ self.data.reshape([-1])[:] = torch.as_tensor(value)
+
+
+@_compat_function(T)
+def range_(self: T, first: float = 0, step: float = 1) -> T:
+ "Compat Interface: Set range inplace for this tensor."
+ result = self.range(first, step)
+ self._data = result._data # pylint: disable=protected-access
+ return self
+
+
+@_compat_function(T)
+def identity_(self: T, pairs: set[tuple[str, str]]) -> T:
+ "Compat Interface: Set idenity inplace for this tensor."
+ result = self.identity(pairs).transpose(self.names)
+ self._data = result._data # pylint: disable=protected-access
+ return self
+
+
+# Exponential arguments
+
+origin_exponential = T.exponential
+
+
+@_compat_function(T)
+def exponential(self: T, pairs: set[tuple[str, str]], step: int | None = None) -> T:
+ "Compat Interface: Get the exponential tensor of this tensor."
+ # pylint: disable=unused-argument
+ return origin_exponential(self, pairs)
+
+
+# Edge point conversion
+
+
+@_compat_function(E)
+def index_by_point(self: E, point: tuple[typing.Any, int]) -> int:
+ "Get index by point on an edge"
+ sym, sub_index = point
+ if not isinstance(sym, tuple):
+ sym = (sym,)
+ for total_index in range(self.dimension):
+ if all(sub_sym == sub_symmetry[total_index] for sub_sym, sub_symmetry in zip(sym, self.symmetry)):
+ if sub_index == 0:
+ return total_index
+ sub_index = sub_index - 1
+ raise ValueError("Invalid input point")
+
+
+@_compat_function(E)
+def point_by_index(self: E, index: int) -> tuple[typing.Any, int]:
+ "Get point by index on an edge"
+ sym = Symmetry(tuple(sub_symmetry[index] for sub_symmetry in self.symmetry))
+ sub_index = sum(
+ 1 for i in range(index) if all(sub_sym == sub_symmetry[i] for sub_sym, sub_symmetry in zip(sym, self.symmetry)))
+ return sym, sub_index
+
+
+# Random utility
+
+
+class CompatRandom:
+ """
+ Fake module for compat random utility in TAT.
+ """
+
+ def uniform_int(self: CompatRandom, low: int, high: int) -> typing.Callable[[], int]:
+ "Generator for integer uniform distribution"
+ # Mypy cannot recognize item of int64 tensor is int, so cast it manually.
+ return staticmethod(lambda: int(torch.randint(low, high + 1, [], dtype=torch.int64).item()))
+
+ def uniform_real(self: CompatRandom, low: float, high: float) -> typing.Callable[[], float]:
+ "Generator for float uniform distribution"
+ return staticmethod(lambda: torch.rand([], dtype=torch.float64).item() * (high - low) + low)
+
+ def normal(self: CompatRandom, mean: float, stddev: float) -> typing.Callable[[], float]:
+ "Generator for float normal distribution"
+ return staticmethod(lambda: torch.normal(mean, stddev, [], dtype=torch.float64).item())
+
+ def seed(self: CompatRandom, new_seed: int) -> None:
+ "Set the seed for random generator manually"
+ torch.manual_seed(new_seed)
+
+
+random = CompatRandom()
diff --git a/tat/edge.py b/tat/edge.py
new file mode 100644
index 000000000..de3aafba6
--- /dev/null
+++ b/tat/edge.py
@@ -0,0 +1,290 @@
+"""
+This file contains the definition of tensor edge.
+"""
+
+from __future__ import annotations
+import functools
+import operator
+import torch
+from . import _utility
+
+
+class Edge:
+ """
+ The edge type of tensor.
+ """
+
+ __slots__ = "_fermion", "_dtypes", "_symmetry", "_dimension", "_arrow", "_parity"
+
+ @property
+ def fermion(self: Edge) -> tuple[bool, ...]:
+ """
+ A tuple records whether every sub symmetry is fermionic. Its length is the number of sub symmetry.
+ """
+ return self._fermion
+
+ @property
+ def dtypes(self: Edge) -> tuple[torch.dtype, ...]:
+ """
+ A tuple records the basic dtype of every sub symmetry. Its length is the number of sub symmetry.
+ """
+ return self._dtypes
+
+ @property
+ def symmetry(self: Edge) -> tuple[torch.Tensor, ...]:
+ """
+ A tuple containing all symmetry of this edge. Its length is the number of sub symmetry. Every element of it is a
+ sub symmetry.
+ """
+ return self._symmetry
+
+ @property
+ def dimension(self: Edge) -> int:
+ """
+ The dimension of this edge.
+ """
+ return self._dimension
+
+ @property
+ def arrow(self: Edge) -> bool:
+ """
+ The arrow of this edge.
+ """
+ return self._arrow
+
+ @property
+ def parity(self: Edge) -> torch.Tensor:
+ """
+ The parity of this edge.
+ """
+ return self._parity
+
+ def __init__(
+ self: Edge,
+ *,
+ fermion: tuple[bool, ...] | None = None,
+ dtypes: tuple[torch.dtype, ...] | None = None,
+ symmetry: tuple[torch.Tensor, ...] | None = None,
+ dimension: int | None = None,
+ arrow: bool | None = None,
+ **kwargs: torch.Tensor,
+ ) -> None:
+ """
+ Create an edge with essential information.
+
+ Examples:
+ - Edge(dimension=5)
+ - Edge(symmetry=(torch.tensor([False, False, True, True]),))
+ - Edge(fermion=(False, True), symmetry=(torch.tensor([False, True]), torch.tensor([False, True])), arrow=True)
+
+ Parameters
+ ----------
+ fermion : tuple[bool, ...], optional
+ Whether each sub symmetry is fermionic symmetry, its length should be the same to symmetry. But it could be
+ left empty, if so, a total bosonic edge will be created.
+ dtypes : tuple[torch.dtype, ...], optional
+ The basic dtype to identify each sub symmetry, its length should be the same to symmetry, and it is nothing
+ but the dtypes of each tensor in the symmetry. It could be left empty, if so, it will be derived from
+ symmetry.
+ symmetry : tuple[torch.Tensor, ...], optional
+ The symmetry information of every sub symmetry, each of sub symmetry should be a one dimensional tensor with
+ the same length dimension, and their dtype should be integral type, aka, int or bool.
+ dimension : int, optional
+ The dimension of the edge, if not specified, dimension will be detected from symmetry.
+ arrow : bool, optional
+ The arrow direction of the edge, it is essential for fermionic edge, aka, an edge with fermionic sub
+ symmetry.
+ """
+ # Symmetry could be left empty to create no symmetry edge
+ if symmetry is None:
+ symmetry = ()
+
+ # Fermion could be empty if it is total bosonic edge
+ if fermion is None:
+ fermion = tuple(False for _ in symmetry)
+
+ # Dtypes could be empty and derived from symmetry
+ if dtypes is None:
+ dtypes = tuple(sub_symmetry.dtype for sub_symmetry in symmetry)
+ # Check dtype is compatible with symmetry
+ assert all(sub_symmetry.dtype is sub_dtype for sub_symmetry, sub_dtype in zip(symmetry, dtypes))
+ # Check dtype is valid, aka, bool or int
+ assert all(not (sub_symmetry.is_floating_point() or sub_symmetry.is_complex()) for sub_symmetry in symmetry)
+
+ # The fermion, dtypes and symmetry information should have the same length
+ assert len(fermion) == len(dtypes) == len(symmetry)
+
+ # If dimension not set, get dimension from symmetry
+ if dimension is None:
+ dimension = len(symmetry[0])
+ # Check if the dimensions of different sub_symmetry mismatch
+ assert all(sub_symmetry.size() == (dimension,) for sub_symmetry in symmetry)
+
+ if arrow is None:
+ # Arrow not set, it should be bosonic edge.
+ arrow = False
+ assert not any(fermion)
+
+ self._fermion: tuple[bool, ...] = fermion
+ self._dtypes: tuple[torch.dtype, ...] = dtypes
+ self._symmetry: tuple[torch.Tensor, ...] = symmetry
+ self._dimension: int = dimension
+ self._arrow: bool = arrow
+
+ self._parity: torch.Tensor
+ if "parity" in kwargs:
+ self._parity = kwargs.pop("parity")
+ assert self.parity.size() == (self.dimension,)
+ assert self.parity.dtype is torch.bool
+ else:
+ self._parity = self._generate_parity()
+ assert not kwargs
+
+ def _generate_parity(self: Edge) -> torch.Tensor:
+ return functools.reduce(
+ # Reduce sub parity for all sub symmetry by logical xor
+ torch.logical_xor,
+ (
+ # The parity of sub symmetry
+ _utility.parity(sub_symmetry)
+ # Loop all sub symmetry
+ for sub_symmetry, sub_fermion in zip(self.symmetry, self.fermion)
+ # But only reduce if it is fermion sub symmetry
+ if sub_fermion),
+ # Reduce with start as tensor filled with False
+ torch.zeros(self.dimension, dtype=torch.bool),
+ )
+
+ def conjugate(self: Edge) -> Edge:
+ """
+ Get the conjugated edge.
+
+ Returns
+ -------
+ Edge
+ The conjugated edge.
+ """
+ # The only two difference of conjugated edge is symmetry and arrow
+ return Edge(
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ symmetry=tuple(
+ _utility.neg_symmetry(sub_symmetry) # bool -> same, int -> neg
+ for sub_symmetry in self.symmetry),
+ dimension=self.dimension,
+ arrow=not self.arrow,
+ parity=self.parity,
+ )
+
+ def __eq__(self: Edge, other: object) -> bool:
+ if not isinstance(other, Edge):
+ return NotImplemented
+ return (
+ # Compare int dimension and bool arrow first since they are fast to compare
+ self.dimension == other.dimension and
+ # But even if arrow are different, if it is bosonic edge, it is also OK
+ (self.arrow == other.arrow or not any(self.fermion)) and
+ # Then the tuple of bool are compared
+ self.fermion == other.fermion and
+ # Then the tuple of dtypes are compared
+ self.dtypes == other.dtypes and
+ # All of symmetries are compared at last, since it is biggest
+ all(
+ torch.equal(self_sub_symmetry, other_sub_symmetry)
+ for self_sub_symmetry, other_sub_symmetry in zip(self.symmetry, other.symmetry)))
+
+ def __str__(self: Edge) -> str:
+ # pylint: disable=no-else-return
+ if any(self.fermion):
+ # Fermionic edge
+ fermion = ','.join(str(sub_fermion) for sub_fermion in self.fermion)
+ symmetry = ','.join(
+ f"[{','.join(str(sub_sym.item()) for sub_sym in sub_symmetry)}]" for sub_symmetry in self.symmetry)
+ return f"(dimension={self.dimension}, arrow={self.arrow}, fermion=({fermion}), symmetry=({symmetry}))"
+ elif self.fermion:
+ # Bosonic edge
+ symmetry = ','.join(
+ f"[{','.join(str(sub_sym.item()) for sub_sym in sub_symmetry)}]" for sub_symmetry in self.symmetry)
+ return f"(dimension={self.dimension}, symmetry=({symmetry}))"
+ else:
+ # Trivial edge
+ return f"(dimension={self.dimension})"
+
+ def __repr__(self: Edge) -> str:
+ return f"Edge{self.__str__()}"
+
+ @staticmethod
+ def merge_edges(
+ edges: tuple[Edge, ...],
+ *,
+ fermion: tuple[bool, ...] | None = None,
+ dtypes: tuple[torch.dtype, ...] | None = None,
+ arrow: bool | None = None,
+ ) -> Edge:
+ """
+ Merge several edges into one edge.
+
+ Parameters
+ ----------
+ edges : tuple[Edge, ...]
+ The edges to be merged.
+ fermion : tuple[bool, ...], optional
+ Whether each sub symmetry is fermionic, it could be left empty to derive from edges
+ dtypes : tuple[torch.dtype, ...], optional
+ The base type of sub symmetry, it could be left empty to derive from edges
+ arrow : bool, optional
+ The arrow of all the edges, it is useful if edges is empty.
+
+ Returns
+ -------
+ Edge
+ The result edge merged by edges.
+ """
+ # If fermion not set, get it from edges
+ if fermion is None:
+ fermion = edges[0].fermion
+ # All edge should share the same fermion
+ assert all(fermion == edge.fermion for edge in edges)
+ # If dtypes not set, get it from edges
+ if dtypes is None:
+ dtypes = edges[0].dtypes
+ # All edge should share the same dtypes
+ assert all(dtypes == edge.dtypes for edge in edges)
+ # If arrow set, check it directly, if not set, set to False or get from edges
+ if arrow is None:
+ if any(fermion):
+ # It is fermionic edge.
+ arrow = edges[0].arrow
+ else:
+ # It is bosonic edge, set to False directly since it is useless.
+ arrow = False
+ # All edge should share the same arrow for fermionic edge
+ assert (not any(fermion)) or all(arrow == edge.arrow for edge in edges)
+
+ rank = len(edges)
+ # Merge edge
+ dimension = functools.reduce(operator.mul, (edge.dimension for edge in edges), 1)
+ symmetry = tuple(
+ # Every merged sub symmetry is calculated by reduce and flatten
+ functools.reduce(
+ # The reduce operator depend on the dtype of this sub symmetry
+ _utility.add_symmetry,
+ (
+ # The sub symmetry of every edge will be reshape to be reduced.
+ _utility.unsqueeze(edge.symmetry[sub_symmetry_index], current_index, rank)
+ # The sub symmetry of every edge is reduced one by one
+ for current_index, edge in enumerate(edges)),
+ # Reduce from a rank-0 tensor
+ torch.zeros([], dtype=sub_symmetry_dtype),
+ ).reshape([-1])
+ # Merge every sub symmetry one by one
+ for sub_symmetry_index, sub_symmetry_dtype in enumerate(dtypes))
+
+ # parity not set here since it need recalculation
+ return Edge(
+ fermion=fermion,
+ dtypes=dtypes,
+ symmetry=symmetry,
+ dimension=dimension,
+ arrow=arrow,
+ )
diff --git a/tat/tensor.py b/tat/tensor.py
new file mode 100644
index 000000000..8aee3a462
--- /dev/null
+++ b/tat/tensor.py
@@ -0,0 +1,1940 @@
+"""
+This file defined the core tensor type for tat package.
+"""
+
+from __future__ import annotations
+import typing
+import operator
+import functools
+from multimethod import multimethod
+import torch
+from . import _utility
+from ._qr import givens_qr, householder_qr # pylint: disable=unused-import
+from ._svd import svd as manual_svd # pylint: disable=unused-import
+from .edge import Edge
+
+# pylint: disable=too-many-public-methods
+# pylint: disable=too-many-lines
+
+
+class Tensor:
+ """
+ The main tensor type, which wraps pytorch tensor and provides edge names and Fermionic functions.
+ """
+
+ __slots__ = "_fermion", "_dtypes", "_names", "_edges", "_data", "_mask"
+
+ def __str__(self: Tensor) -> str:
+ return f"(names={self.names}, edges={self.edges}, data={self.data})"
+
+ def __repr__(self: Tensor) -> str:
+ return f"Tensor(names={self.names}, edges={self.edges})"
+
+ @property
+ def fermion(self: Tensor) -> tuple[bool, ...]:
+ """
+ A tuple records whether every sub symmetry is fermionic. Its length is the number of sub symmetry.
+ """
+ return self._fermion
+
+ @property
+ def dtypes(self: Tensor) -> tuple[torch.dtype, ...]:
+ """
+ A tuple records the basic dtype of every sub symmetry. Its length is the number of sub symmetry.
+ """
+ return self._dtypes
+
+ @property
+ def names(self: Tensor) -> tuple[str, ...]:
+ """
+ The edge names of this tensor.
+ """
+ return self._names
+
+ @property
+ def edges(self: Tensor) -> tuple[Edge, ...]:
+ """
+ The edges information of this tensor.
+ """
+ return self._edges
+
+ @property
+ def data(self: Tensor) -> torch.Tensor:
+ """
+ The content data of this tensor.
+ """
+ return self._data
+
+ @property
+ def mask(self: Tensor) -> torch.Tensor:
+ """
+ The content data mask of this tensor.
+ """
+ return self._mask
+
+ @property
+ def rank(self: Tensor) -> int:
+ """
+ The rank of this tensor.
+ """
+ return len(self._names)
+
+ @property
+ def dtype(self: Tensor) -> torch.dtype:
+ """
+ The data type of the content in this tensor.
+ """
+ return self.data.dtype
+
+ @property
+ def btype(self: Tensor) -> str:
+ """
+ The data type of the content in this tensor, represented in BLAS/LAPACK convention.
+ """
+ if self.dtype is torch.float32:
+ return 'S'
+ if self.dtype is torch.float64:
+ return 'D'
+ if self.dtype is torch.complex64:
+ return 'C'
+ if self.dtype is torch.complex128:
+ return 'Z'
+ return '?'
+
+ @property
+ def is_complex(self: Tensor) -> bool:
+ """
+ Whether it is a complex tensor
+ """
+ return self.dtype.is_complex
+
+ @property
+ def is_real(self: Tensor) -> bool:
+ """
+ Whether it is a real tensor
+ """
+ return self.dtype.is_floating_point
+
+ def edge_by_name(self: Tensor, name: str) -> Edge:
+ """
+ Get edge by the edge name of this tensor.
+
+ Parameters
+ ----------
+ name : str
+ The given edge name.
+
+ Returns
+ -------
+ Edge
+ The edge with the given edge name.
+ """
+ assert name in self.names
+ return self.edges[self.names.index(name)]
+
+ def _arithmetic_operator(self: Tensor, other: object, operate: typing.Callable) -> Tensor:
+ new_data: torch.Tensor
+ if isinstance(other, Tensor):
+ # If it is tensor, check same shape and transpose before calculating.
+ assert self.same_shape_with(other)
+ new_data = operate(self.data, other.transpose(self.names).data)
+ if operate is torch.div:
+ # In div, it may generate nan
+ new_data = torch.where(self.mask, new_data, torch.zeros([], dtype=self.dtype))
+ else:
+ # Otherwise treat other as a scalar, mask should be applied later.
+ new_data = operate(self.data, other)
+ new_data = torch.where(self.mask, new_data, torch.zeros([], dtype=self.dtype))
+ return Tensor(
+ names=self.names,
+ edges=self.edges,
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ data=new_data,
+ mask=self.mask,
+ )
+
+ def __add__(self: Tensor, other: object) -> Tensor:
+ return self._arithmetic_operator(other, torch.add)
+
+ def __sub__(self: Tensor, other: object) -> Tensor:
+ return self._arithmetic_operator(other, torch.sub)
+
+ def __mul__(self: Tensor, other: object) -> Tensor:
+ return self._arithmetic_operator(other, torch.mul)
+
+ def __truediv__(self: Tensor, other: object) -> Tensor:
+ return self._arithmetic_operator(other, torch.div)
+
+ def _right_arithmetic_operator(self: Tensor, other: object, operate: typing.Callable) -> Tensor:
+ new_data: torch.Tensor
+ if isinstance(other, Tensor):
+ # If it is tensor, check same shape and transpose before calculating.
+ assert self.same_shape_with(other)
+ new_data = operate(other.transpose(self.names).data, self.data)
+ if operate is torch.div:
+ # In div, it may generate nan
+ new_data = torch.where(self.mask, new_data, torch.zeros([], dtype=self.dtype))
+ else:
+ # Otherwise treat other as a scalar, mask should be applied later.
+ new_data = operate(other, self.data)
+ new_data = torch.where(self.mask, new_data, torch.zeros([], dtype=self.dtype))
+ return Tensor(
+ names=self.names,
+ edges=self.edges,
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ data=new_data,
+ mask=self.mask,
+ )
+
+ def __radd__(self: Tensor, other: object) -> Tensor:
+ return self._right_arithmetic_operator(other, torch.add)
+
+ def __rsub__(self: Tensor, other: object) -> Tensor:
+ return self._right_arithmetic_operator(other, torch.sub)
+
+ def __rmul__(self: Tensor, other: object) -> Tensor:
+ return self._right_arithmetic_operator(other, torch.mul)
+
+ def __rtruediv__(self: Tensor, other: object) -> Tensor:
+ return self._right_arithmetic_operator(other, torch.div)
+
+ def _inplace_arithmetic_operator(self: Tensor, other: object, operate: typing.Callable) -> Tensor:
+ if isinstance(other, Tensor):
+ # If it is tensor, check same shape and transpose before calculating.
+ assert self.same_shape_with(other)
+ operate(self.data, other.transpose(self.names).data, out=self.data)
+ if operate is torch.div:
+ # In div, it may generate nan
+ torch.where(self.mask, self.data, torch.zeros([], dtype=self.dtype), out=self.data)
+ else:
+ # Otherwise treat other as a scalar, mask should be applied later.
+ operate(self.data, other, out=self.data)
+ torch.where(self.mask, self.data, torch.zeros([], dtype=self.dtype), out=self.data)
+ return self
+
+ def __iadd__(self: Tensor, other: object) -> Tensor:
+ return self._inplace_arithmetic_operator(other, torch.add)
+
+ def __isub__(self: Tensor, other: object) -> Tensor:
+ return self._inplace_arithmetic_operator(other, torch.sub)
+
+ def __imul__(self: Tensor, other: object) -> Tensor:
+ return self._inplace_arithmetic_operator(other, torch.mul)
+
+ def __itruediv__(self: Tensor, other: object) -> Tensor:
+ return self._inplace_arithmetic_operator(other, torch.div)
+
+ def __pos__(self: Tensor) -> Tensor:
+ return Tensor(
+ names=self.names,
+ edges=self.edges,
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ data=+self.data,
+ mask=self.mask,
+ )
+
+ def __neg__(self: Tensor) -> Tensor:
+ return Tensor(
+ names=self.names,
+ edges=self.edges,
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ data=-self.data,
+ mask=self.mask,
+ )
+
+ def __float__(self: Tensor) -> float:
+ return float(self.data)
+
+ def __complex__(self: Tensor) -> complex:
+ return complex(self.data)
+
+ def norm(self: Tensor, order: typing.Any) -> float:
+ """
+ Get the norm of the tensor, this function will flatten tensor first before calculate norm.
+
+ Parameters
+ ----------
+ order
+ The order of norm.
+
+ Returns
+ -------
+ float
+ The norm of the tensor.
+ """
+ return torch.linalg.vector_norm(self.data, ord=order)
+
+ def norm_max(self: Tensor) -> float:
+ "max norm"
+ return self.norm(+torch.inf)
+
+ def norm_min(self: Tensor) -> float:
+ "min norm"
+ return self.norm(-torch.inf)
+
+ def norm_num(self: Tensor) -> float:
+ "0-norm"
+ return self.norm(0)
+
+ def norm_sum(self: Tensor) -> float:
+ "1-norm"
+ return self.norm(1)
+
+ def norm_2(self: Tensor) -> float:
+ "2-norm"
+ return self.norm(2)
+
+ def copy(self: Tensor) -> Tensor:
+ """
+ Get a copy of this tensor
+
+ Returns
+ -------
+ Tensor
+ The copy of this tensor
+ """
+ return Tensor(
+ names=self.names,
+ edges=self.edges,
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ data=torch.clone(self.data, memory_format=torch.contiguous_format),
+ mask=self.mask,
+ )
+
+ def __copy__(self: Tensor) -> Tensor:
+ return self.copy()
+
+ def __deepcopy__(self: Tensor, _: typing.Any = None) -> Tensor:
+ return self.copy()
+
+ def same_shape(self: Tensor) -> Tensor:
+ """
+ Get a tensor with same shape to this tensor
+
+ Returns
+ -------
+ Tensor
+ A new tensor with the same shape to this tensor
+ """
+ return Tensor(
+ names=self.names,
+ edges=self.edges,
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ data=torch.zeros_like(self.data),
+ mask=self.mask,
+ )
+
+ def zero_(self: Tensor) -> Tensor:
+ """
+ Set all element to zero in this tensor
+
+ Returns
+ -------
+ Tensor
+ Return this tensor itself.
+ """
+ self.data.zero_()
+ return self
+
+ def sqrt(self: Tensor) -> Tensor:
+ """
+ Get the sqrt of the tensor.
+
+ Returns
+ -------
+ Tensor
+ The sqrt of this tensor.
+ """
+ return Tensor(
+ names=self.names,
+ edges=self.edges,
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ data=torch.sqrt(torch.abs(self.data)),
+ mask=self.mask,
+ )
+
+ def reciprocal(self: Tensor) -> Tensor:
+ """
+ Get the reciprocal of the tensor.
+
+ Returns
+ -------
+ Tensor
+ The reciprocal of this tensor.
+ """
+ return Tensor(
+ names=self.names,
+ edges=self.edges,
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ data=torch.where(self.data == 0, self.data, 1 / self.data),
+ mask=self.mask,
+ )
+
+ def range(self: Tensor, first: typing.Any = 0, step: typing.Any = 1) -> Tensor:
+ """
+ A useful function to Get tensor filled with simple data for test in the same shape.
+
+ Parameters
+ ----------
+ first, step
+ Parameters to generate data.
+
+ Returns
+ -------
+ Tensor
+ Returns the tensor filled with simple data for test.
+ """
+ data = torch.cumsum(self.mask.reshape([-1]), dim=0, dtype=self.dtype).reshape(self.data.size())
+ data = (data - 1) * step + first
+ data = torch.where(self.mask, data, torch.zeros([], dtype=self.dtype))
+ return Tensor(
+ names=self.names,
+ edges=self.edges,
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ data=data,
+ mask=self.mask,
+ )
+
+ def to(self: Tensor, new_type: typing.Any) -> Tensor:
+ """
+ Convert this tensor to other scalar type.
+
+ Parameters
+ ----------
+ new_type
+ The scalar data type of the new tensor.
+ """
+ # pylint: disable=invalid-name
+ if new_type is int:
+ new_type = torch.int64
+ if new_type is float:
+ new_type = torch.float64
+ if new_type is complex:
+ new_type = torch.complex128
+ if isinstance(new_type, str):
+ if new_type in ["float32", "S"]:
+ new_type = torch.float32
+ elif new_type in ["float64", "float", "D"]:
+ new_type = torch.float64
+ elif new_type in ["complex64", "C"]:
+ new_type = torch.complex64
+ elif new_type in ["complex128", "complex", "Z"]:
+ new_type = torch.complex128
+ if self.dtype.is_complex and new_type.is_floating_point:
+ data = self.data.real.to(new_type)
+ else:
+ data = self.data.to(new_type)
+ return Tensor(
+ names=self.names,
+ edges=self.edges,
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ data=data,
+ mask=self.mask,
+ )
+
+ def __init__(
+ self: Tensor,
+ names: tuple[str, ...],
+ edges: tuple[Edge, ...],
+ *,
+ dtype: torch.dtype | None = None,
+ fermion: tuple[bool, ...] | None = None,
+ dtypes: tuple[torch.dtype, ...] | None = None,
+ **kwargs: torch.Tensor,
+ ) -> None:
+ """
+ Create a tensor with specific shape.
+
+ Parameters
+ ----------
+ names : tuple[str, ...]
+ The edge names of the tensor, which length is just the tensor rank.
+ edges : tuple[Edge, ...]
+ The detail information of each edge, which length is just the tensor rank.
+ dtype : torch.dtype, optional
+ The dtype of the tensor, left it empty to let pytorch choose default dtype.
+ fermion : tuple[bool, ...], optional
+ Whether each sub symmetry is fermionic, it could be left empty to derive from edges
+ dtypes : tuple[torch.dtype, ...], optional
+ The base type of sub symmetry, it could be left empty to derive from edges
+ """
+ # Check the rank is correct in names and edges
+ assert len(names) == len(edges)
+ # Check whether there are duplicated names
+ assert len(set(names)) == len(names)
+ # If fermion not set, get it from edges
+ if fermion is None:
+ fermion = edges[0].fermion
+ # If dtypes not set, get it from edges
+ if dtypes is None:
+ dtypes = edges[0].dtypes
+ # Check if fermion is correct
+ assert all(edge.fermion == fermion for edge in edges)
+ # Check if dtypes is correct
+ assert all(edge.dtypes == dtypes for edge in edges)
+
+ self._fermion: tuple[bool, ...] = fermion
+ self._dtypes: tuple[torch.dtype, ...] = dtypes
+ self._names: tuple[str, ...] = names
+ self._edges: tuple[Edge, ...] = edges
+
+ self._data: torch.Tensor
+ if "data" in kwargs:
+ self._data = kwargs.pop("data")
+ else:
+ if dtype is None:
+ self._data = torch.zeros([edge.dimension for edge in self.edges])
+ else:
+ self._data = torch.zeros([edge.dimension for edge in self.edges], dtype=dtype)
+ assert self.data.size() == tuple(edge.dimension for edge in self.edges)
+ assert dtype is None or self.dtype is dtype
+
+ self._mask: torch.Tensor
+ if "mask" in kwargs:
+ self._mask = kwargs.pop("mask")
+ else:
+ self._mask = self._generate_mask()
+ assert self.mask.size() == tuple(edge.dimension for edge in self.edges)
+ assert self.mask.dtype is torch.bool
+
+ assert not kwargs
+
+ def _generate_mask(self: Tensor) -> torch.Tensor:
+ return functools.reduce(
+ # Mask is valid if all sub symmetry give valid mask.
+ torch.logical_and,
+ (
+ # The mask is valid if total symmetry is False or total symmetry is 0
+ _utility.zero_symmetry(
+ # total sub symmetry is calculated by reduce
+ functools.reduce(
+ # The reduce operator depend on the dtype of this sub symmetry
+ _utility.add_symmetry,
+ (
+ # The sub symmetry of every edge will be reshape to be reduced.
+ _utility.unsqueeze(edge.symmetry[sub_symmetry_index], current_index, self.rank)
+ # The sub symmetry of every edge is reduced one by one
+ for current_index, edge in enumerate(self.edges)),
+ # Reduce from a rank-0 tensor
+ torch.zeros([], dtype=sub_symmetry_dtype),
+ ))
+ # Calculate mask on every sub symmetry one by one
+ for sub_symmetry_index, sub_symmetry_dtype in enumerate(self.dtypes)),
+ # Reduce from all true mask
+ torch.ones(self.data.size(), dtype=torch.bool),
+ )
+
+ @multimethod
+ def _prepare_position(self: Tensor, position: tuple[int, ...]) -> tuple[int, ...]:
+ indices: tuple[int, ...] = position
+ assert len(indices) == self.rank
+ assert all(0 <= index < edge.dimension for edge, index in zip(self.edges, indices))
+ return indices
+
+ @_prepare_position.register
+ def _(self: Tensor, position: tuple[slice, ...]) -> tuple[int, ...]:
+ index_by_name: dict[str, int] = {s.start: s.stop for s in position}
+ indices: tuple[int, ...] = tuple(index_by_name[name] for name in self.names)
+ assert len(indices) == self.rank
+ assert all(0 <= index < edge.dimension for edge, index in zip(self.edges, indices))
+ return indices
+
+ @_prepare_position.register
+ def _(self: Tensor, position: dict[str, int]) -> tuple[int, ...]:
+ indices: tuple[int, ...] = tuple(position[name] for name in self.names)
+ assert len(indices) == self.rank
+ assert all(0 <= index < edge.dimension for edge, index in zip(self.edges, indices))
+ return indices
+
+ def __getitem__(self: Tensor, position: tuple[int, ...] | tuple[slice, ...] | dict[str, int]) -> typing.Any:
+ """
+ Get the element of the tensor
+
+ Parameters
+ ----------
+ position : tuple[int, ...] | tuple[slice, ...] | dict[str, int]
+ The position of the element, which could be either tuple of index directly or a map from edge name to the
+ index in the corresponding edge.
+ """
+ indices: tuple[int, ...] = self._prepare_position(position)
+ return self.data[indices]
+
+ def __setitem__(self: Tensor, position: tuple[int, ...] | tuple[slice, ...] | dict[str, int],
+ value: typing.Any) -> None:
+ """
+ Set the element of the tensor
+
+ Parameters
+ ----------
+ position : tuple[int, ...] | tuple[slice, ...] | dict[str, int]
+ The position of the element, which could be either tuple of index directly or a map from edge name to the
+ index in the corresponding edge.
+ """
+ indices = self._prepare_position(position)
+ if self.mask[indices]:
+ self.data[indices] = value
+ else:
+ raise IndexError("The indices specified are masked, so it is invalid to set item here.")
+
+ def clear_symmetry(self: Tensor) -> Tensor:
+ """
+ Clear all symmetry of this tensor.
+
+ Returns
+ -------
+ Tensor
+ The result tensor with symmetry cleared.
+ """
+ # Mask must be generated again here
+ # pylint: disable=no-else-return
+ if any(self.fermion):
+ return Tensor(
+ names=self.names,
+ edges=tuple(
+ Edge(
+ fermion=(True,),
+ dtypes=(torch.bool,),
+ symmetry=(edge.parity,),
+ dimension=edge.dimension,
+ arrow=edge.arrow,
+ parity=edge.parity,
+ ) for edge in self.edges),
+ fermion=(True,),
+ dtypes=(torch.bool,),
+ data=self.data,
+ )
+ else:
+ return Tensor(
+ names=self.names,
+ edges=tuple(
+ Edge(
+ fermion=(),
+ dtypes=(),
+ symmetry=(),
+ dimension=edge.dimension,
+ arrow=edge.arrow,
+ parity=edge.parity,
+ ) for edge in self.edges),
+ fermion=(),
+ dtypes=(),
+ data=self.data,
+ )
+
+ def randn_(self: Tensor, mean: float = 0., std: float = 1.) -> Tensor:
+ """
+ Fill the tensor with random number in normal distribution.
+
+ Parameters
+ ----------
+ mean, std : float
+ The parameter of normal distribution.
+
+ Returns
+ -------
+ Tensor
+ Return this tensor itself.
+ """
+ self.data.normal_(mean, std)
+ torch.where(self.mask, self.data, torch.zeros([], dtype=self.dtype), out=self.data)
+ return self
+
+ def rand_(self: Tensor, low: float = 0., high: float = 1.) -> Tensor:
+ """
+ Fill the tensor with random number in uniform distribution.
+
+ Parameters
+ ----------
+ low, high : float
+ The parameter of uniform distribution.
+
+ Returns
+ -------
+ Tensor
+ Return this tensor itself.
+ """
+ self.data.uniform_(low, high)
+ torch.where(self.mask, self.data, torch.zeros([], dtype=self.dtype), out=self.data)
+ return self
+
+ def same_type_with(self: Tensor, other: Tensor) -> bool:
+ """
+ Check whether two tensor has the same type, that is to say they share the same symmetry type.
+ """
+ return self.fermion == other.fermion and self.dtypes == other.dtypes
+
+ def same_shape_with(self: Tensor, other: Tensor, *, allow_transpose: bool = True) -> bool:
+ """
+ Check whether two tensor has the same shape, that is to say the only differences between them are transpose and
+ data difference.
+ """
+ if not self.same_type_with(other):
+ return False
+ # pylint: disable=no-else-return
+ if allow_transpose:
+ return dict(zip(self.names, self.edges)) == dict(zip(other.names, other.edges))
+ else:
+ return self.names == other.names and self.edges == other.edges
+
+ def edge_rename(self: Tensor, name_map: dict[str, str]) -> Tensor:
+ """
+ Rename edge name for this tensor.
+
+ Parameters
+ ----------
+ name_map : dict[str, str]
+ The name map to be used in renaming edge name.
+
+ Returns
+ -------
+ Tensor
+ The tensor with names renamed.
+ """
+ return Tensor(
+ names=tuple(name_map.get(name, name) for name in self.names),
+ edges=self.edges,
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ data=self.data,
+ mask=self.mask,
+ )
+
+ def transpose(self: Tensor, names: tuple[str, ...]) -> Tensor:
+ """
+ Transpose the tensor out-place.
+
+ Parameters
+ ----------
+ names : tuple[str, ...]
+ The new edge order identified by edge names.
+
+ Returns
+ -------
+ Tensor
+ The transpose tensor.
+ """
+ if names == self.names:
+ return self
+ assert len(names) == len(self.names)
+ assert set(names) == set(self.names)
+ before_by_after = tuple(self.names.index(name) for name in names)
+ after_by_before = tuple(names.index(name) for name in self.names)
+ data = self.data.permute(before_by_after)
+ mask = self.mask.permute(before_by_after)
+ if any(self.fermion):
+ # It is fermionic tensor
+ parities_before_transpose = tuple(
+ _utility.unsqueeze(edge.parity, current_index, self.rank)
+ for current_index, edge in enumerate(self.edges))
+ # Generate parity by xor all inverse pairs
+ parity = functools.reduce(
+ torch.logical_xor,
+ (
+ torch.logical_and(parities_before_transpose[i], parities_before_transpose[j])
+ # Loop every 0 <= i < j < rank
+ for j in range(self.rank)
+ for i in range(0, j)
+ if after_by_before[i] > after_by_before[j]),
+ torch.zeros([], dtype=torch.bool))
+ # parity True -> -x
+ # parity False -> +x
+ data = torch.where(parity.permute(before_by_after), -data, +data)
+ return Tensor(
+ names=names,
+ edges=tuple(self.edges[index] for index in before_by_after),
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ data=data,
+ mask=mask,
+ )
+
+ def reverse_edge(
+ self: Tensor,
+ reversed_names: set[str],
+ apply_parity: bool = False,
+ parity_exclude_names: set[str] | None = None,
+ ) -> Tensor:
+ """
+ Reverse some edge in the tensor.
+
+ Parameters
+ ----------
+ reversed_names : set[str]
+ The edge names of those edges which will be reversed
+ apply_parity : bool, default=False
+ Whether to apply parity caused by reversing edge, since reversing edge will generate half a sign.
+ parity_exclude_names : set[str], optional
+ The name of edges in the different behavior other than default set by apply_parity.
+
+ Returns
+ -------
+ Tensor
+ The tensor with edges reversed.
+ """
+ if not any(self.fermion):
+ return self
+ if parity_exclude_names is None:
+ parity_exclude_names = set()
+ assert all(name in self.names for name in reversed_names)
+ assert all(name in reversed_names for name in parity_exclude_names)
+ data = self.data
+ if any(self.fermion):
+ # Parity is xor of all valid reverse parity
+ parity = functools.reduce(
+ torch.logical_xor,
+ (
+ _utility.unsqueeze(edge.parity, current_index, self.rank)
+ # Loop over all edge
+ for current_index, [name, edge] in enumerate(zip(self.names, self.edges))
+ # Check if this edge is reversed and if this edge will be applied parity
+ if (name in reversed_names) and (apply_parity ^ (name in parity_exclude_names))),
+ torch.zeros([], dtype=torch.bool),
+ )
+ data = torch.where(parity, -data, +data)
+ return Tensor(
+ names=self.names,
+ edges=tuple(
+ Edge(
+ fermion=edge.fermion,
+ dtypes=edge.dtypes,
+ symmetry=edge.symmetry,
+ dimension=edge.dimension,
+ arrow=not edge.arrow if self.names[current_index] in reversed_names else edge.arrow,
+ parity=edge.parity,
+ ) for current_index, edge in enumerate(self.edges)),
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ data=data,
+ mask=self.mask,
+ )
+
+ @staticmethod
+ def _split_edge_get_name_group(
+ name: str,
+ split_map: dict[str, tuple[tuple[str, Edge], ...]],
+ ) -> list[str]:
+ split_group: tuple[tuple[str, Edge], ...] | None = split_map.get(name, None)
+ # pylint: disable=no-else-return
+ if split_group is None:
+ return [name]
+ else:
+ return [new_name for new_name, _ in split_group]
+
+ @staticmethod
+ def _split_edge_get_edge_group(
+ name: str,
+ edge: Edge,
+ split_map: dict[str, tuple[tuple[str, Edge], ...]],
+ ) -> list[Edge]:
+ split_group: tuple[tuple[str, Edge], ...] | None = split_map.get(name, None)
+ # pylint: disable=no-else-return
+ if split_group is None:
+ return [edge]
+ else:
+ return [new_edge for _, new_edge in split_group]
+
+ def split_edge(
+ self: Tensor,
+ split_map: dict[str, tuple[tuple[str, Edge], ...]],
+ apply_parity: bool = False,
+ parity_exclude_names: set[str] | None = None,
+ ) -> Tensor:
+ """
+ Split some edges in this tensor.
+
+ Parameters
+ ----------
+ split_map : dict[str, tuple[tuple[str, Edge], ...]]
+ The edge splitting plan.
+ apply_parity : bool, default=False
+ Whether to apply parity caused by splitting edge, since splitting edge will generate half a sign.
+ parity_exclude_names : set[str], optional
+ The name of edges in the different behavior other than default set by apply_parity.
+
+ Returns
+ -------
+ Tensor
+ The tensor with edges splitted.
+ """
+ if parity_exclude_names is None:
+ parity_exclude_names = set()
+ # Check the edge to be splitted can be merged by result edges.
+ assert all(
+ self.edge_by_name(name) == Edge.merge_edges(
+ tuple(new_edge for _, new_edge in split_result),
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ arrow=self.edge_by_name(name).arrow,
+ ) for name, split_result in split_map.items())
+ assert all(name in split_map for name in parity_exclude_names)
+ # Calculate the result components
+ names: tuple[str, ...] = tuple(
+ # Convert the list generated by reduce to tuple
+ functools.reduce(
+ # Concat list
+ operator.add,
+ # If name in split_map, use the new names list, otherwise use name itself as a length-1 list
+ (Tensor._split_edge_get_name_group(name, split_map) for name in self.names),
+ # Reduce from [] to concat all list
+ [],
+ ))
+ edges: tuple[Edge, ...] = tuple(
+ # Convert the list generated by reduce to tuple
+ functools.reduce(
+ # Concat list
+ operator.add,
+ # If name in split_map, use the new edges list, otherwise use the edge itself as a length-1 list
+ (Tensor._split_edge_get_edge_group(name, edge, split_map) for name, edge in zip(self.names, self.edges)
+ ),
+ # Reduce from [] to concat all list
+ [],
+ ))
+ new_size = [edge.dimension for edge in edges]
+ data = self.data.reshape(new_size)
+ mask = self.mask.reshape(new_size)
+
+ # Apply parity
+ if any(self.fermion):
+ # It is fermionic tensor, parity need to be applied
+ new_rank = len(names)
+ # Parity is xor of all valid splitting parity
+ parity = functools.reduce(
+ torch.logical_xor,
+ (
+ # Apply the parity for this splitting group here
+ # It is need to perform a total transpose on this split group
+ # {sum 0<=i tuple[str, ...]:
+ reversed_names: list[str] = []
+ for name in reversed(self.names):
+ found_in_merge_map: tuple[str, tuple[str, ...]] | None = next(
+ ((new_name, old_names) for new_name, old_names in merge_map.items() if name in old_names), None)
+ if found_in_merge_map is None:
+ # This edge will not be merged
+ reversed_names.append(name)
+ else:
+ new_name, old_names = found_in_merge_map
+ # This edge will be merged
+ if name == old_names[-1]:
+ # Add new edge only if it is the last edge
+ reversed_names.append(new_name)
+ # Some edge is merged from no edges, it should be considered
+ for new_name, old_names in merge_map.items():
+ if not old_names:
+ reversed_names.append(new_name)
+ return tuple(reversed(reversed_names))
+
+ @staticmethod
+ def _merge_edge_get_name_group(name: str, merge_map: dict[str, tuple[str, ...]]) -> list[str]:
+ merge_group: tuple[str, ...] | None = merge_map.get(name, None)
+ # pylint: disable=no-else-return
+ if merge_group is None:
+ return [name]
+ else:
+ return list(merge_group)
+
+ def merge_edge(
+ self: Tensor,
+ merge_map: dict[str, tuple[str, ...]],
+ apply_parity: bool = False,
+ parity_exclude_names: set[str] | None = None,
+ *,
+ merge_arrow: dict[str, bool] | None = None,
+ names: tuple[str, ...] | None = None,
+ ) -> Tensor:
+ """
+ Merge some edges in this tensor.
+
+ Parameters
+ ----------
+ merge_map : dict[str, tuple[str, ...]]
+ The edge merging plan.
+ apply_parity : bool, default=False
+ Whether to apply parity caused by merging edge, since merging edge will generate half a sign.
+ parity_exclude_names : set[str], optional
+ The name of edges in the different behavior other than default set by apply_parity.
+ merge_arrow : dict[str, bool], optional
+ For merging edge from zero edges, arrow cannot be identified automatically, it requires user set manually.
+ names : tuple[str, ...], optional
+ The edge order of the result, sometimes user may want to specify it manually.
+
+ Returns
+ -------
+ Tensor
+ The tensor with edges merged.
+ """
+ # pylint: disable=too-many-locals
+ if parity_exclude_names is None:
+ parity_exclude_names = set()
+ if merge_arrow is None:
+ merge_arrow = {}
+ assert all(all(old_name in self.names for old_name in old_names) for _, old_names in merge_map.items())
+ assert all(name in merge_map for name in parity_exclude_names)
+ # Two steps: 1. Transpose 2. Merge
+ if names is None:
+ names = self._merge_edge_get_names(merge_map)
+ transposed_names: tuple[str, ...] = tuple(
+ # Convert the list generated by reduce to tuple
+ functools.reduce(
+ # Concat list
+ operator.add,
+ # If name in merge_map, use the old names list, otherwise use name itself as a length-1 list
+ (Tensor._merge_edge_get_name_group(name, merge_map) for name in names),
+ # Reduce from [] to concat all list
+ [],
+ ))
+ transposed_tensor = self.transpose(transposed_names)
+ # Prepare a name to index map, since we need to look up it frequently later.
+ transposed_name_map = {name: index for index, name in enumerate(transposed_tensor.names)}
+ edges = tuple(
+ # If name is created by merging, call Edge.merge_edges to get the merged edge, otherwise get it directly
+ # from transposed_tensor.
+ Edge.merge_edges(
+ edges=tuple(transposed_tensor.edges[transposed_name_map[old_name]]
+ for old_name in merge_map[name]),
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ arrow=merge_arrow.get(name, None),
+ # If merging edge from zero edge, arrow need to be set manually
+ ) if name in merge_map else transposed_tensor.edges[transposed_name_map[name]]
+ # Loop over names
+ for name in names)
+ transposed_data = transposed_tensor.data
+ transposed_mask = transposed_tensor.mask
+
+ # Apply parity
+ if any(self.fermion):
+ # It is fermionic tensor, parity need to be applied
+ # Parity is xor of all valid merging parity
+ parity = functools.reduce(
+ torch.logical_xor,
+ (
+ # Apply the parity for this merging group here
+ # It is need to perform a total transpose on this merging group
+ # {sum 0<=i Tensor:
+ """
+ Contract two tensors.
+
+ Parameters
+ ----------
+ other : Tensor
+ Another tensor to be contracted.
+ contract_pairs : set[tuple[str, str]]
+ The pairs of edges to be contract between two tensors.
+ fuse_names : set[str], optional
+ The set of edges to be fuses.
+
+ Returns
+ -------
+ Tensor
+ The result contracted by two tensors.
+ """
+ # pylint: disable=too-many-locals
+ # Only same type tensor can be contracted.
+ assert self.same_type_with(other)
+
+ if fuse_names is None:
+ fuse_names = set()
+ # Fuse name should not have any symmetry
+ assert all(
+ all(_utility.zero_symmetry(sub_symmetry)
+ for sub_symmetry in self.edge_by_name(fuse_name).symmetry)
+ for fuse_name in fuse_names)
+
+ # Alias tensor
+ tensor_1: Tensor = self
+ tensor_2: Tensor = other
+
+ # Check if contract edge and fuse edge compatible
+ assert all(tensor_1.edge_by_name(name) == tensor_2.edge_by_name(name) for name in fuse_names)
+ assert all(
+ tensor_1.edge_by_name(name_1).conjugate() == tensor_2.edge_by_name(name_2)
+ for name_1, name_2 in contract_pairs)
+
+ # All tensor edges merged to three part: fuse edge, contract edge, free edge
+
+ # Contract of tensor has 5 step:
+ # 1. reverse arrow
+ # reverse all free edge and fuse edge to arrow False, without parity apply.
+ # reverse contract edge to two arrow: False(tensor_2) and True(tensor_1), only apply parity to one tensor.
+ # 2. merge edge
+ # merge all edge in the same part to one edge, only apply parity to contract edge of one tensor
+ # free edge and fuse edge will not be applied parity.
+ # 3. contract matrix
+ # call matrix multiply
+ # 4. split edge
+ # split edge merged in step 2, without apply parity
+ # 5. reverse arrow
+ # reverse arrow reversed in step 1, without parity apply
+
+ # Step 1
+ contract_names_1: set[str] = {name_1 for name_1, name_2 in contract_pairs}
+ contract_names_2: set[str] = {name_2 for name_1, name_2 in contract_pairs}
+ arrow_true_names_1: set[str] = {name for name, edge in zip(tensor_1.names, tensor_1.edges) if edge.arrow}
+ arrow_true_names_2: set[str] = {name for name, edge in zip(tensor_2.names, tensor_2.edges) if edge.arrow}
+
+ # tensor 1: contract_names & arrow_false | not contract_names & arrow_true -> contract_names ^ arrow_true
+ tensor_1 = tensor_1.reverse_edge(contract_names_1 ^ arrow_true_names_1, False,
+ contract_names_1 - arrow_true_names_1)
+ tensor_2 = tensor_2.reverse_edge(arrow_true_names_2, False, set())
+
+ # Step 2
+ free_edges_1: tuple[tuple[str, Edge], ...] = tuple((name, edge)
+ for name, edge in zip(tensor_1.names, tensor_1.edges)
+ if name not in fuse_names and name not in contract_names_1)
+ free_names_1: tuple[str, ...] = tuple(name for name, _ in free_edges_1)
+ free_edges_2: tuple[tuple[str, Edge], ...] = tuple((name, edge)
+ for name, edge in zip(tensor_2.names, tensor_2.edges)
+ if name not in fuse_names and name not in contract_names_2)
+ free_names_2: tuple[str, ...] = tuple(name for name, _ in free_edges_2)
+ # Check which tensor is bigger, and use it to determine the fuse and contract edge order.
+ ordered_fuse_edges: tuple[tuple[str, Edge], ...]
+ ordered_fuse_names: tuple[str, ...]
+ ordered_contract_names_1: tuple[str, ...]
+ ordered_contract_names_2: tuple[str, ...]
+ if tensor_1.data.nelement() > tensor_2.data.nelement():
+ # Tensor 1 larger, fit by tensor 1
+ ordered_fuse_edges = tuple(
+ (name, edge) for name, edge in zip(tensor_1.names, tensor_1.edges) if name in fuse_names)
+ ordered_fuse_names = tuple(name for name, _ in ordered_fuse_edges)
+
+ # pylint: disable=unnecessary-comprehension
+ contract_names_map = {name_1: name_2 for name_1, name_2 in contract_pairs}
+ ordered_contract_names_1 = tuple(name for name in tensor_1.names if name in contract_names_1)
+ ordered_contract_names_2 = tuple(contract_names_map[name] for name in ordered_contract_names_1)
+ else:
+ # Tensor 2 larger, fit by tensor 2
+ ordered_fuse_edges = tuple(
+ (name, edge) for name, edge in zip(tensor_2.names, tensor_2.edges) if name in fuse_names)
+ ordered_fuse_names = tuple(name for name, _ in ordered_fuse_edges)
+
+ contract_names_map = {name_2: name_1 for name_1, name_2 in contract_pairs}
+ ordered_contract_names_2 = tuple(name for name in tensor_2.names if name in contract_names_2)
+ ordered_contract_names_1 = tuple(contract_names_map[name] for name in ordered_contract_names_2)
+
+ put_contract_1_right: bool = next(
+ (name in contract_names_1 for name in reversed(tensor_1.names) if name not in fuse_names), True)
+ put_contract_2_right: bool = next(
+ (name in contract_names_2 for name in reversed(tensor_2.names) if name not in fuse_names), False)
+
+ tensor_1 = tensor_1.merge_edge(
+ {
+ "Free_1": free_names_1,
+ "Contract_1": ordered_contract_names_1,
+ "Fuse_1": ordered_fuse_names,
+ },
+ False,
+ {"Contract_1"},
+ merge_arrow={
+ "Free_1": False,
+ "Contract_1": True,
+ "Fuse_1": False,
+ },
+ names=("Fuse_1", "Free_1", "Contract_1") if put_contract_1_right else ("Fuse_1", "Contract_1", "Free_1"),
+ )
+ tensor_2 = tensor_2.merge_edge(
+ {
+ "Free_2": free_names_2,
+ "Contract_2": ordered_contract_names_2,
+ "Fuse_2": ordered_fuse_names,
+ },
+ False,
+ set(),
+ merge_arrow={
+ "Free_2": False,
+ "Contract_2": False,
+ "Fuse_2": False,
+ },
+ names=("Fuse_2", "Free_2", "Contract_2") if put_contract_2_right else ("Fuse_2", "Contract_2", "Free_2"),
+ )
+ # C[fuse, free1, free2] = A[fuse, free1 contract] B[fuse, contract free2]
+ assert tensor_1.edge_by_name("Fuse_1") == tensor_2.edge_by_name("Fuse_2")
+ assert tensor_1.edge_by_name("Contract_1").conjugate() == tensor_2.edge_by_name("Contract_2")
+
+ # Step 3
+ # The standard arrow is
+ # (0, False, True) (0, False, False)
+ # aka: (a b) (c d) (c+ b+) = (a d)
+ # since: EPR pair order is (False True)
+ # put_contract_1_right should be True
+ # put_contract_2_right should be False
+ # Every mismatch generate a sign
+ # Total sign is
+ # (!put_contract_1_right) ^ (put_contract_2_right) = put_contract_1_right == put_contract_2_right
+ dtype = torch.result_type(tensor_1.data, tensor_2.data)
+ data = torch.einsum(
+ "b" + ("ic" if put_contract_1_right else "ci") + ",b" + ("jc" if put_contract_2_right else "cj") + "->bij",
+ tensor_1.data.to(dtype=dtype), tensor_2.data.to(dtype=dtype))
+ if put_contract_1_right == put_contract_2_right:
+ data = torch.where(tensor_2.edge_by_name("Free_2").parity.reshape([1, 1, -1]), -data, +data)
+ tensor = Tensor(
+ names=("Fuse", "Free_1", "Free_2"),
+ edges=(tensor_1.edge_by_name("Fuse_1"), tensor_1.edge_by_name("Free_1"), tensor_2.edge_by_name("Free_2")),
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ data=data,
+ )
+
+ # Step 4
+ tensor = tensor.split_edge({
+ "Fuse": ordered_fuse_edges,
+ "Free_1": free_edges_1,
+ "Free_2": free_edges_2
+ }, False, set())
+
+ # Step 5
+ tensor = tensor.reverse_edge(
+ (arrow_true_names_1 - contract_names_1) | (arrow_true_names_2 - contract_names_2),
+ False,
+ set(),
+ )
+
+ return tensor
+
+ def _trace_group_edge(
+ self: Tensor,
+ trace_pairs: set[tuple[str, str]],
+ fuse_names: dict[str, tuple[str, str]],
+ ) -> tuple[
+ tuple[str, ...],
+ tuple[str, ...],
+ tuple[str, ...],
+ tuple[str, ...],
+ tuple[str, ...],
+ tuple[str, ...],
+ tuple[int, ...],
+ tuple[int, ...],
+ ]:
+ # pylint: disable=too-many-locals
+ # pylint: disable=unnecessary-comprehension
+ trace_map = {
+ old_name_1: old_name_2 for old_name_1, old_name_2 in trace_pairs
+ } | {
+ old_name_2: old_name_1 for old_name_1, old_name_2 in trace_pairs
+ }
+ fuse_map = {
+ old_name_1: (old_name_2, new_name) for new_name, [old_name_1, old_name_2] in fuse_names.items()
+ } | {
+ old_name_2: (old_name_1, new_name) for new_name, [old_name_1, old_name_2] in fuse_names.items()
+ }
+ reversed_trace_names_1: list[str] = []
+ reversed_trace_names_2: list[str] = []
+ reversed_fuse_names_1: list[str] = []
+ reversed_fuse_names_2: list[str] = []
+ reversed_free_names: list[str] = []
+ reversed_fuse_names_result: list[str] = []
+ reversed_free_index: list[int] = []
+ reversed_fuse_index_result: list[int] = []
+ added_names: set[str] = set()
+ for index, name in zip(reversed(range(self.rank)), reversed(self.names)):
+ if name not in added_names:
+ trace_name: str | None = trace_map.get(name, None)
+ fuse_name: tuple[str, str] | None = fuse_map.get(name, None)
+ if trace_name is not None:
+ reversed_trace_names_2.append(name)
+ reversed_trace_names_1.append(trace_name)
+ added_names.add(trace_name)
+ elif fuse_name is not None:
+ # fuse_name = another old name, new name
+ reversed_fuse_names_2.append(name)
+ reversed_fuse_names_1.append(fuse_name[0])
+ added_names.add(fuse_name[0])
+ reversed_fuse_names_result.append(fuse_name[1])
+ reversed_fuse_index_result.append(index)
+ else:
+ reversed_free_names.append(name)
+ reversed_free_index.append(index)
+ return (
+ tuple(reversed(reversed_trace_names_1)),
+ tuple(reversed(reversed_trace_names_2)),
+ tuple(reversed(reversed_fuse_names_1)),
+ tuple(reversed(reversed_fuse_names_2)),
+ tuple(reversed(reversed_free_names)),
+ tuple(reversed(reversed_fuse_names_result)),
+ tuple(reversed(reversed_free_index)),
+ tuple(reversed(reversed_fuse_index_result)),
+ )
+
+ def trace(
+ self: Tensor,
+ trace_pairs: set[tuple[str, str]],
+ fuse_names: dict[str, tuple[str, str]] | None = None,
+ ) -> Tensor:
+ """
+ Trace a tensor.
+
+ Parameters
+ ----------
+ trace_pairs : set[tuple[str, str]]
+ The pairs of edges to be traced
+ fuse_names : dict[str, tuple[str, str]]
+ The edges to be fused.
+
+ Returns
+ -------
+ Tensor
+ The traced tensor.
+ """
+ # pylint: disable=too-many-locals
+ if fuse_names is None:
+ fuse_names = {}
+ # Fuse names should not have any symmetry
+ assert all(
+ all(_utility.zero_symmetry(sub_symmetry)
+ for sub_symmetry in self.edge_by_name(old_name_1).symmetry)
+ for new_name, [old_name_1, old_name_2] in fuse_names.items())
+ # Fuse names should share the same edges
+ assert all(
+ self.edge_by_name(old_name_1) == self.edge_by_name(old_name_2)
+ for new_name, [old_name_1, old_name_2] in fuse_names.items())
+ # Trace edges should be compatible
+ assert all(
+ self.edge_by_name(old_name_1).conjugate() == self.edge_by_name(old_name_2)
+ for old_name_1, old_name_2 in trace_pairs)
+
+ # Split trace pairs and fuse names to two part before main part of trace.
+ [
+ trace_names_1,
+ trace_names_2,
+ fuse_names_1,
+ fuse_names_2,
+ free_names,
+ fuse_names_result,
+ free_index,
+ fuse_index_result,
+ ] = self._trace_group_edge(trace_pairs, fuse_names)
+
+ # Make alias
+ tensor = self
+
+ # Tensor edges merged to 5 parts: fuse edge 1, fuse edge 2, trace edge 1, trace edge 2, free edge
+ # Trace contains 5 step:
+ # 1. reverse all arrow to False except trace edge 1 to be True, only apply parity to one of two trace edge
+ # 2. merge all edge to 5 part, only apply parity to one of two trace edge
+ # 3. trace trivial tensor
+ # 4. split edge merged in step 2, without apply parity
+ # 5. reverse arrow reversed in step 1, without apply parity
+
+ # Step 1
+ arrow_true_names = {name for name, edge in zip(tensor.names, tensor.edges) if edge.arrow}
+ unordered_trace_names_1 = set(trace_names_1)
+ tensor = tensor.reverse_edge(unordered_trace_names_1 ^ arrow_true_names, False,
+ unordered_trace_names_1 - arrow_true_names)
+
+ # Step 2
+ free_edges: tuple[tuple[str, Edge], ...] = tuple(
+ (name, tensor.edges[index]) for name, index in zip(free_names, free_index))
+ fuse_edges_result: tuple[tuple[str, Edge], ...] = tuple(
+ (name, tensor.edges[index]) for name, index in zip(fuse_names_result, fuse_index_result))
+ tensor = tensor.merge_edge(
+ {
+ "Trace_1": trace_names_1,
+ "Trace_2": trace_names_2,
+ "Fuse_1": fuse_names_1,
+ "Fuse_2": fuse_names_2,
+ "Free": free_names,
+ },
+ False,
+ {"Trace_1"},
+ merge_arrow={
+ "Trace_1": True,
+ "Trace_2": False,
+ "Fuse_1": False,
+ "Fuse_2": False,
+ "Free": False,
+ },
+ names=("Trace_1", "Trace_2", "Fuse_1", "Fuse_2", "Free"),
+ )
+ # B[fuse, free] = A[trace, trace, fuse, fuse, free]
+ assert tensor.edges[2] == tensor.edges[3]
+ assert tensor.edges[0].conjugate() == tensor.edges[1]
+
+ # Step 3
+ # As tested, the order of edges in this einsum is not important
+ # ttffi->fi, fftti->fi, ffitt->fi, ttiff->if, ittff->if, ifftt->if
+ data = torch.einsum("ttffi->fi", tensor.data)
+ tensor = Tensor(
+ names=("Fuse", "Free"),
+ edges=(tensor.edges[2], tensor.edges[4]),
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ data=data,
+ )
+
+ # Step 4
+ tensor = tensor.split_edge({
+ "Fuse": fuse_edges_result,
+ "Free": free_edges,
+ }, False, set())
+
+ # Step 5
+ tensor = tensor.reverse_edge(
+ # Free edge with arrow true
+ {name for name in free_names if name in arrow_true_names} |
+ # New edge from fused edge with arrow true
+ {new_name for old_name, new_name in zip(fuse_names_1, fuse_names_result) if old_name in arrow_true_names},
+ False,
+ set(),
+ )
+
+ return tensor
+
+ def conjugate(self: Tensor, trivial_metric: bool = False) -> Tensor:
+ """
+ Get the conjugate of this tensor.
+
+ Parameters
+ ----------
+ trivial_metric : bool, default=False
+ Fermionic tensor in network may result in non positive definite metric, set this trivial_metric to True to
+ ensure the metric to be positive, but it breaks the associative law with tensor contract.
+
+ Returns
+ -------
+ Tensor
+ The conjugated tensor.
+ """
+ data = torch.conj(self.data)
+
+ # Usually, only a full transpose sign is applied.
+ # If trivial_metric is set True, parity in edges with arrow True is also applied.
+
+ # Apply parity
+ if any(self.fermion):
+ # It is fermionic tensor, parity need to be applied
+
+ # Parity is parity generated from a full transpose
+ # {sum 0<=i Edge:
+ # Used in matrix decomposition: SVD and QR
+ # It relies on decomposition of block tensor is also block tensor.
+ # Otherwise it cannot guess the correct edge
+ # QR
+ # Full rank case:
+ # QR has uniqueness with a diagonal unitary matrix freedom for full rank case,
+ # While diagonal unitary does not change the block condition. Since we know there is at least a decomposition
+ # result which is block matrix, we know all possible decomposition is blocked.
+ # Proof:
+ # shape of A is m * n
+ # if m >= n:
+ # A = [Q1 U1] [[R1] [0]] = [Q2 U2] [[R2] [0]]
+ # A is full rank => R1 and R2 are invertible
+ # Q1 R1 = Q2 R2 and (R1 R2 invertible) => Q2^dagger Q1 = R2 R1^-1, Q1^dagger Q2 = R1 R2^-1
+ # lemma: product of inverse of upper triangular matrix is also upper triangular.
+ # Q2^dagger Q1, Q1^dagger Q2 are upper triangular => Q2^dagger Q1 is upper triangular and lower triangular.
+ # => Q2^dagger Q1 is diagonal => Q2^dagger Q1 = R2 R1^-1 = S, where S is diagonal matrix.
+ # => Q1 = Q1 R1 R1^-1 = Q2 R2 R1^-1 = Q2 S => Q1 = Q2 S => S is diagonal unitary.
+ # At last, we have Q1 = Q2 S where S is a diagonal unitary matrix while S R1 = R2
+ # if m < n:
+ # A = Q1 [R1 N1] = Q2 [R2 N2], so we have Q1 R1 = Q2 R2
+ # This is the case for m = n, so Q1 = Q2 S, S R1 = R2.
+ # At last, Q1 N1 = Q2 S N1 = Q2 N2 implies S N1 = N2.
+ # Where S is diagonal unitary.
+ # Rank sufficient case:
+ # It is hard to get the conclusion. Program may break at this situation.
+ # SVD
+ # For non-singular case
+ # SVD has uniqueness with a blocked unitary matrix freedom, which preserves the singular value subspace.
+ # So edge guessing fails iff there is the same singular value crossing different quantum number.
+ # In this case, program may break.
+ # Proof:
+ # Let m <= n, since it is symmetric on the dimension.
+ # A = U1 S1 V1 => U2 S2 V2 => A A^dagger = U1 S1^2 U1^dagger = U2 S2^2 dagger U2
+ # The eigenvalue is unique in descending order, while singular value is non-negative real number.
+ # => S1 = S2 = S, and for eigenvector, U1 = U2 Q where Q is a unitary matrix that [Q S] = 0
+ # => U1 S V1 = U2 S V2 = U2 Q S V1 = U2 S Q V2 => S Q V2 = S V1, while S is non-singular, so Q V2 = V1.
+ # At last, U1 = U2 Q, S1 = S2, Q V1 = V2.
+ # For singular case
+ # It is not determined for singular part of unitary. It is similar to the non-similar case.
+ # But at last step, S Q V2 = S V1 => Q' V2 = V1, where Q' is the same to Q only in non-singular part.
+ # So, it does break blocks only if blocks has been broken by the same singular value.
+ # pylint: disable=invalid-name
+ m, n = matrix.size()
+ assert edge.dimension == m
+ argmax = torch.argmax(matrix, dim=0)
+ assert argmax.size() == (n,)
+ return Edge(
+ fermion=edge.fermion,
+ dtypes=edge.dtypes,
+ symmetry=tuple(_utility.neg_symmetry(sub_symmetry[argmax]) for sub_symmetry in edge.symmetry),
+ dimension=n,
+ arrow=arrow,
+ parity=edge.parity[argmax],
+ )
+
+ def _ensure_mask(self: Tensor) -> None:
+ """
+ Currently this function is only called from SVD decomposition. It ensure that element at mask False is very
+ small and set them exactly zero.
+
+ Any function other than SVD and QR would not break blocked tensor, while QR is implemented by givens rotation
+ which preserve the blocks, so there is not need to ensure mask there.
+ """
+ assert torch.allclose(torch.where(self.mask, torch.zeros([], dtype=self.dtype), self.data),
+ torch.zeros([], dtype=self.dtype))
+ self._data = torch.where(self.mask, self.data, torch.zeros([], dtype=self.dtype))
+
+ def svd(
+ self: Tensor,
+ free_names_u: set[str],
+ common_name_u: str,
+ common_name_v: str,
+ singular_name_u: str,
+ singular_name_v: str,
+ cut: int = -1,
+ ) -> tuple[Tensor, Tensor, Tensor]:
+ """
+ SVD decomposition a tensor. Because of the edge created by SVD is guessed based on the SVD result, the program
+ may break if there is repeated singular value which may result in non-blocked composition result.
+
+ Parameters
+ ----------
+ free_names_u : set[str]
+ Free names in U tensor of the result of SVD.
+ common_name_u, common_name_v, singular_name_u, singular_name_v : str
+ The name of generated edges.
+ cut : int, default=-1
+ The cut for the singular values.
+
+ Returns
+ -------
+ tuple[Tensor, Tensor, Tensor]
+ U, S, V tensor, the result of SVD.
+ """
+ # pylint: disable=too-many-arguments
+ # pylint: disable=too-many-locals
+
+ free_names_v = {name for name in self.names if name not in free_names_u}
+
+ assert all(name in self.names for name in free_names_u)
+ assert common_name_u not in free_names_u
+ assert common_name_v not in free_names_v
+
+ arrow_true_names = {name for name, edge in zip(self.names, self.edges) if edge.arrow}
+
+ tensor = self.reverse_edge(arrow_true_names, False, set())
+
+ ordered_free_edges_u: tuple[tuple[str, Edge], ...] = tuple(
+ (name, edge) for name, edge in zip(tensor.names, tensor.edges) if name in free_names_u)
+ ordered_free_edges_v: tuple[tuple[str, Edge], ...] = tuple(
+ (name, edge) for name, edge in zip(tensor.names, tensor.edges) if name in free_names_v)
+ ordered_free_names_u: tuple[str, ...] = tuple(name for name, _ in ordered_free_edges_u)
+ ordered_free_names_v: tuple[str, ...] = tuple(name for name, _ in ordered_free_edges_v)
+
+ put_v_right = next((name in free_names_v for name in reversed(tensor.names)), True)
+ tensor = tensor.merge_edge(
+ {
+ "SVD_U": ordered_free_names_u,
+ "SVD_V": ordered_free_names_v
+ },
+ False,
+ set(),
+ merge_arrow={
+ "SVD_U": False,
+ "SVD_V": False
+ },
+ names=("SVD_U", "SVD_V") if put_v_right else ("SVD_V", "SVD_U"),
+ )
+
+ # if self.fermion:
+ # data_1, data_s, data_2 = manual_svd(tensor.data, 1e-6)
+ # else:
+ # data_1, data_s, data_2 = torch.linalg.svd(tensor.data, full_matrices=False)
+ data_1, data_s, data_2 = torch.linalg.svd(tensor.data, full_matrices=False)
+
+ if cut != -1:
+ data_1 = data_1[:, :cut]
+ data_s = data_s[:cut]
+ data_2 = data_2[:cut, :]
+ data_s = torch.diag_embed(data_s)
+
+ free_edge_1 = tensor.edges[0]
+ common_edge_1 = Tensor._guess_edge(torch.abs(data_1), free_edge_1, True)
+ tensor_1 = Tensor(
+ names=("SVD_U", common_name_u) if put_v_right else ("SVD_V", common_name_v),
+ edges=(free_edge_1, common_edge_1),
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ data=data_1,
+ )
+ tensor_1._ensure_mask() # pylint: disable=protected-access
+ free_edge_2 = tensor.edges[1]
+ common_edge_2 = Tensor._guess_edge(torch.abs(data_2).transpose(0, 1), free_edge_2, False)
+ tensor_2 = Tensor(
+ names=(common_name_v, "SVD_V") if put_v_right else (common_name_u, "SVD_U"),
+ edges=(common_edge_2, free_edge_2),
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ data=data_2,
+ )
+ tensor_2._ensure_mask() # pylint: disable=protected-access
+ assert common_edge_1.conjugate() == common_edge_2
+ tensor_s = Tensor(
+ names=(singular_name_u, singular_name_v) if put_v_right else (singular_name_v, singular_name_u),
+ edges=(common_edge_2, common_edge_1),
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ data=data_s,
+ )
+
+ tensor_u = tensor_1 if put_v_right else tensor_2
+ tensor_v = tensor_2 if put_v_right else tensor_1
+
+ tensor_u = tensor_u.split_edge({"SVD_U": ordered_free_edges_u}, False, set())
+ tensor_v = tensor_v.split_edge({"SVD_V": ordered_free_edges_v}, False, set())
+
+ tensor_u = tensor_u.reverse_edge(arrow_true_names & free_names_u, False, set())
+ tensor_v = tensor_v.reverse_edge(arrow_true_names & free_names_v, False, set())
+
+ return tensor_u, tensor_s, tensor_v
+
+ def qr(
+ self: Tensor,
+ free_names_direction: str,
+ free_names: set[str],
+ common_name_q: str,
+ common_name_r: str,
+ ) -> tuple[Tensor, Tensor]:
+ """
+ QR decomposition on this tensor. Because of the edge created by QR is guessed based on the QR result, the
+ program may break if the tensor is rank deficient which may result in non-blocked composition result.
+
+ Parameters
+ ----------
+ free_names_direction : 'Q' | 'q' | 'R' | 'r'
+ Specify which direction the free_names will set
+ free_names : set[str]
+ The names of free edges after QR decomposition.
+ common_name_q, common_name_r : str
+ The names of edges created by QR decomposition.
+
+ Returns
+ -------
+ tuple[Tensor, Tensor]
+ Tensor Q and R, the result of QR decomposition.
+ """
+ # pylint: disable=invalid-name
+ # pylint: disable=too-many-locals
+
+ if free_names_direction in {'Q', 'q'}:
+ free_names_q = free_names
+ free_names_r = {name for name in self.names if name not in free_names}
+ elif free_names_direction in {'R', 'r'}:
+ free_names_r = free_names
+ free_names_q = {name for name in self.names if name not in free_names}
+
+ assert all(name in self.names for name in free_names)
+ assert common_name_q not in free_names_q
+ assert common_name_r not in free_names_r
+
+ arrow_true_names = {name for name, edge in zip(self.names, self.edges) if edge.arrow}
+
+ tensor = self.reverse_edge(arrow_true_names, False, set())
+
+ ordered_free_edges_q: tuple[tuple[str, Edge], ...] = tuple(
+ (name, edge) for name, edge in zip(tensor.names, tensor.edges) if name in free_names_q)
+ ordered_free_edges_r: tuple[tuple[str, Edge], ...] = tuple(
+ (name, edge) for name, edge in zip(tensor.names, tensor.edges) if name in free_names_r)
+ ordered_free_names_q: tuple[str, ...] = tuple(name for name, _ in ordered_free_edges_q)
+ ordered_free_names_r: tuple[str, ...] = tuple(name for name, _ in ordered_free_edges_r)
+
+ # pytorch does not provide LQ, so always put r right here
+ tensor = tensor.merge_edge(
+ {
+ "QR_Q": ordered_free_names_q,
+ "QR_R": ordered_free_names_r
+ },
+ False,
+ set(),
+ merge_arrow={
+ "QR_Q": False,
+ "QR_R": False
+ },
+ names=("QR_Q", "QR_R"),
+ )
+
+ # if self.fermion:
+ # # Blocked tensor, use Givens rotation
+ # data_q, data_r = givens_qr(tensor.data)
+ # else:
+ # # Non-blocked tensor, use Householder reflection
+ # data_q, data_r = torch.linalg.qr(tensor.data, mode="reduced")
+ data_q, data_r = torch.linalg.qr(tensor.data, mode="reduced")
+
+ free_edge_q = tensor.edges[0]
+ common_edge_q = Tensor._guess_edge(torch.abs(data_q), free_edge_q, True)
+ tensor_q = Tensor(
+ names=("QR_Q", common_name_q),
+ edges=(free_edge_q, common_edge_q),
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ data=data_q,
+ )
+ tensor_q._ensure_mask() # pylint: disable=protected-access
+ free_edge_r = tensor.edges[1]
+ # common_edge_r = Tensor._guess_edge(torch.abs(data_r).transpose(0, 1), free_edge_r, False)
+ # Sometimes R matrix maybe singular, guess edge will return arbitary symmetry, so do not use guessed edge.
+ common_edge_r = common_edge_q.conjugate()
+ tensor_r = Tensor(
+ names=(common_name_r, "QR_R"),
+ edges=(common_edge_r, free_edge_r),
+ fermion=self.fermion,
+ dtypes=self.dtypes,
+ data=data_r,
+ )
+ tensor_r._ensure_mask() # pylint: disable=protected-access
+ assert common_edge_q.conjugate() == common_edge_r
+
+ tensor_q = tensor_q.split_edge({"QR_Q": ordered_free_edges_q}, False, set())
+ tensor_r = tensor_r.split_edge({"QR_R": ordered_free_edges_r}, False, set())
+
+ tensor_q = tensor_q.reverse_edge(arrow_true_names & free_names_q, False, set())
+ tensor_r = tensor_r.reverse_edge(arrow_true_names & free_names_r, False, set())
+
+ return tensor_q, tensor_r
+
+ def identity(self: Tensor, pairs: set[tuple[str, str]]) -> Tensor:
+ """
+ Get the identity tensor with same shape to this tensor.
+
+ Parameters
+ ----------
+ pairs : set[tuple[str, str]]
+ The pair of edge names to specify the relation among edges to set identity tensor.
+
+ Returns
+ -------
+ Tensor
+ The result identity tensor.
+ """
+ # The order of edges before setting identity should be (False True)
+ # Merge tensor directly to two edge, set identity and split it directly.
+ # When splitting, only apply parity to one part of edges
+
+ # pylint: disable=unnecessary-comprehension
+ pairs_map = {name_1: name_2 for name_1, name_2 in pairs} | {name_2: name_1 for name_1, name_2 in pairs}
+ added_names: set[str] = set()
+ reversed_names_1: list[str] = []
+ reversed_names_2: list[str] = []
+ for name in reversed(self.names):
+ if name not in added_names:
+ another_name = pairs_map[name]
+ reversed_names_2.append(name)
+ reversed_names_1.append(another_name)
+ added_names.add(another_name)
+ names_1 = tuple(reversed(reversed_names_1))
+ names_2 = tuple(reversed(reversed_names_2))
+ # unordered_names_1 = set(names_1)
+ unordered_names_2 = set(names_2)
+
+ arrow_true_names = {name for name, edge in zip(self.names, self.edges) if edge.arrow}
+
+ # Two edges, arrow of two edges are (False, True)
+ tensor = self.reverse_edge(unordered_names_2 ^ arrow_true_names, False, unordered_names_2 - arrow_true_names)
+
+ edges_1 = tuple((name, tensor.edge_by_name(name)) for name in names_1)
+ edges_2 = tuple((name, tensor.edge_by_name(name)) for name in names_2)
+
+ tensor = tensor.merge_edge(
+ {
+ "Identity_1": names_1,
+ "Identity_2": names_2
+ },
+ False,
+ {"Identity_2"},
+ merge_arrow={
+ "Identity_1": False,
+ "Identity_2": True
+ },
+ names=("Identity_1", "Identity_2"),
+ )
+
+ tensor = Tensor(
+ names=tensor.names,
+ edges=tensor.edges,
+ fermion=tensor.fermion,
+ dtypes=tensor.dtypes,
+ data=torch.eye(*tensor.data.size()),
+ mask=tensor.mask,
+ )
+
+ tensor = tensor.split_edge({"Identity_1": edges_1, "Identity_2": edges_2}, False, {"Identity_2"})
+
+ tensor = tensor.reverse_edge(unordered_names_2 ^ arrow_true_names, False, unordered_names_2 - arrow_true_names)
+
+ return tensor
+
+ def exponential(self: Tensor, pairs: set[tuple[str, str]]) -> Tensor:
+ """
+ Get the exponential tensor of this tensor.
+
+ Parameters
+ ----------
+ pairs : set[tuple[str, str]]
+ The pair of edge names to specify the relation among edges to calculate exponential tensor.
+
+ Returns
+ -------
+ Tensor
+ The result exponential tensor.
+ """
+ # The order of edges before setting exponential should be (False True)
+ # Merge tensor directly to two edge, set exponential and split it directly.
+ # When splitting, only apply parity to one part of edges
+
+ unordered_names_1 = {name_1 for name_1, name_2 in pairs}
+ unordered_names_2 = {name_2 for name_1, name_2 in pairs}
+ if self.names and self.names[-1] in unordered_names_1:
+ unordered_names_1, unordered_names_2 = unordered_names_2, unordered_names_1
+ # pylint: disable=unnecessary-comprehension
+ pairs_map = {name_1: name_2 for name_1, name_2 in pairs} | {name_2: name_1 for name_1, name_2 in pairs}
+ reversed_names_1: list[str] = []
+ reversed_names_2: list[str] = []
+ for name in reversed(self.names):
+ if name in unordered_names_2:
+ another_name = pairs_map[name]
+ reversed_names_2.append(name)
+ reversed_names_1.append(another_name)
+ names_1 = tuple(reversed(reversed_names_1))
+ names_2 = tuple(reversed(reversed_names_2))
+
+ arrow_true_names = {name for name, edge in zip(self.names, self.edges) if edge.arrow}
+
+ # Two edges, arrow of two edges are (False, True)
+ tensor = self.reverse_edge(unordered_names_2 ^ arrow_true_names, False, unordered_names_2 - arrow_true_names)
+
+ edges_1 = tuple((name, tensor.edge_by_name(name)) for name in names_1)
+ edges_2 = tuple((name, tensor.edge_by_name(name)) for name in names_2)
+
+ tensor = tensor.merge_edge(
+ {
+ "Exponential_1": names_1,
+ "Exponential_2": names_2
+ },
+ False,
+ {"Exponential_2"},
+ merge_arrow={
+ "Exponential_1": False,
+ "Exponential_2": True
+ },
+ names=("Exponential_1", "Exponential_2"),
+ )
+
+ tensor = Tensor(
+ names=tensor.names,
+ edges=tensor.edges,
+ fermion=tensor.fermion,
+ dtypes=tensor.dtypes,
+ data=torch.linalg.matrix_exp(tensor.data),
+ mask=tensor.mask,
+ )
+
+ tensor = tensor.split_edge({"Exponential_1": edges_1, "Exponential_2": edges_2}, False, {"Exponential_2"})
+
+ tensor = tensor.reverse_edge(unordered_names_2 ^ arrow_true_names, False, unordered_names_2 - arrow_true_names)
+
+ return tensor
diff --git a/tests/test_compat.py b/tests/test_compat.py
new file mode 100644
index 000000000..7ed453903
--- /dev/null
+++ b/tests/test_compat.py
@@ -0,0 +1,122 @@
+"Test compat"
+
+import torch
+import tat
+from tat import compat as TAT
+
+# pylint: disable=missing-function-docstring
+# pylint: disable=invalid-name
+# pylint: disable=singleton-comparison
+
+# It is strange, but pylint complains function args too many. So add it here
+# pylint: disable=too-many-function-args
+
+
+def test_edge_from_dimension() -> None:
+ assert TAT.No.Edge(4) == tat.Edge(dimension=4)
+ assert TAT.Fermi.Edge(4) == tat.Edge(fermion=(True,),
+ symmetry=(torch.tensor([0, 0, 0, 0], dtype=torch.int),),
+ arrow=False)
+ assert TAT.Z2.Edge(4) == tat.Edge(symmetry=(torch.tensor([False, False, False, False]),))
+
+
+def test_edge_from_segments() -> None:
+ assert TAT.Z2.Edge([
+ (False, 2),
+ (True, 3),
+ ]) == tat.Edge(symmetry=(torch.tensor([False, False, True, True, True]),),)
+ assert TAT.Fermi.Edge([
+ (-1, 1),
+ (0, 2),
+ (+1, 3),
+ ], True) == tat.Edge(
+ symmetry=(torch.tensor([-1, 0, 0, +1, +1, +1], dtype=torch.int),),
+ arrow=True,
+ fermion=(True,),
+ )
+ assert TAT.FermiFermi.Edge([
+ ((-1, -2), 1),
+ ((0, +1), 2),
+ ((+1, 0), 3),
+ ], True) == tat.Edge(
+ symmetry=(
+ torch.tensor([-1, 0, 0, +1, +1, +1], dtype=torch.int),
+ torch.tensor([-2, +1, +1, 0, 0, 0], dtype=torch.int),
+ ),
+ arrow=True,
+ fermion=(True, True),
+ )
+
+
+def test_edge_from_segments_without_dimension() -> None:
+ assert TAT.Z2.Edge([False, True]) == tat.Edge(symmetry=(torch.tensor([False, True]),))
+ assert TAT.Fermi.Edge([-1, 0, +1], True) == tat.Edge(
+ symmetry=(torch.tensor([-1, 0, +1], dtype=torch.int),),
+ arrow=True,
+ fermion=(True,),
+ )
+ assert TAT.FermiFermi.Edge([
+ (-1, -2),
+ (0, +1),
+ (+1, 0),
+ ], True) == tat.Edge(
+ symmetry=(torch.tensor([-1, 0, +1], dtype=torch.int), torch.tensor([-2, +1, 0], dtype=torch.int)),
+ arrow=True,
+ fermion=(True, True),
+ )
+
+
+def test_edge_from_tuple() -> None:
+ assert TAT.FermiFermi.Edge(([
+ ((-1, -2), 1),
+ ((0, +1), 2),
+ ((+1, 0), 3),
+ ], True)) == tat.Edge(
+ symmetry=(
+ torch.tensor([-1, 0, 0, +1, +1, +1], dtype=torch.int),
+ torch.tensor([-2, +1, +1, 0, 0, 0], dtype=torch.int),
+ ),
+ arrow=True,
+ fermion=(True, True),
+ )
+ assert TAT.FermiFermi.Edge(([
+ (-1, -2),
+ (0, +1),
+ (+1, 0),
+ ], True)) == tat.Edge(
+ symmetry=(torch.tensor([-1, 0, +1], dtype=torch.int), torch.tensor([-2, +1, 0], dtype=torch.int)),
+ arrow=True,
+ fermion=(True, True),
+ )
+
+
+def test_tensor() -> None:
+ a = TAT.FermiZ2.D.Tensor(["i", "j"], [
+ [(-1, False), (-1, True), (0, True), (0, False)],
+ [(+1, True), (+1, False), (0, False), (0, True)],
+ ])
+ b = tat.Tensor(
+ (
+ "i",
+ "j",
+ ),
+ (
+ tat.Edge(
+ fermion=(True, False),
+ symmetry=(
+ torch.tensor([-1, -1, 0, 0], dtype=torch.int),
+ torch.tensor([False, True, True, False]),
+ ),
+ arrow=False,
+ ),
+ tat.Edge(
+ fermion=(True, False),
+ symmetry=(
+ torch.tensor([+1, +1, 0, 0], dtype=torch.int),
+ torch.tensor([True, False, False, True]),
+ ),
+ arrow=False,
+ ),
+ ),
+ )
+ assert a.same_shape_with(b, allow_transpose=False)
diff --git a/tests/test_create_tensor.py b/tests/test_create_tensor.py
new file mode 100644
index 000000000..f784539c8
--- /dev/null
+++ b/tests/test_create_tensor.py
@@ -0,0 +1,103 @@
+"Test create tensor"
+
+import torch
+import tat
+
+# pylint: disable=missing-function-docstring
+# pylint: disable=invalid-name
+# pylint: disable=singleton-comparison
+
+
+def test_create_tensor() -> None:
+ a = tat.Tensor(
+ (
+ "i",
+ "j",
+ ),
+ (
+ tat.Edge(symmetry=(torch.tensor([False, False, True]),), fermion=(True,), arrow=True),
+ tat.Edge(symmetry=(torch.tensor([False, False, False, True, True]),), fermion=(True,), arrow=False),
+ ),
+ )
+ assert a.rank == 2
+ assert a.names == ("i", "j")
+ assert a.edges[0] == tat.Edge(symmetry=(torch.tensor([False, False, True]),), fermion=(True,), arrow=True)
+ assert a.edges[1] == tat.Edge(symmetry=(torch.tensor([False, False, False, True, True]),),
+ fermion=(True,),
+ arrow=False)
+ assert a.edges[0] == a.edge_by_name("i")
+ assert a.edges[1] == a.edge_by_name("j")
+
+
+def test_tensor_get_set_item() -> None:
+ a = tat.Tensor(
+ (
+ "i",
+ "j",
+ ),
+ (
+ tat.Edge(symmetry=(torch.tensor([False, False, True]),), fermion=(True,), arrow=True),
+ tat.Edge(symmetry=(torch.tensor([False, False, False, True, True]),), fermion=(True,), arrow=False),
+ ),
+ )
+ a[{"i": 0, "j": 0}] = 1
+ assert a[0, 0] == 1
+ a["i":2, "j":3] = 2 # type: ignore[misc]
+ assert a[{"i": 2, "j": 3}] == 2
+ try:
+ a[2, 0] = 3
+ assert False
+ except IndexError:
+ pass
+ assert a["i":2, "j":0] == 0 # type: ignore[misc]
+
+ b = tat.Tensor(
+ (
+ "i",
+ "j",
+ ),
+ (
+ tat.Edge(symmetry=(torch.tensor([0, 0, -1]),), fermion=(False,)),
+ tat.Edge(symmetry=(torch.tensor([0, 0, 0, +1, +1]),), fermion=(False,)),
+ ),
+ )
+ b[{"i": 0, "j": 0}] = 1
+ assert b[0, 0] == 1
+ b["i":2, "j":3] = 2 # type: ignore[misc]
+ assert b[{"i": 2, "j": 3}] == 2
+ try:
+ b[2, 0] = 3
+ assert False
+ except IndexError:
+ pass
+ assert b["i":2, "j":0] == 0 # type: ignore[misc]
+
+
+def test_create_randn_tensor() -> None:
+ a = tat.Tensor(
+ ("i", "j"),
+ (
+ tat.Edge(symmetry=(torch.tensor([False, True]),)),
+ tat.Edge(symmetry=(torch.tensor([False, True]),)),
+ ),
+ dtype=torch.float16,
+ ).randn_()
+ assert a.dtype == torch.float16
+ assert a[0, 0] != 0
+ assert a[1, 1] != 0
+ assert a[0, 1] == 0
+ assert a[1, 0] == 0
+
+ b = tat.Tensor(
+ ("i", "j"),
+ (
+ tat.Edge(symmetry=(torch.tensor([False, False]), torch.tensor([0, -1]))),
+ tat.Edge(symmetry=(torch.tensor([False, False]), torch.tensor([0, +1]))),
+ ),
+ dtype=torch.float16,
+ ).randn_()
+ assert b.dtype == torch.float16
+ assert b[0, 0] != 0
+ assert b[1, 1] != 0
+ assert b[0, 1] == 0
+ assert b[1, 0] == 0
diff --git a/tests/test_edge.py b/tests/test_edge.py
new file mode 100644
index 000000000..aa6f457f8
--- /dev/null
+++ b/tests/test_edge.py
@@ -0,0 +1,39 @@
+"Test edge"
+
+import torch
+from tat import Edge
+
+# pylint: disable=missing-function-docstring
+# pylint: disable=invalid-name
+# pylint: disable=singleton-comparison
+
+
+def test_create_edge_and_basic() -> None:
+ a = Edge(dimension=5)
+ assert a.arrow == False
+ assert a.dimension == 5
+ b = Edge(symmetry=(torch.tensor([False, False, True, True]),))
+ assert b.arrow == False
+ assert b.dimension == 4
+ c = Edge(fermion=(False, True), symmetry=(torch.tensor([False, True]), torch.tensor([False, True])), arrow=True)
+ assert c.arrow == True
+ assert c.dimension == 2
+
+
+def test_edge_conjugate_and_equal() -> None:
+ a = Edge(fermion=(False, True), symmetry=(torch.tensor([False, True]), torch.tensor([0, 1])), arrow=True)
+ b = Edge(fermion=(False, True), symmetry=(torch.tensor([False, True]), torch.tensor([0, -1])), arrow=False)
+ assert a.conjugate() == b
+ assert a != 2
+
+
+def test_repr() -> None:
+ a = Edge(fermion=(False, True), symmetry=(torch.tensor([False, True]), torch.tensor([0, 1])), arrow=True)
+ repr_a = "Edge(dimension=2, arrow=True, fermion=(False,True), symmetry=([False,True],[0,1]))"
+ assert repr_a == repr(a)
+ b = Edge(symmetry=(torch.tensor([False, True]), torch.tensor([0, 1])))
+ repr_b = "Edge(dimension=2, symmetry=([False,True],[0,1]))"
+ assert repr_b == repr(b)
+ c = Edge(dimension=4)
+ repr_c = "Edge(dimension=4)"
+ assert repr_c == repr(c)
diff --git a/tests/test_qr.py b/tests/test_qr.py
new file mode 100644
index 000000000..5085cc261
--- /dev/null
+++ b/tests/test_qr.py
@@ -0,0 +1,59 @@
+"Test QR"
+
+import torch
+from tat._qr import givens_qr, householder_qr
+
+# pylint: disable=missing-function-docstring
+# pylint: disable=invalid-name
+
+
+def check_givens(A: torch.Tensor) -> None:
+ m, n = A.size()
+ Q, R = givens_qr(A)
+ assert torch.allclose(A, Q @ R)
+ assert torch.allclose(Q.H @ Q, torch.eye(min(m, n), dtype=A.dtype, device=A.device))
+ grad_check = torch.autograd.gradcheck(
+ givens_qr,
+ A,
+ eps=1e-8,
+ atol=1e-4,
+ )
+ assert grad_check
+
+
+def test_qr_real_givens() -> None:
+ check_givens(torch.randn(7, 5, dtype=torch.float64, requires_grad=True))
+ check_givens(torch.randn(5, 5, dtype=torch.float64, requires_grad=True))
+ check_givens(torch.randn(5, 7, dtype=torch.float64, requires_grad=True))
+
+
+def test_qr_complex_givens() -> None:
+ check_givens(torch.randn(7, 5, dtype=torch.complex128, requires_grad=True))
+ check_givens(torch.randn(5, 5, dtype=torch.complex128, requires_grad=True))
+ check_givens(torch.randn(5, 7, dtype=torch.complex128, requires_grad=True))
+
+
+def check_householder(A: torch.Tensor) -> None:
+ m, n = A.size()
+ Q, R = householder_qr(A)
+ assert torch.allclose(A, Q @ R)
+ assert torch.allclose(Q.H @ Q, torch.eye(min(m, n), dtype=A.dtype, device=A.device))
+ grad_check = torch.autograd.gradcheck(
+ householder_qr,
+ A,
+ eps=1e-8,
+ atol=1e-4,
+ )
+ assert grad_check
+
+
+def test_qr_real_householder() -> None:
+ check_householder(torch.randn(7, 5, dtype=torch.float64, requires_grad=True))
+ check_householder(torch.randn(5, 5, dtype=torch.float64, requires_grad=True))
+ check_householder(torch.randn(5, 7, dtype=torch.float64, requires_grad=True))
+
+
+def test_qr_complex_householder() -> None:
+ check_householder(torch.randn(7, 5, dtype=torch.complex128, requires_grad=True))
+ check_householder(torch.randn(5, 5, dtype=torch.complex128, requires_grad=True))
+ check_householder(torch.randn(5, 7, dtype=torch.complex128, requires_grad=True))
diff --git a/tests/test_svd.py b/tests/test_svd.py
new file mode 100644
index 000000000..14610352e
--- /dev/null
+++ b/tests/test_svd.py
@@ -0,0 +1,40 @@
+"Test SVD"
+
+import torch
+from tat._svd import svd
+
+# pylint: disable=missing-function-docstring
+# pylint: disable=invalid-name
+
+
+def svd_func(A: torch.Tensor) -> torch.Tensor:
+ U, S, V = svd(A, 1e-10)
+ return U @ torch.diag(S).to(dtype=A.dtype) @ V
+
+
+def check_svd(A: torch.Tensor) -> None:
+ m, n = A.size()
+ U, S, V = svd(A, 1e-10)
+ assert torch.allclose(U @ torch.diag(S.to(dtype=A.dtype)) @ V, A)
+ assert torch.allclose(U.H @ U, torch.eye(min(m, n), dtype=A.dtype, device=A.device))
+ assert torch.allclose(V @ V.H, torch.eye(min(m, n), dtype=A.dtype, device=A.device))
+ grad_check = torch.autograd.gradcheck(
+ svd_func,
+ A,
+ eps=1e-8,
+ atol=1e-4,
+ nondet_tol=1e-10,
+ )
+ assert grad_check
+
+
+def test_svd_real() -> None:
+ check_svd(torch.randn(7, 5, dtype=torch.float64, requires_grad=True))
+ check_svd(torch.randn(5, 5, dtype=torch.float64, requires_grad=True))
+ check_svd(torch.randn(5, 7, dtype=torch.float64, requires_grad=True))
+
+
+def test_svd_complex() -> None:
+ check_svd(torch.randn(7, 5, dtype=torch.complex128, requires_grad=True))
+ check_svd(torch.randn(5, 5, dtype=torch.complex128, requires_grad=True))
+ check_svd(torch.randn(5, 7, dtype=torch.complex128, requires_grad=True))