Skip to content

Commit

Permalink
fix: handling of missing stored models
Browse files Browse the repository at this point in the history
  • Loading branch information
Francesco Stablum committed Nov 29, 2021
1 parent 4c51197 commit b44db9d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 21 deletions.
50 changes: 29 additions & 21 deletions models/models_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class DSPNAEModelsStorage(utils.Collection): # FIXME: classname-parameterized?
"""
def __init__(self):
os.chdir(project_root_dir)

def load_all_models(self):
os.chdir(project_root_dir)
for rel in relspecs.rels:
model = self.load(rel)
self[rel.name] = model
Expand Down Expand Up @@ -48,7 +51,7 @@ def create_write_callback(self, model):

def generate_kwargs_filename(self, model):

# using the versioning of the last model filename
# using tha version from the last model filename
# as it is saved before the kwargs dump
version = self.last_version(model.rel)

Expand All @@ -60,8 +63,8 @@ def generate_kwargs_filename(self, model):

def dump_kwargs(self,model):
kwargs_filename = self.generate_kwargs_filename(model)
with open(kwargs_filename, 'wb') as handle:
pickle.dump(model.kwargs, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(kwargs_filename, 'wb') as f:
pickle.dump(model.kwargs, f)

def filenames(self,rel,extension):
filenames_glob = os.path.join(
Expand Down Expand Up @@ -118,39 +121,44 @@ def most_recent_model_filename(self, rel):
return None
return filenames[last_version]

def rel_has_stored_model(self,rel):
kwargs_filename = self.most_recent_kwargs_filename(rel)
model_filename = self.most_recent_model_filename(rel)
if None in (kwargs_filename, model_filename):
return False
if os.path.exists(kwargs_filename) and os.path.exists(model_filename):
return True
return False

def load(self, rel):
"""
:param rel: the relation this model has been trained on
:return:
"""
filename = self.most_recent_model_filename(rel)

if filename is None or not os.path.exists(filename):
logging.warning(f"could not find model saved in {filename}")
if not self.rel_has_stored_model(rel):
logging.warning(f"no model for rel {rel.name}")
return

# FIXME: duplicated from GenericModel
kwargs_filename = self.most_recent_kwargs_filename(rel)
with open(kwargs_filename) as f:
with open(kwargs_filename, 'rb') as f:
kwargs = pickle.load(f)
logging.info(f"loading {filename}..")

model = dspn_autoencoder.DSPNAE(**kwargs)
model.load_from_checkpoint(filename)
return model
model_filename = self.most_recent_model_filename(rel)
if model_filename is None or not os.path.exists(model_filename):
logging.warning(f"could not find model saved in {model_filename}")
return

def test(self, rel):
"""
testing that a loaded model is functioning correctly
:param rel: the relation
:return:
"""
model = self.load(rel)
# FIXME: TODO
# FIXME: duplicated from GenericModel
logging.info(f"loading {model_filename}..")

# FIXME: kwargs provided twice?
model = dspn_autoencoder.DSPNAE(**kwargs)
model.load_from_checkpoint(model_filename, **kwargs)
return model

def test():
ms = DSPNAEModelsStorage()
ms.load_all_models()
print("ms",ms)
print("done.")

Expand Down
2 changes: 2 additions & 0 deletions models/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def run(Model,config_name, dynamic_config={}):
max_epochs=model_config['max_epochs']
)
trainer.fit(model, train_loader, test_loader)

storage.dump_kwargs(model)

print("current mlflow run:",mlflow.active_run().info.run_id, " - all done.")
#log_net_visualization(model,torch.zeros(model_config['batch_size'], tsets.item_dim))

0 comments on commit b44db9d

Please sign in to comment.