diff --git a/amlrt_project/train.py b/amlrt_project/train.py index dca758e..8579478 100644 --- a/amlrt_project/train.py +++ b/amlrt_project/train.py @@ -73,6 +73,8 @@ def main(): output_dir = os.path.join(args.tmp_folder, 'output') if not os.path.exists(output_dir): os.makedirs(output_dir) + if os.path.exists(args.output): + rsync_folder(args.output, args.tmp_folder) else: data_dir = args.data output_dir = args.output @@ -165,7 +167,7 @@ def train_impl(model, datamodule, output, hyper_params, use_progress_bar, save_top_k=1, verbose=use_progress_bar, monitor="val_loss", - mode="max", + mode="min", every_n_epochs=1, )