diff --git a/src/aces/predicates.py b/src/aces/predicates.py index 1eee9ee..e4a1ca0 100644 --- a/src/aces/predicates.py +++ b/src/aces/predicates.py @@ -696,6 +696,8 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D raise ValueError(f"Invalid data standard: {standard}. Options are 'direct', 'MEDS', 'ESGPT'.") predicate_cols = list(plain_predicates.keys()) + data = data.sort(by=["subject_id", "timestamp"], nulls_last=False) + # derived predicates logger.info("Loaded plain predicates. Generating derived predicate columns...") static_variables = [pred for pred in cfg.plain_predicates if cfg.plain_predicates[pred].static] @@ -714,8 +716,6 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D logger.info(f"Added predicate column '{name}'.") predicate_cols.append(name) - data = data.sort(by=["subject_id", "timestamp"], nulls_last=False) - # add special predicates: # a column of 1s representing any predicate # a column of 0s with 1 in the first event of each subject_id representing the start of record