diff --git a/model_config/dspn_result_short.yaml b/model_config/dspn_result_short.yaml index c2770c8..a682652 100644 --- a/model_config/dspn_result_short.yaml +++ b/model_config/dspn_result_short.yaml @@ -4,7 +4,7 @@ latent_dim: 2 activation_function: "ELU" depth: 5 weight_decay: 0 -max_epochs: 10 +max_epochs: 2 rel_name: 'result' batch_size: 1024 divide_output_layer: True @@ -12,4 +12,4 @@ latent_l1_norm: 10.0 max_set_size: 200 dspn_iter: 4 dspn_lr: 800 -cap_dataset: 1000 +cap_dataset: 500 diff --git a/models/dspn_autoencoder.py b/models/dspn_autoencoder.py index 4e96911..d7f4fab 100644 --- a/models/dspn_autoencoder.py +++ b/models/dspn_autoencoder.py @@ -11,6 +11,7 @@ import dspn.dspn from models import diagnostics from common import utils, config +from models import models_storage class InvariantModel(torch.nn.Module): #FIXME: delete? def __init__(self, phi, rho): @@ -29,8 +30,13 @@ class DSPNAE(generic_model.GenericModel): DSPNAE is an acronym for Deep Set Prediction Network AutoEncoder """ + @classmethod + def storage(cls): + return models_storage.DSPNAEModelsStorage() + with_set_index = True + class CollateFn(object): """ CollateFn is being used to use the start-item-index diff --git a/models/generic_model.py b/models/generic_model.py index cd27fc6..bd12193 100644 --- a/models/generic_model.py +++ b/models/generic_model.py @@ -150,24 +150,22 @@ class GenericModel(pl.LightningModule): with_set_index = None # please set in subclass + @property + @classmethod + def storage(cls): + raise Exception("implement in subclass") + @property def classname(self): return self.__class__.__name__ - @property - def name(self): - return f"{self.classname}_{self.rel.name}" + #@classmethod FIXME DELME + #def name(cls,rel): + # return f"{cls.__name__}_{rel.name}" @property - def kwargs_filename(self): - os.path.join( - config.trained_models_dirpath, - f"{self.name}_kwargs.pickle" - ) - - def dump_kwargs(self): - with open(self.kwargs_filename, 'wb') as handle: - pickle.dump(self.kwargs, handle, protocol=pickle.HIGHEST_PROTOCOL) + def name(self): + return f"{self.__class__.__name__}_{self.rel.name}" def make_train_loader(self,tsets): raise Exception("implement in subclass") diff --git a/models/models_storage.py b/models/models_storage.py index 7958d47..c93a153 100644 --- a/models/models_storage.py +++ b/models/models_storage.py @@ -2,14 +2,18 @@ import os import sys import pickle - +import pytorch_lightning as pl +import glob +import re project_root_dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__))+"/..") sys.path.insert(0, project_root_dir) -from common import utils, relspecs +from common import utils, relspecs, config from models import dspn_autoencoder -class DSPNAEModelsStorage(utils.Collection): # FIXME: abstraction? + + +class DSPNAEModelsStorage(utils.Collection): # FIXME: classname-parameterized? """ We need to store the DSPNAE models somewhere and to recall them easily. This class offers a straightforward interface to load @@ -21,25 +25,129 @@ def __init__(self): model = self.load(rel) self[rel.name] = model - def load(self, rel): - model_filename = os.path.join( + def create_write_callback(self, model): + """ + creates a model checkpoint dumping callback for the training + of the given model. + Also allows for model versioning, which is being taken care + by the pytorch_lightning library. + :param model: model to be saved + :return: the callback object to be used as callback=[callbacks] + parameter in the pytorch_lightning Trainer + """ + model_filename = model.name + checkpoint_callback = pl.callbacks.ModelCheckpoint( + monitor="train_loss", + dirpath=config.trained_models_dirpath, + filename=model_filename, + save_top_k=1, + save_last=False, + mode="min", + ) + return checkpoint_callback + + def generate_kwargs_filename(self, model): + + # using the versioning of the last model filename + # as it is saved before the kwargs dump + version = self.last_version(model.rel) + + ret = os.path.join( + config.trained_models_dirpath, + f"{model.name}-v{version}.kwargs.pickle" + ) + return ret + + def dump_kwargs(self,model): + kwargs_filename = self.generate_kwargs_filename(model) + with open(kwargs_filename, 'wb') as handle: + pickle.dump(model.kwargs, handle, protocol=pickle.HIGHEST_PROTOCOL) + + def filenames(self,rel,extension): + filenames_glob = os.path.join( "trained_models", - f"DSPNAE_{rel.name}.ckpt" + f"DSPNAE_{rel.name}*.{extension}" ) + print("filenames_glob",filenames_glob) + ret = {} + filenames = glob.glob(filenames_glob) + for curr in filenames: + m = re.match(f'.*-v(\d+).{extension}', curr) + if m: + # has version number + print(curr, m.groups()) + version = int(m.groups()[0]) + else: + # no version number in filename: this was the first + print(f'filename {curr} not matching versioned pattern') + version = 0 + ret[version] = curr + print("ret",ret) + return ret + + def kwargs_filenames(self,rel): + return self.filenames(rel,'kwargs.pickle') + + def models_filenames(self,rel): + return self.filenames(rel,'ckpt') + + def last_version(self, rel): + filenames = self.models_filenames(rel) + versions = filenames.keys() + if len(versions) == 0: + # there was no model stored + return None + last_version = max(versions) + return last_version + + def most_recent_kwargs_filename(self, rel): + last_version = self.last_version(rel) + if last_version is None: + # there was no model stored + return None + filenames = self.kwargs_filenames(rel) + if last_version not in filenames.keys(): + raise Exception(f"cannot find kwargs file for version {last_version}") + return filenames[last_version] + + def most_recent_model_filename(self, rel): + filenames = self.models_filenames(rel) + last_version = self.last_version(rel) + if last_version is None: + # there was no model stored + return None + return filenames[last_version] + + def load(self, rel): + """ + :param rel: the relation this model has been trained on + :return: + """ + filename = self.most_recent_model_filename(rel) + + if filename is None or not os.path.exists(filename): + logging.warning(f"could not find model saved in {filename}") + return # FIXME: duplicated from GenericModel - kwargs_filename = f"DSPNAE_{rel.name}_kwargs.pickle" + kwargs_filename = self.most_recent_kwargs_filename(rel) with open(kwargs_filename) as f: kwargs = pickle.load(f) - logging.info(f"loading {model_filename}..") - if not os.path.exists(model_filename): - logging.warning(f"could not find model saved in {model_filename}") - return + logging.info(f"loading {filename}..") model = dspn_autoencoder.DSPNAE(**kwargs) - model.load_from_checkpoint(model_filename) + model.load_from_checkpoint(filename) return model + def test(self, rel): + """ + testing that a loaded model is functioning correctly + :param rel: the relation + :return: + """ + model = self.load(rel) + # FIXME: TODO + def test(): ms = DSPNAEModelsStorage() diff --git a/models/run.py b/models/run.py index 318a786..c3b4893 100644 --- a/models/run.py +++ b/models/run.py @@ -8,7 +8,7 @@ import argparse import pickle from common import utils, relspecs, persistency, config -from models import diagnostics, measurements as ms +from models import diagnostics, measurements as ms, models_storage def get_args(): args = {} @@ -153,20 +153,13 @@ def run(Model,config_name, dynamic_config={}): item_dim=tsets.item_dim, **model_config ).to(device) - model.dump_kwargs() + storage = model.storage() train_loader = model.make_train_loader(tsets) test_loader = model.make_test_loader(tsets) - model_filename = model.name - checkpoint_callback = pl.callbacks.ModelCheckpoint( - monitor="train_loss", - dirpath=config.trained_models_dirpath, - filename=model_filename, - save_top_k=1, - mode="min", - ) + model_write_callback = storage.create_write_callback(model) callbacks = [ MeasurementsCallback(rel=rel,model=model), - checkpoint_callback + model_write_callback ] trainer = pl.Trainer( limit_train_batches=1.0, @@ -174,5 +167,6 @@ def run(Model,config_name, dynamic_config={}): max_epochs=model_config['max_epochs'] ) trainer.fit(model, train_loader, test_loader) + storage.dump_kwargs(model) print("current mlflow run:",mlflow.active_run().info.run_id, " - all done.") #log_net_visualization(model,torch.zeros(model_config['batch_size'], tsets.item_dim))