From 2668d0c133f9ca9bbe71a92f3e91b3badee7f91b Mon Sep 17 00:00:00 2001 From: Hao Zhang Date: Wed, 11 Sep 2024 14:40:18 +0800 Subject: [PATCH] [tet.py] Add support for plain direct sampling. Add support for plain direct sampling, which is similar to direct sampling, but it only uses the configurations, it does not use weights and multiplicities, which would be set to one trivially. --- CHANGELOG.org | 2 ++ .../sampling_neural_state/__init__.py | 2 +- .../sampling_neural_state/gradient.py | 9 ++++--- .../sampling_neural_state/sampling.py | 25 +++++++++++++++++++ 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.org b/CHANGELOG.org index f95cc421..dd59254a 100644 --- a/CHANGELOG.org +++ b/CHANGELOG.org @@ -3,6 +3,8 @@ ** [[https://github.com/USTC-TNS/TNSP/compare/v0.3.17...dev][Unreleased]] *** Added ++ *tetragono*: Add plain direct sampling method for neural network state, which is similar to direct sampling, but only + using the configurations themselves, without the weights and multiplicities. *** Changed + *tetragono*: Do not measure observers by term by default in neural network state, which need to be toggled by function =set_term_observer= manually. diff --git a/tetragono/tetragono/sampling_neural_state/__init__.py b/tetragono/tetragono/sampling_neural_state/__init__.py index 690c0b94..d1b305e8 100644 --- a/tetragono/tetragono/sampling_neural_state/__init__.py +++ b/tetragono/tetragono/sampling_neural_state/__init__.py @@ -24,5 +24,5 @@ if torch_installed: from .state import Configuration, SamplingNeuralState - from .sampling import SweepSampling, ErgodicSampling, DirectSampling + from .sampling import SweepSampling, ErgodicSampling, DirectSampling, PlainDirectSampling from .observer import Observer diff --git a/tetragono/tetragono/sampling_neural_state/gradient.py b/tetragono/tetragono/sampling_neural_state/gradient.py index 6f2fda62..7a3d28a6 100644 --- a/tetragono/tetragono/sampling_neural_state/gradient.py +++ b/tetragono/tetragono/sampling_neural_state/gradient.py @@ -20,7 +20,7 @@ import numpy as np import torch import TAT -from ..sampling_neural_state import SamplingNeuralState, Observer, SweepSampling, DirectSampling, ErgodicSampling +from ..sampling_neural_state import SamplingNeuralState, Observer, SweepSampling, DirectSampling, PlainDirectSampling, ErgodicSampling from ..utility import (show, showln, mpi_rank, mpi_size, seed_differ, write_to_file, get_imported_function, bcast_number, bcast_buffer, write_configurations, allreduce_number) @@ -135,8 +135,8 @@ def gradient_descent( # About sampling expect_unique_sampling_step : int, optional The expect unique sampling step count. - sampling_method : "sweep" | "direct" | "ergodic", default="sweep" - The sampling method, which could be one of sweep, direct and ergodic. + sampling_method : "sweep" | "direct" | "plain_direct" | "ergodic", default="sweep" + The sampling method, which could be one of sweep, direct, plain_direct and ergodic. sampling_configurations : object, default=zero_configuration The initial configuration used in sweep sampling methods. All sampling methods will save the last configuration into this sampling_configurations variable. If the function is invoked from gm_run(_g) interface, this parameter @@ -286,6 +286,9 @@ def gradient_descent( elif sampling_method == "direct": sampling = DirectSampling(state, sampling_total_step, sweep_alpha) configurations_pool, amplitudes_pool, weights_pool, multiplicities_pool = sampling() + elif sampling_method == "plain_direct": + sampling = PlainDirectSampling(state, sampling_total_step, sweep_alpha) + configurations_pool, amplitudes_pool, weights_pool, multiplicities_pool = sampling() elif sampling_method == "ergodic": sampling = ErgodicSampling(state) configurations_pool, amplitudes_pool, weights_pool, multiplicities_pool = sampling() diff --git a/tetragono/tetragono/sampling_neural_state/sampling.py b/tetragono/tetragono/sampling_neural_state/sampling.py index 7ee231cc..a374752e 100644 --- a/tetragono/tetragono/sampling_neural_state/sampling.py +++ b/tetragono/tetragono/sampling_neural_state/sampling.py @@ -150,6 +150,31 @@ def __call__(self): return scatter_sampling(self.owner, configurations, amplitudes, weights, multiplicities) +class PlainDirectSampling: + """ + Plain Direct sampling. + """ + + __slots__ = ["owner", "total_size", "alpha"] + + def __init__(self, owner, total_size, alpha): + self.owner = owner + self.total_size = total_size + self.alpha = alpha + + def __call__(self): + if mpi_rank == 0: + configurations, amplitudes, weights, multiplicities = self.owner.network.generate( + self.total_size, + self.alpha, + ) + weights = torch.ones_like(weights) + multiplicities = torch.ones_like(multiplicities) + else: + configurations, amplitudes, weights, multiplicities = None, None, None, None + return scatter_sampling(self.owner, configurations, amplitudes, weights, multiplicities) + + class ErgodicSampling: """ Ergodic sampling.