Skip to content

Commit

Permalink
[tet.py] Do not measure observers by term by defaults in nn state.
Browse files Browse the repository at this point in the history
Sum over terms is slow when term number of Hamiltonian is large, so
disable it by default. If users want to measure by term, users need to
toggle it by `set_term_observer`.
  • Loading branch information
hzhangxyz committed Sep 11, 2024
1 parent c6ba115 commit 2e6c101
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 20 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.org
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

*** Added
*** Changed
+ *tetragono*: Do not measure observers by term by default in neural network state, which need to be toggled by function
=set_term_observer= manually.
*** Deprecated
*** Removed
*** Fixed
Expand Down
65 changes: 45 additions & 20 deletions tetragono/tetragono/sampling_neural_state/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Observer():
"""

__slots__ = [
"owner", "_observer", "_enable_gradient", "_enable_natural", "_start", "_result_reweight",
"owner", "_observer", "_term_observer", "_enable_gradient", "_enable_natural", "_start", "_result_reweight",
"_result_reweight_square", "_result_square_reweight_square", "_count", "_total_weight", "_total_weight_square",
"_total_log_ws", "_whole_result_reweight", "_whole_result_reweight_square",
"_whole_result_square_reweight_square", "_total_imaginary_energy_reweight", "_Delta", "_EDelta", "_Deltas"
Expand All @@ -42,17 +42,17 @@ def __enter__(self):
self._result_reweight = {
name: {
positions: 0.0 for positions, observer in observers.items()
} for name, observers in self._observer.items()
} for name, observers in self._observer.items() if name in self._term_observer
}
self._result_reweight_square = {
name: {
positions: 0.0 for positions, observer in observers.items()
} for name, observers in self._observer.items()
} for name, observers in self._observer.items() if name in self._term_observer
}
self._result_square_reweight_square = {
name: {
positions: 0.0 for positions, observer in observers.items()
} for name, observers in self._observer.items()
} for name, observers in self._observer.items() if name in self._term_observer
}
self._count = 0
self._total_weight = 0.0
Expand All @@ -77,10 +77,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return False
buffer = []
for name, observers in self._observer.items():
for positions in observers:
buffer.append(self._result_reweight[name][positions])
buffer.append(self._result_reweight_square[name][positions])
buffer.append(self._result_square_reweight_square[name][positions])
if name in self._term_observer:
for positions in observers:
buffer.append(self._result_reweight[name][positions])
buffer.append(self._result_reweight_square[name][positions])
buffer.append(self._result_square_reweight_square[name][positions])
buffer.append(self._count)
buffer.append(self._total_weight)
buffer.append(self._total_weight_square)
Expand All @@ -105,10 +106,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self._total_weight = buffer.pop()
self._count = buffer.pop()
for name, observers in reversed(self._observer.items()):
for positions in reversed(observers):
self._result_square_reweight_square[name][positions] = buffer.pop()
self._result_reweight_square[name][positions] = buffer.pop()
self._result_reweight[name][positions] = buffer.pop()
if name in self._term_observer:
for positions in reversed(observers):
self._result_square_reweight_square[name][positions] = buffer.pop()
self._result_reweight_square[name][positions] = buffer.pop()
self._result_reweight[name][positions] = buffer.pop()

if self._enable_gradient:
allreduce_buffer(self._Delta)
Expand Down Expand Up @@ -143,6 +145,7 @@ def __init__(
# The observables need to measure.
# dict[str, dict[tuple[tuple[int, int, int], ...], Tensor]]
self._observer = {}
self._term_observer = set()
self._enable_gradient = False
self._enable_natural = False

Expand Down Expand Up @@ -178,7 +181,24 @@ def __init__(
if enable_natural_gradient:
self.enable_natural_gradient()

def add_observer(self, name, observers):
def set_term_observer(self, name, value):
"""
Set to observe specific observer by term or not.
Parameters
----------
name : str
The observer set name.
value : bool
Whether to observe it by term.
"""
assert name in self._observer
if value and name not in self._term_observer:
self._term_observer.add(name)
if not value and name in self._term_observer:
self._term_observer.remove(name)

def add_observer(self, name, observers, term_observer=False):
"""
Add an observer set into this observer object, cannot add observer once observer started.
Expand All @@ -188,6 +208,8 @@ def add_observer(self, name, observers):
This observer set name.
observers : dict[tuple[tuple[int, int, int], ...], Tensor]
The observer map.
term_observer : bool, default=False
Whether to observe the observers by term.
"""
if self._start:
raise RuntimeError("Cannot enable hole after sampling start")
Expand Down Expand Up @@ -219,6 +241,8 @@ def add_observer(self, name, observers):
)

self._observer[name] = result
if term_observer:
self._term_observer.add(name)

def _fermi_sign(self, x, y, indices):
if self.owner.op_pool is None:
Expand Down Expand Up @@ -332,7 +356,7 @@ def __call__(self, configurations, amplitudes, weights, multiplicities):
result = [{
name: {
positions: 0.0 for positions, observer in observers.items()
} for name, observers in self._observer.items()
} for name, observers in self._observer.items() if name in self._term_observer
} for _ in range(batch_size)]
whole_result = [{name: 0.0 for name in self._observer} for _ in range(batch_size)]

Expand Down Expand Up @@ -381,11 +405,12 @@ def __call__(self, configurations, amplitudes, weights, multiplicities):
self._total_log_ws += multiplicity * amplitude.abs().log().item()

for name, observers in self._observer.items():
for positions in observers:
to_save = result[batch_index][name][positions].real
self._result_reweight[name][positions] += multiplicity * to_save * reweight
self._result_reweight_square[name][positions] += multiplicity * to_save * reweight**2
self._result_square_reweight_square[name][positions] += multiplicity * to_save**2 * reweight**2
if name in self._term_observer:
for positions in observers:
to_save = result[batch_index][name][positions].real
self._result_reweight[name][positions] += multiplicity * to_save * reweight
self._result_reweight_square[name][positions] += multiplicity * to_save * reweight**2
self._result_square_reweight_square[name][positions] += multiplicity * to_save**2 * reweight**2
to_save = whole_result[batch_index][name].real
self._whole_result_reweight[name] += multiplicity * to_save * reweight
self._whole_result_reweight_square[name] += multiplicity * to_save * reweight**2
Expand Down Expand Up @@ -492,7 +517,7 @@ def result(self):
self._result_reweight_square[name][positions],
self._result_square_reweight_square[name][positions])
for positions in data
} for name, data in self._observer.items()
} for name, data in self._observer.items() if name in self._term_observer
}

@property
Expand Down

0 comments on commit 2e6c101

Please sign in to comment.