Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipVinc committed Jan 25, 2024
1 parent c6e1872 commit e5319d8
Showing 1 changed file with 83 additions and 52 deletions.
135 changes: 83 additions & 52 deletions netket_fidelity/infidelity/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,93 +19,124 @@ def InfidelityOperator(
sample_Upsi=False,
):
r"""
Operator I_op computing the infidelity I among two variational states |ψ⟩ and |Φ⟩ as:
Operator I_op computing the infidelity I among two variational states
:math:`|\psi\rangle` and :math:`|\phi\rangle` as:
.. math::
I = 1 - |⟨ψ|Φ⟩|^2 / ⟨ψ|ψ⟩ ⟨Φ|Φ⟩ = 1 - ⟨ψ|I_op|ψ⟩ / ⟨ψ|ψ⟩
I = 1 - \frac{|⟨\Psi|\Phi⟩|^2 }{ ⟨\Psi|\Psi⟩ ⟨\Phi|\Phi⟩ } = 1 - \frac{⟨\Psi|\hat{I}_{op}|\Psi⟩ }{ ⟨\Psi|\Psi⟩ }
where:
.. math::
.. math::
I_op = |Φ⟩⟨Φ| / ⟨Φ|Φ⟩
I_{op} = \frac {|\Phi\rangle\langle\Phi| }{ \langle\Phi|\Phi\rangle }
The state |Φ⟩ can be an autonomous state |Φ⟩ =|ϕ⟩ or an operator U applied to it, namely
|Φ⟩ = U|ϕ⟩. I_op is defined by the state |ϕ⟩ (called target) and, possibly, by the operator U.
If U is not passed, it is assumed |Φ⟩ =|ϕ⟩.
The state :math:`|\phi\rangle` can be an autonomous state :math:`|\Phi\rangle = |\phi\rangle`
or an operator :math:`U` applied to it, namely
:math:`|\Phi\rangle = U|\phi\rangle`. :math:`I_{op}` is defined by the
state :math:`|\phi\rangle` (called target) and, possibly, by the operator
:math:`U`. If :math:`U` is not specified, it is assumed :math:`|\Phi\rangle = |\phi\rangle`.
The Monte Carlo estimator of I is:
..math::
.. math::
I = \mathbb{E}_{χ}[ I_{loc}(\sigma,\eta) ] =
\mathbb{E}_{χ}\left[\frac{⟨\sigma|\Phi⟩ ⟨\eta|\Psi⟩}{⟨σ|\Psi⟩ ⟨η|\Phi⟩}\right]
where the sampled probability distribution :math:`χ` is defined as:
.. math::
\chi(\sigma, \eta) = \frac{|\psi(\sigma)|^2 |\Phi(\eta)|^2}{
\langle\Psi|\Psi\rangle \langle\Phi|\Phi\rangle}.
In practice, since I is a real quantity, :math:`\rm{Re}[I_{loc}(\sigma,\eta)]`
is used. This estimator can be utilized both when :math:`|\Phi\rangle =|\phi\rangle` and
when :math:`|\Phi\rangle =U|\phi\rangle`, with :math:`U` a (unitary or non-unitary) operator.
In the second case, we have to sample from :math:`U|\phi\rangle` and this is implemented in
the function :class:`netket_fidelity.infidelity.InfidelityUPsi` .
This works only with the operators provdided in the package.
We remark that sampling from :math:`U|\phi\rangle` requires to compute connected elements of
:math:`U` and so is more expensive than sampling from an autonomous state.
The choice of this estimator is specified by passing :code:`sample_Upsi=True`,
while the flag argument :code:`is_unitary` indicates whether :math:`U` is unitary or not.
I = \mathbb{E}_{χ}[ I_loc(σ,η) ] = \mathbb{E}_{χ}[ ⟨σ|Φ⟩ ⟨η|ψ⟩ / ⟨σ|ψ⟩ ⟨η|Φ⟩ ]
If :math:`U` is unitary, the following alternative estimator can be used:
where χ(σ, η) = |Ψ(σ)|^2 |Φ(η)|^2 / ⟨ψ|ψ⟩ ⟨Φ|Φ⟩. In practice, since I is a real quantity, Re{I_loc(σ,η)}
is used. This estimator can be utilized both when |Φ⟩ =|ϕ⟩ and when |Φ⟩ = U|ϕ⟩, with U a (unitary or
non-unitary) operator. In the second case, we have to sample from U|ϕ⟩ and this is implemented in
the function :ref:`jax.:ref:`InfidelityUPsi`. This works only with the operators provdided in the package.
We remark that sampling from U|ϕ⟩ requires to compute connected elements of U and so is more expensive
than sampling from an autonomous state. The choice of this estimator is specified by passing
`sample_Upsi=True`, while the flag argument `is_unitary` indicates whether U is unitary or not.
.. math::
I = \mathbb{E}_{χ'}\left[ I_{loc}(\sigma, \eta) \right] =
\mathbb{E}_{χ}\left[\frac{\langle\sigma|U|\phi\rangle \langle\eta|\psi\rangle}{
\langle\sigma|U^{\dagger}|\psi\rangle ⟨\eta|\phi⟩} \right].
If U is unitary, the following alternative estimator can be used:
where the sampled probability distribution :math:`\chi` is defined as:
..math::
.. math::
I = \mathbb{E}_{χ'}[ I_loc(σ,η) ] = \mathbb{E}_{χ}[ ⟨σ|U|ϕ⟩ ⟨η|ψ⟩ / ⟨σ|U^{\dagger}|ψ⟩ ⟨η|ϕ⟩ ].
\chi'(\sigma, \eta) = \frac{|\psi(\sigma)|^2 |\phi(\eta)|^2}{
\langle\Psi|\Psi\rangle \langle\phi|\phi\rangle}.
where χ'(σ, η) = |Ψ(σ)|^2 |ϕ(η)|^2 / ⟨ψ|ψ⟩ ⟨ϕ|ϕ⟩. This estimator is more efficient since it does not
require to sample from U|ϕ⟩, but only from |ϕ⟩. This choice of the estimator is the default and it works only
with `is_unitary==True` (besides `sample_Upsi=False`). When |Φ⟩ = |ϕ⟩ the two estimators coincides.
This estimator is more efficient since it does not require to sample from
:math:`U|\phi\rangle`, but only from :math:`|\phi\rangle`.
This choice of the estimator is the default and it works only
with `is_unitary==True` (besides :code:`sample_Upsi=False` ).
When :math:`|\Phi⟩ = |\phi⟩` the two estimators coincides.
To reduce the variance of the estimator, the Control Variates (CV) method can be applied. This consists
in modifying the estimator into:
..math::
.. math::
I_loc^{CV} = Re{I_loc(σ,η)} - c (|1 - I_loc(σ,η)^2| - 1)
I_{loc}^{CV} = \rm{Re}\left[I_{loc}(\sigma, \eta)\right] - c \left(|1 - I_{loc}(\sigma, \eta)^2| - 1\right)
where c ∈ \mathbb{R}. The constant c is chosen to minimize the variance of I_loc^{CV} as:
where :math:`c ∈ \mathbb{R}`. The constant c is chosen to minimize the variance of
:math:`I_{loc}^{CV}` as:
..math::
.. math::
c* = Cov_{χ}[ |1-I_loc|^2, Re{1-I_loc}] / Var_{χ}[ |1-I_loc|^2 ],
c* = \frac{\rm{Cov}_{χ}\left[ |1-I_{loc}|^2, \rm{Re}\left[1-I_{loc}\right]\right]}{
\rm{Var}_{χ}\left[ |1-I_{loc}|^2\right] },
where Cov[..., ...] indicates the covariance and Var[...] the variance. In the relevant limit
|Ψ⟩ →|Φ⟩, we have c*→-1/2. The value -1/2 is adopted as default value for c in the infidelity
where :math:`\rm{Cov}\left\cdot, \cdot\right]` indicates the covariance and :math:`\rm{Var}\left[\cdot\right]` the variance.
In the relevant limit :math:`|\Psi⟩ \rightarrow|\Phi⟩`, we have :math:`c^\star \rightarrow -1/2`. The value :math:`-1/2` is
adopted as default value for c in the infidelity
estimator. To not apply CV, set c=0.
Args:
target: target variational state |ϕ⟩.
U: operator U.
U_dagger: dagger operator U^{\dagger}.
target: target variational state :math:`|\phi⟩` .
U: operator :math:`\hat{U}`.
U_dagger: dagger operator :math:`\hat{U^\dagger}`.
cv_coeff: Control Variates coefficient c.
is_unitary: flag specifiying the unitarity of U. If True with `sample_Upsi=False`, the second estimator is used.
is_unitary: flag specifiying the unitarity of :math:`\hat{U}`. If True with
:code:`sample_Upsi=False`, the second estimator is used.
dtype: The dtype of the output of expectation value and gradient.
sample_Upsi: flag specifiying whether to sample from |ϕ⟩ or from U|ϕ⟩. If False with `is_unitary=False`, an error occurs.
sample_Upsi: flag specifiying whether to sample from |ϕ⟩ or from U|ϕ⟩. If False with `is_unitary=False` , an error occurs.
Returns:
Infidelity operator for which computing expected value and gradient.
Example:
import netket as nk
import netket_fidelity as nkf
hi = nk.hilbert.Spin(0.5, 4)
sampler = nk.sampler.MetropolisLocal(hilbert=hi, n_chains_per_rank=16)
model = nk.models.RBM(alpha=1, param_dtype=complex)
target_vstate = nk.vqs.MCState(sampler=sampler, model=model, n_samples=100)
# To optimise the overlap with |ϕ⟩
I_op = nkf.InfidelityOperator(target_vstate)
# To optimise the overlap with U|ϕ⟩ by sampling from |ψ⟩ and |ϕ⟩
U = nkf.operator.Rx(0.3)
I_op = nkf.InfidelityOperator(target_vstate, U=U, is_unitary=True)
# To optimise the overlap with U|ϕ⟩ by sampling from |ψ⟩ and U|ϕ⟩
I_op = nkf.InfidelityOperator(target_vstate, U=U, sample_Upsi=True)
Examples:
>>> import netket as nk
>>> import netket_fidelity as nkf
>>>
>>> hi = nk.hilbert.Spin(0.5, 4)
>>> sampler = nk.sampler.MetropolisLocal(hilbert=hi, n_chains_per_rank=16)
>>> model = nk.models.RBM(alpha=1, param_dtype=complex)
>>> target_vstate = nk.vqs.MCState(sampler=sampler, model=model, n_samples=100)
>>>
>>> # To optimise the overlap with |ϕ⟩
>>> I_op = nkf.InfidelityOperator(target_vstate)
>>>
>>> # To optimise the overlap with U|ϕ⟩ by sampling from |ψ⟩ and |ϕ⟩
>>> U = nkf.operator.Rx(0.3)
>>> I_op = nkf.InfidelityOperator(target_vstate, U=U, is_unitary=True)
>>>
>>> # To optimise the overlap with U|ϕ⟩ by sampling from |ψ⟩ and U|ϕ⟩
>>> I_op = nkf.InfidelityOperator(target_vstate, U=U, sample_Upsi=True)
"""
if U is None:
Expand Down

0 comments on commit e5319d8

Please sign in to comment.