diff --git a/tetragono/tetragono/sampling_neural_state/gradient.py b/tetragono/tetragono/sampling_neural_state/gradient.py index 894d1719..edecaac0 100644 --- a/tetragono/tetragono/sampling_neural_state/gradient.py +++ b/tetragono/tetragono/sampling_neural_state/gradient.py @@ -312,9 +312,9 @@ def gradient_descent( sampling_total_step = sampling_total_step sampling_batch_size = len(sampling_configurations) elif sampling_method == "direct": - sampling = DirectSampling(state, sweep_alpha, len(sampling_configurations)) + sampling = DirectSampling(state, sweep_alpha, sampling_total_step) sampling_total_step = sampling_total_step - sampling_batch_size = len(sampling_configurations) + sampling_batch_size = sampling_total_step elif sampling_method == "ergodic": sampling = ErgodicSampling(state, len(sampling_configurations)) sampling_total_step = sampling.total_step @@ -323,22 +323,34 @@ def gradient_descent( raise ValueError("Invalid sampling method") # Sampling run sampling_total_step_rank = sampling_total_step // mpi_size + (mpi_rank < sampling_total_step % mpi_size) + cons = [] + amps = [] + weis = [] + i = 1 for sampling_step in range(0, sampling_total_step_rank, sampling_batch_size): + show("start sampling" + "." * i) + i += 1 batch_size = min(sampling_batch_size, sampling_total_step_rank - sampling_step) configurations, amplitudes, weights = sampling() configurations_to_be_saved = configurations configurations = configurations[:batch_size] amplitudes = amplitudes[:batch_size] weights = weights[:batch_size] - configurations, amplitudes, weights = unique_sampling((configurations, amplitudes, weights)) - if need_energy_observer: - configuration_pool.append((configurations, amplitudes, weights)) - slice_size = 16 # The must configuration observe at the same times - n = len(weights) - for i in range(0, n, slice_size): - observer(configurations[i:i + slice_size], amplitudes[i:i + slice_size], weights[i:i + slice_size]) - process = (sampling_step + i / n * sampling_batch_size) / sampling_total_step_rank - show(f"sampling {100*process:.2f}%, energy={observer.energy}") + cons.append(configurations) + amps.append(amplitudes) + weis.append(weights) + configurations = torch.cat(cons) + amplitudes = torch.cat(amps) + weights = torch.cat(weis) + configurations, amplitudes, weights = unique_sampling((configurations, amplitudes, weights)) + if need_energy_observer: + configuration_pool.append((configurations, amplitudes, weights)) + slice_size = 1024 # The must configuration observe at the same times + n = len(weights) + for i in range(0, n, slice_size): + observer(configurations[i:i + slice_size], amplitudes[i:i + slice_size], weights[i:i + slice_size]) + process = (sampling_step + i / n * sampling_batch_size) / sampling_total_step_rank + show(f"sampling {100*process:.2f}%, energy={observer.energy}") # Save configuration if mpi_rank < sampling_total_step and sampling_method != "ergodic": new_configurations = configurations_to_be_saved.cpu().numpy()