Skip to content

Commit

Permalink
Sort first to guarantee null timestamp rows are first per subject_id
Browse files Browse the repository at this point in the history
  • Loading branch information
justin13601 committed Oct 25, 2024
1 parent 76cbe91 commit c56e60b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/aces/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,8 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D
raise ValueError(f"Invalid data standard: {standard}. Options are 'direct', 'MEDS', 'ESGPT'.")
predicate_cols = list(plain_predicates.keys())

data = data.sort(by=["subject_id", "timestamp"], nulls_last=False)

# derived predicates
logger.info("Loaded plain predicates. Generating derived predicate columns...")
static_variables = [pred for pred in cfg.plain_predicates if cfg.plain_predicates[pred].static]
Expand All @@ -714,8 +716,6 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D
logger.info(f"Added predicate column '{name}'.")
predicate_cols.append(name)

data = data.sort(by=["subject_id", "timestamp"], nulls_last=False)

# add special predicates:
# a column of 1s representing any predicate
# a column of 0s with 1 in the first event of each subject_id representing the start of record
Expand Down

0 comments on commit c56e60b

Please sign in to comment.