Skip to content

Commit

Permalink
[tet.py] Add support for plain direct sampling.
Browse files Browse the repository at this point in the history
Add support for plain direct sampling, which is similar to direct sampling, but
it only uses the configurations, it does not use weights and multiplicities,
which would be set to one trivially.
  • Loading branch information
hzhangxyz committed Sep 11, 2024
1 parent 2e6c101 commit 2668d0c
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.org
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
** [[https://github.com/USTC-TNS/TNSP/compare/v0.3.17...dev][Unreleased]]

*** Added
+ *tetragono*: Add plain direct sampling method for neural network state, which is similar to direct sampling, but only
using the configurations themselves, without the weights and multiplicities.
*** Changed
+ *tetragono*: Do not measure observers by term by default in neural network state, which need to be toggled by function
=set_term_observer= manually.
Expand Down
2 changes: 1 addition & 1 deletion tetragono/tetragono/sampling_neural_state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@

if torch_installed:
from .state import Configuration, SamplingNeuralState
from .sampling import SweepSampling, ErgodicSampling, DirectSampling
from .sampling import SweepSampling, ErgodicSampling, DirectSampling, PlainDirectSampling
from .observer import Observer
9 changes: 6 additions & 3 deletions tetragono/tetragono/sampling_neural_state/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy as np
import torch
import TAT
from ..sampling_neural_state import SamplingNeuralState, Observer, SweepSampling, DirectSampling, ErgodicSampling
from ..sampling_neural_state import SamplingNeuralState, Observer, SweepSampling, DirectSampling, PlainDirectSampling, ErgodicSampling
from ..utility import (show, showln, mpi_rank, mpi_size, seed_differ, write_to_file, get_imported_function,
bcast_number, bcast_buffer, write_configurations, allreduce_number)

Expand Down Expand Up @@ -135,8 +135,8 @@ def gradient_descent(
# About sampling
expect_unique_sampling_step : int, optional
The expect unique sampling step count.
sampling_method : "sweep" | "direct" | "ergodic", default="sweep"
The sampling method, which could be one of sweep, direct and ergodic.
sampling_method : "sweep" | "direct" | "plain_direct" | "ergodic", default="sweep"
The sampling method, which could be one of sweep, direct, plain_direct and ergodic.
sampling_configurations : object, default=zero_configuration
The initial configuration used in sweep sampling methods. All sampling methods will save the last configuration
into this sampling_configurations variable. If the function is invoked from gm_run(_g) interface, this parameter
Expand Down Expand Up @@ -286,6 +286,9 @@ def gradient_descent(
elif sampling_method == "direct":
sampling = DirectSampling(state, sampling_total_step, sweep_alpha)
configurations_pool, amplitudes_pool, weights_pool, multiplicities_pool = sampling()
elif sampling_method == "plain_direct":
sampling = PlainDirectSampling(state, sampling_total_step, sweep_alpha)
configurations_pool, amplitudes_pool, weights_pool, multiplicities_pool = sampling()
elif sampling_method == "ergodic":
sampling = ErgodicSampling(state)
configurations_pool, amplitudes_pool, weights_pool, multiplicities_pool = sampling()
Expand Down
25 changes: 25 additions & 0 deletions tetragono/tetragono/sampling_neural_state/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,31 @@ def __call__(self):
return scatter_sampling(self.owner, configurations, amplitudes, weights, multiplicities)


class PlainDirectSampling:
"""
Plain Direct sampling.
"""

__slots__ = ["owner", "total_size", "alpha"]

def __init__(self, owner, total_size, alpha):
self.owner = owner
self.total_size = total_size
self.alpha = alpha

def __call__(self):
if mpi_rank == 0:
configurations, amplitudes, weights, multiplicities = self.owner.network.generate(
self.total_size,
self.alpha,
)
weights = torch.ones_like(weights)
multiplicities = torch.ones_like(multiplicities)
else:
configurations, amplitudes, weights, multiplicities = None, None, None, None
return scatter_sampling(self.owner, configurations, amplitudes, weights, multiplicities)


class ErgodicSampling:
"""
Ergodic sampling.
Expand Down

0 comments on commit 2668d0c

Please sign in to comment.