From 09c0c21fc8ffd803cc55ac6e79bfa10646f4d735 Mon Sep 17 00:00:00 2001 From: "clara.bayley" Date: Wed, 17 Apr 2024 23:12:39 +0200 Subject: [PATCH] feat(pySD): edit config file compatible with yaml configs instead of txt --- pySD/editconfigfile.py | 52 +++++++++++++++++++++--------------------- requirements.txt | 1 + 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/pySD/editconfigfile.py b/pySD/editconfigfile.py index 83326e47a..09fd107c9 100644 --- a/pySD/editconfigfile.py +++ b/pySD/editconfigfile.py @@ -18,35 +18,35 @@ File Description: ''' -# TODO(CB): rewrite funciton for .yaml not .txt config file +from ruamel.yaml import YAML + +def update_param(node, param, new_value): + '''Function to recursively searches for 'param' key in YAML node + and updates it's value to with 'new_value' when found ''' + if isinstance(node, dict): + if param in node: + node[param] = new_value # update value + else: + for key, val in node.items(): + update_param(val, param, new_value) + elif isinstance(node, list): + for item in node: + update_param(item, param, new_value) def edit_config_params(filename, params2change): - """rewrites config file with parameters listed in - dict params2change edited to new values also in dict""" + ''' rewrites config YAML file with key,value pairs listed in params2change updated to new values + whilst preserving original YAML file's formatting and comments etc. ''' - wlines=[] + yaml = YAML() - with open(filename) as file: - filelines = file.readlines() - for line in filelines: - wlines.append(line) + # Load the YAML file + with open(filename, 'r') as file: + data = yaml.load(file) - for l, line in enumerate(wlines): - if line[0] != "#" and line[0] != "/" and "=" in line: - for key, value in params2change.items(): - if key in line: + # Update the parameters from the YAML file + for param, new_value in params2change.items(): + update_param(data, param, new_value) - # create line with new value for key - newline = key+" = "+str(value) - newline = newline.ljust(40) - - # add comment to new line if there is one - ind = line.find("#") - newline = newline+line[ind:] - - # overwrite line with newline - wlines[l] = newline - - file = open(filename, "w") - file.writelines(wlines) - file.close() + # Overwrite the YAML file + with open(filename, 'w') as file: + yaml.dump(data, file) diff --git a/requirements.txt b/requirements.txt index 5f50c9b35..d9ec4e1f1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ awkward zarr pre-commit mpi4py +ruamel.yaml