diff --git a/.mypy.ini b/.mypy.ini index cb5991d..13186d4 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -53,3 +53,6 @@ ignore_missing_imports = True [mypy-msgpack.*] ignore_missing_imports = True + +[mypy-xformers.*] +ignore_missing_imports = True diff --git a/src/femr/featurizers/core.py b/src/femr/featurizers/core.py index 7a6ea72..361475e 100644 --- a/src/femr/featurizers/core.py +++ b/src/femr/featurizers/core.py @@ -13,7 +13,6 @@ import numpy as np import scipy.sparse -import femr.mr import femr.ontology @@ -30,12 +29,12 @@ class ColumnValue(NamedTuple): def _preprocess_map_func( patients: Iterator[meds_reader.Patient], *, label_map: Mapping[int, List[meds.Label]], featurizers: List[Featurizer] ) -> List[List[Any]]: - result = [] - patients_list = list(patients) - for featurizer in featurizers: - result.append(featurizer.generate_preprocess_data(iter(patients_list), label_map)) + initial_data = [featurizer.get_initial_preprocess_data() for featurizer in featurizers] + for patient in patients: + for data, featurizer in zip(initial_data, featurizers): + featurizer.add_preprocess_data(data, patient, label_map) - return result + return initial_data def _features_map_func( @@ -43,9 +42,11 @@ def _features_map_func( ) -> Mapping[str, Any]: # Construct CSR sparse matrix # non-zero entries in sparse matrix - data: List[Any] = [] - # maps each element in `data`` to its column in the sparse matrix - indices: List[int] = [] + data_and_indices = np.zeros((1024, 2), np.float64) + data_and_indices_arrays = [] + + current_index = 0 + # maps each element in `data` and `indices` to the rows of the sparse matrix indptr: List[int] = [] @@ -73,7 +74,7 @@ def _features_map_func( a.append(b) for features in features_per_label: - indptr.append(len(indices)) + indptr.append(current_index + len(data_and_indices_arrays) * 1024) # Keep track of starting column for each successive featurizer as we # combine their features into one large matrix @@ -85,14 +86,20 @@ def _features_map_func( f"{column} on patient {patient.patient_id} ({column} must be between 0 and " f"{featurizer.get_num_columns()})" ) - indices.append(column_offset + column) - data.append(value) + data_and_indices[current_index, 0] = value + data_and_indices[current_index, 1] = column_offset + column + + current_index += 1 + + if current_index == 1024: + current_index = 0 + data_and_indices_arrays.append(data_and_indices.copy()) # Record what the starting column should be for the next featurizer column_offset += featurizer.get_num_columns() # Need one last `indptr` for end of last row in CSR sparse matrix - indptr.append(len(indices)) + indptr.append(current_index + len(data_and_indices_arrays) * 1024) # n_rows = number of Labels across all Patients total_rows: int = len(indptr) - 1 @@ -100,8 +107,13 @@ def _features_map_func( total_columns: int = sum(x.get_num_columns() for x in featurizers) # Explanation of CSR Matrix: https://stackoverflow.com/questions/52299420/scipy-csr-matrix-understand-indptr - np_data: np.ndarray = np.array(data, dtype=np.float32) - np_indices: np.ndarray = np.array(indices, dtype=np.int64) + data_and_indices_arrays.append(data_and_indices[:current_index, :]) + + np_data_and_indices: np.ndarray = np.concatenate(data_and_indices_arrays) + + np_data = np_data_and_indices[:, 0].astype(np.float32) + np_indices = np_data_and_indices[:, 1].astype(np.int64) + np_indptr: np.ndarray = np.array(indptr, dtype=np.int64) assert ( @@ -118,22 +130,19 @@ def _features_map_func( return {"patient_ids": np_patient_ids, "feature_times": np_feature_times, "features": data_matrix} -def _features_agg_func(first_result: Any, second_result: Any) -> Any: - for k in first_result: - first_result[k].extend(second_result[k]) - - return first_result - - class Featurizer(ABC): """A Featurizer takes a Patient and a list of Labels, then returns a row for each timepoint. Featurizers must be preprocessed before they are used to compute normalization statistics. A sparse representation named ColumnValue is used to represent the values returned by a Featurizer. """ - def generate_preprocess_data( - self, patients: Iterator[meds_reader.Patient], label_map: Mapping[int, List[meds.Label]] - ) -> Any: + def get_initial_preprocess_data(self) -> Any: + """ + Get the initial preprocess data + """ + pass + + def add_preprocess_data(self, data: Any, patient: meds_reader.Patient, label_map: Mapping[int, List[meds.Label]]): """ Some featurizers need to do some preprocessing in order to prepare for featurization. This function performs that preprocessing on the given patients and labels, and returns some state. @@ -230,7 +239,7 @@ def __init__(self, featurizers: List[Featurizer]): def preprocess_featurizers( self, - pool: femr.mr.Pool, + db: meds_reader.PatientDatabase, labels: List[meds.Label], ) -> None: """Preprocess `self.featurizers` on the provided set of labels.""" @@ -245,13 +254,12 @@ def preprocess_featurizers( for label in labels: label_map[label["patient_id"]].append(label) # Split patients across multiple threads - patient_ids: List[int] = sorted(list({label["patient_id"] for label in labels})) + patient_ids: List[int] = list({label["patient_id"] for label in labels}) featurize_stats: List[List[Any]] = [[] for _ in self.featurizers] - for chunk_stats in pool.map( - functools.partial(_preprocess_map_func, label_map=label_map, featurizers=self.featurizers), - patient_ids=patient_ids, + for chunk_stats in db.filter(patient_ids).map( + functools.partial(_preprocess_map_func, label_map=label_map, featurizers=self.featurizers) ): for a, b in zip(featurize_stats, chunk_stats): a.append(b) @@ -263,7 +271,7 @@ def preprocess_featurizers( def featurize( self, - pool: femr.mr.Pool, + db: meds_reader.PatientDatabase, labels: List[meds.Label], ) -> Mapping[str, np.ndarray]: """ @@ -288,8 +296,8 @@ def featurize( features = collections.defaultdict(list) - for feat_chunk in pool.map( - functools.partial(_features_map_func, label_map=label_map, featurizers=self.featurizers), patient_ids + for feat_chunk in db.filter(patient_ids).map( + functools.partial(_features_map_func, label_map=label_map, featurizers=self.featurizers) ): for k, v in feat_chunk.items(): features[k].append(v) @@ -313,30 +321,30 @@ def join_labels(features: Mapping[str, np.ndarray], labels: List[meds.Label]) -> labels = list(labels) labels.sort(key=lambda a: (a["patient_id"], a["prediction_time"])) - label_index = 0 - indices = [] label_values = [] order = np.lexsort((features["feature_times"], features["patient_ids"])) - for i, patient_id, feature_time in zip(order, features["patient_ids"][order], features["feature_times"][order]): - if label_index == len(labels): - break - - assert patient_id <= labels[label_index]["patient_id"], f"Missing features for label {labels[label_index]}" - if patient_id < labels[label_index]["patient_id"]: - continue + feature_index = 0 + for label in labels: + while ( + (feature_index + 1) < len(order) + and features["patient_ids"][order[feature_index + 1]] <= label["patient_id"] + and features["feature_times"][order[feature_index + 1]] <= label["prediction_time"] + ): + feature_index += 1 + is_valid = ( + feature_index < len(order) + and features["patient_ids"][order[feature_index]] == label["patient_id"] + and features["feature_times"][order[feature_index]] <= label["prediction_time"] + ) assert ( - feature_time <= labels[label_index]["prediction_time"] - ), f"Missing features for label {labels[label_index]}" - if feature_time < labels[label_index]["prediction_time"]: - continue - - indices.append(i) - label_values.append(labels[label_index]["boolean_value"]) - label_index += 1 + is_valid + ), f'{feature_index} {label} {features["patient_ids"][order[feature_index]]} {features["feature_times"][order[feature_index]]}' + indices.append(order[feature_index]) + label_values.append(label["boolean_value"]) return { "boolean_values": np.array(label_values), diff --git a/src/femr/featurizers/featurizers.py b/src/femr/featurizers/featurizers.py index 07ce942..74d9adf 100644 --- a/src/femr/featurizers/featurizers.py +++ b/src/femr/featurizers/featurizers.py @@ -35,28 +35,22 @@ def __init__(self, is_normalize: bool = True): def get_num_columns(self) -> int: return 1 - def generate_preprocess_data( - self, patients: Iterator[meds_reader.Patient], label_map: Mapping[int, List[meds.Label]] + def get_initial_preprocess_data(self) -> OnlineStatistics: + return OnlineStatistics() + + def add_preprocess_data( + self, age_statistics: OnlineStatistics, patient: meds_reader.Patient, label_map: Mapping[int, List[meds.Label]] ) -> OnlineStatistics: """Save the age of this patient (in years) at each label, to use for normalization.""" - if not self.is_needs_preprocessing(): - return OnlineStatistics() - - age_statistics: OnlineStatistics = OnlineStatistics() - - for patient in patients: - patient_birth_date: Optional[datetime.datetime] = get_patient_birthdate(patient) - assert patient_birth_date, "Patients must have a birth date" - - for label in label_map[patient.patient_id]: - age_in_yrs: float = (label["prediction_time"] - patient_birth_date).days / 365 - age_statistics.add(age_in_yrs) + patient_birth_date: Optional[datetime.datetime] = get_patient_birthdate(patient) + assert patient_birth_date, "Patients must have a birth date" - return age_statistics + for label in label_map[patient.patient_id]: + age_in_yrs: float = (label["prediction_time"] - patient_birth_date).days / 365 + age_statistics.add(age_in_yrs) def encorperate_prepreprocessed_data(self, data_elements: List[OnlineStatistics]) -> None: self.age_statistics = OnlineStatistics.merge(data_elements) - print("What", self.age_statistics, data_elements) def featurize( self, @@ -256,8 +250,15 @@ def get_columns(self, event: meds_reader.Event) -> Iterator[int]: if code in self.code_to_column_index: yield self.code_to_column_index[code] - def generate_preprocess_data( - self, patients: Iterator[meds_reader.Patient], label_map: Mapping[int, List[meds.Label]] + def get_initial_preprocess_data(self) -> Any: + return { + "observed_codes": set(), + "observed_string_value": collections.defaultdict(int), + "observed_numeric_value": collections.defaultdict(functools.partial(ReservoirSampler, 10000, 100)), + } + + def add_preprocess_data( + self, data: Any, patient: meds_reader.Patient, label_map: Mapping[int, List[meds.Label]] ) -> Any: """ Some featurizers need to do some preprocessing in order to prepare for featurization. @@ -266,34 +267,25 @@ def generate_preprocess_data( Note that this function shouldn't mutate the Featurizer as it will be sharded. """ - observed_codes: Set[str] = set() - observed_string_value: Dict[Tuple[str, str], int] = collections.defaultdict(int) - observed_numeric_value: Dict[str, ReservoirSampler] = collections.defaultdict( - functools.partial(ReservoirSampler, 10000, 100) - ) - - for patient in patients: - for event in patient.events: - # Check for excluded events - if self.excluded_event_filter is not None and self.excluded_event_filter(event): - continue - - if event.text_value is not None: - if self.string_value_combination: - observed_string_value[(event.code, event.text_value[: self.characters_for_string_values])] += 1 - elif event.numeric_value is not None: - if self.numeric_value_decile: - observed_numeric_value[event.code].add(event.numeric_value) - else: - for code in self.get_codes(event.code): - # If we haven't seen this code before, then add it to our list of included codes - observed_codes.add(code) - - return { - "observed_codes": observed_codes, - "observed_string_value": observed_string_value, - "observed_numeric_value": observed_numeric_value, - } + observed_codes: Set[str] = data["observed_codes"] + observed_string_value: Dict[Tuple[str, str], int] = data["observed_string_value"] + observed_numeric_value: Dict[str, ReservoirSampler] = data["observed_numeric_value"] + + for event in patient.events: + # Check for excluded events + if self.excluded_event_filter is not None and self.excluded_event_filter(event): + continue + + if event.text_value is not None: + if self.string_value_combination: + observed_string_value[(event.code, event.text_value[: self.characters_for_string_values])] += 1 + elif event.numeric_value is not None: + if self.numeric_value_decile: + observed_numeric_value[event.code].add(event.numeric_value) + else: + for code in self.get_codes(event.code): + # If we haven't seen this code before, then add it to our list of included codes + observed_codes.add(code) def encorperate_prepreprocessed_data(self, data_elements: List[Any]) -> None: """ diff --git a/src/femr/labelers/core.py b/src/femr/labelers/core.py index 8f5dc1e..d008877 100644 --- a/src/femr/labelers/core.py +++ b/src/femr/labelers/core.py @@ -15,8 +15,6 @@ import meds import meds_reader -import femr.mr - @dataclass(frozen=True) class TimeHorizon: @@ -63,7 +61,7 @@ def label(self, patient: meds_reader.Patient) -> List[meds.Label]: def apply( self, - pool: femr.mr.Pool, + db: meds_reader.PatientDatabase, ) -> List[meds.Label]: """Apply the `label()` function one-by-one to each Patient in a sequence of Patients. @@ -75,7 +73,7 @@ def apply( A list of labels """ - return list(itertools.chain.from_iterable(pool.map(functools.partial(_label_map_func, labeler=self)))) + return list(itertools.chain.from_iterable(db.map(functools.partial(_label_map_func, labeler=self)))) ########################################################## diff --git a/src/femr/models/processor.py b/src/femr/models/processor.py index abe1744..bb1c972 100644 --- a/src/femr/models/processor.py +++ b/src/femr/models/processor.py @@ -3,18 +3,21 @@ import collections import datetime import functools -from typing import Any, Dict, List, Mapping, Optional, Tuple +from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple +import datasets import meds +import meds_reader import numpy as np import torch.utils.data -import femr.hf_utils import femr.models.tokenizer import femr.pat_utils -def map_preliminary_batch_stats(batch, indices, *, processor: FEMRBatchProcessor, max_length: int): +def map_preliminary_batch_stats( + patients: Iterable[meds_reader.Patient], *, processor: FEMRBatchProcessor, max_length: int +): """ This function creates preliminary batch statistics, to be used for final batching. @@ -39,11 +42,7 @@ def map_preliminary_batch_stats(batch, indices, *, processor: FEMRBatchProcessor """ lengths = [] - for patient_index, patient_id, events in zip(indices, batch["patient_id"], batch["events"]): - patient = { - "patient_id": patient_id, - "events": events, - } + for patient in patients: data = processor.convert_patient(patient) # There are no labels for this patient @@ -57,28 +56,21 @@ def map_preliminary_batch_stats(batch, indices, *, processor: FEMRBatchProcessor for label_index in data["transformer"]["label_indices"]: if (label_index - current_start + 1) >= max_length: if current_start != current_end: - lengths.append((patient_index, current_start, current_end - current_start + 1)) + lengths.append((patient.patient_id, current_start, current_end - current_start + 1)) current_start = label_index - max_length + 1 current_end = label_index else: current_end = label_index - lengths.append((patient_index, current_start, current_end - current_start + 1)) + lengths.append((patient.patient_id, current_start, current_end - current_start + 1)) else: last_index = data["transformer"]["label_indices"][-1] length = min(max_length, last_index + 1) - lengths.append((patient_index, last_index + 1 - length, length)) + lengths.append((patient.patient_id, last_index + 1 - length, length)) if len(lengths) > 0: - return [np.array(lengths, dtype=np.int64)] + return np.array(lengths, dtype=np.int64) else: - return [] - - -def agg_preliminary_batch_stats(lengths1, lengths2): - """Aggregate preliminary length statistics from the map_preliminary_batch_stats""" - lengths1.extend(lengths2) - - return lengths1 + return np.zeros(shape=(0, 3), dtype=np.int64) class BatchCreator: @@ -168,41 +160,40 @@ def add_patient(self, patient: meds_reader.Patient, offset: int = 0, max_length: current_date = event.time.date() codes_seen_today = set() - for measurement in event["measurements"]: - # Get features and weights for the current event - features, weights = self.tokenizer.get_feature_codes(event.time, measurement) + # Get features and weights for the current event + features, weights = self.tokenizer.get_feature_codes(event) - # Ignore events with no features - if len(features) == 0: - continue + # Ignore events with no features + if len(features) == 0: + continue - # Ignore events where all features have already occurred - if all(feature in codes_seen_today for feature in features): - continue + # Ignore events where all features have already occurred + if all(feature in codes_seen_today for feature in features): + continue - codes_seen_today |= set(features) + codes_seen_today |= set(features) - if (self.task is not None) and (last_time is not None): - # Now we have to consider whether or not to have labels for this time step - # The add_event function returns how many labels to assign for this time - num_added = self.task.add_event(last_time, event.time, features) - for _ in range(num_added): - per_patient_label_indices.append(len(per_patient_ages) - 1) + if (self.task is not None) and (last_time is not None): + # Now we have to consider whether or not to have labels for this time step + # The add_event function returns how many labels to assign for this time + num_added = self.task.add_event(last_time, event.time, features) + for _ in range(num_added): + per_patient_label_indices.append(len(per_patient_ages) - 1) - if not self.tokenizer.is_hierarchical: - assert len(features) == 1 - per_patient_tokens.append(features[0]) - else: - assert weights is not None - per_patient_hierarchical_tokens.extend(features) - per_patient_hierarchical_weights.extend(weights) - per_patient_token_indices.append(len(per_patient_hierarchical_tokens)) + if not self.tokenizer.is_hierarchical: + assert len(features) == 1 + per_patient_tokens.append(features[0]) + else: + assert weights is not None + per_patient_hierarchical_tokens.extend(features) + per_patient_hierarchical_weights.extend(weights) + per_patient_token_indices.append(len(per_patient_hierarchical_tokens)) - per_patient_ages.append((event.time - birth) / datetime.timedelta(days=1)) - per_patient_normalized_ages.append(self.tokenizer.normalize_age(event.time - birth)) - per_patient_timestamps.append(event.time.replace(tzinfo=datetime.timezone.utc).timestamp()) + per_patient_ages.append((event.time - birth) / datetime.timedelta(days=1)) + per_patient_normalized_ages.append(self.tokenizer.normalize_age(event.time - birth)) + per_patient_timestamps.append(event.time.replace(tzinfo=datetime.timezone.utc).timestamp()) - last_time = event.time + last_time = event.time if self.task is not None and last_time is not None: num_added = self.task.add_event(last_time, None, None) @@ -327,18 +318,19 @@ def cleanup_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: return batch -def _batch_generator(batch_data: Tuple[np.ndarray, np.ndarray], *, creator: BatchCreator, dataset: datasets.Dataset): - for lengths, offsets in batch_data: - offsets = list(offsets) - for start, end in zip(offsets, offsets[1:]): - creator.start_batch() - for patient_index, offset, length in lengths[start:end, :]: - creator.add_patient(dataset[patient_index.item()], offset, length) +def _batch_generator(batch_data: Tuple[np.ndarray, np.ndarray], *, creator: BatchCreator, path_to_database: str): + with meds_reader.PatientDatabase(path_to_database) as database: + for lengths, offsets in batch_data: + offsets = list(offsets) + for start, end in zip(offsets, offsets[1:]): + creator.start_batch() + for patient_index, offset, length in lengths[start:end, :]: + creator.add_patient(database[patient_index.item()], offset, length) - result = creator.get_batch_data() - assert "task" in result, f"No task present in {lengths[start:end,:]}" + result = creator.get_batch_data() + assert "task" in result, f"No task present in {lengths[start:end,:]}" - yield result + yield result def _add_dimension(data: Any) -> Any: @@ -399,7 +391,9 @@ def collate(self, batches: List[Mapping[str, Any]]) -> Mapping[str, Any]: assert len(batches) == 1, "Can only have one batch when collating" return {"batch": _add_dimension(self.creator.cleanup_batch(batches[0]))} - def convert_dataset(self, dataset, tokens_per_batch: int, min_patients_per_batch: int = 4, num_proc: int = 1): + def convert_dataset( + self, db: meds_reader.PatientDatabase, tokens_per_batch: int, min_patients_per_batch: int = 4, num_proc: int = 1 + ): """Convert an entire dataset to batches. Arguments: @@ -411,25 +405,16 @@ def convert_dataset(self, dataset, tokens_per_batch: int, min_patients_per_batch Returns: A huggingface dataset object containing batches """ - if isinstance(dataset, datasets.DatasetDict): - return datasets.DatasetDict( - { - k: self.convert_dataset(v, tokens_per_batch, min_patients_per_batch, num_proc) - for k, v in dataset.items() - } - ) max_length = tokens_per_batch // min_patients_per_batch - lengths = femr.hf_utils.aggregate_over_dataset( - dataset, - functools.partial(map_preliminary_batch_stats, processor=self, max_length=max_length), - agg_preliminary_batch_stats, - num_proc=num_proc, - batch_size=1_000, - with_indices=True, + + length_chunks = tuple( + db.map( + functools.partial(map_preliminary_batch_stats, processor=self, max_length=max_length), + ) ) - lengths = np.concatenate(lengths) + lengths = np.concatenate(length_chunks) rng = np.random.default_rng() rng.shuffle(lengths) @@ -468,12 +453,10 @@ def convert_dataset(self, dataset, tokens_per_batch: int, min_patients_per_batch ) ) - print("Creating batches", len(batches)) - batch_func = functools.partial( _batch_generator, creator=self.creator, - dataset=dataset, + path_to_database=db.path_to_database, ) batch_dataset = datasets.Dataset.from_generator( diff --git a/src/femr/models/tasks.py b/src/femr/models/tasks.py index d67cd88..bfa24b4 100644 --- a/src/femr/models/tasks.py +++ b/src/femr/models/tasks.py @@ -4,15 +4,16 @@ import collections import datetime import functools -from typing import Any, Dict, List, Mapping, Optional, Sequence, Set, Tuple +from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Set, Tuple import meds +import meds_reader import numpy as np import scipy.sparse import torch -import femr.index import femr.models.config +import femr.ontology import femr.pat_utils import femr.stat_utils @@ -67,10 +68,6 @@ def __init__(self, labels: Sequence[meds.Label]): def get_task_config(self) -> femr.models.config.FEMRTaskConfig: return femr.models.config.FEMRTaskConfig(task_type="labeled_patients") - def filter_dataset(self, dataset: datasets.Dataset, index: femr.index.PatientIndex) -> datasets.Dataset: - indices = [index.get_index(patient_id) for patient_id in self.label_map] - return dataset.select(indices) - def start_patient(self, patient: meds_reader.Patient, _ontology: Optional[femr.ontology.Ontology]) -> None: self.current_labels = self.label_map[patient.patient_id] self.current_label_index = 0 @@ -184,15 +181,14 @@ def __init__( self, ontology: femr.ontology.Ontology, patient: meds_reader.Patient, code_whitelist: Optional[Set[str]] = None ): self.survival_events = [] - self.final_date = patient.events[-1]["time"] + self.final_date = patient.events[-1].time self.future_times = collections.defaultdict(list) for event in patient.events: codes = set() - for measurement in event["measurements"]: - for parent in ontology.get_all_parents(event.code): - if code_whitelist is None or parent in code_whitelist: - codes.add(parent) + for parent in ontology.get_all_parents(event.code): + if code_whitelist is None or parent in code_whitelist: + codes.add(parent) for code in codes: self.future_times[code].append(event.time) @@ -219,14 +215,14 @@ def get_future_events_for_time( return (delta, {k: v[-1] - time for k, v in self.future_times.items()}) -def _prefit_motor_map(batch, *, tasks: List[str], ontology: femr.ontology.Ontology) -> Any: +def _prefit_motor_map( + patients: Iterator[meds_reader.Patient], *, tasks: List[str], ontology: femr.ontology.Ontology +) -> Any: task_time_stats: List[Any] = [[0, 0, femr.stat_utils.OnlineStatistics()] for _ in range(len(tasks))] event_times = femr.stat_utils.ReservoirSampler(100_000) task_set = set(tasks) - for patient_id, events in zip(batch["patient_id"], batch["events"]): - patient = {"patient_id": patient_id, "events": events} - + for patient in patients: calculator = SurvivalCalculator(ontology, patient, task_set) birth = femr.pat_utils.get_patient_birthdate(patient) @@ -268,7 +264,7 @@ class MOTORTask(Task): @classmethod def fit_pretraining_task_info( cls, - dataset: datasets.Dataset, + db: meds_reader.PatientDatabase, tokenizer: femr.models.tokenizer.FEMRTokenizer, num_tasks: int, num_bins: int, @@ -284,12 +280,8 @@ def fit_pretraining_task_info( assert len(tasks) == num_tasks, "Could not find enough tasks in the provided tokenizer" - length_samples, stats = femr.hf_utils.aggregate_over_dataset( - dataset, - functools.partial(_prefit_motor_map, tasks=tasks, ontology=tokenizer.ontology), - _prefit_motor_agg, - 1_000, - num_proc=num_proc, + length_samples, stats = functools.reduce( + _prefit_motor_agg, db.map(functools.partial(_prefit_motor_map, tasks=tasks, ontology=tokenizer.ontology)) ) time_bins = np.percentile(length_samples.samples, np.linspace(0, 100, num_bins + 1)) diff --git a/src/femr/models/tokenizer.py b/src/femr/models/tokenizer.py index f7004b9..45f5347 100644 --- a/src/femr/models/tokenizer.py +++ b/src/femr/models/tokenizer.py @@ -6,35 +6,39 @@ import functools import math import os -from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterator, List, Mapping, Optional, Set, Tuple, Union import meds +import meds_reader import msgpack import numpy as np import transformers -import femr.hf_utils +import femr.ontology import femr.stat_utils def train_tokenizer( - dataset, + db: meds_reader.PatientDatabase, vocab_size: int, is_hierarchical: bool = False, num_numeric: int = 1000, ontology: Optional[femr.ontology.Ontology] = None, - num_proc: int = 1, ) -> FEMRTokenizer: """Train a FEMR tokenizer from the given dataset""" - statistics = femr.hf_utils.aggregate_over_dataset( - dataset, - functools.partial( - map_statistics, num_patients=len(dataset), is_hierarchical=is_hierarchical, ontology=ontology - ), + + statistics = functools.reduce( agg_statistics, - num_proc=num_proc, - batch_size=1_000, + db.map( + functools.partial( + map_statistics, + num_patients=len(db), + is_hierarchical=is_hierarchical, + ontology=ontology, + ) + ), ) + return FEMRTokenizer( convert_statistics_to_msgpack(statistics, vocab_size, is_hierarchical, num_numeric, ontology), ontology ) @@ -64,7 +68,12 @@ def normalize_unit(unit): def map_statistics( - batch, *, num_patients: int, is_hierarchical: bool, frac_values=0.05, ontology: Optional[femr.ontology.Ontology] + patients: Iterator[meds_reader.Patient], + *, + num_patients: int, + is_hierarchical: bool, + frac_values=0.05, + ontology: Optional[femr.ontology.Ontology], ) -> Mapping[str, Any]: age_stats = femr.stat_utils.OnlineStatistics() code_counts: Dict[str, float] = collections.defaultdict(float) @@ -81,43 +90,39 @@ def map_statistics( text_counts: Dict[Any, float] = collections.defaultdict(float) - for events in batch["events"]: - total_events = 0 - for event in events: - for measurement in event["measurements"]: - total_events += 1 + for patient in patients: + total_events = len(patient.events) if total_events == 0: continue weight = 1.0 / (num_patients * total_events) - birth_date = events[0]["time"] + birth_date = patient.events[0].time code_set = set() text_set = set() pat_numeric_samples = [] - for event in events: - for measurement in event["measurements"]: - if event.time != birth_date: - age_stats.add(weight, (event.time - birth_date).total_seconds()) - if not is_hierarchical: - assert numeric_samples_by_lab is not None - if event.numeric_value is not None: - numeric_samples_by_lab[event.code].add(event.numeric_value, weight) - elif event.text_value is not None: - text_counts[(event.code, event.text_value)] += weight - else: - code_counts[event.code] += weight + for event in patient.events: + if event.time != birth_date: + age_stats.add(weight, (event.time - birth_date).total_seconds()) + if not is_hierarchical: + assert numeric_samples_by_lab is not None + if event.numeric_value is not None: + numeric_samples_by_lab[event.code].add(event.numeric_value, weight) + elif event.text_value is not None: + text_counts[(event.code, event.text_value)] += weight else: - code_set.add(event.code) + code_counts[event.code] += weight + else: + code_set.add(event.code) - if event.text_value is not None and event.text_value != "": - text_set.add(event.text_value) + if event.text_value is not None and event.text_value != "": + text_set.add(event.text_value) - if measurement.get("metadata") and normalize_unit(measurement["metadata"].get("unit")) is not None: - text_set.add(normalize_unit(measurement["metadata"]["unit"])) + if getattr(event, "unit", None) is not None: + text_set.add(normalize_unit(event.unit)) - if event.numeric_value is not None: - pat_numeric_samples.append(event.numeric_value) + if event.numeric_value is not None: + pat_numeric_samples.append(event.numeric_value) if is_hierarchical: assert numeric_samples is not None @@ -390,9 +395,7 @@ def start_patient(self): # This is currently a null-op, but is required for cost featurization pass - def get_feature_codes( - self, _time: datetime.datetime, measurement: meds_reader.Event - ) -> Tuple[List[int], Optional[List[float]]]: + def get_feature_codes(self, event: meds_reader.Event) -> Tuple[List[int], Optional[List[float]]]: """Get codes for the provided measurement and time""" # Note that time is currently not used in this code, but it is required for cost featurization @@ -404,15 +407,15 @@ def get_feature_codes( if parent in self.code_lookup ] weights = [1 / len(codes) for _ in codes] - if measurement.get("metadata") and normalize_unit(measurement["metadata"].get("unit")) is not None: - value = self.string_lookup.get(normalize_unit(measurement["metadata"]["unit"])) + if getattr(event, "unit", None) is not None: + value = self.string_lookup.get(normalize_unit(event.unit)) if value is not None: codes.append(value) weights.append(1) - if measurement.get("numeric_value") is not None and len(self.numeric_indices) > 0: + if event.numeric_value is not None and len(self.numeric_indices) > 0: codes.append(self.numeric_indices[bisect.bisect(self.numeric_values, event.numeric_value)]) weights.append(1) - if measurement.get("text_value") is not None: + if event.text_value is not None: value = self.string_lookup.get(event.text_value) if value is not None: codes.append(value) @@ -420,13 +423,13 @@ def get_feature_codes( return codes, weights else: - if measurement.get("numeric_value") is not None: + if event.numeric_value is not None: for start, end, i in self.numeric_lookup.get(event.code, []): if start <= event.numeric_value < end: return [i], None else: return [], None - elif measurement.get("text_value") is not None: + elif event.text_value is not None: value = self.string_lookup.get((event.code, event.text_value)) if value is not None: return [value], None diff --git a/src/femr/models/transformer.py b/src/femr/models/transformer.py index 4c529b6..b0c0e58 100644 --- a/src/femr/models/transformer.py +++ b/src/femr/models/transformer.py @@ -1,10 +1,12 @@ from __future__ import annotations import collections +import datetime import math from typing import Any, Dict, List, Mapping, Optional, Tuple import meds +import meds_reader import numpy as np import torch import torch.nn.functional as F @@ -282,6 +284,8 @@ def create_task_head(self) -> nn.Module: return LabeledPatientTaskHead(hidden_size, **task_kwargs) elif task_type == "motor": return MOTORTaskHead(hidden_size, **task_kwargs) + else: + raise RuntimeError("Could not determine head for task " + task_type) def forward(self, batch: Mapping[str, Any], return_loss=True, return_logits=False, return_reprs=False): # Need a return_loss parameter for transformers.Trainer to work properly @@ -321,7 +325,7 @@ def forward(self, batch: Mapping[str, Any], return_loss=True, return_logits=Fals def compute_features( - dataset: datasets.Dataset, + db: meds_reader.PatientDatabase, model_path: str, labels: List[meds.Label], num_proc: int = 1, @@ -347,13 +351,11 @@ def compute_features( """ task = femr.models.tasks.LabeledPatientTask(labels) - index = femr.index.PatientIndex(dataset, num_proc=num_proc) - model = femr.models.transformer.FEMRModel.from_pretrained(model_path, task_config=task.get_task_config()) tokenizer = femr.models.tokenizer.FEMRTokenizer.from_pretrained(model_path, ontology=ontology) processor = femr.models.processor.FEMRBatchProcessor(tokenizer, task=task) - filtered_data = task.filter_dataset(dataset, index) + filtered_data = db.filter(list(task.label_map.keys())) if device: model = model.to(device) diff --git a/src/femr/mr.py b/src/femr/mr.py deleted file mode 100644 index b387c78..0000000 --- a/src/femr/mr.py +++ /dev/null @@ -1,97 +0,0 @@ -from __future__ import annotations - -import multiprocessing -import pickle -from typing import Any, Callable, Iterable, Iterator, List, Optional, Tuple, TypeVar - -import meds_reader - -A = TypeVar("A") - -WorkEntry = Tuple[Callable[[Iterable[meds_reader.Patient]], Any], List[int]] - - -def _runner( - database: meds_reader.PatientDatabase, - input_queue: multiprocessing.Queue[Optional[WorkEntry]], - result_queue: multiprocessing.Queue[Any], -) -> None: - while True: - next_work = input_queue.get() - if next_work is None: - break - - map_func, patient_ids = next_work - - map_func = pickle.loads(map_func) - - result = map_func(database[patient_id] for patient_id in patient_ids) - result_queue.put(result) - - -class Pool: - def __init__(self, database: meds_reader.PatientDatabase, num_threads: int) -> None: - self._all_patient_ids = list(database) - self._num_threads = num_threads - - if num_threads != 1: - self._processes = [] - mp = multiprocessing.get_context("spawn") - - self._input_queue: multiprocessing.Queue[Optional[WorkEntry]] = mp.Queue() - self._result_queue: multiprocessing.Queue[Any] = mp.Queue() - - for _ in range(num_threads): - process = mp.Process( - target=_runner, - kwargs={"database": database, "input_queue": self._input_queue, "result_queue": self._result_queue}, - ) - process.start() - self._processes.append(process) - else: - self._database = database - - def map( - self, map_func: Callable[[Iterable[meds_reader.Patient]], A], patient_ids: Optional[List[int]] = None - ) -> Iterator[A]: - """Apply the provided map function to the database""" - - if patient_ids is None: - patient_ids = self._all_patient_ids - - if self._num_threads != 1: - patients_per_part = (len(patient_ids) + len(self._processes) - 1) // len(self._processes) - - num_work_entries = 0 - - map_func_p = pickle.dumps(map_func) - - for i in range(len(self._processes)): - patient_ids_for_thread = patient_ids[i * patients_per_part : (i + 1) * patients_per_part] - - if len(patient_ids_for_thread) == 0: - continue - - num_work_entries += 1 - self._input_queue.put((map_func_p, patient_ids_for_thread)) - - return (self._result_queue.get() for _ in range(num_work_entries)) - else: - return (map_func(self._database[patient_id] for patient_id in patient_ids),) - - def terminate(self) -> None: - """Close the pool""" - if self._num_threads != 1: - for _ in self._processes: - self._input_queue.put(None) - for process in self._processes: - process.join() - - def __enter__(self) -> Pool: - return self - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - self.terminate() - - -__all__ = ["Pool"] diff --git a/src/femr/ontology.py b/src/femr/ontology.py index 0019927..6a42b6d 100644 --- a/src/femr/ontology.py +++ b/src/femr/ontology.py @@ -3,28 +3,21 @@ import collections import functools import os -from typing import Any, Dict, Iterable, Optional, Set +from typing import Any, Dict, Iterable, Iterator, Optional, Set import meds +import meds_reader import polars as pl -import femr.mr - -def _get_all_codes_map(batch) -> Set[str]: +def _get_all_codes_map(patients: Iterator[meds_reader.Patient]) -> Set[str]: result = set() - for events in batch["events"]: - for event in events: - for measurement in event["measurements"]: - result.add(event.code) + for patient in patients: + for event in patient.events: + result.add(event.code) return result -def _get_all_codes_agg(first: Set[str], second: Set[str]) -> Set[str]: - first |= second - return first - - class Ontology: def __init__(self, athena_path: str, code_metadata: meds.CodeMetadata = {}): """Create an Ontology from an Athena download and an optional meds Code Metadata structure. @@ -105,18 +98,13 @@ def __init__(self, athena_path: str, code_metadata: meds.CodeMetadata = {}): def prune_to_dataset( self, - dataset: datasets.Dataset, - num_proc: int = 1, + data_pool: meds_reader.PatientDatabase, prune_all_descriptions: bool = False, remove_ontologies: Set[str] = set(), ) -> None: - valid_codes = femr.hf_utils.aggregate_over_dataset( - dataset, - functools.partial(_get_all_codes_map), - _get_all_codes_agg, - num_proc=num_proc, - batch_size=1_000, - ) + valid_codes = set() + for chunk_codes in data_pool.map(_get_all_codes_map): + valid_codes |= chunk_codes if prune_all_descriptions: self.description_map = {} diff --git a/src/femr/post_etl_pipelines/stanford.py b/src/femr/post_etl_pipelines/stanford.py index 18ce91c..62d0ee9 100644 --- a/src/femr/post_etl_pipelines/stanford.py +++ b/src/femr/post_etl_pipelines/stanford.py @@ -7,6 +7,7 @@ from typing import Callable, Sequence import meds +import meds_reader from femr.transforms import delta_encode, remove_nones from femr.transforms.stanford import ( @@ -19,7 +20,7 @@ def _is_visit_measurement(e: meds_reader.Event) -> bool: - return e["metadata"]["table"] == "visit" + return e.table == "visit" def _get_stanford_transformations() -> Callable[[meds_reader.Patient], meds_reader.Patient]: diff --git a/src/femr/splits.py b/src/femr/splits.py index deb533f..1398942 100644 --- a/src/femr/splits.py +++ b/src/femr/splits.py @@ -6,8 +6,6 @@ import struct from typing import List -import femr.index - @dataclasses.dataclass class PatientSplit: diff --git a/src/femr/transforms/stanford.py b/src/femr/transforms/stanford.py index f139eb0..79932f3 100644 --- a/src/femr/transforms/stanford.py +++ b/src/femr/transforms/stanford.py @@ -1,3 +1,5 @@ +# mypy: disable-error-code="attr-defined" + """Transforms that are unique to STARR OMOP.""" import datetime @@ -158,7 +160,7 @@ def move_billing_codes(patient: meds_reader.Patient) -> meds_reader.Patient: if event.end is not None: if event.visit_id is None: # Every event with an end time should have a visit ID associated with it - raise RuntimeError(f"Expected visit id for visit? {patient['patient_id']} {event}") + raise RuntimeError(f"Expected visit id for visit? {patient.patient_id} {event}") if end_visits.get(event.visit_id, event.end) != event.end: # Every event associated with a visit should have an end time that matches the visit end time # Also the end times of all events associated with a visit should have the same end time diff --git a/tests/featurizers/test_featurizers.py b/tests/featurizers/test_featurizers.py index 714be9c..2eaa53f 100644 --- a/tests/featurizers/test_featurizers.py +++ b/tests/featurizers/test_featurizers.py @@ -7,7 +7,6 @@ import scipy.sparse import femr -import femr.mr from femr.featurizers import FeaturizerList from femr.featurizers.featurizers import AgeFeaturizer, CountFeaturizer from femr.labelers import TimeHorizon @@ -31,7 +30,6 @@ def test_age_featurizer() -> None: time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=180)) dataset = femr_test_tools.create_patients_dataset(100) - pool = femr.mr.Pool(dataset, num_threads=1) labeler = CodeLabeler(["2"], time_horizon, ["3"]) @@ -44,12 +42,12 @@ def test_age_featurizer() -> None: assert patient_features[1] == [(0, 17.767123287671232)] assert patient_features[-1] == [(0, 20.46027397260274)] - all_labels = labeler.apply(pool) + all_labels = labeler.apply(dataset) featurizer = AgeFeaturizer(is_normalize=True) featurizer_list = FeaturizerList([featurizer]) - featurizer_list.preprocess_featurizers(pool, all_labels) - featurized_patients = featurizer_list.featurize(pool, all_labels) + featurizer_list.preprocess_featurizers(dataset, all_labels) + featurized_patients = featurizer_list.featurize(dataset, all_labels) _assert_featurized_patients_structure(all_labels, featurized_patients) @@ -58,14 +56,14 @@ def test_count_featurizer() -> None: time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=180)) dataset = femr_test_tools.create_patients_dataset(100) - pool = femr.mr.Pool(dataset, num_threads=1) labeler = CodeLabeler(["2"], time_horizon, ["3"]) patient: meds_reader.Patient = dataset[0] labels = labeler.label(patient) featurizer = CountFeaturizer() - data = featurizer.generate_preprocess_data([patient], {patient.patient_id: labels}) + data = featurizer.get_initial_preprocess_data() + featurizer.add_preprocess_data(data, patient, {patient.patient_id: labels}) featurizer.encorperate_prepreprocessed_data([data]) patient_features = featurizer.featurize(patient, labels) @@ -89,12 +87,12 @@ def test_count_featurizer() -> None: ("2", 4), } - all_labels = labeler.apply(pool) + all_labels = labeler.apply(dataset) featurizer = CountFeaturizer() featurizer_list = FeaturizerList([featurizer]) - featurizer_list.preprocess_featurizers(pool, all_labels) - featurized_patients = featurizer_list.featurize(pool, all_labels) + featurizer_list.preprocess_featurizers(dataset, all_labels) + featurized_patients = featurizer_list.featurize(dataset, all_labels) _assert_featurized_patients_structure(all_labels, featurized_patients) @@ -103,7 +101,6 @@ def test_count_featurizer_with_ontology() -> None: time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=180)) dataset = femr_test_tools.create_patients_dataset(100) - pool = femr.mr.Pool(dataset, num_threads=1) labeler = CodeLabeler(["2"], time_horizon, ["3"]) @@ -118,7 +115,8 @@ def get_all_parents(self, code): return {code} featurizer = CountFeaturizer(is_ontology_expansion=True, ontology=cast(femr.ontology.Ontology, DummyOntology())) - data = featurizer.generate_preprocess_data([patient], {patient.patient_id: labels}) + data = featurizer.get_initial_preprocess_data() + featurizer.add_preprocess_data(data, patient, {patient.patient_id: labels}) featurizer.encorperate_prepreprocessed_data([data]) patient_features = featurizer.featurize(patient, labels) @@ -145,12 +143,12 @@ def get_all_parents(self, code): ("2", 4), } - all_labels = labeler.apply(pool) + all_labels = labeler.apply(dataset) featurizer = CountFeaturizer(is_ontology_expansion=True, ontology=cast(femr.ontology.Ontology, DummyOntology())) featurizer_list = FeaturizerList([featurizer]) - featurizer_list.preprocess_featurizers(pool, all_labels) - featurized_patients = featurizer_list.featurize(pool, all_labels) + featurizer_list.preprocess_featurizers(dataset, all_labels) + featurized_patients = featurizer_list.featurize(dataset, all_labels) _assert_featurized_patients_structure(all_labels, featurized_patients) @@ -159,14 +157,14 @@ def test_count_featurizer_with_values() -> None: time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=180)) dataset = femr_test_tools.create_patients_dataset(100) - pool = femr.mr.Pool(dataset, num_threads=1) labeler = CodeLabeler(["2"], time_horizon, ["3"]) patient: meds_reader.Patient = dataset[0] labels = labeler.label(patient) featurizer = CountFeaturizer(numeric_value_decile=True, string_value_combination=True) - data = featurizer.generate_preprocess_data([patient], {patient.patient_id: labels}) + data = featurizer.get_initial_preprocess_data() + featurizer.add_preprocess_data(data, patient, {patient.patient_id: labels}) featurizer.encorperate_prepreprocessed_data([data]) patient_features = featurizer.featurize(patient, labels) @@ -197,12 +195,12 @@ def test_count_featurizer_with_values() -> None: ("1 test_value", 2), } - all_labels = labeler.apply(pool) + all_labels = labeler.apply(dataset) featurizer = CountFeaturizer(numeric_value_decile=True, string_value_combination=True) featurizer_list = FeaturizerList([featurizer]) - featurizer_list.preprocess_featurizers(pool, all_labels) - featurized_patients = featurizer_list.featurize(pool, all_labels) + featurizer_list.preprocess_featurizers(dataset, all_labels) + featurized_patients = featurizer_list.featurize(dataset, all_labels) _assert_featurized_patients_structure(all_labels, featurized_patients) @@ -211,7 +209,6 @@ def test_count_featurizer_exclude_filter() -> None: time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=180)) dataset = femr_test_tools.create_patients_dataset(100) - pool = femr.mr.Pool(dataset, num_threads=1) labeler = CodeLabeler(["2"], time_horizon, ["3"]) @@ -220,21 +217,24 @@ def test_count_featurizer_exclude_filter() -> None: # Test filtering all codes featurizer = CountFeaturizer(excluded_event_filter=lambda _: True) - data = featurizer.generate_preprocess_data([patient], {patient.patient_id: labels}) + data = featurizer.get_initial_preprocess_data() + featurizer.add_preprocess_data(data, patient, {patient.patient_id: labels}) featurizer.encorperate_prepreprocessed_data([data]) assert featurizer.get_num_columns() == 0 # Test filtering no codes featurizer = CountFeaturizer(excluded_event_filter=lambda _: False) - data = featurizer.generate_preprocess_data([patient], {patient.patient_id: labels}) + data = featurizer.get_initial_preprocess_data() + featurizer.add_preprocess_data(data, patient, {patient.patient_id: labels}) featurizer.encorperate_prepreprocessed_data([data]) assert featurizer.get_num_columns() == 4 # Test filtering single code featurizer = CountFeaturizer(excluded_event_filter=lambda e: e.code == "3") - data = featurizer.generate_preprocess_data([patient], {patient.patient_id: labels}) + data = featurizer.get_initial_preprocess_data() + featurizer.add_preprocess_data(data, patient, {patient.patient_id: labels}) featurizer.encorperate_prepreprocessed_data([data]) assert featurizer.get_num_columns() == 3 @@ -244,7 +244,6 @@ def test_count_bins_featurizer() -> None: time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=180)) dataset = femr_test_tools.create_patients_dataset(100) - pool = femr.mr.Pool(dataset, num_threads=1) labeler = CodeLabeler(["2"], time_horizon, ["3"]) @@ -258,7 +257,8 @@ def test_count_bins_featurizer() -> None: featurizer = CountFeaturizer( time_bins=time_bins, ) - data = featurizer.generate_preprocess_data([patient], {patient.patient_id: labels}) + data = featurizer.get_initial_preprocess_data() + featurizer.add_preprocess_data(data, patient, {patient.patient_id: labels}) featurizer.encorperate_prepreprocessed_data([data]) patient_features = featurizer.featurize(patient, labels) @@ -285,7 +285,7 @@ def test_count_bins_featurizer() -> None: ("3_70000 days, 0:00:00", 2), } - all_labels = labeler.apply(pool) + all_labels = labeler.apply(dataset) time_bins = [ datetime.timedelta(days=90), @@ -296,8 +296,8 @@ def test_count_bins_featurizer() -> None: time_bins=time_bins, ) featurizer_list = FeaturizerList([featurizer]) - featurizer_list.preprocess_featurizers(pool, all_labels) - featurized_patients = featurizer_list.featurize(pool, all_labels) + featurizer_list.preprocess_featurizers(dataset, all_labels) + featurized_patients = featurizer_list.featurize(dataset, all_labels) _assert_featurized_patients_structure(all_labels, featurized_patients) @@ -306,16 +306,15 @@ def test_complete_featurization() -> None: time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=180)) dataset = femr_test_tools.create_patients_dataset(100) - pool = femr.mr.Pool(dataset, num_threads=1) labeler = CodeLabeler(["2"], time_horizon, ["3"]) - all_labels = labeler.apply(pool) + all_labels = labeler.apply(dataset) age_featurizer = AgeFeaturizer(is_normalize=True) age_featurizer_list = FeaturizerList([age_featurizer]) - age_featurizer_list.preprocess_featurizers(pool, all_labels) - age_featurized_patients = age_featurizer_list.featurize(pool, all_labels) + age_featurizer_list.preprocess_featurizers(dataset, all_labels) + age_featurized_patients = age_featurizer_list.featurize(dataset, all_labels) time_bins = [ datetime.timedelta(days=90), @@ -324,8 +323,8 @@ def test_complete_featurization() -> None: ] count_featurizer = CountFeaturizer(time_bins=time_bins) count_featurizer_list = FeaturizerList([count_featurizer]) - count_featurizer_list.preprocess_featurizers(pool, all_labels) - count_featurized_patients = count_featurizer_list.featurize(pool, all_labels) + count_featurizer_list.preprocess_featurizers(dataset, all_labels) + count_featurized_patients = count_featurizer_list.featurize(dataset, all_labels) age_featurizer = AgeFeaturizer(is_normalize=True) time_bins = [ @@ -335,8 +334,8 @@ def test_complete_featurization() -> None: ] count_featurizer = CountFeaturizer(time_bins=time_bins) featurizer_list = FeaturizerList([age_featurizer, count_featurizer]) - featurizer_list.preprocess_featurizers(pool, all_labels) - featurized_patients = featurizer_list.featurize(pool, all_labels) + featurizer_list.preprocess_featurizers(dataset, all_labels) + featurized_patients = featurizer_list.featurize(dataset, all_labels) assert featurized_patients["patient_ids"].shape == count_featurized_patients["patient_ids"].shape diff --git a/tests/femr_test_tools.py b/tests/femr_test_tools.py index d673a3c..db4f2d6 100644 --- a/tests/femr_test_tools.py +++ b/tests/femr_test_tools.py @@ -7,7 +7,6 @@ import meds import meds_reader -import femr.mr from femr.labelers import Labeler # 2nd elem of tuple -- 'skip' means no label, None means censored @@ -56,7 +55,14 @@ class DummyPatient: class DummyDatabase(dict): - pass + def filter(self, patient_ids): + return DummyDatabase({p: self[p] for p in patient_ids}) + + def map( + self, + map_func, + ) -> Iterator[A]: + return [map_func(self.values())] def create_patients_dataset( @@ -117,21 +123,21 @@ def run_test_for_labeler( help_text: str = "", ) -> None: patients: meds_reader.PatientDatabase = create_patients_dataset(10, [x[0] for x in events_with_labels]) - with femr.mr.Pool(patients, num_threads=1) as pool: - true_labels: List[Tuple[datetime.datetime, Optional[bool]]] = [ - (datetime.datetime(*x[0][0]), x[1]) for x in events_with_labels if isinstance(x[1], bool) - ] - if true_prediction_times is not None: - # If manually specified prediction times, adjust labels from occurring at `event.start` - # e.g. we may make predictions at `event.end` or `event.start + 1 day` - true_labels = [(tp, tl[1]) for (tl, tp) in zip(true_labels, true_prediction_times)] - labeled_patients: List[meds.Label] = labeler.apply(pool) - - # Check accuracy of Labels - for patient_id in patients: - assert_labels_are_accurate( - labeled_patients, - patient_id, - true_labels, - help_text=help_text, - ) + + true_labels: List[Tuple[datetime.datetime, Optional[bool]]] = [ + (datetime.datetime(*x[0][0]), x[1]) for x in events_with_labels if isinstance(x[1], bool) + ] + if true_prediction_times is not None: + # If manually specified prediction times, adjust labels from occurring at `event.start` + # e.g. we may make predictions at `event.end` or `event.start + 1 day` + true_labels = [(tp, tl[1]) for (tl, tp) in zip(true_labels, true_prediction_times)] + labeled_patients: List[meds.Label] = labeler.apply(patients) + + # Check accuracy of Labels + for patient_id in patients: + assert_labels_are_accurate( + labeled_patients, + patient_id, + true_labels, + help_text=help_text, + ) diff --git a/tests/models/test_batch_creator.py b/tests/models/test_batch_creator.py index a82ace9..fd5c1d7 100644 --- a/tests/models/test_batch_creator.py +++ b/tests/models/test_batch_creator.py @@ -15,7 +15,7 @@ def __init__(self, is_hierarchical: bool = False): def start_patient(self): pass - def get_feature_codes(self, time, measurement): + def get_feature_codes(self, event): if event.code == "SNOMED/184099003": return [1], None else: diff --git a/tests/models/test_survival_calculator.py b/tests/models/test_survival_calculator.py index b08e557..9273e52 100644 --- a/tests/models/test_survival_calculator.py +++ b/tests/models/test_survival_calculator.py @@ -1,6 +1,8 @@ import datetime from typing import Set +from femr_test_tools import DummyEvent, DummyPatient + import femr.models.tasks @@ -13,35 +15,15 @@ def get_all_parents(self, code: str) -> Set[str]: def test_calculator(): - patient = { - "patient_id": 100, - "events": [ - { - "time": datetime.datetime(1990, 1, 10), - "measurements": [ - {"code": "1"}, - ], - }, - { - "time": datetime.datetime(1990, 1, 20), - "measurements": [ - {"code": "2"}, - ], - }, - { - "time": datetime.datetime(1990, 1, 25), - "measurements": [ - {"code": "3"}, - ], - }, - { - "time": datetime.datetime(1990, 1, 25), - "measurements": [ - {"code": "1"}, - ], - }, + patient = DummyPatient( + patient_id=100, + events=[ + DummyEvent(time=datetime.datetime(1990, 1, 10), code="1"), + DummyEvent(time=datetime.datetime(1990, 1, 20), code="2"), + DummyEvent(time=datetime.datetime(1990, 1, 25), code="3"), + DummyEvent(time=datetime.datetime(1990, 1, 25), code="1"), ], - } + ) calculator = femr.models.tasks.SurvivalCalculator(DummyOntology(), patient) diff --git a/tutorials/1_Ontology.ipynb b/tutorials/1_Ontology.ipynb index 4e92790..0a21642 100644 --- a/tutorials/1_Ontology.ipynb +++ b/tutorials/1_Ontology.ipynb @@ -44,17 +44,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -68,7 +60,7 @@ "\n", "# You can load / save ontology objects with pickle\n", "\n", - "with open('input/meds/ontology.pkl', 'rb') as f:\n", + "with open('input/ontology.pkl', 'rb') as f:\n", " ontology = pickle.load(f)\n", "\n", "print(\"Loaded ontology\")" @@ -76,32 +68,25 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Generating train split: 200 examples [00:00, 34972.93 examples/s]\n", - "Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 3282.29 examples/s]\n" - ] - } - ], + "outputs": [], "source": [ "# Ontology datasets downloaded by Athena tend to be very large as they contain many codes, including several that are no longer used.\n", "# We therefore provide a function to prune ontologies to a particular dataset of interest.\n", "# This makes it much cheaper to store and use an ontology object, both in terms of disk space and RAM\n", "\n", "\n", - "dataset = datasets.Dataset.from_parquet(\"input/meds/data/*\")\n", + "import meds_reader\n", + "\n", + "database = meds_reader.PatientDatabase(\"input/meds_reader\")\n", "\n", - "ontology.prune_to_dataset(dataset)" + "ontology.prune_to_dataset(database)" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -111,8 +96,8 @@ "Description DRUGS FOR PEPTIC ULCER AND GASTRO-OESOPHAGEAL REFLUX DISEASE (GORD)\n", "Parents {'ATC/A02'}\n", "Children {'ATC/A02BX'}\n", - "All children {'RxNorm/2344', 'ATC/A02BX', 'RxNorm/4501', 'ATC/A02BX71', 'ATC/A02B', 'RxNorm/7815', 'RxNorm/7019', 'ATC/A02BX77', 'RxNorm/2353', 'RxNorm/8705', 'RxNorm/38574', 'RxNorm/2620', 'RxNorm/2018', 'RxNorm/8704', 'RxNorm/8730', 'RxNorm/6852', 'RxNorm/2017', 'RxNorm/2403'}\n", - "All parents {'ATC/A', 'ATC/A02', 'ATC/A02B'}\n" + "All children {'RxNorm/38574', 'ATC/A02BX71', 'RxNorm/2017', 'ATC/A02B', 'RxNorm/2018', 'RxNorm/6852', 'RxNorm/8705', 'RxNorm/8704', 'RxNorm/2344', 'RxNorm/2403', 'RxNorm/2353', 'RxNorm/8730', 'RxNorm/4501', 'ATC/A02BX77', 'RxNorm/7815', 'RxNorm/2620', 'RxNorm/7019', 'ATC/A02BX'}\n", + "All parents {'ATC/A02B', 'ATC/A', 'ATC/A02'}\n" ] } ], @@ -148,7 +133,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/tutorials/2_Labeling.ipynb b/tutorials/2_Labeling.ipynb index 08095f1..af93b76 100644 --- a/tutorials/2_Labeling.ipynb +++ b/tutorials/2_Labeling.ipynb @@ -22,7 +22,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 10, "id": "c6ac5c41-bc99-4731-ad82-7152274c67e1", "metadata": {}, "outputs": [], @@ -48,19 +48,10 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 11, "id": "8d9e2ccd-71c2-4ae0-897b-7ec022f9fdf4", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], + "outputs": [], "source": [ "# We can construct these labels manually\n", "\n", @@ -104,17 +95,10 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 12, "id": "9ac22dbe-ef34-468a-8ab3-673e58e5a920", "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 3040.98 examples/s]" - ] - }, { "name": "stdout", "output_type": "stream", @@ -130,35 +114,31 @@ "{'patient_id': 108, 'prediction_time': datetime.datetime(1991, 10, 20, 0, 0), 'boolean_value': True}\n", "{'patient_id': 109, 'prediction_time': datetime.datetime(1991, 6, 25, 0, 0), 'boolean_value': True}\n" ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] } ], "source": [ "from typing import List\n", "import femr.pat_utils\n", + "import meds_reader\n", + "import meds\n", + "import femr.labelers\n", "\n", "\n", "class IsMaleLabeler(femr.labelers.Labeler):\n", " # Dummy labeler to predict gender at birth\n", " \n", " def label(self, patient: meds_reader.Patient) -> List[meds.Label]:\n", - " is_male = any('Gender/M' == measurement['code'] for event in patient['events'] for measurement in event['measurements'])\n", + " is_male = any('Gender/M' == event.code for event in patient.events)\n", " return [{\n", - " 'patient_id': patient['patient_id'], \n", + " 'patient_id': patient.patient_id, \n", " 'prediction_time': femr.pat_utils.get_patient_birthdate(patient),\n", " 'boolean_value': is_male,\n", " }]\n", " \n", - "dataset = datasets.Dataset.from_parquet(\"input/meds/data/*\")\n", + "database = meds_reader.PatientDatabase(\"input/meds_reader\")\n", "\n", "labeler = IsMaleLabeler()\n", - "labeled_patients = labeler.apply(dataset)\n", + "labeled_patients = labeler.apply(database)\n", "\n", "for i in range(10):\n", " print(labeled_patients[100 + i])\n", @@ -167,7 +147,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 13, "id": "20bd7859", "metadata": {}, "outputs": [], @@ -197,7 +177,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/tutorials/3_Count Featurization And Modeling.ipynb b/tutorials/3_Count Featurization And Modeling.ipynb index e9dc61c..b0ad3ff 100644 --- a/tutorials/3_Count Featurization And Modeling.ipynb +++ b/tutorials/3_Count Featurization And Modeling.ipynb @@ -16,39 +16,23 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "892ab2d5-0c5a-43c9-a210-9201f775e4fb", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 36758.28 examples/s]\n", - "Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 3295.78 examples/s]\n", - "Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 2998.20 examples/s]\n" - ] - } - ], + "outputs": [], "source": [ "import pickle\n", "import femr.featurizers\n", "import femr.labelers\n", "import meds\n", "import pyarrow.csv\n", - "\n", - "import femr.index\n", + "import meds_reader\n", "\n", "# Load some labels\n", "labels = pyarrow.csv.read_csv('input/labels.csv').to_pylist()\n", "\n", "# Load our data\n", - "dataset = datasets.Dataset.from_parquet(\"input/meds/data/*\")\n", - "\n", - "# We need to create an index to allow us to find patients quickly\n", - "index = femr.index.PatientIndex(dataset)\n", + "database = meds_reader.PatientDatabase(\"input/meds_reader\")\n", " \n", "# Define our featurizer\n", "\n", @@ -58,15 +42,15 @@ "featurizer_age_count = femr.featurizers.FeaturizerList([age, count])\n", "\n", "# Preprocessing the featurizers, which includes processes such as normalizing age.\n", - "featurizer_age_count.preprocess_featurizers(dataset, index, labels)\n", + "featurizer_age_count.preprocess_featurizers(database, labels)\n", "\n", "# Actually do the featurization\n", - "features = featurizer_age_count.featurize(dataset, index, labels)" + "features = featurizer_age_count.featurize(database, labels)" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "112fe99d", "metadata": {}, "outputs": [ @@ -101,7 +85,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "cd0f43fd", "metadata": {}, "outputs": [ @@ -137,7 +121,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "id": "01acd922-668b-481b-8dbb-54ab6ae433af", "metadata": {}, "outputs": [], @@ -176,36 +160,19 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "id": "caae3126-1437-408e-b25f-04568e15c96a", "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "---- Logistic Regression ----\n", - "Train:\n", - "\tAUROC: 1.0\n", - "\tAPS: 1.0\n", - "\tAccuracy: 1.0\n", - "\tF1 Score: 1.0\n", - "Test:\n", - "\tAUROC: 1.0\n", - "\tAPS: 1.0\n", - "\tAccuracy: 1.0\n", - "\tF1 Score: 1.0\n", - "---- XGBoost ----\n", - "Train:\n", - "\tAUROC: 1.0\n", - "\tAPS: 1.0\n", - "\tAccuracy: 1.0\n", - "\tF1 Score: 1.0\n", - "Test:\n", - "\tAUROC: 1.0\n", - "\tAPS: 1.0\n", - "\tAccuracy: 1.0\n", - "\tF1 Score: 1.0\n" + "ename": "ModuleNotFoundError", + "evalue": "No module named 'xgboost'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mxgboost\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mxgb\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01msklearn\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlinear_model\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01msklearn\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmetrics\u001b[39;00m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'xgboost'" ] } ], @@ -270,7 +237,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/tutorials/4_Train CLMBR.ipynb b/tutorials/4_Train CLMBR.ipynb index 8048e25..3b31733 100644 --- a/tutorials/4_Train CLMBR.ipynb +++ b/tutorials/4_Train CLMBR.ipynb @@ -378,7 +378,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/tutorials/6_Train MOTOR.ipynb b/tutorials/6_Train MOTOR.ipynb index 1628fa4..ebc64ef 100644 --- a/tutorials/6_Train MOTOR.ipynb +++ b/tutorials/6_Train MOTOR.ipynb @@ -27,9 +27,6 @@ "import shutil\n", "import os\n", "\n", - "# os.environ[\"HF_DATASETS_CACHE\"] = '/share/pi/nigam/ethanid/cache_dir'\n", - "\n", - "\n", "TARGET_DIR = 'trash/tutorial_6'\n", "\n", "if os.path.exists(TARGET_DIR):\n", @@ -43,62 +40,17 @@ "execution_count": 2, "id": "646f7590", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 1395.58 examples/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0, 1, 2, 4, 6, 7, 10, 11, 12, 13, 14, 15, 18, 20, 21, 23, 24, 26, 27, 28, 29, 30, 31, 33, 36, 37, 38, 40, 42, 44, 45, 47, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 62, 63, 64, 65, 66, 67, 69, 70, 73, 74, 75, 76, 77, 79, 80, 83, 85, 86, 88, 89, 90, 91, 93, 94, 95, 96, 97, 98, 100, 101, 102, 103, 104, 105, 107, 109, 110, 112, 114, 115, 116, 117, 118, 120, 121, 122, 123, 124, 125, 126, 127, 128, 133, 134, 135, 136, 137, 139, 141, 142, 143, 144, 149, 150, 151, 152, 153, 154, 156, 157, 158, 159, 160, 161, 162, 163, 165, 166, 168, 169, 171, 172, 173, 174, 178, 181, 182, 183, 184, 185, 186, 187, 189, 192, 193, 195, 196, 197, 198, 199]\n", - "[19, 22, 25, 39, 46, 71, 82, 84, 87, 92, 106, 108, 113, 131, 132, 138, 146, 147, 148, 155, 177, 179, 180, 188, 190, 191]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n", - "Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:00<00:00, 1316.29 examples/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "DatasetDict({\n", - " train: Dataset({\n", - " features: ['patient_id', 'events'],\n", - " num_rows: 144\n", - " })\n", - " test: Dataset({\n", - " features: ['patient_id', 'events'],\n", - " num_rows: 26\n", - " })\n", - "})\n" - ] - } - ], + "outputs": [], "source": [ - "\n", - "import femr.index\n", + "import meds_reader\n", "import femr.splits\n", "\n", "# First, we want to split our dataset into train, valid, and test\n", "# We do this by calling our split functionality twice\n", "\n", - "dataset = datasets.Dataset.from_parquet('input/meds/data/*')\n", - "\n", + "database = meds_reader.PatientDatabase('input/meds_reader')\n", "\n", - "index = femr.index.PatientIndex(dataset, num_proc=4)\n", - "main_split = femr.splits.generate_hash_split(index.get_patient_ids(), 97, frac_test=0.15)\n", + "main_split = femr.splits.generate_hash_split(list(database), 97, frac_test=0.15)\n", "\n", "os.mkdir(os.path.join(TARGET_DIR, 'motor_model'))\n", "# Note that we want to save this to the target directory since this is important information\n", @@ -107,13 +59,9 @@ "\n", "train_split = femr.splits.generate_hash_split(main_split.train_patient_ids, 87, frac_test=0.15)\n", "\n", - "print(train_split.train_patient_ids)\n", - "print(train_split.test_patient_ids)\n", - "\n", - "main_dataset = main_split.split_dataset(dataset, index)\n", - "train_dataset = train_split.split_dataset(main_dataset['train'], femr.index.PatientIndex(main_dataset['train'], num_proc=4))\n", - "\n", - "print(train_dataset)" + "main_database = database.filter(main_split.train_patient_ids)\n", + "train_database = main_database.filter(train_split.train_patient_ids)\n", + "val_database = main_database.filter(train_split.test_patient_ids)\n" ] }, { @@ -126,7 +74,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "Map (num_proc=4): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:00<00:00, 331.19 examples/s]\n" + "/home/ethanid/health_research/venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], @@ -137,12 +86,12 @@ "# First, we need to train a tokenizer\n", "# Note, we need to use a hierarchical tokenizer for MOTOR\n", "\n", - "with open('input/meds/ontology.pkl', 'rb') as f:\n", + "with open('input/ontology.pkl', 'rb') as f:\n", " ontology = pickle.load(f)\n", "\n", "# NOTE: A vocab size of 128 is probably too low for a real model. 128 was chosen to make this tutorial quick to run\n", "tokenizer = femr.models.tokenizer.train_tokenizer(\n", - " main_dataset['train'], vocab_size=128, is_hierarchical=True, num_proc=4, ontology=ontology)\n", + " train_database, vocab_size=128, is_hierarchical=True, ontology=ontology)\n", "\n", "# Save the tokenizer to the same directory as the model\n", "tokenizer.save_pretrained(os.path.join(TARGET_DIR, \"motor_model\"))" @@ -153,15 +102,7 @@ "execution_count": 4, "id": "69b60daa", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Map (num_proc=4): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:00<00:00, 249.31 examples/s]\n" - ] - } - ], + "outputs": [], "source": [ "\n", "import femr.models.tasks\n", @@ -169,15 +110,14 @@ "# Second, we need to prefit the MOTOR model. This is necessary because piecewise exponential models are unstable without an initial fit\n", "\n", "motor_task = femr.models.tasks.MOTORTask.fit_pretraining_task_info(\n", - " main_dataset['train'], tokenizer, num_tasks=64, num_bins=4, final_layer_size=32, num_proc=4)\n", - "\n", + " train_database, tokenizer, num_tasks=64, num_bins=4, final_layer_size=32)\n", "\n", "# It's recommended to save this with pickle to avoid recomputing since it's an expensive operation\n" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "id": "89611ba9-a242-4b87-9b8f-25670d838fc6", "metadata": {}, "outputs": [ @@ -186,35 +126,23 @@ "output_type": "stream", "text": [ "Convert a single patient\n", - "Convert batches\n" + "Convert batches\n", + "Creating batches 8\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "Map (num_proc=4): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 144/144 [00:00<00:00, 261.72 examples/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Creating batches 7\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Generating train split: 7 examples [00:00, 12.06 examples/s]\n", - "Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:00<00:00, 50.63 examples/s]\n" + "Generating train split: 8 examples [00:00, 26.46 examples/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ + "Convert batches to pytorch\n", + "Done\n", "Creating batches 1\n" ] }, @@ -223,22 +151,7 @@ "output_type": "stream", "text": [ "Setting num_proc from 4 back to 1 for the train split to disable multiprocessing as it only contains one shard.\n", - "Generating train split: 1 examples [00:00, 57.97 examples/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Convert batches to pytorch\n", - "Done\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" + "Generating train split: 1 examples [00:00, 172.15 examples/s]\n" ] } ], @@ -250,23 +163,30 @@ "\n", "processor = femr.models.processor.FEMRBatchProcessor(tokenizer, motor_task)\n", "\n", + "example_patient_id = list(train_database)[0]\n", + "example_patient = train_database[example_patient_id]\n", + "\n", "# We can do this one patient at a time\n", "print(\"Convert a single patient\")\n", - "example_batch = processor.collate([processor.convert_patient(train_dataset['train'][0], tensor_type='pt')])\n", + "example_batch = processor.collate([processor.convert_patient(example_patient, tensor_type='pt')])\n", "\n", "print(\"Convert batches\")\n", "# But generally we want to convert entire datasets\n", - "train_batches = processor.convert_dataset(train_dataset, tokens_per_batch=32, num_proc=4)\n", + "train_batches = processor.convert_dataset(train_database, tokens_per_batch=32, num_proc=4)\n", "\n", "print(\"Convert batches to pytorch\")\n", "# Convert our batches to pytorch tensors\n", "train_batches.set_format(\"pt\")\n", - "print(\"Done\")" + "print(\"Done\")\n", + "\n", + "val_batches = processor.convert_dataset(val_database, tokens_per_batch=32, num_proc=4)\n", + "# Convert our batches to pytorch tensors\n", + "val_batches.set_format(\"pt\")" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "id": "f654a46c-5aa7-465c-b6c5-73d8ba26ed67", "metadata": {}, "outputs": [ @@ -274,10 +194,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/torch/cuda/__init__.py:628: UserWarning: Can't initialize NVML\n", - " warnings.warn(\"Can't initialize NVML\")\n", - "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n", - "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\n", + "/home/ethanid/health_research/venv/lib/python3.12/site-packages/transformers/training_args.py:1494: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n", " warnings.warn(\n", "Could not estimate the number of tokens of the input, floating-point operations will not be computed\n" ] @@ -288,8 +205,8 @@ "\n", "
20 | \n", - "0.855400 | \n", - "0.506942 | \n", + "0.830100 | \n", + "0.523818 | \n", "
40 | \n", - "0.871100 | \n", - "0.506998 | \n", + "0.810400 | \n", + "0.523785 | \n", "
60 | \n", - "0.826900 | \n", - "0.507056 | \n", + "0.809800 | \n", + "0.523750 | \n", "
80 | \n", - "0.856700 | \n", - "0.507116 | \n", + "0.829200 | \n", + "0.523720 | \n", "
100 | \n", - "0.856200 | \n", - "0.507181 | \n", + "0.827000 | \n", + "0.523700 | \n", "
120 | \n", - "0.829800 | \n", - "0.507251 | \n", + "0.810600 | \n", + "0.523686 | \n", "
140 | \n", - "0.859700 | \n", - "0.507321 | \n", + "0.798000 | \n", + "0.523673 | \n", "
160 | \n", - "0.851600 | \n", - "0.507393 | \n", + "0.838300 | \n", + "0.523665 | \n", "
180 | \n", - "0.852500 | \n", - "0.507467 | \n", + "0.827100 | \n", + "0.523665 | \n", "
200 | \n", - "0.868400 | \n", - "0.507540 | \n", + "0.808000 | \n", + "0.523664 | \n", "
220 | \n", - "0.850800 | \n", - "0.507617 | \n", + "0.818500 | \n", + "0.523664 | \n", "
240 | \n", - "0.835900 | \n", - "0.507696 | \n", + "0.815400 | \n", + "0.523672 | \n", "
260 | \n", - "0.850000 | \n", - "0.507768 | \n", + "0.839000 | \n", + "0.523685 | \n", "
280 | \n", - "0.831500 | \n", - "0.507841 | \n", + "0.793800 | \n", + "0.523696 | \n", "
300 | \n", - "0.860700 | \n", - "0.507915 | \n", + "0.815600 | \n", + "0.523708 | \n", "
320 | \n", - "0.846000 | \n", - "0.507988 | \n", + "0.816300 | \n", + "0.523721 | \n", "
340 | \n", - "0.826800 | \n", - "0.508055 | \n", + "0.824800 | \n", + "0.523741 | \n", "
360 | \n", - "0.830600 | \n", - "0.508123 | \n", + "0.806100 | \n", + "0.523758 | \n", "
380 | \n", - "0.884700 | \n", - "0.508188 | \n", + "0.836500 | \n", + "0.523773 | \n", "
400 | \n", - "0.823900 | \n", - "0.508248 | \n", + "0.793600 | \n", + "0.523792 | \n", "
420 | \n", - "0.856200 | \n", - "0.508309 | \n", + "0.782700 | \n", + "0.523814 | \n", "
440 | \n", - "0.848400 | \n", - "0.508360 | \n", + "0.846600 | \n", + "0.523835 | \n", "
460 | \n", - "0.855900 | \n", - "0.508413 | \n", + "0.813100 | \n", + "0.523853 | \n", "
480 | \n", - "0.849500 | \n", - "0.508458 | \n", + "0.815500 | \n", + "0.523872 | \n", "
500 | \n", - "0.831200 | \n", - "0.508502 | \n", + "0.846000 | \n", + "0.523890 | \n", "
520 | \n", - "0.848300 | \n", - "0.508542 | \n", + "0.781900 | \n", + "0.523907 | \n", "
540 | \n", - "0.858700 | \n", - "0.508577 | \n", + "0.802900 | \n", + "0.523925 | \n", "
560 | \n", - "0.829200 | \n", - "0.508608 | \n", + "0.824500 | \n", + "0.523942 | \n", "
580 | \n", - "0.858500 | \n", - "0.508636 | \n", + "0.803500 | \n", + "0.523959 | \n", "
600 | \n", - "0.825800 | \n", - "0.508659 | \n", + "0.823400 | \n", + "0.523972 | \n", "
620 | \n", - "0.878200 | \n", - "0.508677 | \n", + "0.804000 | \n", + "0.523985 | \n", "
640 | \n", - "0.839500 | \n", - "0.508692 | \n", + "0.822600 | \n", + "0.523996 | \n", "
660 | \n", - "0.813000 | \n", - "0.508703 | \n", + "0.794200 | \n", + "0.524006 | \n", "
680 | \n", - "0.854800 | \n", - "0.508709 | \n", + "0.832000 | \n", + "0.524016 | \n", "
700 | \n", - "0.847300 | \n", - "0.508711 | \n", + "0.830300 | \n", + "0.524024 | \n", + "
720 | \n", + "0.795700 | \n", + "0.524031 | \n", + "||
740 | \n", + "0.802900 | \n", + "0.524037 | \n", + "||
760 | \n", + "0.823000 | \n", + "0.524040 | \n", + "||
780 | \n", + "0.790700 | \n", + "0.524042 | \n", + "||
800 | \n", + "0.835000 | \n", + "0.524043 | \n", "
" @@ -529,8 +471,8 @@ "trainer = transformers.Trainer(\n", " model=model,\n", " data_collator=processor.collate,\n", - " train_dataset=train_batches['train'],\n", - " eval_dataset=train_batches['test'],\n", + " train_dataset=train_batches,\n", + " eval_dataset=val_batches,\n", " args=trainer_config,\n", ")\n", "\n", @@ -556,7 +498,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/tutorials/7_MOTOR Featurization And Modeling.ipynb b/tutorials/7_MOTOR Featurization And Modeling.ipynb index 4a56512..64100c0 100644 --- a/tutorials/7_MOTOR Featurization And Modeling.ipynb +++ b/tutorials/7_MOTOR Featurization And Modeling.ipynb @@ -254,7 +254,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/tutorials/input/meds/data/patients.parquet b/tutorials/input/meds/data/patients.parquet deleted file mode 100644 index 7036eaf..0000000 Binary files a/tutorials/input/meds/data/patients.parquet and /dev/null differ diff --git a/tutorials/input/meds/ontology.pkl b/tutorials/input/meds/ontology.pkl deleted file mode 100644 index ed89da9..0000000 Binary files a/tutorials/input/meds/ontology.pkl and /dev/null differ diff --git a/tutorials/input/meds_reader/code/data b/tutorials/input/meds_reader/code/data new file mode 100644 index 0000000..1a80704 Binary files /dev/null and b/tutorials/input/meds_reader/code/data differ diff --git a/tutorials/input/meds_reader/code/dictionary b/tutorials/input/meds_reader/code/dictionary new file mode 100644 index 0000000..4b1e01c Binary files /dev/null and b/tutorials/input/meds_reader/code/dictionary differ diff --git a/tutorials/input/meds_reader/code/zdict b/tutorials/input/meds_reader/code/zdict new file mode 100644 index 0000000..11a0a4b Binary files /dev/null and b/tutorials/input/meds_reader/code/zdict differ diff --git a/tutorials/input/meds_reader/datetime_value/data b/tutorials/input/meds_reader/datetime_value/data new file mode 100644 index 0000000..ec689d3 Binary files /dev/null and b/tutorials/input/meds_reader/datetime_value/data differ diff --git a/tutorials/input/meds_reader/datetime_value/zdict b/tutorials/input/meds_reader/datetime_value/zdict new file mode 100644 index 0000000..7718a83 Binary files /dev/null and b/tutorials/input/meds_reader/datetime_value/zdict differ diff --git a/tutorials/input/meds_reader/length b/tutorials/input/meds_reader/length new file mode 100644 index 0000000..94c7400 Binary files /dev/null and b/tutorials/input/meds_reader/length differ diff --git a/tutorials/input/meds/metadata.json b/tutorials/input/meds_reader/metadata.json similarity index 100% rename from tutorials/input/meds/metadata.json rename to tutorials/input/meds_reader/metadata.json diff --git a/tutorials/input/meds_reader/numeric_value/data b/tutorials/input/meds_reader/numeric_value/data new file mode 100644 index 0000000..ec689d3 Binary files /dev/null and b/tutorials/input/meds_reader/numeric_value/data differ diff --git a/tutorials/input/meds_reader/numeric_value/zdict b/tutorials/input/meds_reader/numeric_value/zdict new file mode 100644 index 0000000..7718a83 Binary files /dev/null and b/tutorials/input/meds_reader/numeric_value/zdict differ diff --git a/tutorials/input/meds_reader/patient_id b/tutorials/input/meds_reader/patient_id new file mode 100644 index 0000000..54def8b Binary files /dev/null and b/tutorials/input/meds_reader/patient_id differ diff --git a/tutorials/input/meds_reader/properties b/tutorials/input/meds_reader/properties new file mode 100644 index 0000000..5cb0a0e Binary files /dev/null and b/tutorials/input/meds_reader/properties differ diff --git a/tutorials/input/meds_reader/text_value/data b/tutorials/input/meds_reader/text_value/data new file mode 100644 index 0000000..ea3a50e Binary files /dev/null and b/tutorials/input/meds_reader/text_value/data differ diff --git a/tutorials/input/meds_reader/text_value/dictionary b/tutorials/input/meds_reader/text_value/dictionary new file mode 100644 index 0000000..e69de29 diff --git a/tutorials/input/meds_reader/text_value/zdict b/tutorials/input/meds_reader/text_value/zdict new file mode 100644 index 0000000..e8abfcb Binary files /dev/null and b/tutorials/input/meds_reader/text_value/zdict differ diff --git a/tutorials/input/meds_reader/time b/tutorials/input/meds_reader/time new file mode 100644 index 0000000..695edb7 Binary files /dev/null and b/tutorials/input/meds_reader/time differ diff --git a/tutorials/input/ontology.pkl b/tutorials/input/ontology.pkl new file mode 100644 index 0000000..49aa11c Binary files /dev/null and b/tutorials/input/ontology.pkl differ diff --git a/tutorials/synthetic_data_generation/generate_patients.py b/tutorials/synthetic_data_generation/generate_patients.py index b30bf36..46bb2e4 100644 --- a/tutorials/synthetic_data_generation/generate_patients.py +++ b/tutorials/synthetic_data_generation/generate_patients.py @@ -7,87 +7,91 @@ import jsonschema import meds +import meds_reader import pyarrow import pyarrow.parquet import femr.ontology import femr.transforms -parser = argparse.ArgumentParser(prog="generate_patients", description="Create synthetic data") -parser.add_argument("athena", type=str) -parser.add_argument("destination", type=str) -args = parser.parse_args() - -random.seed(4533) - - -def get_random_patient(patient_id): - epoch = datetime.datetime(1990, 1, 1) - birth = epoch + datetime.timedelta(days=random.randint(100, 1000)) - current_date = birth - - gender = "Gender/" + random.choice(["F", "M"]) - race = "Race/" + random.choice(["White", "Non-White"]) - - patient = { - "patient_id": patient_id, - "events": [ - { - "time": birth, - "measurements": [ - {"code": meds.birth_code}, - {"code": gender}, - {"code": race}, - ], - }, - ], - } - code_cats = ["ICD9CM", "RxNorm"] - for code in range(random.randint(1, 10 + (20 if gender == "Gender/F" else 0))): - code_cat = random.choice(code_cats) - if code_cat == "RxNorm": - code = str(random.randint(0, 10000)) - else: - code = str(random.randint(0, 10000)) - if len(code) > 3: - code = code[:3] + "." + code[3:] - current_date = current_date + datetime.timedelta(days=random.randint(1, 100)) - code = code_cat + "/" + code - patient.events.append({"time": current_date, "measurements": [{"code": code}]}) +if __name__ == "__main__": + + parser = argparse.ArgumentParser(prog="generate_patients", description="Create synthetic data") + parser.add_argument("athena", type=str) + parser.add_argument("destination", type=str) + args = parser.parse_args() + + random.seed(4533) + + def get_random_patient(patient_id): + epoch = datetime.datetime(1990, 1, 1) + birth = epoch + datetime.timedelta(days=random.randint(100, 1000)) + current_date = birth + + gender = "Gender/" + random.choice(["F", "M"]) + race = "Race/" + random.choice(["White", "Non-White"]) - return patient + patient = { + "patient_id": patient_id, + "events": [], + } + birth_codes = [meds.birth_code, gender, race] -patients = [] -for i in range(200): - patients.append(get_random_patient(i)) + for birth_code in birth_codes: + patient["events"].append({"time": birth, "code": birth_code}) -patient_schema = meds_reader.Patient_schema() + code_cats = ["ICD9CM", "RxNorm"] + for code in range(random.randint(1, 10 + (20 if gender == "Gender/F" else 0))): + code_cat = random.choice(code_cats) + if code_cat == "RxNorm": + code = str(random.randint(0, 10000)) + else: + code = str(random.randint(0, 10000)) + if len(code) > 3: + code = code[:3] + "." + code[3:] + current_date = current_date + datetime.timedelta(days=random.randint(1, 100)) + code = code_cat + "/" + code + patient["events"].append({"time": current_date, "code": code}) -patient_table = pyarrow.Table.from_pylist(patients, patient_schema) + return patient -os.makedirs(os.path.join(args.destination, "data"), exist_ok=True) + patients = [] + for i in range(200): + patients.append(get_random_patient(i)) -pyarrow.parquet.write_table(patient_table, os.path.join(args.destination, "data", "patients.parquet")) + patient_schema = meds.schema.patient_schema() + + patient_table = pyarrow.Table.from_pylist(patients, patient_schema) + + os.makedirs(os.path.join(args.destination, "data"), exist_ok=True) + + pyarrow.parquet.write_table(patient_table, os.path.join(args.destination, "data", "patients.parquet")) + + metadata = { + "dataset_name": "femr synthetic datata", + "dataset_version": "1", + "etl_name": "synthetic data", + "etl_version": "1", + "code_metadata": {}, + } -metadata = { - "dataset_name": "femr synthetic datata", - "dataset_version": "1", - "etl_name": "synthetic data", - "etl_version": "1", - "code_metadata": {}, -} + jsonschema.validate(instance=metadata, schema=meds.dataset_metadata) -jsonschema.validate(instance=metadata, schema=meds.dataset_metadata) + with open(os.path.join(args.destination, "metadata.json"), "w") as f: + json.dump(metadata, f) -with open(os.path.join(args.destination, "metadata.json"), "w") as f: - json.dump(metadata, f) + print("Converting") + os.system(f"convert_to_meds_reader {args.destination} {args.destination}_meds") -dataset = datasets.Dataset.from_parquet(os.path.join(args.destination, "data", "*")) + print("Opening database") -ontology = femr.ontology.Ontology(args.athena) + with meds_reader.PatientDatabase(args.destination + "_meds", num_threads=6) as database: + print("Creating ontology") + ontology = femr.ontology.Ontology(args.athena) -ontology.prune_to_dataset(dataset, remove_ontologies=("SNOMED")) + print("Pruning ontology") + ontology.prune_to_dataset(database, remove_ontologies=("SNOMED")) -with open(os.path.join(args.destination, "ontology.pkl"), "wb") as f: - pickle.dump(ontology, f) + with open(os.path.join(args.destination, "ontology.pkl"), "wb") as f: + pickle.dump(ontology, f)