Skip to content

Commit

Permalink
Add implementation for get and set states
Browse files Browse the repository at this point in the history
  • Loading branch information
NLGithubWP committed Sep 4, 2023
1 parent bb878cd commit 7f0152a
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions examples/model_selection_psql/ms_mlp/train_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,18 @@ def step(self):
self.dam_value.copy_from(dam_value)
self.decay_value.copy_from(decay_value)

def get_states(self):
states = super().get_states()
if self.mom_value > 0:
states[
'moments'] = self.moments # a dict for 1st order moments tensors
return states

def set_states(self, states):
super().set_states(states)
if 'moments' in states:
self.moments = states['moments']
self.mom_value = self.momentum(self.step_counter)

# Data augmentation
def augmentation(x, batch_size):
Expand Down

0 comments on commit 7f0152a

Please sign in to comment.