Skip to content

Commit

Permalink
sampling first
Browse files Browse the repository at this point in the history
因为不像PEPS,我们没有aux tensor, 完全可以先采样完
  • Loading branch information
hzhangxyz committed Jul 29, 2024
1 parent a847fb5 commit fcb6d82
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions tetragono/tetragono/sampling_neural_state/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,9 @@ def gradient_descent(
sampling_total_step = sampling_total_step
sampling_batch_size = len(sampling_configurations)
elif sampling_method == "direct":
sampling = DirectSampling(state, sweep_alpha, len(sampling_configurations))
sampling = DirectSampling(state, sweep_alpha, sampling_total_step)
sampling_total_step = sampling_total_step
sampling_batch_size = len(sampling_configurations)
sampling_batch_size = sampling_total_step
elif sampling_method == "ergodic":
sampling = ErgodicSampling(state, len(sampling_configurations))
sampling_total_step = sampling.total_step
Expand All @@ -323,22 +323,34 @@ def gradient_descent(
raise ValueError("Invalid sampling method")
# Sampling run
sampling_total_step_rank = sampling_total_step // mpi_size + (mpi_rank < sampling_total_step % mpi_size)
cons = []
amps = []
weis = []
i = 1
for sampling_step in range(0, sampling_total_step_rank, sampling_batch_size):
show("start sampling" + "." * i)
i += 1
batch_size = min(sampling_batch_size, sampling_total_step_rank - sampling_step)
configurations, amplitudes, weights = sampling()
configurations_to_be_saved = configurations
configurations = configurations[:batch_size]
amplitudes = amplitudes[:batch_size]
weights = weights[:batch_size]
configurations, amplitudes, weights = unique_sampling((configurations, amplitudes, weights))
if need_energy_observer:
configuration_pool.append((configurations, amplitudes, weights))
slice_size = 16 # The must configuration observe at the same times
n = len(weights)
for i in range(0, n, slice_size):
observer(configurations[i:i + slice_size], amplitudes[i:i + slice_size], weights[i:i + slice_size])
process = (sampling_step + i / n * sampling_batch_size) / sampling_total_step_rank
show(f"sampling {100*process:.2f}%, energy={observer.energy}")
cons.append(configurations)
amps.append(amplitudes)
weis.append(weights)
configurations = torch.cat(cons)
amplitudes = torch.cat(amps)
weights = torch.cat(weis)
configurations, amplitudes, weights = unique_sampling((configurations, amplitudes, weights))
if need_energy_observer:
configuration_pool.append((configurations, amplitudes, weights))
slice_size = 1024 # The must configuration observe at the same times
n = len(weights)
for i in range(0, n, slice_size):
observer(configurations[i:i + slice_size], amplitudes[i:i + slice_size], weights[i:i + slice_size])
process = (sampling_step + i / n * sampling_batch_size) / sampling_total_step_rank
show(f"sampling {100*process:.2f}%, energy={observer.energy}")
# Save configuration
if mpi_rank < sampling_total_step and sampling_method != "ergodic":
new_configurations = configurations_to_be_saved.cpu().numpy()
Expand Down

0 comments on commit fcb6d82

Please sign in to comment.