Skip to content

Commit

Permalink
Fix tests and config parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
justin13601 committed Aug 8, 2024
1 parent 9c6cb3d commit e209857
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
4 changes: 1 addition & 3 deletions sample_configs/inhospital_mortality.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# Task: 24-hour In-hospital Mortality Prediction
predicates:
admission:
code:
regex:
any: [event_type//ADMISSION, event_type//EMERGENCY, event_type//ELECTIVE]
code: event_type//ADMISSION
random_col: foo
discharge:
code: event_type//DISCHARGE
Expand Down
12 changes: 7 additions & 5 deletions src/aces/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,21 +1096,21 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta
raise ValueError(f"Unrecognized keys in configuration file: '{', '.join(loaded_dict.keys())}'")

logger.info("Parsing predicates...")
predicates = {}
predicate_objs = {}
for n, p in predicates.items():
if "expr" in p:
predicates[n] = DerivedPredicateConfig(**p)
predicate_objs[n] = DerivedPredicateConfig(**p)
else:
config_data = {k: v for k, v in p.items() if k in PlainPredicateConfig.__dataclass_fields__}
other_cols = {k: v for k, v in p.items() if k not in config_data.keys()}
predicates[n] = PlainPredicateConfig(**p, other_cols=other_cols)
predicate_objs[n] = PlainPredicateConfig(**config_data, other_cols=other_cols)

if patient_demographics:
logger.info("Parsing patient demographics...")
patient_demographics = {
n: PlainPredicateConfig(**p, static=True) for n, p in patient_demographics.items()
}
predicates.update(patient_demographics)
predicate_objs.update(patient_demographics)

logger.info("Parsing trigger event...")
trigger = EventConfig(trigger)
Expand All @@ -1124,7 +1124,9 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta
else:
windows = {n: WindowConfig(**w) for n, w in windows.items()}

return cls(predicates=predicates, trigger=trigger, windows=windows)
print(predicate_objs)

return cls(predicates=predicate_objs, trigger=trigger, windows=windows)

def save(self, config_path: str | Path, do_overwrite: bool = False):
"""Load a configuration file from the given path and return it as a dict.
Expand Down

0 comments on commit e209857

Please sign in to comment.