From 98a87b1e3b058c6722a9cb09b10d6649d1009251 Mon Sep 17 00:00:00 2001 From: Michael Wornow Date: Fri, 5 Apr 2024 18:15:19 -0700 Subject: [PATCH] ehrshot updates to labeling functions --- src/femr/labelers/ehrshot.py | 89 ++++++++++++++++++++++++++------- src/femr/labelers/omop.py | 2 +- src/femr/transforms/stanford.py | 11 +++- 3 files changed, 81 insertions(+), 21 deletions(-) diff --git a/src/femr/labelers/ehrshot.py b/src/femr/labelers/ehrshot.py index f270bca..db0ec60 100644 --- a/src/femr/labelers/ehrshot.py +++ b/src/femr/labelers/ehrshot.py @@ -18,6 +18,55 @@ ) +def get_icu_visit_detail_care_site_ids(ontology: femr.ontology.Ontology) -> Set[str]: + return ontology.get_all_children([ + # All care sites with "ICU" (case insensitive) in the name + "528292", + "528612", + "528604", + "528623", + "528396", + "528377", + "528314", + "528478", + "528112", + "528024", + "527323", + "527858", + ]) + +def get_icu_measurements( + patient: meds.Patient, ontology: femr.ontology.Ontology +) -> List[Tuple[datetime.datetime, meds.Measurement]]: + """Return all ICU events for this patient. + """ + icu_visit_detail_care_site_ids: Set[str] = get_icu_visit_detail_care_site_ids(ontology) + measurements: List[Tuple[datetime.datetime, meds.Measurement]]= [] # type: ignore + for idx, e in enumerate(patient['events']): + # `visit_detail` is more accurate + comprehensive than `visit_occurrence` for + # ICU measurements for STARR OMOP for some reason + for m in e["measurements"]: + if ( + m['metadata']['table'] == "visit_detail" + and 'care_site_id' in m['metadata'] + and m['metadata']['care_site_id'] in icu_visit_detail_care_site_ids # no ontology expansion for ICU + ): + # Error checking + if isinstance(m['metadata']['end'], str): + m['metadata']['end'] = datetime.datetime.fromisoformat(m['metadata']['end']) + if e['time'] is None or m['metadata']['end'] is None: + raise RuntimeError( + f"Event {e} for patient {patient['patient_id']} cannot have `None` as its `start` or `end` attribute." + ) + elif e['time'] > m['metadata']['end']: + raise RuntimeError(f"Event {e} for patient {patient['patient_id']} cannot have `start` after `end`.") + # Drop single point in time measurements + if e['time'] == m['metadata']['end']: + continue + measurements.append((e['time'], m)) # type: ignore + return measurements + + def get_visit_codes(ontology: femr.ontology.Ontology) -> Set[str]: return ontology.get_all_children(get_inpatient_admission_codes().union(get_outpatient_visit_codes())) @@ -35,16 +84,18 @@ def get_outpatient_visit_measurements(patient: meds.Patient, ontology: femr.onto for e in patient['events']: for m in e["measurements"]: if ( - m['metadata']['table'] == "visit_occurrence" + m['metadata']['table'] == "visit" and (m['code'] in admission_codes or len(ontology.get_parents(m['code']).intersection(admission_codes)) > 0) ): + if isinstance(m['metadata']['end'], str): + m['metadata']['end'] = datetime.datetime.fromisoformat(m['metadata']['end']) # Error checking - if m['start'] is None or m['end'] is None: + if e['time'] is None or m['metadata']['end'] is None: raise RuntimeError(f"Event {e} cannot have `None` as its `start` or `end` attribute.") - elif m['start'] > m['end']: + elif e['time'] > m['metadata']['end']: raise RuntimeError(f"Event {e} cannot have `start` after `end`.") # Drop single point in time events - if m['start'] == m['end']: + if e['time'] == m['metadata']['end']: continue measurements.append((e['time'], m)) return measurements @@ -57,9 +108,11 @@ def get_inpatient_admission_measurements(patient: meds.Patient, for e in patient["events"]: for m in e["measurements"]: if ( - m['metadata']['table'] == "visit_occurrence" + m['metadata']['table'] == "visit" and (m['code'] in admission_codes or len(ontology.get_parents(m['code']).intersection(admission_codes)) > 0) ): + if isinstance(m['metadata']['end'], str): + m['metadata']['end'] = datetime.datetime.fromisoformat(m['metadata']['end']) # Error checking if e['time'] is None or m['metadata']['end'] is None: raise RuntimeError(f"Event {e} cannot have `None` as its `start` or `end` attribute.") @@ -79,6 +132,8 @@ def get_inpatient_admission_discharge_times( measurements: List[Tuple[datetime.datetime, meds.Measurement]] = get_inpatient_admission_measurements(patient, ontology) times: List[Tuple[datetime.datetime, datetime.datetime]] = [] for (start, m) in measurements: + if isinstance(m['metadata']['end'], str): + m['metadata']['end'] = datetime.datetime.fromisoformat(m['metadata']['end']) if m['metadata']['end'] is None: raise RuntimeError(f"Event {m} cannot have `None` as its `end` attribute.") if start > m['metadata']['end']: @@ -195,25 +250,27 @@ def __init__( def get_outcome_times(self, patient: meds.Patient) -> List[datetime.datetime]: # Return the start times of all ICU admissions -- this is our outcome - return [e.start for e in get_icu_events(patient, self.ontology)] # type: ignore + return [time for time, __ in get_icu_measurements(patient, self.ontology)] # type: ignore - def get_visit_measurements(self, patient: meds.Patient) -> List[meds.Measurement]: + def get_visit_measurements(self, patient: meds.Patient) -> List[Tuple[datetime.datetime, meds.Measurement]]: """Return all inpatient visits where ICU transfer does not occur on the same day as admission.""" # Get all inpatient visits -- each visit comprises a prediction (start, end) time horizon - all_visits: List[meds.Measurement] = get_outpatient_visit_measurements(patient, self.ontology) + measurements: List[Tuple[datetime.datetime, meds.Measurement]] = get_inpatient_admission_measurements(patient, self.ontology) # Exclude visits where ICU admission occurs on the same day as admission icu_transfer_dates: List[datetime.datetime] = [ x.replace(hour=0, minute=0, second=0, microsecond=0) for x in self.get_outcome_times(patient) ] - valid_visits: List[meds.Measurement] = [] - for start, visit in all_visits: + valid_visits: List[Tuple[datetime.datetime, meds.Measurement]] = [] + for time, m in measurements: # If admission and discharge are on the same day, then ignore - if start.date() == visit['metadata']['end'].date(): + if isinstance(m['metadata']['end'], str): + m['metadata']['end'] = datetime.datetime.fromisoformat(m['metadata']['end']) + if time.date() == m['metadata']['end'].date(): continue # If ICU transfer occurs on the same day as admission, then ignore - if start.replace(hour=0, minute=0, second=0, microsecond=0) in icu_transfer_dates: + if time.replace(hour=0, minute=0, second=0, microsecond=0) in icu_transfer_dates: continue - valid_visits.append(visit) + valid_visits.append((time, m)) return valid_visits @@ -480,12 +537,6 @@ class AcuteMyocardialInfarctionCodeLabeler(FirstDiagnosisTimeHorizonCodeLabeler) # n = 21982 root_concept_code = "SNOMED/57054005" - -class CTEPHCodeLabeler(FirstDiagnosisTimeHorizonCodeLabeler): - # n = 1433 - root_concept_code = "SNOMED/233947005" - - class EssentialHypertensionCodeLabeler(FirstDiagnosisTimeHorizonCodeLabeler): # n = 4644483 root_concept_code = "SNOMED/59621000" diff --git a/src/femr/labelers/omop.py b/src/femr/labelers/omop.py index 9f3c356..58dd484 100644 --- a/src/femr/labelers/omop.py +++ b/src/femr/labelers/omop.py @@ -64,7 +64,7 @@ def get_outcome_times(self, patient: meds.Patient) -> List[datetime.datetime]: return [] @abstractmethod - def get_visit_measurements(self, patient: meds.Patient) -> List[meds.Measurement]: + def get_visit_measurements(self, patient: meds.Patient) -> List[Tuple[datetime.datetime, meds.Measurement]]: """Return a list of all visits we want to consider (useful for limiting to inpatient visits).""" return [] diff --git a/src/femr/transforms/stanford.py b/src/femr/transforms/stanford.py index e7661d4..1ca4ed1 100644 --- a/src/femr/transforms/stanford.py +++ b/src/femr/transforms/stanford.py @@ -84,6 +84,8 @@ def move_visit_start_to_first_event_start(patient: meds.Patient) -> meds.Patient if measurement["metadata"].get("end") is not None: # Reset the visit end to be ≥ the visit start + if isinstance(measurement["metadata"]["end"], str): + measurement["metadata"]["end"] = datetime.datetime.fromisoformat(measurement["metadata"]["end"]) measurement["metadata"]["end"] = max(event["time"], measurement["metadata"]["end"]) else: new_measurements.append(measurement) @@ -103,6 +105,8 @@ def move_to_day_end(patient: meds.Patient) -> meds.Patient: event["time"] = _move_date_to_end(event["time"]) for measurement in event["measurements"]: if measurement["metadata"].get("end") is not None: + if isinstance(measurement["metadata"]["end"], str): + measurement["metadata"]["end"] = datetime.datetime.fromisoformat(measurement["metadata"]["end"]) measurement["metadata"]["end"] = _move_date_to_end(measurement["metadata"]["end"]) measurement["metadata"]["end"] = max(measurement["metadata"]["end"], event["time"]) @@ -128,7 +132,6 @@ def move_pre_birth(patient: meds.Patient) -> meds.Patient: for measurement in event["measurements"]: if measurement["code"] == meds.birth_code: birth_date = event["time"] - assert birth_date is not None new_events = [] @@ -141,6 +144,8 @@ def move_pre_birth(patient: meds.Patient) -> meds.Patient: event["time"] = birth_date for measurement in event["measurements"]: + if isinstance(measurement["metadata"]["end"], str): + measurement["metadata"]["end"] = datetime.datetime.fromisoformat(measurement["metadata"]["end"]) if measurement["metadata"].get("end") is not None and measurement["metadata"]["end"] < birth_date: measurement["metadata"]["end"] = birth_date @@ -185,6 +190,8 @@ def move_billing_codes(patient: meds.Patient) -> meds.Patient: if measurement["metadata"].get("clarity_table") in ("lpch_pat_enc", "shc_pat_enc"): if measurement["metadata"].get("end") is not None: + if isinstance(measurement["metadata"]["end"], str): + measurement["metadata"]["end"] = datetime.datetime.fromisoformat(measurement["metadata"]["end"]) if measurement["metadata"]["visit_id"] is None: # Every event with an end time should have a visit ID associated with it raise RuntimeError(f"Expected visit id for visit? {patient['patient_id']} {event}") @@ -223,6 +230,8 @@ def move_billing_codes(patient: meds.Patient) -> meds.Patient: # The end time for an event should be no later than its associated visit end time if measurement["metadata"].get("end") is not None: + if isinstance(measurement["metadata"]["end"], str): + measurement["metadata"]["end"] = datetime.datetime.fromisoformat(measurement["metadata"]["end"]) measurement["metadata"]["end"] = max(measurement["metadata"]["end"], end_visit) # The start time for an event should be no later than its associated visit end time