Skip to content

Commit

Permalink
ehrshot updates to labeling functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Miking98 committed Apr 6, 2024
1 parent d602c25 commit 98a87b1
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 21 deletions.
89 changes: 70 additions & 19 deletions src/femr/labelers/ehrshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,55 @@
)


def get_icu_visit_detail_care_site_ids(ontology: femr.ontology.Ontology) -> Set[str]:
return ontology.get_all_children([
# All care sites with "ICU" (case insensitive) in the name
"528292",
"528612",
"528604",
"528623",
"528396",
"528377",
"528314",
"528478",
"528112",
"528024",
"527323",
"527858",
])

def get_icu_measurements(
patient: meds.Patient, ontology: femr.ontology.Ontology
) -> List[Tuple[datetime.datetime, meds.Measurement]]:
"""Return all ICU events for this patient.
"""
icu_visit_detail_care_site_ids: Set[str] = get_icu_visit_detail_care_site_ids(ontology)
measurements: List[Tuple[datetime.datetime, meds.Measurement]]= [] # type: ignore
for idx, e in enumerate(patient['events']):
# `visit_detail` is more accurate + comprehensive than `visit_occurrence` for
# ICU measurements for STARR OMOP for some reason
for m in e["measurements"]:
if (
m['metadata']['table'] == "visit_detail"
and 'care_site_id' in m['metadata']
and m['metadata']['care_site_id'] in icu_visit_detail_care_site_ids # no ontology expansion for ICU
):
# Error checking
if isinstance(m['metadata']['end'], str):
m['metadata']['end'] = datetime.datetime.fromisoformat(m['metadata']['end'])
if e['time'] is None or m['metadata']['end'] is None:
raise RuntimeError(
f"Event {e} for patient {patient['patient_id']} cannot have `None` as its `start` or `end` attribute."
)
elif e['time'] > m['metadata']['end']:
raise RuntimeError(f"Event {e} for patient {patient['patient_id']} cannot have `start` after `end`.")
# Drop single point in time measurements
if e['time'] == m['metadata']['end']:
continue
measurements.append((e['time'], m)) # type: ignore
return measurements


def get_visit_codes(ontology: femr.ontology.Ontology) -> Set[str]:
return ontology.get_all_children(get_inpatient_admission_codes().union(get_outpatient_visit_codes()))

Expand All @@ -35,16 +84,18 @@ def get_outpatient_visit_measurements(patient: meds.Patient, ontology: femr.onto
for e in patient['events']:
for m in e["measurements"]:
if (
m['metadata']['table'] == "visit_occurrence"
m['metadata']['table'] == "visit"
and (m['code'] in admission_codes or len(ontology.get_parents(m['code']).intersection(admission_codes)) > 0)
):
if isinstance(m['metadata']['end'], str):
m['metadata']['end'] = datetime.datetime.fromisoformat(m['metadata']['end'])
# Error checking
if m['start'] is None or m['end'] is None:
if e['time'] is None or m['metadata']['end'] is None:
raise RuntimeError(f"Event {e} cannot have `None` as its `start` or `end` attribute.")
elif m['start'] > m['end']:
elif e['time'] > m['metadata']['end']:
raise RuntimeError(f"Event {e} cannot have `start` after `end`.")
# Drop single point in time events
if m['start'] == m['end']:
if e['time'] == m['metadata']['end']:
continue
measurements.append((e['time'], m))
return measurements
Expand All @@ -57,9 +108,11 @@ def get_inpatient_admission_measurements(patient: meds.Patient,
for e in patient["events"]:
for m in e["measurements"]:
if (
m['metadata']['table'] == "visit_occurrence"
m['metadata']['table'] == "visit"
and (m['code'] in admission_codes or len(ontology.get_parents(m['code']).intersection(admission_codes)) > 0)
):
if isinstance(m['metadata']['end'], str):
m['metadata']['end'] = datetime.datetime.fromisoformat(m['metadata']['end'])
# Error checking
if e['time'] is None or m['metadata']['end'] is None:
raise RuntimeError(f"Event {e} cannot have `None` as its `start` or `end` attribute.")
Expand All @@ -79,6 +132,8 @@ def get_inpatient_admission_discharge_times(
measurements: List[Tuple[datetime.datetime, meds.Measurement]] = get_inpatient_admission_measurements(patient, ontology)
times: List[Tuple[datetime.datetime, datetime.datetime]] = []
for (start, m) in measurements:
if isinstance(m['metadata']['end'], str):
m['metadata']['end'] = datetime.datetime.fromisoformat(m['metadata']['end'])
if m['metadata']['end'] is None:
raise RuntimeError(f"Event {m} cannot have `None` as its `end` attribute.")
if start > m['metadata']['end']:
Expand Down Expand Up @@ -195,25 +250,27 @@ def __init__(

def get_outcome_times(self, patient: meds.Patient) -> List[datetime.datetime]:
# Return the start times of all ICU admissions -- this is our outcome
return [e.start for e in get_icu_events(patient, self.ontology)] # type: ignore
return [time for time, __ in get_icu_measurements(patient, self.ontology)] # type: ignore

def get_visit_measurements(self, patient: meds.Patient) -> List[meds.Measurement]:
def get_visit_measurements(self, patient: meds.Patient) -> List[Tuple[datetime.datetime, meds.Measurement]]:
"""Return all inpatient visits where ICU transfer does not occur on the same day as admission."""
# Get all inpatient visits -- each visit comprises a prediction (start, end) time horizon
all_visits: List[meds.Measurement] = get_outpatient_visit_measurements(patient, self.ontology)
measurements: List[Tuple[datetime.datetime, meds.Measurement]] = get_inpatient_admission_measurements(patient, self.ontology)
# Exclude visits where ICU admission occurs on the same day as admission
icu_transfer_dates: List[datetime.datetime] = [
x.replace(hour=0, minute=0, second=0, microsecond=0) for x in self.get_outcome_times(patient)
]
valid_visits: List[meds.Measurement] = []
for start, visit in all_visits:
valid_visits: List[Tuple[datetime.datetime, meds.Measurement]] = []
for time, m in measurements:
# If admission and discharge are on the same day, then ignore
if start.date() == visit['metadata']['end'].date():
if isinstance(m['metadata']['end'], str):
m['metadata']['end'] = datetime.datetime.fromisoformat(m['metadata']['end'])
if time.date() == m['metadata']['end'].date():
continue
# If ICU transfer occurs on the same day as admission, then ignore
if start.replace(hour=0, minute=0, second=0, microsecond=0) in icu_transfer_dates:
if time.replace(hour=0, minute=0, second=0, microsecond=0) in icu_transfer_dates:
continue
valid_visits.append(visit)
valid_visits.append((time, m))
return valid_visits


Expand Down Expand Up @@ -480,12 +537,6 @@ class AcuteMyocardialInfarctionCodeLabeler(FirstDiagnosisTimeHorizonCodeLabeler)
# n = 21982
root_concept_code = "SNOMED/57054005"


class CTEPHCodeLabeler(FirstDiagnosisTimeHorizonCodeLabeler):
# n = 1433
root_concept_code = "SNOMED/233947005"


class EssentialHypertensionCodeLabeler(FirstDiagnosisTimeHorizonCodeLabeler):
# n = 4644483
root_concept_code = "SNOMED/59621000"
Expand Down
2 changes: 1 addition & 1 deletion src/femr/labelers/omop.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_outcome_times(self, patient: meds.Patient) -> List[datetime.datetime]:
return []

@abstractmethod
def get_visit_measurements(self, patient: meds.Patient) -> List[meds.Measurement]:
def get_visit_measurements(self, patient: meds.Patient) -> List[Tuple[datetime.datetime, meds.Measurement]]:
"""Return a list of all visits we want to consider (useful for limiting to inpatient visits)."""
return []

Expand Down
11 changes: 10 additions & 1 deletion src/femr/transforms/stanford.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def move_visit_start_to_first_event_start(patient: meds.Patient) -> meds.Patient

if measurement["metadata"].get("end") is not None:
# Reset the visit end to be ≥ the visit start
if isinstance(measurement["metadata"]["end"], str):
measurement["metadata"]["end"] = datetime.datetime.fromisoformat(measurement["metadata"]["end"])
measurement["metadata"]["end"] = max(event["time"], measurement["metadata"]["end"])
else:
new_measurements.append(measurement)
Expand All @@ -103,6 +105,8 @@ def move_to_day_end(patient: meds.Patient) -> meds.Patient:
event["time"] = _move_date_to_end(event["time"])
for measurement in event["measurements"]:
if measurement["metadata"].get("end") is not None:
if isinstance(measurement["metadata"]["end"], str):
measurement["metadata"]["end"] = datetime.datetime.fromisoformat(measurement["metadata"]["end"])
measurement["metadata"]["end"] = _move_date_to_end(measurement["metadata"]["end"])
measurement["metadata"]["end"] = max(measurement["metadata"]["end"], event["time"])

Expand All @@ -128,7 +132,6 @@ def move_pre_birth(patient: meds.Patient) -> meds.Patient:
for measurement in event["measurements"]:
if measurement["code"] == meds.birth_code:
birth_date = event["time"]

assert birth_date is not None

new_events = []
Expand All @@ -141,6 +144,8 @@ def move_pre_birth(patient: meds.Patient) -> meds.Patient:
event["time"] = birth_date

for measurement in event["measurements"]:
if isinstance(measurement["metadata"]["end"], str):
measurement["metadata"]["end"] = datetime.datetime.fromisoformat(measurement["metadata"]["end"])
if measurement["metadata"].get("end") is not None and measurement["metadata"]["end"] < birth_date:
measurement["metadata"]["end"] = birth_date

Expand Down Expand Up @@ -185,6 +190,8 @@ def move_billing_codes(patient: meds.Patient) -> meds.Patient:

if measurement["metadata"].get("clarity_table") in ("lpch_pat_enc", "shc_pat_enc"):
if measurement["metadata"].get("end") is not None:
if isinstance(measurement["metadata"]["end"], str):
measurement["metadata"]["end"] = datetime.datetime.fromisoformat(measurement["metadata"]["end"])
if measurement["metadata"]["visit_id"] is None:
# Every event with an end time should have a visit ID associated with it
raise RuntimeError(f"Expected visit id for visit? {patient['patient_id']} {event}")
Expand Down Expand Up @@ -223,6 +230,8 @@ def move_billing_codes(patient: meds.Patient) -> meds.Patient:

# The end time for an event should be no later than its associated visit end time
if measurement["metadata"].get("end") is not None:
if isinstance(measurement["metadata"]["end"], str):
measurement["metadata"]["end"] = datetime.datetime.fromisoformat(measurement["metadata"]["end"])
measurement["metadata"]["end"] = max(measurement["metadata"]["end"], end_visit)

# The start time for an event should be no later than its associated visit end time
Expand Down

0 comments on commit 98a87b1

Please sign in to comment.