Skip to content

Commit

Permalink
Merge pull request #124 from justin13601/102_separate_predicates
Browse files Browse the repository at this point in the history
Overriding predicates and aces-cli changes for separate predicates
  • Loading branch information
mmcdermott authored Sep 1, 2024
2 parents d2bd0ee + 54502b3 commit 51167ae
Show file tree
Hide file tree
Showing 5 changed files with 474 additions and 20 deletions.
9 changes: 8 additions & 1 deletion src/aces/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,14 @@ def main(cfg: DictConfig):

# load configuration
logger.info(f"Loading config from '{cfg.config_path}'")
task_cfg = config.TaskExtractorConfig.load(Path(cfg.config_path))
if cfg.predicates_path:
logger.info(f"Overriding predicates and/or demographics from '{cfg.predicates_path}'")
predicates_path = Path(cfg.predicates_path)
else:
predicates_path = None
task_cfg = config.TaskExtractorConfig.load(
config_path=Path(cfg.config_path), predicates_path=predicates_path
)

logger.info(f"Attempting to get predicates dataframe given:\n{OmegaConf.to_yaml(cfg.data)}")
predicates_df = predicates.get_predicates_df(task_cfg, cfg.data)
Expand Down
40 changes: 23 additions & 17 deletions src/aces/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,6 +1210,8 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta
f"Only supports reading from '.yaml'. Got: '{config_path.suffix}' in '{config_path.name}'."
)

overriding_predicates = {}
overriding_demographics = {}
if predicates_path:
if isinstance(predicates_path, str):
predicates_path = Path(predicates_path)
Expand All @@ -1228,33 +1230,35 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta
f"'{predicates_path.name}'."
)

predicates = predicates_dict.pop("predicates")
patient_demographics = predicates_dict.pop("patient_demographics", None)

# Remove the description or metadata keys if they exist - currently unused except for readability
# in the YAML
_ = predicates_dict.pop("description", None)
_ = predicates_dict.pop("metadata", None)
overriding_predicates = predicates_dict.pop("predicates", {})
overriding_demographics = predicates_dict.pop("patient_demographics", {})

if predicates_dict:
raise ValueError(
f"Unrecognized keys in configuration file: '{', '.join(predicates_dict.keys())}'"
)
else:
predicates = loaded_dict.pop("predicates")
patient_demographics = loaded_dict.pop("patient_demographics", None)

trigger = loaded_dict.pop("trigger")
windows = loaded_dict.pop("windows", None)

# Remove the description or metadata keys if they exist - currently unused except for readability
# in the YAML
_ = loaded_dict.pop("description", None)
_ = loaded_dict.pop("metadata", None)

trigger = loaded_dict.pop("trigger")
windows = loaded_dict.pop("windows", None)

predicates = loaded_dict.pop("predicates", {})
patient_demographics = loaded_dict.pop("patient_demographics", {})

if loaded_dict:
raise ValueError(f"Unrecognized keys in configuration file: '{', '.join(loaded_dict.keys())}'")

final_predicates = {**predicates, **overriding_predicates}
final_demographics = {**patient_demographics, **overriding_demographics}

logger.info("Parsing windows...")
if windows is None:
windows = {}
Expand All @@ -1272,15 +1276,17 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta
current_predicates = set(referenced_predicates)
special_predicates = {ANY_EVENT_COLUMN, START_OF_RECORD_KEY, END_OF_RECORD_KEY}
for pred in current_predicates - special_predicates:
if pred not in predicates:
if pred not in final_predicates:
raise KeyError(
f"Something referenced predicate {pred} that wasn't defined in the configuration."
)
if "expr" in predicates[pred]:
referenced_predicates.update(DerivedPredicateConfig(**predicates[pred]).input_predicates)
if "expr" in final_predicates[pred]:
referenced_predicates.update(
DerivedPredicateConfig(**final_predicates[pred]).input_predicates
)

logger.info("Parsing predicates...")
predicates_to_parse = {k: v for k, v in predicates.items() if k in referenced_predicates}
predicates_to_parse = {k: v for k, v in final_predicates.items() if k in referenced_predicates}
predicate_objs = {}
for n, p in predicates_to_parse.items():
if "expr" in p:
Expand All @@ -1290,12 +1296,12 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta
other_cols = {k: v for k, v in p.items() if k not in config_data}
predicate_objs[n] = PlainPredicateConfig(**config_data, other_cols=other_cols)

if patient_demographics:
if final_demographics:
logger.info("Parsing patient demographics...")
patient_demographics = {
n: PlainPredicateConfig(**p, static=True) for n, p in patient_demographics.items()
final_demographics = {
n: PlainPredicateConfig(**p, static=True) for n, p in final_demographics.items()
}
predicate_objs.update(patient_demographics)
predicate_objs.update(final_demographics)

return cls(predicates=predicate_objs, trigger=trigger, windows=windows)

Expand Down
1 change: 1 addition & 0 deletions src/aces/configs/_aces.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ cohort_name: ""

# Path to the task configuration file
config_path: ${cohort_dir}/${cohort_name}.yaml
predicates_path: null

# Path to store the output file. The `${data._prefix}` addition allows us to add shard specific prefixes in a
# sharded data mode.
Expand Down
Loading

0 comments on commit 51167ae

Please sign in to comment.