Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support derived predicates between static predicates and plain predicates #145

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
f36821f
Test config for derived predicates based on static predicates
justin13601 Oct 25, 2024
4f8dc20
Pass in static predicates to DerivedPredicateConfig
justin13601 Oct 25, 2024
6495d3c
Logic for derived predicate between plain and static by propagating d…
justin13601 Oct 25, 2024
6b2ec02
Revert back to original config
justin13601 Oct 25, 2024
12e9694
Revert DerivedPredicateConfig static attribute due to parquet issues
justin13601 Oct 25, 2024
2c8304c
Explicitly get list of static predicates from config plain predicates
justin13601 Oct 25, 2024
76cbe91
Freeze pre-commit version and update workflows
justin13601 Oct 25, 2024
c56e60b
Sort first to guarantee null timestamp rows are first per subject_id
justin13601 Oct 25, 2024
2c57b85
Warnings and error messages per #141 #142 #146 (#147)
justin13601 Oct 25, 2024
495f2c0
Add percentage for printing label distribution (code rabbit suggestion)
justin13601 Oct 25, 2024
ec07a94
Remove label distribution printing (can be done post-hoc)
justin13601 Oct 26, 2024
ca85310
Code quality PR workflow not on all files
justin13601 Oct 26, 2024
0cb8b1a
Exclude test coverage for logging statements
justin13601 Oct 26, 2024
7121127
Update docs to clarify supported code fields
justin13601 Oct 26, 2024
220c99a
Polish error message
justin13601 Oct 26, 2024
13eb15d
Switch to warning when no rows returned
justin13601 Oct 26, 2024
f5b0dbc
#94 nested derived predicates (ex: needed when creating different ref…
justin13601 Oct 26, 2024
8902365
Added tests for derived predicates between static and plain as well a…
justin13601 Oct 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions .github/workflows/code-quality-main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,16 @@ jobs:

steps:
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v3
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"

- name: Install packages
run: |
pip install .[dev]

- name: Run pre-commits
uses: pre-commit/[email protected].0
uses: pre-commit/[email protected].1
16 changes: 11 additions & 5 deletions .github/workflows/code-quality-pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@ jobs:

steps:
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v3
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"
justin13601 marked this conversation as resolved.
Show resolved Hide resolved

- name: Install packages
run: |
pip install .[dev]

- name: Find modified files
id: file_changes
Expand All @@ -31,6 +37,6 @@ jobs:
run: echo '${{ steps.file_changes.outputs.files}}'

- name: Run pre-commits
uses: pre-commit/[email protected].0
uses: pre-commit/[email protected].1
with:
extra_args: --files ${{ steps.file_changes.outputs.files}}
justin13601 marked this conversation as resolved.
Show resolved Hide resolved
extra_args: --all-files
6 changes: 3 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ jobs:

steps:
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4

- name: Set up Python 3.10
uses: actions/setup-python@v3
uses: actions/setup-python@v5
with:
python-version: "3.10"

- name: Install packages
run: |
pip install -e .[dev]
pip install .[dev]

#----------------------------------------------
# run test suite
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ exclude: "to_organize"

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v5.0.0
justin13601 marked this conversation as resolved.
Show resolved Hide resolved
hooks:
# list of supported hooks: https://pre-commit.com/hooks.html
- id: trailing-whitespace
Expand Down
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,16 @@ Fields for a "plain" predicate:

- `code` (required): Must be one of the following:
- a string with `//` sequence separating the column name and column value.
- a list of strings as above in the form of {any: \[???, ???, ...\]}, which will match any of the listed codes.
- a regex in the form of {regex: "???"}, which will match any code that matches that regular expression.
- a list of strings as above in the form of `{any: \[???, ???, ...\]}`, which will match any of the listed codes.
- a regex in the form of `{regex: "???"}`, which will match any code that matches that regular expression.
- `value_min` (optional): Must be float or integer specifying the minimum value of the predicate, if the variable is presented as numerical values.
- `value_max` (optional): Must be float or integer specifying the maximum value of the predicate, if the variable is presented as numerical values.
- `value_min_inclusive` (optional): Must be a boolean specifying whether `value_min` is inclusive or not.
- `value_max_inclusive` (optional): Must be a boolean specifying whether `value_max` is inclusive or not.
- `other_cols` (optional): Must be a 1-to-1 dictionary of column name and column value, which places additional constraints on further columns.

**Note**: For memory optimization, we strongly recommend using either the List of Values or Regular Expression formats whenever possible, especially when needing to match multiple values. Defining each code as an individual string will increase memory usage significantly, as each code generates a separate predicate column. Using a list or regex consolidates multiple matching codes under a single column, reducing the overall memory footprint.

#### Derived Predicates

"Derived" predicates combine existing "plain" predicates using `and` / `or` keywords and have exactly 1 required `expr` field: For instance, the following defines a predicate representing either death or discharge (by combining "plain" predicates of `death` and `discharge`):
Expand Down
9 changes: 9 additions & 0 deletions docs/source/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,29 @@ These configs consist of the following four fields:
The field can additionally be a dictionary with either a `regex` key and the value being a regular
expression (satisfied if the regular expression evaluates to True), or a `any` key and the value being a
list of strings (satisfied if there is an occurrence for any code in the list).

**Note**: Each individual definition of `PlainPredicateConfig` and `code` will generate a separate predicate
column. Thus, for memory optimization, it is strongly recommended to match multiple values using either the
List of Values or Regular Expression formats whenever possible.

- `value_min`: If specified, an observation will only satisfy this predicate if the occurrence of the
underlying `code` with a reported numerical value that is either greater than or greater than or equal to
`value_min` (with these options being decided on the basis of `value_min_inclusive`, where
`value_min_inclusive=True` indicating that an observation satisfies this predicate if its value is greater
than or equal to `value_min`, and `value_min_inclusive=False` indicating a greater than but not equal to
will be used).

- `value_max`: If specified, an observation will only satisfy this predicate if the occurrence of the
underlying `code` with a reported numerical value that is either less than or less than or equal to
`value_max` (with these options being decided on the basis of `value_max_inclusive`, where
`value_max_inclusive=True` indicating that an observation satisfies this predicate if its value is less
than or equal to `value_max`, and `value_max_inclusive=False` indicating a less than but not equal to
will be used).

- `value_min_inclusive`: See `value_min`

- `value_max_inclusive`: See `value_max`

- `other_cols`: This optional field accepts a 1-to-1 dictionary of column names to column values, and can be
used to specify further constraints on other columns (ie., not `code`) for this predicate.

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ expand_shards = "aces.expand_shards:main"

[project.optional-dependencies]
dev = [
"pre-commit", "pytest", "pytest-cov", "pytest-subtests", "rootutils", "hypothesis"
"pre-commit<4", "pytest", "pytest-cov", "pytest-subtests", "rootutils", "hypothesis"
]
profiling = ["psutil"]

Expand Down
23 changes: 23 additions & 0 deletions src/aces/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,6 +1195,23 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta
start_inclusive=True, end_inclusive=True, has={},
label=None, index_timestamp=None)},
label_window=None, index_timestamp_window=None)
>>> predicates_dict = {
... "metadata": {'description': 'A test predicates file'},
... "description": 'this is a test',
... "patient_demographics": {"brown_eyes": {"code": "eye_color//BR"}},
... "predicates": {'admission': "invalid"},
... }
>>> with (tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as config_fp,
... tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as pred_fp):
... config_path = Path(config_fp.name)
... pred_path = Path(pred_fp.name)
... yaml.dump(no_predicates_config, config_fp)
... yaml.dump(predicates_dict, pred_fp)
... cfg = TaskExtractorConfig.load(config_path, pred_path) # doctest: +NORMALIZE_WHITESPACE
Traceback (most recent call last):
...
ValueError: Predicate 'admission' is not defined correctly in the configuration file. Currently
defined as the string: invalid. Please refer to the documentation for the supported formats.
"""
if isinstance(config_path, str):
config_path = Path(config_path)
Expand Down Expand Up @@ -1295,6 +1312,12 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta
if "expr" in p:
predicate_objs[n] = DerivedPredicateConfig(**p)
else:
if isinstance(p, str):
raise ValueError(
f"Predicate '{n}' is not defined correctly in the configuration file. "
f"Currently defined as the string: {p}. "
"Please refer to the documentation for the supported formats."
)
config_data = {k: v for k, v in p.items() if k in PlainPredicateConfig.__dataclass_fields__}
other_cols = {k: v for k, v in p.items() if k not in config_data}
predicate_objs[n] = PlainPredicateConfig(**config_data, other_cols=other_cols)
Expand Down
73 changes: 71 additions & 2 deletions src/aces/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,64 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D
│ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 1 ┆ 1 ┆ 0 │
│ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 0 ┆ 1 ┆ 0 ┆ 1 │
└────────────┴─────────────────────┴─────┴──────┴────────────┴───────────────┴─────────────┘

>>> data = pl.DataFrame({
... "subject_id": [1, 1, 1, 2, 2],
... "timestamp": [
... None,
... "01/01/2021 00:00",
... "01/01/2021 12:00",
... "01/02/2021 00:00",
... "01/02/2021 12:00"],
... "adm": [0, 1, 0, 1, 0],
... "male": [1, 0, 0, 0, 0],
... })
>>> predicates = {
... "adm": PlainPredicateConfig("adm"),
... "male": PlainPredicateConfig("male", static=True), # predicate match based on name for direct
... "male_adm": DerivedPredicateConfig("and(male, adm)", static=['male']),
... }
>>> trigger = EventConfig("adm")
>>> windows = {
... "input": WindowConfig(
... start=None,
... end="trigger + 24h",
... start_inclusive=True,
... end_inclusive=True,
... has={"_ANY_EVENT": "(32, None)"},
... ),
... "gap": WindowConfig(
... start="input.end",
... end="start + 24h",
... start_inclusive=False,
... end_inclusive=True,
... has={
... "adm": "(None, 0)",
... "male_adm": "(None, 0)",
... },
... ),
... }
>>> config = TaskExtractorConfig(predicates=predicates, trigger=trigger, windows=windows)
>>> with tempfile.NamedTemporaryFile(mode="w", suffix=".csv") as f:
... data_path = Path(f.name)
... data.write_csv(data_path)
... data_config = DictConfig({
... "path": str(data_path), "standard": "direct", "ts_format": "%m/%d/%Y %H:%M"
... })
... get_predicates_df(config, data_config)
shape: (5, 6)
┌────────────┬─────────────────────┬─────┬──────┬──────────┬────────────┐
│ subject_id ┆ timestamp ┆ adm ┆ male ┆ male_adm ┆ _ANY_EVENT │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
╞════════════╪═════════════════════╪═════╪══════╪══════════╪════════════╡
│ 1 ┆ null ┆ 0 ┆ 1 ┆ 0 ┆ null │
│ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 1 ┆ 1 ┆ 1 │
│ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 1 │
│ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 1 │
│ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │
└────────────┴─────────────────────┴─────┴──────┴──────────┴────────────┘

>>> with tempfile.NamedTemporaryFile(mode="w", suffix=".csv") as f:
... data_path = Path(f.name)
... data.write_csv(data_path)
Expand Down Expand Up @@ -638,15 +696,26 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D
raise ValueError(f"Invalid data standard: {standard}. Options are 'direct', 'MEDS', 'ESGPT'.")
predicate_cols = list(plain_predicates.keys())

data = data.sort(by=["subject_id", "timestamp"], nulls_last=False)

# derived predicates
logger.info("Loaded plain predicates. Generating derived predicate columns...")
static_variables = [pred for pred in cfg.plain_predicates if cfg.plain_predicates[pred].static]
for name, code in cfg.derived_predicates.items():
if any(x in static_variables for x in code.input_predicates):
data = data.with_columns(
[
pl.col(static_var)
.first()
mmcdermott marked this conversation as resolved.
Show resolved Hide resolved
.over("subject_id") # take the first value in each subject_id group and propagate it
.alias(static_var)
for static_var in static_variables
]
)
data = data.with_columns(code.eval_expr().cast(PRED_CNT_TYPE).alias(name))
logger.info(f"Added predicate column '{name}'.")
predicate_cols.append(name)

data = data.sort(by=["subject_id", "timestamp"], nulls_last=False)
mmcdermott marked this conversation as resolved.
Show resolved Hide resolved

# add special predicates:
# a column of 1s representing any predicate
# a column of 0s with 1 in the first event of each subject_id representing the start of record
Expand Down
19 changes: 19 additions & 0 deletions src/aces/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
"""


from collections import Counter

import polars as pl
from bigtree import preorder_iter
from loguru import logger
Expand Down Expand Up @@ -137,6 +139,23 @@
)
to_return_cols.insert(1, "label")

if result["label"].n_unique() == 1:
justin13601 marked this conversation as resolved.
Show resolved Hide resolved
logger.warning(

Check warning on line 143 in src/aces/query.py

View check run for this annotation

Codecov / codecov/patch

src/aces/query.py#L142-L143

Added lines #L142 - L143 were not covered by tests
f"All labels in the extracted cohort are the same: '{result['label'][0]}'. "
"This may indicate an issue with the task logic. "
"Please double-check your configuration file if this is not expected."
)
else:
justin13601 marked this conversation as resolved.
Show resolved Hide resolved
unique_labels = result["label"].n_unique()
label_distribution = Counter(result["label"])
total_count = sum(label_distribution.values())
distribution_with_pct = {

Check warning on line 152 in src/aces/query.py

View check run for this annotation

Codecov / codecov/patch

src/aces/query.py#L149-L152

Added lines #L149 - L152 were not covered by tests
k: f"{v} ({v/total_count*100:.1f}%)" for k, v in label_distribution.items()
}
logger.info(

Check warning on line 155 in src/aces/query.py

View check run for this annotation

Codecov / codecov/patch

src/aces/query.py#L155

Added line #L155 was not covered by tests
f"Found {unique_labels} unique labels in the extracted cohort: " f"{distribution_with_pct}."
)
justin13601 marked this conversation as resolved.
Show resolved Hide resolved

# add index_timestamp column if specified
if cfg.index_timestamp_window:
logger.info(
Expand Down
Loading