diff --git a/tetragono/tetragono/sampling_neural_state/observer.py b/tetragono/tetragono/sampling_neural_state/observer.py index 38933ce0..51fab1de 100644 --- a/tetragono/tetragono/sampling_neural_state/observer.py +++ b/tetragono/tetragono/sampling_neural_state/observer.py @@ -76,11 +76,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is not None: return False buffer = [] - for name, observers in self._observer.items(): - for positions in observers: - buffer.append(self._result_reweight[name][positions]) - buffer.append(self._result_reweight_square[name][positions]) - buffer.append(self._result_square_reweight_square[name][positions]) + #for name, observers in self._observer.items(): + # for positions in observers: + # buffer.append(self._result_reweight[name][positions]) + # buffer.append(self._result_reweight_square[name][positions]) + # buffer.append(self._result_square_reweight_square[name][positions]) buffer.append(self._count) buffer.append(self._total_weight) buffer.append(self._total_weight_square) @@ -104,11 +104,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._total_weight_square = buffer.pop() self._total_weight = buffer.pop() self._count = buffer.pop() - for name, observers in reversed(self._observer.items()): - for positions in reversed(observers): - self._result_square_reweight_square[name][positions] = buffer.pop() - self._result_reweight_square[name][positions] = buffer.pop() - self._result_reweight[name][positions] = buffer.pop() + #for name, observers in reversed(self._observer.items()): + # for positions in reversed(observers): + # self._result_square_reweight_square[name][positions] = buffer.pop() + # self._result_reweight_square[name][positions] = buffer.pop() + # self._result_reweight[name][positions] = buffer.pop() if self._enable_gradient: allreduce_buffer(self._Delta) @@ -381,11 +381,11 @@ def __call__(self, configurations, amplitudes, weights, multiplicities): self._total_log_ws += multiplicity * amplitude.abs().log().item() for name, observers in self._observer.items(): - for positions in observers: - to_save = result[batch_index][name][positions].real - self._result_reweight[name][positions] += multiplicity * to_save * reweight - self._result_reweight_square[name][positions] += multiplicity * to_save * reweight**2 - self._result_square_reweight_square[name][positions] += multiplicity * to_save**2 * reweight**2 + # for positions in observers: + # to_save = result[batch_index][name][positions].real + # self._result_reweight[name][positions] += multiplicity * to_save * reweight + # self._result_reweight_square[name][positions] += multiplicity * to_save * reweight**2 + # self._result_square_reweight_square[name][positions] += multiplicity * to_save**2 * reweight**2 to_save = whole_result[batch_index][name].real self._whole_result_reweight[name] += multiplicity * to_save * reweight self._whole_result_reweight_square[name] += multiplicity * to_save * reweight**2