diff --git a/common/tsets.py b/common/tsets.py index 409612c..04aa9cb 100644 --- a/common/tsets.py +++ b/common/tsets.py @@ -37,12 +37,12 @@ def __init__( ): self.rel = rel kwargs.update(dict.fromkeys(self.tsets_names, None)) + self.creation_time = kwargs['creation_time'] super().__init__(**kwargs) self.load_data() for which_tset in self.tsets_names: - if self.with_set_index is False: # removes the set_index column from the glued-up tensor self[which_tset] = self[which_tset][:, 1:] diff --git a/models/run.py b/models/run.py index e8df0a4..75187b2 100644 --- a/models/run.py +++ b/models/run.py @@ -123,7 +123,8 @@ def run(Model,config_name, dynamic_config={}): model_config = utils.load_model_config(config_name, dynamic_config=dynamic_config) mlflow.set_experiment(model_config['experiment_name']) mlflow.pytorch.autolog() - with mlflow.start_run(run_name=model_config['config_name']): + run_name = f"{model_config['config_name']}_{model_config['rel_name']}" + with mlflow.start_run(run_name=run_name): mlflow.log_params(model_config) print("__file__",__file__) mlflow.log_artifact(__file__) @@ -135,7 +136,7 @@ def run(Model,config_name, dynamic_config={}): with_set_index=Model.with_set_index, cap=model_config['cap_dataset'] ) - + mlflow.log_param('tsets_creation_time',tsets.creation_time) for curr in tsets.tsets_names: mlflow.log_param(f"{curr}_datapoints",tsets[curr].shape[0])