Skip to content

Commit

Permalink
Switch to meds_reader
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanSteinberg committed Jul 4, 2024
1 parent 26303de commit c54f6ad
Show file tree
Hide file tree
Showing 42 changed files with 564 additions and 826 deletions.
3 changes: 3 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,6 @@ ignore_missing_imports = True

[mypy-msgpack.*]
ignore_missing_imports = True

[mypy-xformers.*]
ignore_missing_imports = True
108 changes: 58 additions & 50 deletions src/femr/featurizers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import numpy as np
import scipy.sparse

import femr.mr
import femr.ontology


Expand All @@ -30,22 +29,24 @@ class ColumnValue(NamedTuple):
def _preprocess_map_func(
patients: Iterator[meds_reader.Patient], *, label_map: Mapping[int, List[meds.Label]], featurizers: List[Featurizer]
) -> List[List[Any]]:
result = []
patients_list = list(patients)
for featurizer in featurizers:
result.append(featurizer.generate_preprocess_data(iter(patients_list), label_map))
initial_data = [featurizer.get_initial_preprocess_data() for featurizer in featurizers]
for patient in patients:
for data, featurizer in zip(initial_data, featurizers):
featurizer.add_preprocess_data(data, patient, label_map)

return result
return initial_data


def _features_map_func(
patients: Iterator[meds_reader.Patient], *, label_map: Mapping[int, List[meds.Label]], featurizers: List[Featurizer]
) -> Mapping[str, Any]:
# Construct CSR sparse matrix
# non-zero entries in sparse matrix
data: List[Any] = []
# maps each element in `data`` to its column in the sparse matrix
indices: List[int] = []
data_and_indices = np.zeros((1024, 2), np.float64)
data_and_indices_arrays = []

current_index = 0

# maps each element in `data` and `indices` to the rows of the sparse matrix
indptr: List[int] = []

Expand Down Expand Up @@ -73,7 +74,7 @@ def _features_map_func(
a.append(b)

for features in features_per_label:
indptr.append(len(indices))
indptr.append(current_index + len(data_and_indices_arrays) * 1024)

# Keep track of starting column for each successive featurizer as we
# combine their features into one large matrix
Expand All @@ -85,23 +86,34 @@ def _features_map_func(
f"{column} on patient {patient.patient_id} ({column} must be between 0 and "
f"{featurizer.get_num_columns()})"
)
indices.append(column_offset + column)
data.append(value)
data_and_indices[current_index, 0] = value
data_and_indices[current_index, 1] = column_offset + column

current_index += 1

if current_index == 1024:
current_index = 0
data_and_indices_arrays.append(data_and_indices.copy())

# Record what the starting column should be for the next featurizer
column_offset += featurizer.get_num_columns()

# Need one last `indptr` for end of last row in CSR sparse matrix
indptr.append(len(indices))
indptr.append(current_index + len(data_and_indices_arrays) * 1024)

# n_rows = number of Labels across all Patients
total_rows: int = len(indptr) - 1
# n_cols = sum of number of columns output by each Featurizer
total_columns: int = sum(x.get_num_columns() for x in featurizers)

# Explanation of CSR Matrix: https://stackoverflow.com/questions/52299420/scipy-csr-matrix-understand-indptr
np_data: np.ndarray = np.array(data, dtype=np.float32)
np_indices: np.ndarray = np.array(indices, dtype=np.int64)
data_and_indices_arrays.append(data_and_indices[:current_index, :])

np_data_and_indices: np.ndarray = np.concatenate(data_and_indices_arrays)

np_data = np_data_and_indices[:, 0].astype(np.float32)
np_indices = np_data_and_indices[:, 1].astype(np.int64)

np_indptr: np.ndarray = np.array(indptr, dtype=np.int64)

assert (
Expand All @@ -118,22 +130,19 @@ def _features_map_func(
return {"patient_ids": np_patient_ids, "feature_times": np_feature_times, "features": data_matrix}


def _features_agg_func(first_result: Any, second_result: Any) -> Any:
for k in first_result:
first_result[k].extend(second_result[k])

return first_result


class Featurizer(ABC):
"""A Featurizer takes a Patient and a list of Labels, then returns a row for each timepoint.
Featurizers must be preprocessed before they are used to compute normalization statistics.
A sparse representation named ColumnValue is used to represent the values returned by a Featurizer.
"""

def generate_preprocess_data(
self, patients: Iterator[meds_reader.Patient], label_map: Mapping[int, List[meds.Label]]
) -> Any:
def get_initial_preprocess_data(self) -> Any:
"""
Get the initial preprocess data
"""
pass

def add_preprocess_data(self, data: Any, patient: meds_reader.Patient, label_map: Mapping[int, List[meds.Label]]):
"""
Some featurizers need to do some preprocessing in order to prepare for featurization.
This function performs that preprocessing on the given patients and labels, and returns some state.
Expand Down Expand Up @@ -230,7 +239,7 @@ def __init__(self, featurizers: List[Featurizer]):

def preprocess_featurizers(
self,
pool: femr.mr.Pool,
db: meds_reader.PatientDatabase,
labels: List[meds.Label],
) -> None:
"""Preprocess `self.featurizers` on the provided set of labels."""
Expand All @@ -245,13 +254,12 @@ def preprocess_featurizers(
for label in labels:
label_map[label["patient_id"]].append(label)
# Split patients across multiple threads
patient_ids: List[int] = sorted(list({label["patient_id"] for label in labels}))
patient_ids: List[int] = list({label["patient_id"] for label in labels})

featurize_stats: List[List[Any]] = [[] for _ in self.featurizers]

for chunk_stats in pool.map(
functools.partial(_preprocess_map_func, label_map=label_map, featurizers=self.featurizers),
patient_ids=patient_ids,
for chunk_stats in db.filter(patient_ids).map(
functools.partial(_preprocess_map_func, label_map=label_map, featurizers=self.featurizers)
):
for a, b in zip(featurize_stats, chunk_stats):
a.append(b)
Expand All @@ -263,7 +271,7 @@ def preprocess_featurizers(

def featurize(
self,
pool: femr.mr.Pool,
db: meds_reader.PatientDatabase,
labels: List[meds.Label],
) -> Mapping[str, np.ndarray]:
"""
Expand All @@ -288,8 +296,8 @@ def featurize(

features = collections.defaultdict(list)

for feat_chunk in pool.map(
functools.partial(_features_map_func, label_map=label_map, featurizers=self.featurizers), patient_ids
for feat_chunk in db.filter(patient_ids).map(
functools.partial(_features_map_func, label_map=label_map, featurizers=self.featurizers)
):
for k, v in feat_chunk.items():
features[k].append(v)
Expand All @@ -313,30 +321,30 @@ def join_labels(features: Mapping[str, np.ndarray], labels: List[meds.Label]) ->
labels = list(labels)
labels.sort(key=lambda a: (a["patient_id"], a["prediction_time"]))

label_index = 0

indices = []
label_values = []

order = np.lexsort((features["feature_times"], features["patient_ids"]))

for i, patient_id, feature_time in zip(order, features["patient_ids"][order], features["feature_times"][order]):
if label_index == len(labels):
break

assert patient_id <= labels[label_index]["patient_id"], f"Missing features for label {labels[label_index]}"
if patient_id < labels[label_index]["patient_id"]:
continue
feature_index = 0

for label in labels:
while (
(feature_index + 1) < len(order)
and features["patient_ids"][order[feature_index + 1]] <= label["patient_id"]
and features["feature_times"][order[feature_index + 1]] <= label["prediction_time"]
):
feature_index += 1
is_valid = (
feature_index < len(order)
and features["patient_ids"][order[feature_index]] == label["patient_id"]
and features["feature_times"][order[feature_index]] <= label["prediction_time"]
)
assert (
feature_time <= labels[label_index]["prediction_time"]
), f"Missing features for label {labels[label_index]}"
if feature_time < labels[label_index]["prediction_time"]:
continue

indices.append(i)
label_values.append(labels[label_index]["boolean_value"])
label_index += 1
is_valid
), f'{feature_index} {label} {features["patient_ids"][order[feature_index]]} {features["feature_times"][order[feature_index]]}'
indices.append(order[feature_index])
label_values.append(label["boolean_value"])

return {
"boolean_values": np.array(label_values),
Expand Down
84 changes: 38 additions & 46 deletions src/femr/featurizers/featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,22 @@ def __init__(self, is_normalize: bool = True):
def get_num_columns(self) -> int:
return 1

def generate_preprocess_data(
self, patients: Iterator[meds_reader.Patient], label_map: Mapping[int, List[meds.Label]]
def get_initial_preprocess_data(self) -> OnlineStatistics:
return OnlineStatistics()

def add_preprocess_data(
self, age_statistics: OnlineStatistics, patient: meds_reader.Patient, label_map: Mapping[int, List[meds.Label]]
) -> OnlineStatistics:
"""Save the age of this patient (in years) at each label, to use for normalization."""
if not self.is_needs_preprocessing():
return OnlineStatistics()

age_statistics: OnlineStatistics = OnlineStatistics()

for patient in patients:
patient_birth_date: Optional[datetime.datetime] = get_patient_birthdate(patient)
assert patient_birth_date, "Patients must have a birth date"

for label in label_map[patient.patient_id]:
age_in_yrs: float = (label["prediction_time"] - patient_birth_date).days / 365
age_statistics.add(age_in_yrs)
patient_birth_date: Optional[datetime.datetime] = get_patient_birthdate(patient)
assert patient_birth_date, "Patients must have a birth date"

return age_statistics
for label in label_map[patient.patient_id]:
age_in_yrs: float = (label["prediction_time"] - patient_birth_date).days / 365
age_statistics.add(age_in_yrs)

def encorperate_prepreprocessed_data(self, data_elements: List[OnlineStatistics]) -> None:
self.age_statistics = OnlineStatistics.merge(data_elements)
print("What", self.age_statistics, data_elements)

def featurize(
self,
Expand Down Expand Up @@ -256,8 +250,15 @@ def get_columns(self, event: meds_reader.Event) -> Iterator[int]:
if code in self.code_to_column_index:
yield self.code_to_column_index[code]

def generate_preprocess_data(
self, patients: Iterator[meds_reader.Patient], label_map: Mapping[int, List[meds.Label]]
def get_initial_preprocess_data(self) -> Any:
return {
"observed_codes": set(),
"observed_string_value": collections.defaultdict(int),
"observed_numeric_value": collections.defaultdict(functools.partial(ReservoirSampler, 10000, 100)),
}

def add_preprocess_data(
self, data: Any, patient: meds_reader.Patient, label_map: Mapping[int, List[meds.Label]]
) -> Any:
"""
Some featurizers need to do some preprocessing in order to prepare for featurization.
Expand All @@ -266,34 +267,25 @@ def generate_preprocess_data(
Note that this function shouldn't mutate the Featurizer as it will be sharded.
"""
observed_codes: Set[str] = set()
observed_string_value: Dict[Tuple[str, str], int] = collections.defaultdict(int)
observed_numeric_value: Dict[str, ReservoirSampler] = collections.defaultdict(
functools.partial(ReservoirSampler, 10000, 100)
)

for patient in patients:
for event in patient.events:
# Check for excluded events
if self.excluded_event_filter is not None and self.excluded_event_filter(event):
continue

if event.text_value is not None:
if self.string_value_combination:
observed_string_value[(event.code, event.text_value[: self.characters_for_string_values])] += 1
elif event.numeric_value is not None:
if self.numeric_value_decile:
observed_numeric_value[event.code].add(event.numeric_value)
else:
for code in self.get_codes(event.code):
# If we haven't seen this code before, then add it to our list of included codes
observed_codes.add(code)

return {
"observed_codes": observed_codes,
"observed_string_value": observed_string_value,
"observed_numeric_value": observed_numeric_value,
}
observed_codes: Set[str] = data["observed_codes"]
observed_string_value: Dict[Tuple[str, str], int] = data["observed_string_value"]
observed_numeric_value: Dict[str, ReservoirSampler] = data["observed_numeric_value"]

for event in patient.events:
# Check for excluded events
if self.excluded_event_filter is not None and self.excluded_event_filter(event):
continue

if event.text_value is not None:
if self.string_value_combination:
observed_string_value[(event.code, event.text_value[: self.characters_for_string_values])] += 1
elif event.numeric_value is not None:
if self.numeric_value_decile:
observed_numeric_value[event.code].add(event.numeric_value)
else:
for code in self.get_codes(event.code):
# If we haven't seen this code before, then add it to our list of included codes
observed_codes.add(code)

def encorperate_prepreprocessed_data(self, data_elements: List[Any]) -> None:
"""
Expand Down
6 changes: 2 additions & 4 deletions src/femr/labelers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import meds
import meds_reader

import femr.mr


@dataclass(frozen=True)
class TimeHorizon:
Expand Down Expand Up @@ -63,7 +61,7 @@ def label(self, patient: meds_reader.Patient) -> List[meds.Label]:

def apply(
self,
pool: femr.mr.Pool,
db: meds_reader.PatientDatabase,
) -> List[meds.Label]:
"""Apply the `label()` function one-by-one to each Patient in a sequence of Patients.
Expand All @@ -75,7 +73,7 @@ def apply(
A list of labels
"""

return list(itertools.chain.from_iterable(pool.map(functools.partial(_label_map_func, labeler=self))))
return list(itertools.chain.from_iterable(db.map(functools.partial(_label_map_func, labeler=self))))


##########################################################
Expand Down
Loading

0 comments on commit c54f6ad

Please sign in to comment.