diff --git a/CHANGELOG.org b/CHANGELOG.org index 20a6f3d2..f95cc421 100644 --- a/CHANGELOG.org +++ b/CHANGELOG.org @@ -4,6 +4,8 @@ *** Added *** Changed ++ *tetragono*: Do not measure observers by term by default in neural network state, which need to be toggled by function + =set_term_observer= manually. *** Deprecated *** Removed *** Fixed diff --git a/tetragono/tetragono/sampling_neural_state/observer.py b/tetragono/tetragono/sampling_neural_state/observer.py index 38933ce0..48315f5c 100644 --- a/tetragono/tetragono/sampling_neural_state/observer.py +++ b/tetragono/tetragono/sampling_neural_state/observer.py @@ -28,7 +28,7 @@ class Observer(): """ __slots__ = [ - "owner", "_observer", "_enable_gradient", "_enable_natural", "_start", "_result_reweight", + "owner", "_observer", "_term_observer", "_enable_gradient", "_enable_natural", "_start", "_result_reweight", "_result_reweight_square", "_result_square_reweight_square", "_count", "_total_weight", "_total_weight_square", "_total_log_ws", "_whole_result_reweight", "_whole_result_reweight_square", "_whole_result_square_reweight_square", "_total_imaginary_energy_reweight", "_Delta", "_EDelta", "_Deltas" @@ -42,17 +42,17 @@ def __enter__(self): self._result_reweight = { name: { positions: 0.0 for positions, observer in observers.items() - } for name, observers in self._observer.items() + } for name, observers in self._observer.items() if name in self._term_observer } self._result_reweight_square = { name: { positions: 0.0 for positions, observer in observers.items() - } for name, observers in self._observer.items() + } for name, observers in self._observer.items() if name in self._term_observer } self._result_square_reweight_square = { name: { positions: 0.0 for positions, observer in observers.items() - } for name, observers in self._observer.items() + } for name, observers in self._observer.items() if name in self._term_observer } self._count = 0 self._total_weight = 0.0 @@ -77,10 +77,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False buffer = [] for name, observers in self._observer.items(): - for positions in observers: - buffer.append(self._result_reweight[name][positions]) - buffer.append(self._result_reweight_square[name][positions]) - buffer.append(self._result_square_reweight_square[name][positions]) + if name in self._term_observer: + for positions in observers: + buffer.append(self._result_reweight[name][positions]) + buffer.append(self._result_reweight_square[name][positions]) + buffer.append(self._result_square_reweight_square[name][positions]) buffer.append(self._count) buffer.append(self._total_weight) buffer.append(self._total_weight_square) @@ -105,10 +106,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._total_weight = buffer.pop() self._count = buffer.pop() for name, observers in reversed(self._observer.items()): - for positions in reversed(observers): - self._result_square_reweight_square[name][positions] = buffer.pop() - self._result_reweight_square[name][positions] = buffer.pop() - self._result_reweight[name][positions] = buffer.pop() + if name in self._term_observer: + for positions in reversed(observers): + self._result_square_reweight_square[name][positions] = buffer.pop() + self._result_reweight_square[name][positions] = buffer.pop() + self._result_reweight[name][positions] = buffer.pop() if self._enable_gradient: allreduce_buffer(self._Delta) @@ -143,6 +145,7 @@ def __init__( # The observables need to measure. # dict[str, dict[tuple[tuple[int, int, int], ...], Tensor]] self._observer = {} + self._term_observer = set() self._enable_gradient = False self._enable_natural = False @@ -178,7 +181,24 @@ def __init__( if enable_natural_gradient: self.enable_natural_gradient() - def add_observer(self, name, observers): + def set_term_observer(self, name, value): + """ + Set to observe specific observer by term or not. + + Parameters + ---------- + name : str + The observer set name. + value : bool + Whether to observe it by term. + """ + assert name in self._observer + if value and name not in self._term_observer: + self._term_observer.add(name) + if not value and name in self._term_observer: + self._term_observer.remove(name) + + def add_observer(self, name, observers, term_observer=False): """ Add an observer set into this observer object, cannot add observer once observer started. @@ -188,6 +208,8 @@ def add_observer(self, name, observers): This observer set name. observers : dict[tuple[tuple[int, int, int], ...], Tensor] The observer map. + term_observer : bool, default=False + Whether to observe the observers by term. """ if self._start: raise RuntimeError("Cannot enable hole after sampling start") @@ -219,6 +241,8 @@ def add_observer(self, name, observers): ) self._observer[name] = result + if term_observer: + self._term_observer.add(name) def _fermi_sign(self, x, y, indices): if self.owner.op_pool is None: @@ -332,7 +356,7 @@ def __call__(self, configurations, amplitudes, weights, multiplicities): result = [{ name: { positions: 0.0 for positions, observer in observers.items() - } for name, observers in self._observer.items() + } for name, observers in self._observer.items() if name in self._term_observer } for _ in range(batch_size)] whole_result = [{name: 0.0 for name in self._observer} for _ in range(batch_size)] @@ -381,11 +405,12 @@ def __call__(self, configurations, amplitudes, weights, multiplicities): self._total_log_ws += multiplicity * amplitude.abs().log().item() for name, observers in self._observer.items(): - for positions in observers: - to_save = result[batch_index][name][positions].real - self._result_reweight[name][positions] += multiplicity * to_save * reweight - self._result_reweight_square[name][positions] += multiplicity * to_save * reweight**2 - self._result_square_reweight_square[name][positions] += multiplicity * to_save**2 * reweight**2 + if name in self._term_observer: + for positions in observers: + to_save = result[batch_index][name][positions].real + self._result_reweight[name][positions] += multiplicity * to_save * reweight + self._result_reweight_square[name][positions] += multiplicity * to_save * reweight**2 + self._result_square_reweight_square[name][positions] += multiplicity * to_save**2 * reweight**2 to_save = whole_result[batch_index][name].real self._whole_result_reweight[name] += multiplicity * to_save * reweight self._whole_result_reweight_square[name] += multiplicity * to_save * reweight**2 @@ -492,7 +517,7 @@ def result(self): self._result_reweight_square[name][positions], self._result_square_reweight_square[name][positions]) for positions in data - } for name, data in self._observer.items() + } for name, data in self._observer.items() if name in self._term_observer } @property