From f36821fc2e4fe3bde2d67f1086217a4941c0dfa8 Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Fri, 25 Oct 2024 02:02:21 +0100 Subject: [PATCH 01/18] Test config for derived predicates based on static predicates --- sample_configs/inhospital_mortality.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sample_configs/inhospital_mortality.yaml b/sample_configs/inhospital_mortality.yaml index 68758f7..5a32fc3 100644 --- a/sample_configs/inhospital_mortality.yaml +++ b/sample_configs/inhospital_mortality.yaml @@ -8,6 +8,8 @@ predicates: code: DEATH discharge_or_death: expr: or(discharge, death) + male_admission: + expr: and(male, admission) patient_demographics: male: @@ -33,6 +35,7 @@ windows: admission: (None, 0) discharge: (None, 0) death: (None, 0) + male_admission: (None, 0) target: start: gap.end end: start -> discharge_or_death From 4f8dc2035912e0c3a50a1e5f29ed4968be109078 Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Fri, 25 Oct 2024 02:03:26 +0100 Subject: [PATCH 02/18] Pass in static predicates to DerivedPredicateConfig --- src/aces/config.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/aces/config.py b/src/aces/config.py index b2a3183..dd93041 100644 --- a/src/aces/config.py +++ b/src/aces/config.py @@ -287,7 +287,7 @@ class DerivedPredicateConfig: """ expr: str - static: bool = False + static: list = field(default_factory=list) def __post_init__(self): if not self.expr: @@ -979,9 +979,9 @@ class TaskExtractorConfig: >>> print(config.index_timestamp_window) # doctest: +NORMALIZE_WHITESPACE input >>> print(config.derived_predicates) # doctest: +NORMALIZE_WHITESPACE - {'death_or_discharge': DerivedPredicateConfig(expr='or(death, discharge)', static=False), - 'diabetes': DerivedPredicateConfig(expr='or(diabetes_icd9, diabetes_icd10)', static=False), - 'diabetes_and_discharge': DerivedPredicateConfig(expr='and(diabetes, discharge)', static=False)} + {'death_or_discharge': DerivedPredicateConfig(expr='or(death, discharge)', static=[]), + 'diabetes': DerivedPredicateConfig(expr='or(diabetes_icd9, diabetes_icd10)', static=[]), + 'diabetes_and_discharge': DerivedPredicateConfig(expr='and(diabetes, discharge)', static=[])} >>> print(nx.write_network_text(config.predicates_DAG)) ╟── death ╎ └─╼ death_or_discharge ╾ discharge @@ -1293,7 +1293,7 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta predicate_objs = {} for n, p in predicates_to_parse.items(): if "expr" in p: - predicate_objs[n] = DerivedPredicateConfig(**p) + predicate_objs[n] = DerivedPredicateConfig(**p, static=list(final_demographics.keys())) else: config_data = {k: v for k, v in p.items() if k in PlainPredicateConfig.__dataclass_fields__} other_cols = {k: v for k, v in p.items() if k not in config_data} From 6495d3c301e71de9b1f88207ed751ae7dcad7ebd Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Fri, 25 Oct 2024 02:05:15 +0100 Subject: [PATCH 03/18] Logic for derived predicate between plain and static by propagating down static variable values --- src/aces/predicates.py | 68 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/src/aces/predicates.py b/src/aces/predicates.py index 2f10a76..0ec08d7 100644 --- a/src/aces/predicates.py +++ b/src/aces/predicates.py @@ -609,6 +609,64 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 1 ┆ 1 ┆ 0 │ │ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 0 ┆ 1 ┆ 0 ┆ 1 │ └────────────┴─────────────────────┴─────┴──────┴────────────┴───────────────┴─────────────┘ + + >>> data = pl.DataFrame({ + ... "subject_id": [1, 1, 1, 2, 2], + ... "timestamp": [ + ... None, + ... "01/01/2021 00:00", + ... "01/01/2021 12:00", + ... "01/02/2021 00:00", + ... "01/02/2021 12:00"], + ... "adm": [0, 1, 0, 1, 0], + ... "male": [1, 0, 0, 0, 0], + ... }) + >>> predicates = { + ... "adm": PlainPredicateConfig("adm"), + ... "male": PlainPredicateConfig("male", static=True), # predicate match based on name for direct + ... "male_adm": DerivedPredicateConfig("and(male, adm)", static=['male']), + ... } + >>> trigger = EventConfig("adm") + >>> windows = { + ... "input": WindowConfig( + ... start=None, + ... end="trigger + 24h", + ... start_inclusive=True, + ... end_inclusive=True, + ... has={"_ANY_EVENT": "(32, None)"}, + ... ), + ... "gap": WindowConfig( + ... start="input.end", + ... end="start + 24h", + ... start_inclusive=False, + ... end_inclusive=True, + ... has={ + ... "adm": "(None, 0)", + ... "male_adm": "(None, 0)", + ... }, + ... ), + ... } + >>> config = TaskExtractorConfig(predicates=predicates, trigger=trigger, windows=windows) + >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".csv") as f: + ... data_path = Path(f.name) + ... data.write_csv(data_path) + ... data_config = DictConfig({ + ... "path": str(data_path), "standard": "direct", "ts_format": "%m/%d/%Y %H:%M" + ... }) + ... get_predicates_df(config, data_config) + shape: (5, 6) + ┌────────────┬─────────────────────┬─────┬──────┬──────────┬────────────┐ + │ subject_id ┆ timestamp ┆ adm ┆ male ┆ male_adm ┆ _ANY_EVENT │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ + ╞════════════╪═════════════════════╪═════╪══════╪══════════╪════════════╡ + │ 1 ┆ null ┆ 0 ┆ 1 ┆ 0 ┆ null │ + │ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 1 ┆ 1 ┆ 1 │ + │ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 1 │ + │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 1 │ + │ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │ + └────────────┴─────────────────────┴─────┴──────┴──────────┴────────────┘ + >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".csv") as f: ... data_path = Path(f.name) ... data.write_csv(data_path) @@ -641,6 +699,16 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D # derived predicates logger.info("Loaded plain predicates. Generating derived predicate columns...") for name, code in cfg.derived_predicates.items(): + if any(x in code.static for x in code.input_predicates): + data = data.with_columns( + [ + pl.col(static_var) + .first() + .over("subject_id") # take the first value in each subject_id group and propagate it + .alias(static_var) + for static_var in code.static + ] + ) data = data.with_columns(code.eval_expr().cast(PRED_CNT_TYPE).alias(name)) logger.info(f"Added predicate column '{name}'.") predicate_cols.append(name) From 6b2ec02a5ea217eb5e6a70eb4595bb875e4630dd Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Fri, 25 Oct 2024 02:05:39 +0100 Subject: [PATCH 04/18] Revert back to original config --- sample_configs/inhospital_mortality.yaml | 3 --- 1 file changed, 3 deletions(-) diff --git a/sample_configs/inhospital_mortality.yaml b/sample_configs/inhospital_mortality.yaml index 5a32fc3..68758f7 100644 --- a/sample_configs/inhospital_mortality.yaml +++ b/sample_configs/inhospital_mortality.yaml @@ -8,8 +8,6 @@ predicates: code: DEATH discharge_or_death: expr: or(discharge, death) - male_admission: - expr: and(male, admission) patient_demographics: male: @@ -35,7 +33,6 @@ windows: admission: (None, 0) discharge: (None, 0) death: (None, 0) - male_admission: (None, 0) target: start: gap.end end: start -> discharge_or_death From 12e96945a13c24c9fa4611e9a7ef5f0c1f715ee5 Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Fri, 25 Oct 2024 02:45:31 +0100 Subject: [PATCH 05/18] Revert DerivedPredicateConfig static attribute due to parquet issues --- src/aces/config.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/aces/config.py b/src/aces/config.py index dd93041..b2a3183 100644 --- a/src/aces/config.py +++ b/src/aces/config.py @@ -287,7 +287,7 @@ class DerivedPredicateConfig: """ expr: str - static: list = field(default_factory=list) + static: bool = False def __post_init__(self): if not self.expr: @@ -979,9 +979,9 @@ class TaskExtractorConfig: >>> print(config.index_timestamp_window) # doctest: +NORMALIZE_WHITESPACE input >>> print(config.derived_predicates) # doctest: +NORMALIZE_WHITESPACE - {'death_or_discharge': DerivedPredicateConfig(expr='or(death, discharge)', static=[]), - 'diabetes': DerivedPredicateConfig(expr='or(diabetes_icd9, diabetes_icd10)', static=[]), - 'diabetes_and_discharge': DerivedPredicateConfig(expr='and(diabetes, discharge)', static=[])} + {'death_or_discharge': DerivedPredicateConfig(expr='or(death, discharge)', static=False), + 'diabetes': DerivedPredicateConfig(expr='or(diabetes_icd9, diabetes_icd10)', static=False), + 'diabetes_and_discharge': DerivedPredicateConfig(expr='and(diabetes, discharge)', static=False)} >>> print(nx.write_network_text(config.predicates_DAG)) ╟── death ╎ └─╼ death_or_discharge ╾ discharge @@ -1293,7 +1293,7 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta predicate_objs = {} for n, p in predicates_to_parse.items(): if "expr" in p: - predicate_objs[n] = DerivedPredicateConfig(**p, static=list(final_demographics.keys())) + predicate_objs[n] = DerivedPredicateConfig(**p) else: config_data = {k: v for k, v in p.items() if k in PlainPredicateConfig.__dataclass_fields__} other_cols = {k: v for k, v in p.items() if k not in config_data} From 2c8304cc3d858cfba06686f60d88b1b205d2ab27 Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Fri, 25 Oct 2024 02:45:58 +0100 Subject: [PATCH 06/18] Explicitly get list of static predicates from config plain predicates --- src/aces/predicates.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/aces/predicates.py b/src/aces/predicates.py index 0ec08d7..1eee9ee 100644 --- a/src/aces/predicates.py +++ b/src/aces/predicates.py @@ -698,15 +698,16 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D # derived predicates logger.info("Loaded plain predicates. Generating derived predicate columns...") + static_variables = [pred for pred in cfg.plain_predicates if cfg.plain_predicates[pred].static] for name, code in cfg.derived_predicates.items(): - if any(x in code.static for x in code.input_predicates): + if any(x in static_variables for x in code.input_predicates): data = data.with_columns( [ pl.col(static_var) .first() .over("subject_id") # take the first value in each subject_id group and propagate it .alias(static_var) - for static_var in code.static + for static_var in static_variables ] ) data = data.with_columns(code.eval_expr().cast(PRED_CNT_TYPE).alias(name)) From 76cbe919b2b39d9ba32509b8c5caa01fd1d2550f Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Fri, 25 Oct 2024 02:47:08 +0100 Subject: [PATCH 07/18] Freeze pre-commit version and update workflows --- .github/workflows/code-quality-main.yaml | 14 ++++++++++---- .github/workflows/code-quality-pr.yaml | 16 +++++++++++----- .github/workflows/tests.yml | 6 +++--- .pre-commit-config.yaml | 2 +- pyproject.toml | 2 +- 5 files changed, 26 insertions(+), 14 deletions(-) diff --git a/.github/workflows/code-quality-main.yaml b/.github/workflows/code-quality-main.yaml index d336969..691b47c 100644 --- a/.github/workflows/code-quality-main.yaml +++ b/.github/workflows/code-quality-main.yaml @@ -13,10 +13,16 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v3 + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install packages + run: | + pip install .[dev] - name: Run pre-commits - uses: pre-commit/action@v3.0.0 + uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/code-quality-pr.yaml b/.github/workflows/code-quality-pr.yaml index 7ca7753..dfc64e1 100644 --- a/.github/workflows/code-quality-pr.yaml +++ b/.github/workflows/code-quality-pr.yaml @@ -16,10 +16,16 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v3 + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install packages + run: | + pip install .[dev] - name: Find modified files id: file_changes @@ -31,6 +37,6 @@ jobs: run: echo '${{ steps.file_changes.outputs.files}}' - name: Run pre-commits - uses: pre-commit/action@v3.0.0 + uses: pre-commit/action@v3.0.1 with: - extra_args: --files ${{ steps.file_changes.outputs.files}} + extra_args: --all-files diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c22ef16..4e51b11 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,16 +17,16 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python 3.10 - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Install packages run: | - pip install -e .[dev] + pip install .[dev] #---------------------------------------------- # run test suite diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 591bc53..8210517 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ exclude: "to_organize" repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v5.0.0 hooks: # list of supported hooks: https://pre-commit.com/hooks.html - id: trailing-whitespace diff --git a/pyproject.toml b/pyproject.toml index 0a3399d..4b2b26a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ expand_shards = "aces.expand_shards:main" [project.optional-dependencies] dev = [ - "pre-commit", "pytest", "pytest-cov", "pytest-subtests", "rootutils", "hypothesis" + "pre-commit<4", "pytest", "pytest-cov", "pytest-subtests", "rootutils", "hypothesis" ] profiling = ["psutil"] From c56e60be7aa511c69d87502fd59332ed21ba59e8 Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Fri, 25 Oct 2024 02:56:21 +0100 Subject: [PATCH 08/18] Sort first to guarantee null timestamp rows are first per subject_id --- src/aces/predicates.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aces/predicates.py b/src/aces/predicates.py index 1eee9ee..e4a1ca0 100644 --- a/src/aces/predicates.py +++ b/src/aces/predicates.py @@ -696,6 +696,8 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D raise ValueError(f"Invalid data standard: {standard}. Options are 'direct', 'MEDS', 'ESGPT'.") predicate_cols = list(plain_predicates.keys()) + data = data.sort(by=["subject_id", "timestamp"], nulls_last=False) + # derived predicates logger.info("Loaded plain predicates. Generating derived predicate columns...") static_variables = [pred for pred in cfg.plain_predicates if cfg.plain_predicates[pred].static] @@ -714,8 +716,6 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D logger.info(f"Added predicate column '{name}'.") predicate_cols.append(name) - data = data.sort(by=["subject_id", "timestamp"], nulls_last=False) - # add special predicates: # a column of 1s representing any predicate # a column of 0s with 1 in the first event of each subject_id representing the start of record From 2c57b858ab4fa23c9e1b282c53c19babeba053c9 Mon Sep 17 00:00:00 2001 From: Justin Xu <52216145+justin13601@users.noreply.github.com> Date: Fri, 25 Oct 2024 15:25:14 +0100 Subject: [PATCH 09/18] Warnings and error messages per #141 #142 #146 (#147) * #141 note about memory in README * #141 warning about memory in the docs * #142 add warning messages if labels are all the same * Add error message when predicates are specified using only strings (includes ??? case) Closes #141, #142, and #146 --- README.md | 6 ++++-- docs/source/configuration.md | 9 +++++++++ src/aces/config.py | 23 +++++++++++++++++++++++ src/aces/query.py | 16 ++++++++++++++++ 4 files changed, 52 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index eb335ea..fffb3e6 100644 --- a/README.md +++ b/README.md @@ -218,14 +218,16 @@ Fields for a "plain" predicate: - `code` (required): Must be one of the following: - a string with `//` sequence separating the column name and column value. - - a list of strings as above in the form of {any: \[???, ???, ...\]}, which will match any of the listed codes. - - a regex in the form of {regex: "???"}, which will match any code that matches that regular expression. + - a list of strings as above in the form of `{any: \[???, ???, ...\]}`, which will match any of the listed codes. + - a regex in the form of `{regex: "???"}`, which will match any code that matches that regular expression. - `value_min` (optional): Must be float or integer specifying the minimum value of the predicate, if the variable is presented as numerical values. - `value_max` (optional): Must be float or integer specifying the maximum value of the predicate, if the variable is presented as numerical values. - `value_min_inclusive` (optional): Must be a boolean specifying whether `value_min` is inclusive or not. - `value_max_inclusive` (optional): Must be a boolean specifying whether `value_max` is inclusive or not. - `other_cols` (optional): Must be a 1-to-1 dictionary of column name and column value, which places additional constraints on further columns. +**Note**: For memory optimization, we strongly recommend using either the List of Values or Regular Expression formats whenever possible, especially when needing to match multiple values. Defining each code as an individual string will increase memory usage significantly, as each code generates a separate predicate column. Using a list or regex consolidates multiple matching codes under a single column, reducing the overall memory footprint. + #### Derived Predicates "Derived" predicates combine existing "plain" predicates using `and` / `or` keywords and have exactly 1 required `expr` field: For instance, the following defines a predicate representing either death or discharge (by combining "plain" predicates of `death` and `discharge`): diff --git a/docs/source/configuration.md b/docs/source/configuration.md index ca83871..6571ae9 100644 --- a/docs/source/configuration.md +++ b/docs/source/configuration.md @@ -49,20 +49,29 @@ These configs consist of the following four fields: The field can additionally be a dictionary with either a `regex` key and the value being a regular expression (satisfied if the regular expression evaluates to True), or a `any` key and the value being a list of strings (satisfied if there is an occurrence for any code in the list). + + **Note**: Each individual definition of `PlainPredicateConfig` and `code` will generate a separate predicate + column. Thus, for memory optimization, it is strongly recommended to match multiple values using either the + List of Values or Regular Expression formats whenever possible. + - `value_min`: If specified, an observation will only satisfy this predicate if the occurrence of the underlying `code` with a reported numerical value that is either greater than or greater than or equal to `value_min` (with these options being decided on the basis of `value_min_inclusive`, where `value_min_inclusive=True` indicating that an observation satisfies this predicate if its value is greater than or equal to `value_min`, and `value_min_inclusive=False` indicating a greater than but not equal to will be used). + - `value_max`: If specified, an observation will only satisfy this predicate if the occurrence of the underlying `code` with a reported numerical value that is either less than or less than or equal to `value_max` (with these options being decided on the basis of `value_max_inclusive`, where `value_max_inclusive=True` indicating that an observation satisfies this predicate if its value is less than or equal to `value_max`, and `value_max_inclusive=False` indicating a less than but not equal to will be used). + - `value_min_inclusive`: See `value_min` + - `value_max_inclusive`: See `value_max` + - `other_cols`: This optional field accepts a 1-to-1 dictionary of column names to column values, and can be used to specify further constraints on other columns (ie., not `code`) for this predicate. diff --git a/src/aces/config.py b/src/aces/config.py index b2a3183..325f7d0 100644 --- a/src/aces/config.py +++ b/src/aces/config.py @@ -1195,6 +1195,23 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta start_inclusive=True, end_inclusive=True, has={}, label=None, index_timestamp=None)}, label_window=None, index_timestamp_window=None) + >>> predicates_dict = { + ... "metadata": {'description': 'A test predicates file'}, + ... "description": 'this is a test', + ... "patient_demographics": {"brown_eyes": {"code": "eye_color//BR"}}, + ... "predicates": {'admission': "invalid"}, + ... } + >>> with (tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as config_fp, + ... tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as pred_fp): + ... config_path = Path(config_fp.name) + ... pred_path = Path(pred_fp.name) + ... yaml.dump(no_predicates_config, config_fp) + ... yaml.dump(predicates_dict, pred_fp) + ... cfg = TaskExtractorConfig.load(config_path, pred_path) # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: Predicate 'admission' is not defined correctly in the configuration file. Currently + defined as the string: invalid. Please refer to the documentation for the supported formats. """ if isinstance(config_path, str): config_path = Path(config_path) @@ -1295,6 +1312,12 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta if "expr" in p: predicate_objs[n] = DerivedPredicateConfig(**p) else: + if isinstance(p, str): + raise ValueError( + f"Predicate '{n}' is not defined correctly in the configuration file. " + f"Currently defined as the string: {p}. " + "Please refer to the documentation for the supported formats." + ) config_data = {k: v for k, v in p.items() if k in PlainPredicateConfig.__dataclass_fields__} other_cols = {k: v for k, v in p.items() if k not in config_data} predicate_objs[n] = PlainPredicateConfig(**config_data, other_cols=other_cols) diff --git a/src/aces/query.py b/src/aces/query.py index 701bf8f..70b4569 100644 --- a/src/aces/query.py +++ b/src/aces/query.py @@ -4,6 +4,8 @@ """ +from collections import Counter + import polars as pl from bigtree import preorder_iter from loguru import logger @@ -137,6 +139,20 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame ) to_return_cols.insert(1, "label") + if result["label"].n_unique() == 1: + logger.warning( + f"All labels in the extracted cohort are the same: '{result['label'][0]}'. " + "This may indicate an issue with the task logic. " + "Please double-check your configuration file if this is not expected." + ) + else: + unique_labels = result["label"].n_unique() + label_distribution = Counter(result["label"]) + logger.info( + f"Found {unique_labels} unique labels in the extracted cohort: " + f"{dict(label_distribution)}." + ) + # add index_timestamp column if specified if cfg.index_timestamp_window: logger.info( From 495f2c0edb347655074ca0f4e681ec15a7614274 Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Fri, 25 Oct 2024 15:37:16 +0100 Subject: [PATCH 10/18] Add percentage for printing label distribution (code rabbit suggestion) --- src/aces/query.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/aces/query.py b/src/aces/query.py index 70b4569..9903a21 100644 --- a/src/aces/query.py +++ b/src/aces/query.py @@ -148,9 +148,12 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame else: unique_labels = result["label"].n_unique() label_distribution = Counter(result["label"]) + total_count = sum(label_distribution.values()) + distribution_with_pct = { + k: f"{v} ({v/total_count*100:.1f}%)" for k, v in label_distribution.items() + } logger.info( - f"Found {unique_labels} unique labels in the extracted cohort: " - f"{dict(label_distribution)}." + f"Found {unique_labels} unique labels in the extracted cohort: " f"{distribution_with_pct}." ) # add index_timestamp column if specified From ec07a944919dd266fc040bd98da6fb4aaf7d26d3 Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Sat, 26 Oct 2024 15:20:42 +0100 Subject: [PATCH 11/18] Remove label distribution printing (can be done post-hoc) --- src/aces/query.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/aces/query.py b/src/aces/query.py index 9903a21..b5a99e8 100644 --- a/src/aces/query.py +++ b/src/aces/query.py @@ -4,8 +4,6 @@ """ -from collections import Counter - import polars as pl from bigtree import preorder_iter from loguru import logger @@ -145,16 +143,6 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame "This may indicate an issue with the task logic. " "Please double-check your configuration file if this is not expected." ) - else: - unique_labels = result["label"].n_unique() - label_distribution = Counter(result["label"]) - total_count = sum(label_distribution.values()) - distribution_with_pct = { - k: f"{v} ({v/total_count*100:.1f}%)" for k, v in label_distribution.items() - } - logger.info( - f"Found {unique_labels} unique labels in the extracted cohort: " f"{distribution_with_pct}." - ) # add index_timestamp column if specified if cfg.index_timestamp_window: From ca85310fabc334355b62320940ca58363708c0bf Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Sat, 26 Oct 2024 15:21:47 +0100 Subject: [PATCH 12/18] Code quality PR workflow not on all files --- .github/workflows/code-quality-pr.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/code-quality-pr.yaml b/.github/workflows/code-quality-pr.yaml index dfc64e1..bee2e11 100644 --- a/.github/workflows/code-quality-pr.yaml +++ b/.github/workflows/code-quality-pr.yaml @@ -39,4 +39,4 @@ jobs: - name: Run pre-commits uses: pre-commit/action@v3.0.1 with: - extra_args: --all-files + extra_args: --files ${{ steps.file_changes.outputs.files}} From 0cb8b1a3a435a28243253a784a368bef41410e60 Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Sat, 26 Oct 2024 15:28:13 +0100 Subject: [PATCH 13/18] Exclude test coverage for logging statements --- src/aces/query.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/aces/query.py b/src/aces/query.py index b5a99e8..591823a 100644 --- a/src/aces/query.py +++ b/src/aces/query.py @@ -106,7 +106,7 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame return pl.DataFrame() result = extract_subtree(cfg.window_tree, prospective_root_anchors, predicates_df) - if result.is_empty(): + if result.is_empty(): # pragma: no cover logger.info("No valid rows found.") else: # number of patients @@ -125,7 +125,7 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame # add label column if specified if cfg.label_window: - logger.info( + logger.info( # pragma: no cover f"Extracting label '{cfg.windows[cfg.label_window].label}' from window " f"'{cfg.label_window}'..." ) @@ -137,7 +137,7 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame ) to_return_cols.insert(1, "label") - if result["label"].n_unique() == 1: + if result["label"].n_unique() == 1: # pragma: no cover logger.warning( f"All labels in the extracted cohort are the same: '{result['label'][0]}'. " "This may indicate an issue with the task logic. " @@ -146,7 +146,7 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame # add index_timestamp column if specified if cfg.index_timestamp_window: - logger.info( + logger.info( # pragma: no cover f"Setting index timestamp as '{cfg.windows[cfg.index_timestamp_window].index_timestamp}' " f"of window '{cfg.index_timestamp_window}'..." ) From 71211279aad79a92749f83e011730be94c6208f7 Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Sat, 26 Oct 2024 16:45:02 +0100 Subject: [PATCH 14/18] Update docs to clarify supported code fields --- README.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index fffb3e6..a375aef 100644 --- a/README.md +++ b/README.md @@ -217,9 +217,10 @@ normal_spo2: Fields for a "plain" predicate: - `code` (required): Must be one of the following: - - a string with `//` sequence separating the column name and column value. - - a list of strings as above in the form of `{any: \[???, ???, ...\]}`, which will match any of the listed codes. - - a regex in the form of `{regex: "???"}`, which will match any code that matches that regular expression. + - a string matching values in a column named `code` (for `MEDS` only). + - a string with a `//` sequence separating the column name and the matching column value (for `ESGPT` only). + - a list of strings as above in the form of `{any: \[???, ???, ...\]}` (or the corresponding expanded indented `YAML` format), which will match any of the listed codes. + - a regex in the form of `{regex: "???"}` (or the corresponding expanded indented `YAML` format), which will match any code that matches that regular expression. - `value_min` (optional): Must be float or integer specifying the minimum value of the predicate, if the variable is presented as numerical values. - `value_max` (optional): Must be float or integer specifying the maximum value of the predicate, if the variable is presented as numerical values. - `value_min_inclusive` (optional): Must be a boolean specifying whether `value_min` is inclusive or not. From 220c99afb97a04ec7fac6b9f1b90704ff7eaf5f6 Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Sat, 26 Oct 2024 19:44:49 +0100 Subject: [PATCH 15/18] Polish error message --- src/aces/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aces/config.py b/src/aces/config.py index 325f7d0..54a0bfa 100644 --- a/src/aces/config.py +++ b/src/aces/config.py @@ -1074,7 +1074,7 @@ class TaskExtractorConfig: >>> config = TaskExtractorConfig(predicates=predicates, trigger=trigger, windows={}) Traceback (most recent call last): ... - KeyError: "Missing 1 relationships:\\nDerived predicate 'foobar' references undefined predicate 'bar'" + KeyError: "Missing 1 relationships: Derived predicate 'foobar' references undefined predicate 'bar'" >>> predicates = {"foo": PlainPredicateConfig("foo")} >>> trigger = EventConfig("foo") @@ -1367,7 +1367,7 @@ def _initialize_predicates(self): ) if missing_predicates: raise KeyError( - f"Missing {len(missing_predicates)} relationships:\n" + "\n".join(missing_predicates) + f"Missing {len(missing_predicates)} relationships: " + "; ".join(missing_predicates) ) self._predicate_dag_graph = nx.DiGraph(dag_relationships) From 13eb15d0baa6ce79c81164fdb2b1d9a6d43b81bf Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Sat, 26 Oct 2024 19:45:19 +0100 Subject: [PATCH 16/18] Switch to warning when no rows returned --- src/aces/query.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/aces/query.py b/src/aces/query.py index 591823a..07cef47 100644 --- a/src/aces/query.py +++ b/src/aces/query.py @@ -107,7 +107,8 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame result = extract_subtree(cfg.window_tree, prospective_root_anchors, predicates_df) if result.is_empty(): # pragma: no cover - logger.info("No valid rows found.") + logger.warning("No valid rows found.") + return pl.DataFrame() else: # number of patients logger.info( From f5b0dbcd3a7d68f5016e42a921c71500623a2d86 Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Sat, 26 Oct 2024 21:47:23 +0100 Subject: [PATCH 17/18] #94 nested derived predicates (ex: needed when creating different reference ranges for male/female) --- src/aces/config.py | 39 +++++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/src/aces/config.py b/src/aces/config.py index 54a0bfa..9540b0f 100644 --- a/src/aces/config.py +++ b/src/aces/config.py @@ -1275,6 +1275,7 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta final_predicates = {**predicates, **overriding_predicates} final_demographics = {**patient_demographics, **overriding_demographics} + all_predicates = {**final_predicates, **final_demographics} logger.info("Parsing windows...") if windows is None: @@ -1288,23 +1289,45 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta logger.info("Parsing trigger event...") trigger = EventConfig(trigger) + # add window referenced predicates referenced_predicates = {pred for w in windows.values() for pred in w.referenced_predicates} + + # add trigger predicate referenced_predicates.add(trigger.predicate) + + # add label predicate if it exists and not already added label_reference = [w.label for w in windows.values() if w.label] if label_reference: referenced_predicates.update(set(label_reference)) - current_predicates = set(referenced_predicates) + special_predicates = {ANY_EVENT_COLUMN, START_OF_RECORD_KEY, END_OF_RECORD_KEY} - for pred in current_predicates - special_predicates: - if pred not in final_predicates: + for pred in set(referenced_predicates) - special_predicates: + if pred not in all_predicates: raise KeyError( - f"Something referenced predicate {pred} that wasn't defined in the configuration." - ) - if "expr" in final_predicates[pred]: - referenced_predicates.update( - DerivedPredicateConfig(**final_predicates[pred]).input_predicates + f"Something referenced predicate '{pred}' that wasn't defined in the configuration." ) + if "expr" in all_predicates[pred]: + stack = list(DerivedPredicateConfig(**all_predicates[pred]).input_predicates) + + while stack: + nested_pred = stack.pop() + + if nested_pred not in all_predicates: + raise KeyError( + f"Predicate '{nested_pred}' referenced in '{pred}' is not defined in the " + "configuration." + ) + + # if nested_pred is a DerivedPredicateConfig, unpack input_predicates and add to stack + if "expr" in all_predicates[nested_pred]: + derived_config = DerivedPredicateConfig(**all_predicates[nested_pred]) + stack.extend(derived_config.input_predicates) + referenced_predicates.add(nested_pred) # also add itself to referenced_predicates + else: + # if nested_pred is a PlainPredicateConfig, only add it to referenced_predicates + referenced_predicates.add(nested_pred) + logger.info("Parsing predicates...") predicates_to_parse = {k: v for k, v in final_predicates.items() if k in referenced_predicates} predicate_objs = {} From 89023651b5d57e5faca0b85607083a4d2dc9783b Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Sat, 26 Oct 2024 22:34:02 +0100 Subject: [PATCH 18/18] Added tests for derived predicates between static and plain as well as nested derived predicates --- src/aces/config.py | 61 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/src/aces/config.py b/src/aces/config.py index 9540b0f..415954a 100644 --- a/src/aces/config.py +++ b/src/aces/config.py @@ -1166,6 +1166,7 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta start_inclusive=True, end_inclusive=True, has={}, label=None, index_timestamp=None)}, label_window=None, index_timestamp_window=None) + >>> predicates_dict = { ... "metadata": {'description': 'A test predicates file'}, ... "description": 'this is a test', @@ -1195,6 +1196,66 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta start_inclusive=True, end_inclusive=True, has={}, label=None, index_timestamp=None)}, label_window=None, index_timestamp_window=None) + + >>> config_dict = { + ... "metadata": {'description': 'A test configuration file'}, + ... "description": 'this is a test for joining static and plain predicates', + ... "patient_demographics": {"male": {"code": "MALE"}, "female": {"code": "FEMALE"}}, + ... "predicates": {"normal_male_lab_range": {"code": "LAB", "value_min": 0, "value_max": 100, + ... "value_min_inclusive": True, "value_max_inclusive": True}, + ... "normal_female_lab_range": {"code": "LAB", "value_min": 0, "value_max": 90, + ... "value_min_inclusive": True, "value_max_inclusive": True}, + ... "normal_lab_male": {"expr": "and(normal_male_lab_range, male)"}, + ... "normal_lab_female": {"expr": "and(normal_female_lab_range, female)"}}, + ... "trigger": "_ANY_EVENT", + ... "windows": { + ... "start": { + ... "start": None, "end": "trigger + 24h", "start_inclusive": True, + ... "end_inclusive": True, "has": {"normal_lab_male": "(1, None)"}, + ... } + ... }, + ... } + >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f: + ... config_path = Path(f.name) + ... yaml.dump(config_dict, f) + ... cfg = TaskExtractorConfig.load(config_path) + >>> cfg.predicates.keys() # doctest: +NORMALIZE_WHITESPACE + dict_keys(['normal_lab_male', 'normal_male_lab_range', 'female', 'male']) + + >>> config_dict = { + ... "metadata": {'description': 'A test configuration file'}, + ... "description": 'this is a test for nested derived predicates', + ... "patient_demographics": {"male": {"code": "MALE"}, "female": {"code": "FEMALE"}}, + ... "predicates": {"abnormally_low_male_lab_range": {"code": "LAB", "value_max": 90, + ... "value_max_inclusive": False}, + ... "abnormally_low_female_lab_range": {"code": "LAB", "value_max": 80, + ... "value_max_inclusive": False}, + ... "abnormally_high_lab_range": {"code": "LAB", "value_min": 120, + ... "value_min_inclusive": False}, + ... "abnormal_lab_male_range": {"expr": + ... "or(abnormally_low_male_lab_range, abnormally_high_lab_range)"}, + ... "abnormal_lab_female_range": {"expr": + ... "or(abnormally_low_female_lab_range, abnormally_high_lab_range)"}, + ... "abnormal_lab_male": {"expr": "and(abnormal_lab_male_range, male)"}, + ... "abnormal_lab_female": {"expr": "and(abnormal_lab_female_range, female)"}, + ... "abnormal_labs": {"expr": "or(abnormal_lab_male, abnormal_lab_female)"}}, + ... "trigger": "_ANY_EVENT", + ... "windows": { + ... "start": { + ... "start": None, "end": "trigger + 24h", "start_inclusive": True, + ... "end_inclusive": True, "has": {"abnormal_labs": "(1, None)"}, + ... } + ... }, + ... } + >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f: + ... config_path = Path(f.name) + ... yaml.dump(config_dict, f) + ... cfg = TaskExtractorConfig.load(config_path) + >>> cfg.predicates.keys() # doctest: +NORMALIZE_WHITESPACE + dict_keys(['abnormal_lab_female', 'abnormal_lab_female_range', 'abnormal_lab_male', + 'abnormal_lab_male_range', 'abnormal_labs', 'abnormally_high_lab_range', + 'abnormally_low_female_lab_range', 'abnormally_low_male_lab_range', 'female', 'male']) + >>> predicates_dict = { ... "metadata": {'description': 'A test predicates file'}, ... "description": 'this is a test',