diff --git a/src/sensai/torch/torch_opt.py b/src/sensai/torch/torch_opt.py index 0467c63a..b092aaae 100644 --- a/src/sensai/torch/torch_opt.py +++ b/src/sensai/torch/torch_opt.py @@ -538,7 +538,15 @@ def get_validation_metric_name(self): class NNOptimiserParams(ToStringMixin): REMOVED_PARAMS = {"cuda"} RENAMED_PARAMS = { - "optimiserClip": "shrinkageClip" + "optimiserClip": "optimiser_clip", + "lossEvaluator": "loss_evaluator", + "optimiserLR": "optimiser_lr", + "earlyStoppingEpochs": "early_stopping_epochs", + "batchSize": "batch_size", + "trainFraction": "train_fraction", + "scaledOutputs": "scaled_outputs", + "useShrinkage": "use_shrinkage", + "shrinkageClip": "shrinkage_clip", } def __init__(self,