Skip to content

Commit

Permalink
临时禁用分项测量
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhangxyz committed Sep 6, 2024
1 parent 3d1877a commit 142e791
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions tetragono/tetragono/sampling_neural_state/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
return False
buffer = []
for name, observers in self._observer.items():
for positions in observers:
buffer.append(self._result_reweight[name][positions])
buffer.append(self._result_reweight_square[name][positions])
buffer.append(self._result_square_reweight_square[name][positions])
#for name, observers in self._observer.items():
# for positions in observers:
# buffer.append(self._result_reweight[name][positions])
# buffer.append(self._result_reweight_square[name][positions])
# buffer.append(self._result_square_reweight_square[name][positions])
buffer.append(self._count)
buffer.append(self._total_weight)
buffer.append(self._total_weight_square)
Expand All @@ -104,11 +104,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self._total_weight_square = buffer.pop()
self._total_weight = buffer.pop()
self._count = buffer.pop()
for name, observers in reversed(self._observer.items()):
for positions in reversed(observers):
self._result_square_reweight_square[name][positions] = buffer.pop()
self._result_reweight_square[name][positions] = buffer.pop()
self._result_reweight[name][positions] = buffer.pop()
#for name, observers in reversed(self._observer.items()):
# for positions in reversed(observers):
# self._result_square_reweight_square[name][positions] = buffer.pop()
# self._result_reweight_square[name][positions] = buffer.pop()
# self._result_reweight[name][positions] = buffer.pop()

if self._enable_gradient:
allreduce_buffer(self._Delta)
Expand Down Expand Up @@ -381,11 +381,11 @@ def __call__(self, configurations, amplitudes, weights, multiplicities):
self._total_log_ws += multiplicity * amplitude.abs().log().item()

for name, observers in self._observer.items():
for positions in observers:
to_save = result[batch_index][name][positions].real
self._result_reweight[name][positions] += multiplicity * to_save * reweight
self._result_reweight_square[name][positions] += multiplicity * to_save * reweight**2
self._result_square_reweight_square[name][positions] += multiplicity * to_save**2 * reweight**2
# for positions in observers:
# to_save = result[batch_index][name][positions].real
# self._result_reweight[name][positions] += multiplicity * to_save * reweight
# self._result_reweight_square[name][positions] += multiplicity * to_save * reweight**2
# self._result_square_reweight_square[name][positions] += multiplicity * to_save**2 * reweight**2
to_save = whole_result[batch_index][name].real
self._whole_result_reweight[name] += multiplicity * to_save * reweight
self._whole_result_reweight_square[name] += multiplicity * to_save * reweight**2
Expand Down

0 comments on commit 142e791

Please sign in to comment.