diff --git a/setup.py b/setup.py index 63c897d..0e55918 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ def readme(): return f.read() -VERSION = "1.3.21" +VERSION = "1.3.22" def write_version_py(filename="sigProfilerPlotting/version.py"): @@ -23,7 +23,7 @@ def write_version_py(filename="sigProfilerPlotting/version.py"): # THIS FILE IS GENERATED FROM SIGPROFILERPLOTTING SETUP.PY short_version = '%(version)s' version = '%(version)s' -update = 'Upgrade v1.3.21: Add CLI and container updates' +update = 'Upgrade v1.3.22: Update input_processing index.name handling to be more robust.' """ fh = open(filename, "w") diff --git a/sigProfilerPlotting/sigProfilerPlotting.py b/sigProfilerPlotting/sigProfilerPlotting.py index 4bc9f6d..87865ce 100644 --- a/sigProfilerPlotting/sigProfilerPlotting.py +++ b/sigProfilerPlotting/sigProfilerPlotting.py @@ -39,6 +39,7 @@ matplotlib.use("Agg") MUTTYPE = "MutationType" +INDEX_VALS = ["MutationType", "index", "Mutation Types", "classification"] SPP_PATH = spplt.__path__[0] SPP_TEMPLATES = os.path.join(SPP_PATH, "templates/") SPP_FONTS = os.path.join(SPP_PATH, "fonts/") @@ -168,17 +169,26 @@ def process_input(matrix_path, plot_type): if isinstance(matrix_path, pd.DataFrame): # copy dataframe with deepcopy data = matrix_path.copy() - # Index is not MutationType - if MUTTYPE != data.index.name: + # index is a non-standard value + if data.index.name not in INDEX_VALS: if MUTTYPE in data.columns: data = data.set_index(MUTTYPE, drop=True) + # the first column is non-MUTTYPE and non-integer + elif not data.iloc[:, 0].apply(lambda x: isinstance(x, int)).all(): + data.rename(columns={data.columns[0]: MUTTYPE}, inplace=True) + data = data.set_index(data.columns[0], drop=True) else: + data = data.reset_index() data.rename(columns={data.columns[0]: MUTTYPE}, inplace=True) data = data.set_index(MUTTYPE, drop=True) - # input data is a path to a file + else: + # Note: set the index to MUTTYPE for consistency with the rest of the code + data.index.name = MUTTYPE + # input data is a file path elif isinstance(matrix_path, str): data = pd.read_csv(matrix_path, sep="\t", index_col=0) data = data.dropna(axis=1, how="all") + data.index.name = MUTTYPE # input data is a numpy array elif isinstance(matrix_path, np.ndarray): # Note: ndarray does not have index or column names and is not recommended diff --git a/sigProfilerPlotting/version.py b/sigProfilerPlotting/version.py index edf220a..5ac356e 100644 --- a/sigProfilerPlotting/version.py +++ b/sigProfilerPlotting/version.py @@ -1,7 +1,7 @@ # THIS FILE IS GENERATED FROM SIGPROFILERPLOTTING SETUP.PY -short_version = '1.3.20' -version = '1.3.20' -update = 'Upgrade v1.3.20: Add np.ndarray to process_input' +short_version = '1.3.22' +version = '1.3.22' +update = 'Upgrade v1.3.22: Update input_processing index.name handling to be more robust.' \ No newline at end of file diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 2cbca5c..b7cefec 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -22,7 +22,7 @@ def image_difference(img1_path, img2_path): diff = ImageChops.difference(img1, img2) total_difference = sum(abs(p) for p in diff.getdata()) max_difference = img1.size[0] * img1.size[1] * 255 - if total_difference > 1e-4: + if (total_difference / max_difference) > 1e-4: diff.show() return total_difference / max_difference