diff --git a/.mypy.ini b/.mypy.ini index cb5991d..13186d4 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -53,3 +53,6 @@ ignore_missing_imports = True [mypy-msgpack.*] ignore_missing_imports = True + +[mypy-xformers.*] +ignore_missing_imports = True diff --git a/src/femr/featurizers/core.py b/src/femr/featurizers/core.py index 7a6ea72..361475e 100644 --- a/src/femr/featurizers/core.py +++ b/src/femr/featurizers/core.py @@ -13,7 +13,6 @@ import numpy as np import scipy.sparse -import femr.mr import femr.ontology @@ -30,12 +29,12 @@ class ColumnValue(NamedTuple): def _preprocess_map_func( patients: Iterator[meds_reader.Patient], *, label_map: Mapping[int, List[meds.Label]], featurizers: List[Featurizer] ) -> List[List[Any]]: - result = [] - patients_list = list(patients) - for featurizer in featurizers: - result.append(featurizer.generate_preprocess_data(iter(patients_list), label_map)) + initial_data = [featurizer.get_initial_preprocess_data() for featurizer in featurizers] + for patient in patients: + for data, featurizer in zip(initial_data, featurizers): + featurizer.add_preprocess_data(data, patient, label_map) - return result + return initial_data def _features_map_func( @@ -43,9 +42,11 @@ def _features_map_func( ) -> Mapping[str, Any]: # Construct CSR sparse matrix # non-zero entries in sparse matrix - data: List[Any] = [] - # maps each element in `data`` to its column in the sparse matrix - indices: List[int] = [] + data_and_indices = np.zeros((1024, 2), np.float64) + data_and_indices_arrays = [] + + current_index = 0 + # maps each element in `data` and `indices` to the rows of the sparse matrix indptr: List[int] = [] @@ -73,7 +74,7 @@ def _features_map_func( a.append(b) for features in features_per_label: - indptr.append(len(indices)) + indptr.append(current_index + len(data_and_indices_arrays) * 1024) # Keep track of starting column for each successive featurizer as we # combine their features into one large matrix @@ -85,14 +86,20 @@ def _features_map_func( f"{column} on patient {patient.patient_id} ({column} must be between 0 and " f"{featurizer.get_num_columns()})" ) - indices.append(column_offset + column) - data.append(value) + data_and_indices[current_index, 0] = value + data_and_indices[current_index, 1] = column_offset + column + + current_index += 1 + + if current_index == 1024: + current_index = 0 + data_and_indices_arrays.append(data_and_indices.copy()) # Record what the starting column should be for the next featurizer column_offset += featurizer.get_num_columns() # Need one last `indptr` for end of last row in CSR sparse matrix - indptr.append(len(indices)) + indptr.append(current_index + len(data_and_indices_arrays) * 1024) # n_rows = number of Labels across all Patients total_rows: int = len(indptr) - 1 @@ -100,8 +107,13 @@ def _features_map_func( total_columns: int = sum(x.get_num_columns() for x in featurizers) # Explanation of CSR Matrix: https://stackoverflow.com/questions/52299420/scipy-csr-matrix-understand-indptr - np_data: np.ndarray = np.array(data, dtype=np.float32) - np_indices: np.ndarray = np.array(indices, dtype=np.int64) + data_and_indices_arrays.append(data_and_indices[:current_index, :]) + + np_data_and_indices: np.ndarray = np.concatenate(data_and_indices_arrays) + + np_data = np_data_and_indices[:, 0].astype(np.float32) + np_indices = np_data_and_indices[:, 1].astype(np.int64) + np_indptr: np.ndarray = np.array(indptr, dtype=np.int64) assert ( @@ -118,22 +130,19 @@ def _features_map_func( return {"patient_ids": np_patient_ids, "feature_times": np_feature_times, "features": data_matrix} -def _features_agg_func(first_result: Any, second_result: Any) -> Any: - for k in first_result: - first_result[k].extend(second_result[k]) - - return first_result - - class Featurizer(ABC): """A Featurizer takes a Patient and a list of Labels, then returns a row for each timepoint. Featurizers must be preprocessed before they are used to compute normalization statistics. A sparse representation named ColumnValue is used to represent the values returned by a Featurizer. """ - def generate_preprocess_data( - self, patients: Iterator[meds_reader.Patient], label_map: Mapping[int, List[meds.Label]] - ) -> Any: + def get_initial_preprocess_data(self) -> Any: + """ + Get the initial preprocess data + """ + pass + + def add_preprocess_data(self, data: Any, patient: meds_reader.Patient, label_map: Mapping[int, List[meds.Label]]): """ Some featurizers need to do some preprocessing in order to prepare for featurization. This function performs that preprocessing on the given patients and labels, and returns some state. @@ -230,7 +239,7 @@ def __init__(self, featurizers: List[Featurizer]): def preprocess_featurizers( self, - pool: femr.mr.Pool, + db: meds_reader.PatientDatabase, labels: List[meds.Label], ) -> None: """Preprocess `self.featurizers` on the provided set of labels.""" @@ -245,13 +254,12 @@ def preprocess_featurizers( for label in labels: label_map[label["patient_id"]].append(label) # Split patients across multiple threads - patient_ids: List[int] = sorted(list({label["patient_id"] for label in labels})) + patient_ids: List[int] = list({label["patient_id"] for label in labels}) featurize_stats: List[List[Any]] = [[] for _ in self.featurizers] - for chunk_stats in pool.map( - functools.partial(_preprocess_map_func, label_map=label_map, featurizers=self.featurizers), - patient_ids=patient_ids, + for chunk_stats in db.filter(patient_ids).map( + functools.partial(_preprocess_map_func, label_map=label_map, featurizers=self.featurizers) ): for a, b in zip(featurize_stats, chunk_stats): a.append(b) @@ -263,7 +271,7 @@ def preprocess_featurizers( def featurize( self, - pool: femr.mr.Pool, + db: meds_reader.PatientDatabase, labels: List[meds.Label], ) -> Mapping[str, np.ndarray]: """ @@ -288,8 +296,8 @@ def featurize( features = collections.defaultdict(list) - for feat_chunk in pool.map( - functools.partial(_features_map_func, label_map=label_map, featurizers=self.featurizers), patient_ids + for feat_chunk in db.filter(patient_ids).map( + functools.partial(_features_map_func, label_map=label_map, featurizers=self.featurizers) ): for k, v in feat_chunk.items(): features[k].append(v) @@ -313,30 +321,30 @@ def join_labels(features: Mapping[str, np.ndarray], labels: List[meds.Label]) -> labels = list(labels) labels.sort(key=lambda a: (a["patient_id"], a["prediction_time"])) - label_index = 0 - indices = [] label_values = [] order = np.lexsort((features["feature_times"], features["patient_ids"])) - for i, patient_id, feature_time in zip(order, features["patient_ids"][order], features["feature_times"][order]): - if label_index == len(labels): - break - - assert patient_id <= labels[label_index]["patient_id"], f"Missing features for label {labels[label_index]}" - if patient_id < labels[label_index]["patient_id"]: - continue + feature_index = 0 + for label in labels: + while ( + (feature_index + 1) < len(order) + and features["patient_ids"][order[feature_index + 1]] <= label["patient_id"] + and features["feature_times"][order[feature_index + 1]] <= label["prediction_time"] + ): + feature_index += 1 + is_valid = ( + feature_index < len(order) + and features["patient_ids"][order[feature_index]] == label["patient_id"] + and features["feature_times"][order[feature_index]] <= label["prediction_time"] + ) assert ( - feature_time <= labels[label_index]["prediction_time"] - ), f"Missing features for label {labels[label_index]}" - if feature_time < labels[label_index]["prediction_time"]: - continue - - indices.append(i) - label_values.append(labels[label_index]["boolean_value"]) - label_index += 1 + is_valid + ), f'{feature_index} {label} {features["patient_ids"][order[feature_index]]} {features["feature_times"][order[feature_index]]}' + indices.append(order[feature_index]) + label_values.append(label["boolean_value"]) return { "boolean_values": np.array(label_values), diff --git a/src/femr/featurizers/featurizers.py b/src/femr/featurizers/featurizers.py index 07ce942..74d9adf 100644 --- a/src/femr/featurizers/featurizers.py +++ b/src/femr/featurizers/featurizers.py @@ -35,28 +35,22 @@ def __init__(self, is_normalize: bool = True): def get_num_columns(self) -> int: return 1 - def generate_preprocess_data( - self, patients: Iterator[meds_reader.Patient], label_map: Mapping[int, List[meds.Label]] + def get_initial_preprocess_data(self) -> OnlineStatistics: + return OnlineStatistics() + + def add_preprocess_data( + self, age_statistics: OnlineStatistics, patient: meds_reader.Patient, label_map: Mapping[int, List[meds.Label]] ) -> OnlineStatistics: """Save the age of this patient (in years) at each label, to use for normalization.""" - if not self.is_needs_preprocessing(): - return OnlineStatistics() - - age_statistics: OnlineStatistics = OnlineStatistics() - - for patient in patients: - patient_birth_date: Optional[datetime.datetime] = get_patient_birthdate(patient) - assert patient_birth_date, "Patients must have a birth date" - - for label in label_map[patient.patient_id]: - age_in_yrs: float = (label["prediction_time"] - patient_birth_date).days / 365 - age_statistics.add(age_in_yrs) + patient_birth_date: Optional[datetime.datetime] = get_patient_birthdate(patient) + assert patient_birth_date, "Patients must have a birth date" - return age_statistics + for label in label_map[patient.patient_id]: + age_in_yrs: float = (label["prediction_time"] - patient_birth_date).days / 365 + age_statistics.add(age_in_yrs) def encorperate_prepreprocessed_data(self, data_elements: List[OnlineStatistics]) -> None: self.age_statistics = OnlineStatistics.merge(data_elements) - print("What", self.age_statistics, data_elements) def featurize( self, @@ -256,8 +250,15 @@ def get_columns(self, event: meds_reader.Event) -> Iterator[int]: if code in self.code_to_column_index: yield self.code_to_column_index[code] - def generate_preprocess_data( - self, patients: Iterator[meds_reader.Patient], label_map: Mapping[int, List[meds.Label]] + def get_initial_preprocess_data(self) -> Any: + return { + "observed_codes": set(), + "observed_string_value": collections.defaultdict(int), + "observed_numeric_value": collections.defaultdict(functools.partial(ReservoirSampler, 10000, 100)), + } + + def add_preprocess_data( + self, data: Any, patient: meds_reader.Patient, label_map: Mapping[int, List[meds.Label]] ) -> Any: """ Some featurizers need to do some preprocessing in order to prepare for featurization. @@ -266,34 +267,25 @@ def generate_preprocess_data( Note that this function shouldn't mutate the Featurizer as it will be sharded. """ - observed_codes: Set[str] = set() - observed_string_value: Dict[Tuple[str, str], int] = collections.defaultdict(int) - observed_numeric_value: Dict[str, ReservoirSampler] = collections.defaultdict( - functools.partial(ReservoirSampler, 10000, 100) - ) - - for patient in patients: - for event in patient.events: - # Check for excluded events - if self.excluded_event_filter is not None and self.excluded_event_filter(event): - continue - - if event.text_value is not None: - if self.string_value_combination: - observed_string_value[(event.code, event.text_value[: self.characters_for_string_values])] += 1 - elif event.numeric_value is not None: - if self.numeric_value_decile: - observed_numeric_value[event.code].add(event.numeric_value) - else: - for code in self.get_codes(event.code): - # If we haven't seen this code before, then add it to our list of included codes - observed_codes.add(code) - - return { - "observed_codes": observed_codes, - "observed_string_value": observed_string_value, - "observed_numeric_value": observed_numeric_value, - } + observed_codes: Set[str] = data["observed_codes"] + observed_string_value: Dict[Tuple[str, str], int] = data["observed_string_value"] + observed_numeric_value: Dict[str, ReservoirSampler] = data["observed_numeric_value"] + + for event in patient.events: + # Check for excluded events + if self.excluded_event_filter is not None and self.excluded_event_filter(event): + continue + + if event.text_value is not None: + if self.string_value_combination: + observed_string_value[(event.code, event.text_value[: self.characters_for_string_values])] += 1 + elif event.numeric_value is not None: + if self.numeric_value_decile: + observed_numeric_value[event.code].add(event.numeric_value) + else: + for code in self.get_codes(event.code): + # If we haven't seen this code before, then add it to our list of included codes + observed_codes.add(code) def encorperate_prepreprocessed_data(self, data_elements: List[Any]) -> None: """ diff --git a/src/femr/labelers/core.py b/src/femr/labelers/core.py index 8f5dc1e..d008877 100644 --- a/src/femr/labelers/core.py +++ b/src/femr/labelers/core.py @@ -15,8 +15,6 @@ import meds import meds_reader -import femr.mr - @dataclass(frozen=True) class TimeHorizon: @@ -63,7 +61,7 @@ def label(self, patient: meds_reader.Patient) -> List[meds.Label]: def apply( self, - pool: femr.mr.Pool, + db: meds_reader.PatientDatabase, ) -> List[meds.Label]: """Apply the `label()` function one-by-one to each Patient in a sequence of Patients. @@ -75,7 +73,7 @@ def apply( A list of labels """ - return list(itertools.chain.from_iterable(pool.map(functools.partial(_label_map_func, labeler=self)))) + return list(itertools.chain.from_iterable(db.map(functools.partial(_label_map_func, labeler=self)))) ########################################################## diff --git a/src/femr/models/processor.py b/src/femr/models/processor.py index abe1744..bb1c972 100644 --- a/src/femr/models/processor.py +++ b/src/femr/models/processor.py @@ -3,18 +3,21 @@ import collections import datetime import functools -from typing import Any, Dict, List, Mapping, Optional, Tuple +from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple +import datasets import meds +import meds_reader import numpy as np import torch.utils.data -import femr.hf_utils import femr.models.tokenizer import femr.pat_utils -def map_preliminary_batch_stats(batch, indices, *, processor: FEMRBatchProcessor, max_length: int): +def map_preliminary_batch_stats( + patients: Iterable[meds_reader.Patient], *, processor: FEMRBatchProcessor, max_length: int +): """ This function creates preliminary batch statistics, to be used for final batching. @@ -39,11 +42,7 @@ def map_preliminary_batch_stats(batch, indices, *, processor: FEMRBatchProcessor """ lengths = [] - for patient_index, patient_id, events in zip(indices, batch["patient_id"], batch["events"]): - patient = { - "patient_id": patient_id, - "events": events, - } + for patient in patients: data = processor.convert_patient(patient) # There are no labels for this patient @@ -57,28 +56,21 @@ def map_preliminary_batch_stats(batch, indices, *, processor: FEMRBatchProcessor for label_index in data["transformer"]["label_indices"]: if (label_index - current_start + 1) >= max_length: if current_start != current_end: - lengths.append((patient_index, current_start, current_end - current_start + 1)) + lengths.append((patient.patient_id, current_start, current_end - current_start + 1)) current_start = label_index - max_length + 1 current_end = label_index else: current_end = label_index - lengths.append((patient_index, current_start, current_end - current_start + 1)) + lengths.append((patient.patient_id, current_start, current_end - current_start + 1)) else: last_index = data["transformer"]["label_indices"][-1] length = min(max_length, last_index + 1) - lengths.append((patient_index, last_index + 1 - length, length)) + lengths.append((patient.patient_id, last_index + 1 - length, length)) if len(lengths) > 0: - return [np.array(lengths, dtype=np.int64)] + return np.array(lengths, dtype=np.int64) else: - return [] - - -def agg_preliminary_batch_stats(lengths1, lengths2): - """Aggregate preliminary length statistics from the map_preliminary_batch_stats""" - lengths1.extend(lengths2) - - return lengths1 + return np.zeros(shape=(0, 3), dtype=np.int64) class BatchCreator: @@ -168,41 +160,40 @@ def add_patient(self, patient: meds_reader.Patient, offset: int = 0, max_length: current_date = event.time.date() codes_seen_today = set() - for measurement in event["measurements"]: - # Get features and weights for the current event - features, weights = self.tokenizer.get_feature_codes(event.time, measurement) + # Get features and weights for the current event + features, weights = self.tokenizer.get_feature_codes(event) - # Ignore events with no features - if len(features) == 0: - continue + # Ignore events with no features + if len(features) == 0: + continue - # Ignore events where all features have already occurred - if all(feature in codes_seen_today for feature in features): - continue + # Ignore events where all features have already occurred + if all(feature in codes_seen_today for feature in features): + continue - codes_seen_today |= set(features) + codes_seen_today |= set(features) - if (self.task is not None) and (last_time is not None): - # Now we have to consider whether or not to have labels for this time step - # The add_event function returns how many labels to assign for this time - num_added = self.task.add_event(last_time, event.time, features) - for _ in range(num_added): - per_patient_label_indices.append(len(per_patient_ages) - 1) + if (self.task is not None) and (last_time is not None): + # Now we have to consider whether or not to have labels for this time step + # The add_event function returns how many labels to assign for this time + num_added = self.task.add_event(last_time, event.time, features) + for _ in range(num_added): + per_patient_label_indices.append(len(per_patient_ages) - 1) - if not self.tokenizer.is_hierarchical: - assert len(features) == 1 - per_patient_tokens.append(features[0]) - else: - assert weights is not None - per_patient_hierarchical_tokens.extend(features) - per_patient_hierarchical_weights.extend(weights) - per_patient_token_indices.append(len(per_patient_hierarchical_tokens)) + if not self.tokenizer.is_hierarchical: + assert len(features) == 1 + per_patient_tokens.append(features[0]) + else: + assert weights is not None + per_patient_hierarchical_tokens.extend(features) + per_patient_hierarchical_weights.extend(weights) + per_patient_token_indices.append(len(per_patient_hierarchical_tokens)) - per_patient_ages.append((event.time - birth) / datetime.timedelta(days=1)) - per_patient_normalized_ages.append(self.tokenizer.normalize_age(event.time - birth)) - per_patient_timestamps.append(event.time.replace(tzinfo=datetime.timezone.utc).timestamp()) + per_patient_ages.append((event.time - birth) / datetime.timedelta(days=1)) + per_patient_normalized_ages.append(self.tokenizer.normalize_age(event.time - birth)) + per_patient_timestamps.append(event.time.replace(tzinfo=datetime.timezone.utc).timestamp()) - last_time = event.time + last_time = event.time if self.task is not None and last_time is not None: num_added = self.task.add_event(last_time, None, None) @@ -327,18 +318,19 @@ def cleanup_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: return batch -def _batch_generator(batch_data: Tuple[np.ndarray, np.ndarray], *, creator: BatchCreator, dataset: datasets.Dataset): - for lengths, offsets in batch_data: - offsets = list(offsets) - for start, end in zip(offsets, offsets[1:]): - creator.start_batch() - for patient_index, offset, length in lengths[start:end, :]: - creator.add_patient(dataset[patient_index.item()], offset, length) +def _batch_generator(batch_data: Tuple[np.ndarray, np.ndarray], *, creator: BatchCreator, path_to_database: str): + with meds_reader.PatientDatabase(path_to_database) as database: + for lengths, offsets in batch_data: + offsets = list(offsets) + for start, end in zip(offsets, offsets[1:]): + creator.start_batch() + for patient_index, offset, length in lengths[start:end, :]: + creator.add_patient(database[patient_index.item()], offset, length) - result = creator.get_batch_data() - assert "task" in result, f"No task present in {lengths[start:end,:]}" + result = creator.get_batch_data() + assert "task" in result, f"No task present in {lengths[start:end,:]}" - yield result + yield result def _add_dimension(data: Any) -> Any: @@ -399,7 +391,9 @@ def collate(self, batches: List[Mapping[str, Any]]) -> Mapping[str, Any]: assert len(batches) == 1, "Can only have one batch when collating" return {"batch": _add_dimension(self.creator.cleanup_batch(batches[0]))} - def convert_dataset(self, dataset, tokens_per_batch: int, min_patients_per_batch: int = 4, num_proc: int = 1): + def convert_dataset( + self, db: meds_reader.PatientDatabase, tokens_per_batch: int, min_patients_per_batch: int = 4, num_proc: int = 1 + ): """Convert an entire dataset to batches. Arguments: @@ -411,25 +405,16 @@ def convert_dataset(self, dataset, tokens_per_batch: int, min_patients_per_batch Returns: A huggingface dataset object containing batches """ - if isinstance(dataset, datasets.DatasetDict): - return datasets.DatasetDict( - { - k: self.convert_dataset(v, tokens_per_batch, min_patients_per_batch, num_proc) - for k, v in dataset.items() - } - ) max_length = tokens_per_batch // min_patients_per_batch - lengths = femr.hf_utils.aggregate_over_dataset( - dataset, - functools.partial(map_preliminary_batch_stats, processor=self, max_length=max_length), - agg_preliminary_batch_stats, - num_proc=num_proc, - batch_size=1_000, - with_indices=True, + + length_chunks = tuple( + db.map( + functools.partial(map_preliminary_batch_stats, processor=self, max_length=max_length), + ) ) - lengths = np.concatenate(lengths) + lengths = np.concatenate(length_chunks) rng = np.random.default_rng() rng.shuffle(lengths) @@ -468,12 +453,10 @@ def convert_dataset(self, dataset, tokens_per_batch: int, min_patients_per_batch ) ) - print("Creating batches", len(batches)) - batch_func = functools.partial( _batch_generator, creator=self.creator, - dataset=dataset, + path_to_database=db.path_to_database, ) batch_dataset = datasets.Dataset.from_generator( diff --git a/src/femr/models/tasks.py b/src/femr/models/tasks.py index d67cd88..bfa24b4 100644 --- a/src/femr/models/tasks.py +++ b/src/femr/models/tasks.py @@ -4,15 +4,16 @@ import collections import datetime import functools -from typing import Any, Dict, List, Mapping, Optional, Sequence, Set, Tuple +from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Set, Tuple import meds +import meds_reader import numpy as np import scipy.sparse import torch -import femr.index import femr.models.config +import femr.ontology import femr.pat_utils import femr.stat_utils @@ -67,10 +68,6 @@ def __init__(self, labels: Sequence[meds.Label]): def get_task_config(self) -> femr.models.config.FEMRTaskConfig: return femr.models.config.FEMRTaskConfig(task_type="labeled_patients") - def filter_dataset(self, dataset: datasets.Dataset, index: femr.index.PatientIndex) -> datasets.Dataset: - indices = [index.get_index(patient_id) for patient_id in self.label_map] - return dataset.select(indices) - def start_patient(self, patient: meds_reader.Patient, _ontology: Optional[femr.ontology.Ontology]) -> None: self.current_labels = self.label_map[patient.patient_id] self.current_label_index = 0 @@ -184,15 +181,14 @@ def __init__( self, ontology: femr.ontology.Ontology, patient: meds_reader.Patient, code_whitelist: Optional[Set[str]] = None ): self.survival_events = [] - self.final_date = patient.events[-1]["time"] + self.final_date = patient.events[-1].time self.future_times = collections.defaultdict(list) for event in patient.events: codes = set() - for measurement in event["measurements"]: - for parent in ontology.get_all_parents(event.code): - if code_whitelist is None or parent in code_whitelist: - codes.add(parent) + for parent in ontology.get_all_parents(event.code): + if code_whitelist is None or parent in code_whitelist: + codes.add(parent) for code in codes: self.future_times[code].append(event.time) @@ -219,14 +215,14 @@ def get_future_events_for_time( return (delta, {k: v[-1] - time for k, v in self.future_times.items()}) -def _prefit_motor_map(batch, *, tasks: List[str], ontology: femr.ontology.Ontology) -> Any: +def _prefit_motor_map( + patients: Iterator[meds_reader.Patient], *, tasks: List[str], ontology: femr.ontology.Ontology +) -> Any: task_time_stats: List[Any] = [[0, 0, femr.stat_utils.OnlineStatistics()] for _ in range(len(tasks))] event_times = femr.stat_utils.ReservoirSampler(100_000) task_set = set(tasks) - for patient_id, events in zip(batch["patient_id"], batch["events"]): - patient = {"patient_id": patient_id, "events": events} - + for patient in patients: calculator = SurvivalCalculator(ontology, patient, task_set) birth = femr.pat_utils.get_patient_birthdate(patient) @@ -268,7 +264,7 @@ class MOTORTask(Task): @classmethod def fit_pretraining_task_info( cls, - dataset: datasets.Dataset, + db: meds_reader.PatientDatabase, tokenizer: femr.models.tokenizer.FEMRTokenizer, num_tasks: int, num_bins: int, @@ -284,12 +280,8 @@ def fit_pretraining_task_info( assert len(tasks) == num_tasks, "Could not find enough tasks in the provided tokenizer" - length_samples, stats = femr.hf_utils.aggregate_over_dataset( - dataset, - functools.partial(_prefit_motor_map, tasks=tasks, ontology=tokenizer.ontology), - _prefit_motor_agg, - 1_000, - num_proc=num_proc, + length_samples, stats = functools.reduce( + _prefit_motor_agg, db.map(functools.partial(_prefit_motor_map, tasks=tasks, ontology=tokenizer.ontology)) ) time_bins = np.percentile(length_samples.samples, np.linspace(0, 100, num_bins + 1)) diff --git a/src/femr/models/tokenizer.py b/src/femr/models/tokenizer.py index f7004b9..45f5347 100644 --- a/src/femr/models/tokenizer.py +++ b/src/femr/models/tokenizer.py @@ -6,35 +6,39 @@ import functools import math import os -from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterator, List, Mapping, Optional, Set, Tuple, Union import meds +import meds_reader import msgpack import numpy as np import transformers -import femr.hf_utils +import femr.ontology import femr.stat_utils def train_tokenizer( - dataset, + db: meds_reader.PatientDatabase, vocab_size: int, is_hierarchical: bool = False, num_numeric: int = 1000, ontology: Optional[femr.ontology.Ontology] = None, - num_proc: int = 1, ) -> FEMRTokenizer: """Train a FEMR tokenizer from the given dataset""" - statistics = femr.hf_utils.aggregate_over_dataset( - dataset, - functools.partial( - map_statistics, num_patients=len(dataset), is_hierarchical=is_hierarchical, ontology=ontology - ), + + statistics = functools.reduce( agg_statistics, - num_proc=num_proc, - batch_size=1_000, + db.map( + functools.partial( + map_statistics, + num_patients=len(db), + is_hierarchical=is_hierarchical, + ontology=ontology, + ) + ), ) + return FEMRTokenizer( convert_statistics_to_msgpack(statistics, vocab_size, is_hierarchical, num_numeric, ontology), ontology ) @@ -64,7 +68,12 @@ def normalize_unit(unit): def map_statistics( - batch, *, num_patients: int, is_hierarchical: bool, frac_values=0.05, ontology: Optional[femr.ontology.Ontology] + patients: Iterator[meds_reader.Patient], + *, + num_patients: int, + is_hierarchical: bool, + frac_values=0.05, + ontology: Optional[femr.ontology.Ontology], ) -> Mapping[str, Any]: age_stats = femr.stat_utils.OnlineStatistics() code_counts: Dict[str, float] = collections.defaultdict(float) @@ -81,43 +90,39 @@ def map_statistics( text_counts: Dict[Any, float] = collections.defaultdict(float) - for events in batch["events"]: - total_events = 0 - for event in events: - for measurement in event["measurements"]: - total_events += 1 + for patient in patients: + total_events = len(patient.events) if total_events == 0: continue weight = 1.0 / (num_patients * total_events) - birth_date = events[0]["time"] + birth_date = patient.events[0].time code_set = set() text_set = set() pat_numeric_samples = [] - for event in events: - for measurement in event["measurements"]: - if event.time != birth_date: - age_stats.add(weight, (event.time - birth_date).total_seconds()) - if not is_hierarchical: - assert numeric_samples_by_lab is not None - if event.numeric_value is not None: - numeric_samples_by_lab[event.code].add(event.numeric_value, weight) - elif event.text_value is not None: - text_counts[(event.code, event.text_value)] += weight - else: - code_counts[event.code] += weight + for event in patient.events: + if event.time != birth_date: + age_stats.add(weight, (event.time - birth_date).total_seconds()) + if not is_hierarchical: + assert numeric_samples_by_lab is not None + if event.numeric_value is not None: + numeric_samples_by_lab[event.code].add(event.numeric_value, weight) + elif event.text_value is not None: + text_counts[(event.code, event.text_value)] += weight else: - code_set.add(event.code) + code_counts[event.code] += weight + else: + code_set.add(event.code) - if event.text_value is not None and event.text_value != "": - text_set.add(event.text_value) + if event.text_value is not None and event.text_value != "": + text_set.add(event.text_value) - if measurement.get("metadata") and normalize_unit(measurement["metadata"].get("unit")) is not None: - text_set.add(normalize_unit(measurement["metadata"]["unit"])) + if getattr(event, "unit", None) is not None: + text_set.add(normalize_unit(event.unit)) - if event.numeric_value is not None: - pat_numeric_samples.append(event.numeric_value) + if event.numeric_value is not None: + pat_numeric_samples.append(event.numeric_value) if is_hierarchical: assert numeric_samples is not None @@ -390,9 +395,7 @@ def start_patient(self): # This is currently a null-op, but is required for cost featurization pass - def get_feature_codes( - self, _time: datetime.datetime, measurement: meds_reader.Event - ) -> Tuple[List[int], Optional[List[float]]]: + def get_feature_codes(self, event: meds_reader.Event) -> Tuple[List[int], Optional[List[float]]]: """Get codes for the provided measurement and time""" # Note that time is currently not used in this code, but it is required for cost featurization @@ -404,15 +407,15 @@ def get_feature_codes( if parent in self.code_lookup ] weights = [1 / len(codes) for _ in codes] - if measurement.get("metadata") and normalize_unit(measurement["metadata"].get("unit")) is not None: - value = self.string_lookup.get(normalize_unit(measurement["metadata"]["unit"])) + if getattr(event, "unit", None) is not None: + value = self.string_lookup.get(normalize_unit(event.unit)) if value is not None: codes.append(value) weights.append(1) - if measurement.get("numeric_value") is not None and len(self.numeric_indices) > 0: + if event.numeric_value is not None and len(self.numeric_indices) > 0: codes.append(self.numeric_indices[bisect.bisect(self.numeric_values, event.numeric_value)]) weights.append(1) - if measurement.get("text_value") is not None: + if event.text_value is not None: value = self.string_lookup.get(event.text_value) if value is not None: codes.append(value) @@ -420,13 +423,13 @@ def get_feature_codes( return codes, weights else: - if measurement.get("numeric_value") is not None: + if event.numeric_value is not None: for start, end, i in self.numeric_lookup.get(event.code, []): if start <= event.numeric_value < end: return [i], None else: return [], None - elif measurement.get("text_value") is not None: + elif event.text_value is not None: value = self.string_lookup.get((event.code, event.text_value)) if value is not None: return [value], None diff --git a/src/femr/models/transformer.py b/src/femr/models/transformer.py index 4c529b6..b0c0e58 100644 --- a/src/femr/models/transformer.py +++ b/src/femr/models/transformer.py @@ -1,10 +1,12 @@ from __future__ import annotations import collections +import datetime import math from typing import Any, Dict, List, Mapping, Optional, Tuple import meds +import meds_reader import numpy as np import torch import torch.nn.functional as F @@ -282,6 +284,8 @@ def create_task_head(self) -> nn.Module: return LabeledPatientTaskHead(hidden_size, **task_kwargs) elif task_type == "motor": return MOTORTaskHead(hidden_size, **task_kwargs) + else: + raise RuntimeError("Could not determine head for task " + task_type) def forward(self, batch: Mapping[str, Any], return_loss=True, return_logits=False, return_reprs=False): # Need a return_loss parameter for transformers.Trainer to work properly @@ -321,7 +325,7 @@ def forward(self, batch: Mapping[str, Any], return_loss=True, return_logits=Fals def compute_features( - dataset: datasets.Dataset, + db: meds_reader.PatientDatabase, model_path: str, labels: List[meds.Label], num_proc: int = 1, @@ -347,13 +351,11 @@ def compute_features( """ task = femr.models.tasks.LabeledPatientTask(labels) - index = femr.index.PatientIndex(dataset, num_proc=num_proc) - model = femr.models.transformer.FEMRModel.from_pretrained(model_path, task_config=task.get_task_config()) tokenizer = femr.models.tokenizer.FEMRTokenizer.from_pretrained(model_path, ontology=ontology) processor = femr.models.processor.FEMRBatchProcessor(tokenizer, task=task) - filtered_data = task.filter_dataset(dataset, index) + filtered_data = db.filter(list(task.label_map.keys())) if device: model = model.to(device) diff --git a/src/femr/mr.py b/src/femr/mr.py deleted file mode 100644 index b387c78..0000000 --- a/src/femr/mr.py +++ /dev/null @@ -1,97 +0,0 @@ -from __future__ import annotations - -import multiprocessing -import pickle -from typing import Any, Callable, Iterable, Iterator, List, Optional, Tuple, TypeVar - -import meds_reader - -A = TypeVar("A") - -WorkEntry = Tuple[Callable[[Iterable[meds_reader.Patient]], Any], List[int]] - - -def _runner( - database: meds_reader.PatientDatabase, - input_queue: multiprocessing.Queue[Optional[WorkEntry]], - result_queue: multiprocessing.Queue[Any], -) -> None: - while True: - next_work = input_queue.get() - if next_work is None: - break - - map_func, patient_ids = next_work - - map_func = pickle.loads(map_func) - - result = map_func(database[patient_id] for patient_id in patient_ids) - result_queue.put(result) - - -class Pool: - def __init__(self, database: meds_reader.PatientDatabase, num_threads: int) -> None: - self._all_patient_ids = list(database) - self._num_threads = num_threads - - if num_threads != 1: - self._processes = [] - mp = multiprocessing.get_context("spawn") - - self._input_queue: multiprocessing.Queue[Optional[WorkEntry]] = mp.Queue() - self._result_queue: multiprocessing.Queue[Any] = mp.Queue() - - for _ in range(num_threads): - process = mp.Process( - target=_runner, - kwargs={"database": database, "input_queue": self._input_queue, "result_queue": self._result_queue}, - ) - process.start() - self._processes.append(process) - else: - self._database = database - - def map( - self, map_func: Callable[[Iterable[meds_reader.Patient]], A], patient_ids: Optional[List[int]] = None - ) -> Iterator[A]: - """Apply the provided map function to the database""" - - if patient_ids is None: - patient_ids = self._all_patient_ids - - if self._num_threads != 1: - patients_per_part = (len(patient_ids) + len(self._processes) - 1) // len(self._processes) - - num_work_entries = 0 - - map_func_p = pickle.dumps(map_func) - - for i in range(len(self._processes)): - patient_ids_for_thread = patient_ids[i * patients_per_part : (i + 1) * patients_per_part] - - if len(patient_ids_for_thread) == 0: - continue - - num_work_entries += 1 - self._input_queue.put((map_func_p, patient_ids_for_thread)) - - return (self._result_queue.get() for _ in range(num_work_entries)) - else: - return (map_func(self._database[patient_id] for patient_id in patient_ids),) - - def terminate(self) -> None: - """Close the pool""" - if self._num_threads != 1: - for _ in self._processes: - self._input_queue.put(None) - for process in self._processes: - process.join() - - def __enter__(self) -> Pool: - return self - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - self.terminate() - - -__all__ = ["Pool"] diff --git a/src/femr/ontology.py b/src/femr/ontology.py index 0019927..6a42b6d 100644 --- a/src/femr/ontology.py +++ b/src/femr/ontology.py @@ -3,28 +3,21 @@ import collections import functools import os -from typing import Any, Dict, Iterable, Optional, Set +from typing import Any, Dict, Iterable, Iterator, Optional, Set import meds +import meds_reader import polars as pl -import femr.mr - -def _get_all_codes_map(batch) -> Set[str]: +def _get_all_codes_map(patients: Iterator[meds_reader.Patient]) -> Set[str]: result = set() - for events in batch["events"]: - for event in events: - for measurement in event["measurements"]: - result.add(event.code) + for patient in patients: + for event in patient.events: + result.add(event.code) return result -def _get_all_codes_agg(first: Set[str], second: Set[str]) -> Set[str]: - first |= second - return first - - class Ontology: def __init__(self, athena_path: str, code_metadata: meds.CodeMetadata = {}): """Create an Ontology from an Athena download and an optional meds Code Metadata structure. @@ -105,18 +98,13 @@ def __init__(self, athena_path: str, code_metadata: meds.CodeMetadata = {}): def prune_to_dataset( self, - dataset: datasets.Dataset, - num_proc: int = 1, + data_pool: meds_reader.PatientDatabase, prune_all_descriptions: bool = False, remove_ontologies: Set[str] = set(), ) -> None: - valid_codes = femr.hf_utils.aggregate_over_dataset( - dataset, - functools.partial(_get_all_codes_map), - _get_all_codes_agg, - num_proc=num_proc, - batch_size=1_000, - ) + valid_codes = set() + for chunk_codes in data_pool.map(_get_all_codes_map): + valid_codes |= chunk_codes if prune_all_descriptions: self.description_map = {} diff --git a/src/femr/post_etl_pipelines/stanford.py b/src/femr/post_etl_pipelines/stanford.py index 18ce91c..62d0ee9 100644 --- a/src/femr/post_etl_pipelines/stanford.py +++ b/src/femr/post_etl_pipelines/stanford.py @@ -7,6 +7,7 @@ from typing import Callable, Sequence import meds +import meds_reader from femr.transforms import delta_encode, remove_nones from femr.transforms.stanford import ( @@ -19,7 +20,7 @@ def _is_visit_measurement(e: meds_reader.Event) -> bool: - return e["metadata"]["table"] == "visit" + return e.table == "visit" def _get_stanford_transformations() -> Callable[[meds_reader.Patient], meds_reader.Patient]: diff --git a/src/femr/splits.py b/src/femr/splits.py index deb533f..1398942 100644 --- a/src/femr/splits.py +++ b/src/femr/splits.py @@ -6,8 +6,6 @@ import struct from typing import List -import femr.index - @dataclasses.dataclass class PatientSplit: diff --git a/src/femr/transforms/stanford.py b/src/femr/transforms/stanford.py index f139eb0..79932f3 100644 --- a/src/femr/transforms/stanford.py +++ b/src/femr/transforms/stanford.py @@ -1,3 +1,5 @@ +# mypy: disable-error-code="attr-defined" + """Transforms that are unique to STARR OMOP.""" import datetime @@ -158,7 +160,7 @@ def move_billing_codes(patient: meds_reader.Patient) -> meds_reader.Patient: if event.end is not None: if event.visit_id is None: # Every event with an end time should have a visit ID associated with it - raise RuntimeError(f"Expected visit id for visit? {patient['patient_id']} {event}") + raise RuntimeError(f"Expected visit id for visit? {patient.patient_id} {event}") if end_visits.get(event.visit_id, event.end) != event.end: # Every event associated with a visit should have an end time that matches the visit end time # Also the end times of all events associated with a visit should have the same end time diff --git a/tests/featurizers/test_featurizers.py b/tests/featurizers/test_featurizers.py index 714be9c..2eaa53f 100644 --- a/tests/featurizers/test_featurizers.py +++ b/tests/featurizers/test_featurizers.py @@ -7,7 +7,6 @@ import scipy.sparse import femr -import femr.mr from femr.featurizers import FeaturizerList from femr.featurizers.featurizers import AgeFeaturizer, CountFeaturizer from femr.labelers import TimeHorizon @@ -31,7 +30,6 @@ def test_age_featurizer() -> None: time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=180)) dataset = femr_test_tools.create_patients_dataset(100) - pool = femr.mr.Pool(dataset, num_threads=1) labeler = CodeLabeler(["2"], time_horizon, ["3"]) @@ -44,12 +42,12 @@ def test_age_featurizer() -> None: assert patient_features[1] == [(0, 17.767123287671232)] assert patient_features[-1] == [(0, 20.46027397260274)] - all_labels = labeler.apply(pool) + all_labels = labeler.apply(dataset) featurizer = AgeFeaturizer(is_normalize=True) featurizer_list = FeaturizerList([featurizer]) - featurizer_list.preprocess_featurizers(pool, all_labels) - featurized_patients = featurizer_list.featurize(pool, all_labels) + featurizer_list.preprocess_featurizers(dataset, all_labels) + featurized_patients = featurizer_list.featurize(dataset, all_labels) _assert_featurized_patients_structure(all_labels, featurized_patients) @@ -58,14 +56,14 @@ def test_count_featurizer() -> None: time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=180)) dataset = femr_test_tools.create_patients_dataset(100) - pool = femr.mr.Pool(dataset, num_threads=1) labeler = CodeLabeler(["2"], time_horizon, ["3"]) patient: meds_reader.Patient = dataset[0] labels = labeler.label(patient) featurizer = CountFeaturizer() - data = featurizer.generate_preprocess_data([patient], {patient.patient_id: labels}) + data = featurizer.get_initial_preprocess_data() + featurizer.add_preprocess_data(data, patient, {patient.patient_id: labels}) featurizer.encorperate_prepreprocessed_data([data]) patient_features = featurizer.featurize(patient, labels) @@ -89,12 +87,12 @@ def test_count_featurizer() -> None: ("2", 4), } - all_labels = labeler.apply(pool) + all_labels = labeler.apply(dataset) featurizer = CountFeaturizer() featurizer_list = FeaturizerList([featurizer]) - featurizer_list.preprocess_featurizers(pool, all_labels) - featurized_patients = featurizer_list.featurize(pool, all_labels) + featurizer_list.preprocess_featurizers(dataset, all_labels) + featurized_patients = featurizer_list.featurize(dataset, all_labels) _assert_featurized_patients_structure(all_labels, featurized_patients) @@ -103,7 +101,6 @@ def test_count_featurizer_with_ontology() -> None: time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=180)) dataset = femr_test_tools.create_patients_dataset(100) - pool = femr.mr.Pool(dataset, num_threads=1) labeler = CodeLabeler(["2"], time_horizon, ["3"]) @@ -118,7 +115,8 @@ def get_all_parents(self, code): return {code} featurizer = CountFeaturizer(is_ontology_expansion=True, ontology=cast(femr.ontology.Ontology, DummyOntology())) - data = featurizer.generate_preprocess_data([patient], {patient.patient_id: labels}) + data = featurizer.get_initial_preprocess_data() + featurizer.add_preprocess_data(data, patient, {patient.patient_id: labels}) featurizer.encorperate_prepreprocessed_data([data]) patient_features = featurizer.featurize(patient, labels) @@ -145,12 +143,12 @@ def get_all_parents(self, code): ("2", 4), } - all_labels = labeler.apply(pool) + all_labels = labeler.apply(dataset) featurizer = CountFeaturizer(is_ontology_expansion=True, ontology=cast(femr.ontology.Ontology, DummyOntology())) featurizer_list = FeaturizerList([featurizer]) - featurizer_list.preprocess_featurizers(pool, all_labels) - featurized_patients = featurizer_list.featurize(pool, all_labels) + featurizer_list.preprocess_featurizers(dataset, all_labels) + featurized_patients = featurizer_list.featurize(dataset, all_labels) _assert_featurized_patients_structure(all_labels, featurized_patients) @@ -159,14 +157,14 @@ def test_count_featurizer_with_values() -> None: time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=180)) dataset = femr_test_tools.create_patients_dataset(100) - pool = femr.mr.Pool(dataset, num_threads=1) labeler = CodeLabeler(["2"], time_horizon, ["3"]) patient: meds_reader.Patient = dataset[0] labels = labeler.label(patient) featurizer = CountFeaturizer(numeric_value_decile=True, string_value_combination=True) - data = featurizer.generate_preprocess_data([patient], {patient.patient_id: labels}) + data = featurizer.get_initial_preprocess_data() + featurizer.add_preprocess_data(data, patient, {patient.patient_id: labels}) featurizer.encorperate_prepreprocessed_data([data]) patient_features = featurizer.featurize(patient, labels) @@ -197,12 +195,12 @@ def test_count_featurizer_with_values() -> None: ("1 test_value", 2), } - all_labels = labeler.apply(pool) + all_labels = labeler.apply(dataset) featurizer = CountFeaturizer(numeric_value_decile=True, string_value_combination=True) featurizer_list = FeaturizerList([featurizer]) - featurizer_list.preprocess_featurizers(pool, all_labels) - featurized_patients = featurizer_list.featurize(pool, all_labels) + featurizer_list.preprocess_featurizers(dataset, all_labels) + featurized_patients = featurizer_list.featurize(dataset, all_labels) _assert_featurized_patients_structure(all_labels, featurized_patients) @@ -211,7 +209,6 @@ def test_count_featurizer_exclude_filter() -> None: time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=180)) dataset = femr_test_tools.create_patients_dataset(100) - pool = femr.mr.Pool(dataset, num_threads=1) labeler = CodeLabeler(["2"], time_horizon, ["3"]) @@ -220,21 +217,24 @@ def test_count_featurizer_exclude_filter() -> None: # Test filtering all codes featurizer = CountFeaturizer(excluded_event_filter=lambda _: True) - data = featurizer.generate_preprocess_data([patient], {patient.patient_id: labels}) + data = featurizer.get_initial_preprocess_data() + featurizer.add_preprocess_data(data, patient, {patient.patient_id: labels}) featurizer.encorperate_prepreprocessed_data([data]) assert featurizer.get_num_columns() == 0 # Test filtering no codes featurizer = CountFeaturizer(excluded_event_filter=lambda _: False) - data = featurizer.generate_preprocess_data([patient], {patient.patient_id: labels}) + data = featurizer.get_initial_preprocess_data() + featurizer.add_preprocess_data(data, patient, {patient.patient_id: labels}) featurizer.encorperate_prepreprocessed_data([data]) assert featurizer.get_num_columns() == 4 # Test filtering single code featurizer = CountFeaturizer(excluded_event_filter=lambda e: e.code == "3") - data = featurizer.generate_preprocess_data([patient], {patient.patient_id: labels}) + data = featurizer.get_initial_preprocess_data() + featurizer.add_preprocess_data(data, patient, {patient.patient_id: labels}) featurizer.encorperate_prepreprocessed_data([data]) assert featurizer.get_num_columns() == 3 @@ -244,7 +244,6 @@ def test_count_bins_featurizer() -> None: time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=180)) dataset = femr_test_tools.create_patients_dataset(100) - pool = femr.mr.Pool(dataset, num_threads=1) labeler = CodeLabeler(["2"], time_horizon, ["3"]) @@ -258,7 +257,8 @@ def test_count_bins_featurizer() -> None: featurizer = CountFeaturizer( time_bins=time_bins, ) - data = featurizer.generate_preprocess_data([patient], {patient.patient_id: labels}) + data = featurizer.get_initial_preprocess_data() + featurizer.add_preprocess_data(data, patient, {patient.patient_id: labels}) featurizer.encorperate_prepreprocessed_data([data]) patient_features = featurizer.featurize(patient, labels) @@ -285,7 +285,7 @@ def test_count_bins_featurizer() -> None: ("3_70000 days, 0:00:00", 2), } - all_labels = labeler.apply(pool) + all_labels = labeler.apply(dataset) time_bins = [ datetime.timedelta(days=90), @@ -296,8 +296,8 @@ def test_count_bins_featurizer() -> None: time_bins=time_bins, ) featurizer_list = FeaturizerList([featurizer]) - featurizer_list.preprocess_featurizers(pool, all_labels) - featurized_patients = featurizer_list.featurize(pool, all_labels) + featurizer_list.preprocess_featurizers(dataset, all_labels) + featurized_patients = featurizer_list.featurize(dataset, all_labels) _assert_featurized_patients_structure(all_labels, featurized_patients) @@ -306,16 +306,15 @@ def test_complete_featurization() -> None: time_horizon = TimeHorizon(datetime.timedelta(days=0), datetime.timedelta(days=180)) dataset = femr_test_tools.create_patients_dataset(100) - pool = femr.mr.Pool(dataset, num_threads=1) labeler = CodeLabeler(["2"], time_horizon, ["3"]) - all_labels = labeler.apply(pool) + all_labels = labeler.apply(dataset) age_featurizer = AgeFeaturizer(is_normalize=True) age_featurizer_list = FeaturizerList([age_featurizer]) - age_featurizer_list.preprocess_featurizers(pool, all_labels) - age_featurized_patients = age_featurizer_list.featurize(pool, all_labels) + age_featurizer_list.preprocess_featurizers(dataset, all_labels) + age_featurized_patients = age_featurizer_list.featurize(dataset, all_labels) time_bins = [ datetime.timedelta(days=90), @@ -324,8 +323,8 @@ def test_complete_featurization() -> None: ] count_featurizer = CountFeaturizer(time_bins=time_bins) count_featurizer_list = FeaturizerList([count_featurizer]) - count_featurizer_list.preprocess_featurizers(pool, all_labels) - count_featurized_patients = count_featurizer_list.featurize(pool, all_labels) + count_featurizer_list.preprocess_featurizers(dataset, all_labels) + count_featurized_patients = count_featurizer_list.featurize(dataset, all_labels) age_featurizer = AgeFeaturizer(is_normalize=True) time_bins = [ @@ -335,8 +334,8 @@ def test_complete_featurization() -> None: ] count_featurizer = CountFeaturizer(time_bins=time_bins) featurizer_list = FeaturizerList([age_featurizer, count_featurizer]) - featurizer_list.preprocess_featurizers(pool, all_labels) - featurized_patients = featurizer_list.featurize(pool, all_labels) + featurizer_list.preprocess_featurizers(dataset, all_labels) + featurized_patients = featurizer_list.featurize(dataset, all_labels) assert featurized_patients["patient_ids"].shape == count_featurized_patients["patient_ids"].shape diff --git a/tests/femr_test_tools.py b/tests/femr_test_tools.py index d673a3c..db4f2d6 100644 --- a/tests/femr_test_tools.py +++ b/tests/femr_test_tools.py @@ -7,7 +7,6 @@ import meds import meds_reader -import femr.mr from femr.labelers import Labeler # 2nd elem of tuple -- 'skip' means no label, None means censored @@ -56,7 +55,14 @@ class DummyPatient: class DummyDatabase(dict): - pass + def filter(self, patient_ids): + return DummyDatabase({p: self[p] for p in patient_ids}) + + def map( + self, + map_func, + ) -> Iterator[A]: + return [map_func(self.values())] def create_patients_dataset( @@ -117,21 +123,21 @@ def run_test_for_labeler( help_text: str = "", ) -> None: patients: meds_reader.PatientDatabase = create_patients_dataset(10, [x[0] for x in events_with_labels]) - with femr.mr.Pool(patients, num_threads=1) as pool: - true_labels: List[Tuple[datetime.datetime, Optional[bool]]] = [ - (datetime.datetime(*x[0][0]), x[1]) for x in events_with_labels if isinstance(x[1], bool) - ] - if true_prediction_times is not None: - # If manually specified prediction times, adjust labels from occurring at `event.start` - # e.g. we may make predictions at `event.end` or `event.start + 1 day` - true_labels = [(tp, tl[1]) for (tl, tp) in zip(true_labels, true_prediction_times)] - labeled_patients: List[meds.Label] = labeler.apply(pool) - - # Check accuracy of Labels - for patient_id in patients: - assert_labels_are_accurate( - labeled_patients, - patient_id, - true_labels, - help_text=help_text, - ) + + true_labels: List[Tuple[datetime.datetime, Optional[bool]]] = [ + (datetime.datetime(*x[0][0]), x[1]) for x in events_with_labels if isinstance(x[1], bool) + ] + if true_prediction_times is not None: + # If manually specified prediction times, adjust labels from occurring at `event.start` + # e.g. we may make predictions at `event.end` or `event.start + 1 day` + true_labels = [(tp, tl[1]) for (tl, tp) in zip(true_labels, true_prediction_times)] + labeled_patients: List[meds.Label] = labeler.apply(patients) + + # Check accuracy of Labels + for patient_id in patients: + assert_labels_are_accurate( + labeled_patients, + patient_id, + true_labels, + help_text=help_text, + ) diff --git a/tests/models/test_batch_creator.py b/tests/models/test_batch_creator.py index a82ace9..fd5c1d7 100644 --- a/tests/models/test_batch_creator.py +++ b/tests/models/test_batch_creator.py @@ -15,7 +15,7 @@ def __init__(self, is_hierarchical: bool = False): def start_patient(self): pass - def get_feature_codes(self, time, measurement): + def get_feature_codes(self, event): if event.code == "SNOMED/184099003": return [1], None else: diff --git a/tests/models/test_survival_calculator.py b/tests/models/test_survival_calculator.py index b08e557..9273e52 100644 --- a/tests/models/test_survival_calculator.py +++ b/tests/models/test_survival_calculator.py @@ -1,6 +1,8 @@ import datetime from typing import Set +from femr_test_tools import DummyEvent, DummyPatient + import femr.models.tasks @@ -13,35 +15,15 @@ def get_all_parents(self, code: str) -> Set[str]: def test_calculator(): - patient = { - "patient_id": 100, - "events": [ - { - "time": datetime.datetime(1990, 1, 10), - "measurements": [ - {"code": "1"}, - ], - }, - { - "time": datetime.datetime(1990, 1, 20), - "measurements": [ - {"code": "2"}, - ], - }, - { - "time": datetime.datetime(1990, 1, 25), - "measurements": [ - {"code": "3"}, - ], - }, - { - "time": datetime.datetime(1990, 1, 25), - "measurements": [ - {"code": "1"}, - ], - }, + patient = DummyPatient( + patient_id=100, + events=[ + DummyEvent(time=datetime.datetime(1990, 1, 10), code="1"), + DummyEvent(time=datetime.datetime(1990, 1, 20), code="2"), + DummyEvent(time=datetime.datetime(1990, 1, 25), code="3"), + DummyEvent(time=datetime.datetime(1990, 1, 25), code="1"), ], - } + ) calculator = femr.models.tasks.SurvivalCalculator(DummyOntology(), patient) diff --git a/tutorials/1_Ontology.ipynb b/tutorials/1_Ontology.ipynb index 4e92790..0a21642 100644 --- a/tutorials/1_Ontology.ipynb +++ b/tutorials/1_Ontology.ipynb @@ -44,17 +44,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -68,7 +60,7 @@ "\n", "# You can load / save ontology objects with pickle\n", "\n", - "with open('input/meds/ontology.pkl', 'rb') as f:\n", + "with open('input/ontology.pkl', 'rb') as f:\n", " ontology = pickle.load(f)\n", "\n", "print(\"Loaded ontology\")" @@ -76,32 +68,25 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Generating train split: 200 examples [00:00, 34972.93 examples/s]\n", - "Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 3282.29 examples/s]\n" - ] - } - ], + "outputs": [], "source": [ "# Ontology datasets downloaded by Athena tend to be very large as they contain many codes, including several that are no longer used.\n", "# We therefore provide a function to prune ontologies to a particular dataset of interest.\n", "# This makes it much cheaper to store and use an ontology object, both in terms of disk space and RAM\n", "\n", "\n", - "dataset = datasets.Dataset.from_parquet(\"input/meds/data/*\")\n", + "import meds_reader\n", + "\n", + "database = meds_reader.PatientDatabase(\"input/meds_reader\")\n", "\n", - "ontology.prune_to_dataset(dataset)" + "ontology.prune_to_dataset(database)" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -111,8 +96,8 @@ "Description DRUGS FOR PEPTIC ULCER AND GASTRO-OESOPHAGEAL REFLUX DISEASE (GORD)\n", "Parents {'ATC/A02'}\n", "Children {'ATC/A02BX'}\n", - "All children {'RxNorm/2344', 'ATC/A02BX', 'RxNorm/4501', 'ATC/A02BX71', 'ATC/A02B', 'RxNorm/7815', 'RxNorm/7019', 'ATC/A02BX77', 'RxNorm/2353', 'RxNorm/8705', 'RxNorm/38574', 'RxNorm/2620', 'RxNorm/2018', 'RxNorm/8704', 'RxNorm/8730', 'RxNorm/6852', 'RxNorm/2017', 'RxNorm/2403'}\n", - "All parents {'ATC/A', 'ATC/A02', 'ATC/A02B'}\n" + "All children {'RxNorm/38574', 'ATC/A02BX71', 'RxNorm/2017', 'ATC/A02B', 'RxNorm/2018', 'RxNorm/6852', 'RxNorm/8705', 'RxNorm/8704', 'RxNorm/2344', 'RxNorm/2403', 'RxNorm/2353', 'RxNorm/8730', 'RxNorm/4501', 'ATC/A02BX77', 'RxNorm/7815', 'RxNorm/2620', 'RxNorm/7019', 'ATC/A02BX'}\n", + "All parents {'ATC/A02B', 'ATC/A', 'ATC/A02'}\n" ] } ], @@ -148,7 +133,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/tutorials/2_Labeling.ipynb b/tutorials/2_Labeling.ipynb index 08095f1..af93b76 100644 --- a/tutorials/2_Labeling.ipynb +++ b/tutorials/2_Labeling.ipynb @@ -22,7 +22,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 10, "id": "c6ac5c41-bc99-4731-ad82-7152274c67e1", "metadata": {}, "outputs": [], @@ -48,19 +48,10 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 11, "id": "8d9e2ccd-71c2-4ae0-897b-7ec022f9fdf4", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], + "outputs": [], "source": [ "# We can construct these labels manually\n", "\n", @@ -104,17 +95,10 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 12, "id": "9ac22dbe-ef34-468a-8ab3-673e58e5a920", "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 3040.98 examples/s]" - ] - }, { "name": "stdout", "output_type": "stream", @@ -130,35 +114,31 @@ "{'patient_id': 108, 'prediction_time': datetime.datetime(1991, 10, 20, 0, 0), 'boolean_value': True}\n", "{'patient_id': 109, 'prediction_time': datetime.datetime(1991, 6, 25, 0, 0), 'boolean_value': True}\n" ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] } ], "source": [ "from typing import List\n", "import femr.pat_utils\n", + "import meds_reader\n", + "import meds\n", + "import femr.labelers\n", "\n", "\n", "class IsMaleLabeler(femr.labelers.Labeler):\n", " # Dummy labeler to predict gender at birth\n", " \n", " def label(self, patient: meds_reader.Patient) -> List[meds.Label]:\n", - " is_male = any('Gender/M' == measurement['code'] for event in patient['events'] for measurement in event['measurements'])\n", + " is_male = any('Gender/M' == event.code for event in patient.events)\n", " return [{\n", - " 'patient_id': patient['patient_id'], \n", + " 'patient_id': patient.patient_id, \n", " 'prediction_time': femr.pat_utils.get_patient_birthdate(patient),\n", " 'boolean_value': is_male,\n", " }]\n", " \n", - "dataset = datasets.Dataset.from_parquet(\"input/meds/data/*\")\n", + "database = meds_reader.PatientDatabase(\"input/meds_reader\")\n", "\n", "labeler = IsMaleLabeler()\n", - "labeled_patients = labeler.apply(dataset)\n", + "labeled_patients = labeler.apply(database)\n", "\n", "for i in range(10):\n", " print(labeled_patients[100 + i])\n", @@ -167,7 +147,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 13, "id": "20bd7859", "metadata": {}, "outputs": [], @@ -197,7 +177,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/tutorials/3_Count Featurization And Modeling.ipynb b/tutorials/3_Count Featurization And Modeling.ipynb index e9dc61c..b0ad3ff 100644 --- a/tutorials/3_Count Featurization And Modeling.ipynb +++ b/tutorials/3_Count Featurization And Modeling.ipynb @@ -16,39 +16,23 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "892ab2d5-0c5a-43c9-a210-9201f775e4fb", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 36758.28 examples/s]\n", - "Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 3295.78 examples/s]\n", - "Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 2998.20 examples/s]\n" - ] - } - ], + "outputs": [], "source": [ "import pickle\n", "import femr.featurizers\n", "import femr.labelers\n", "import meds\n", "import pyarrow.csv\n", - "\n", - "import femr.index\n", + "import meds_reader\n", "\n", "# Load some labels\n", "labels = pyarrow.csv.read_csv('input/labels.csv').to_pylist()\n", "\n", "# Load our data\n", - "dataset = datasets.Dataset.from_parquet(\"input/meds/data/*\")\n", - "\n", - "# We need to create an index to allow us to find patients quickly\n", - "index = femr.index.PatientIndex(dataset)\n", + "database = meds_reader.PatientDatabase(\"input/meds_reader\")\n", " \n", "# Define our featurizer\n", "\n", @@ -58,15 +42,15 @@ "featurizer_age_count = femr.featurizers.FeaturizerList([age, count])\n", "\n", "# Preprocessing the featurizers, which includes processes such as normalizing age.\n", - "featurizer_age_count.preprocess_featurizers(dataset, index, labels)\n", + "featurizer_age_count.preprocess_featurizers(database, labels)\n", "\n", "# Actually do the featurization\n", - "features = featurizer_age_count.featurize(dataset, index, labels)" + "features = featurizer_age_count.featurize(database, labels)" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "112fe99d", "metadata": {}, "outputs": [ @@ -101,7 +85,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "cd0f43fd", "metadata": {}, "outputs": [ @@ -137,7 +121,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "id": "01acd922-668b-481b-8dbb-54ab6ae433af", "metadata": {}, "outputs": [], @@ -176,36 +160,19 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "id": "caae3126-1437-408e-b25f-04568e15c96a", "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "---- Logistic Regression ----\n", - "Train:\n", - "\tAUROC: 1.0\n", - "\tAPS: 1.0\n", - "\tAccuracy: 1.0\n", - "\tF1 Score: 1.0\n", - "Test:\n", - "\tAUROC: 1.0\n", - "\tAPS: 1.0\n", - "\tAccuracy: 1.0\n", - "\tF1 Score: 1.0\n", - "---- XGBoost ----\n", - "Train:\n", - "\tAUROC: 1.0\n", - "\tAPS: 1.0\n", - "\tAccuracy: 1.0\n", - "\tF1 Score: 1.0\n", - "Test:\n", - "\tAUROC: 1.0\n", - "\tAPS: 1.0\n", - "\tAccuracy: 1.0\n", - "\tF1 Score: 1.0\n" + "ename": "ModuleNotFoundError", + "evalue": "No module named 'xgboost'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mxgboost\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mxgb\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01msklearn\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mlinear_model\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01msklearn\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmetrics\u001b[39;00m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'xgboost'" ] } ], @@ -270,7 +237,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/tutorials/4_Train CLMBR.ipynb b/tutorials/4_Train CLMBR.ipynb index 8048e25..3b31733 100644 --- a/tutorials/4_Train CLMBR.ipynb +++ b/tutorials/4_Train CLMBR.ipynb @@ -378,7 +378,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/tutorials/6_Train MOTOR.ipynb b/tutorials/6_Train MOTOR.ipynb index 1628fa4..ebc64ef 100644 --- a/tutorials/6_Train MOTOR.ipynb +++ b/tutorials/6_Train MOTOR.ipynb @@ -27,9 +27,6 @@ "import shutil\n", "import os\n", "\n", - "# os.environ[\"HF_DATASETS_CACHE\"] = '/share/pi/nigam/ethanid/cache_dir'\n", - "\n", - "\n", "TARGET_DIR = 'trash/tutorial_6'\n", "\n", "if os.path.exists(TARGET_DIR):\n", @@ -43,62 +40,17 @@ "execution_count": 2, "id": "646f7590", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 1395.58 examples/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0, 1, 2, 4, 6, 7, 10, 11, 12, 13, 14, 15, 18, 20, 21, 23, 24, 26, 27, 28, 29, 30, 31, 33, 36, 37, 38, 40, 42, 44, 45, 47, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 62, 63, 64, 65, 66, 67, 69, 70, 73, 74, 75, 76, 77, 79, 80, 83, 85, 86, 88, 89, 90, 91, 93, 94, 95, 96, 97, 98, 100, 101, 102, 103, 104, 105, 107, 109, 110, 112, 114, 115, 116, 117, 118, 120, 121, 122, 123, 124, 125, 126, 127, 128, 133, 134, 135, 136, 137, 139, 141, 142, 143, 144, 149, 150, 151, 152, 153, 154, 156, 157, 158, 159, 160, 161, 162, 163, 165, 166, 168, 169, 171, 172, 173, 174, 178, 181, 182, 183, 184, 185, 186, 187, 189, 192, 193, 195, 196, 197, 198, 199]\n", - "[19, 22, 25, 39, 46, 71, 82, 84, 87, 92, 106, 108, 113, 131, 132, 138, 146, 147, 148, 155, 177, 179, 180, 188, 190, 191]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n", - "Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:00<00:00, 1316.29 examples/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "DatasetDict({\n", - " train: Dataset({\n", - " features: ['patient_id', 'events'],\n", - " num_rows: 144\n", - " })\n", - " test: Dataset({\n", - " features: ['patient_id', 'events'],\n", - " num_rows: 26\n", - " })\n", - "})\n" - ] - } - ], + "outputs": [], "source": [ - "\n", - "import femr.index\n", + "import meds_reader\n", "import femr.splits\n", "\n", "# First, we want to split our dataset into train, valid, and test\n", "# We do this by calling our split functionality twice\n", "\n", - "dataset = datasets.Dataset.from_parquet('input/meds/data/*')\n", - "\n", + "database = meds_reader.PatientDatabase('input/meds_reader')\n", "\n", - "index = femr.index.PatientIndex(dataset, num_proc=4)\n", - "main_split = femr.splits.generate_hash_split(index.get_patient_ids(), 97, frac_test=0.15)\n", + "main_split = femr.splits.generate_hash_split(list(database), 97, frac_test=0.15)\n", "\n", "os.mkdir(os.path.join(TARGET_DIR, 'motor_model'))\n", "# Note that we want to save this to the target directory since this is important information\n", @@ -107,13 +59,9 @@ "\n", "train_split = femr.splits.generate_hash_split(main_split.train_patient_ids, 87, frac_test=0.15)\n", "\n", - "print(train_split.train_patient_ids)\n", - "print(train_split.test_patient_ids)\n", - "\n", - "main_dataset = main_split.split_dataset(dataset, index)\n", - "train_dataset = train_split.split_dataset(main_dataset['train'], femr.index.PatientIndex(main_dataset['train'], num_proc=4))\n", - "\n", - "print(train_dataset)" + "main_database = database.filter(main_split.train_patient_ids)\n", + "train_database = main_database.filter(train_split.train_patient_ids)\n", + "val_database = main_database.filter(train_split.test_patient_ids)\n" ] }, { @@ -126,7 +74,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "Map (num_proc=4): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:00<00:00, 331.19 examples/s]\n" + "/home/ethanid/health_research/venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], @@ -137,12 +86,12 @@ "# First, we need to train a tokenizer\n", "# Note, we need to use a hierarchical tokenizer for MOTOR\n", "\n", - "with open('input/meds/ontology.pkl', 'rb') as f:\n", + "with open('input/ontology.pkl', 'rb') as f:\n", " ontology = pickle.load(f)\n", "\n", "# NOTE: A vocab size of 128 is probably too low for a real model. 128 was chosen to make this tutorial quick to run\n", "tokenizer = femr.models.tokenizer.train_tokenizer(\n", - " main_dataset['train'], vocab_size=128, is_hierarchical=True, num_proc=4, ontology=ontology)\n", + " train_database, vocab_size=128, is_hierarchical=True, ontology=ontology)\n", "\n", "# Save the tokenizer to the same directory as the model\n", "tokenizer.save_pretrained(os.path.join(TARGET_DIR, \"motor_model\"))" @@ -153,15 +102,7 @@ "execution_count": 4, "id": "69b60daa", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Map (num_proc=4): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:00<00:00, 249.31 examples/s]\n" - ] - } - ], + "outputs": [], "source": [ "\n", "import femr.models.tasks\n", @@ -169,15 +110,14 @@ "# Second, we need to prefit the MOTOR model. This is necessary because piecewise exponential models are unstable without an initial fit\n", "\n", "motor_task = femr.models.tasks.MOTORTask.fit_pretraining_task_info(\n", - " main_dataset['train'], tokenizer, num_tasks=64, num_bins=4, final_layer_size=32, num_proc=4)\n", - "\n", + " train_database, tokenizer, num_tasks=64, num_bins=4, final_layer_size=32)\n", "\n", "# It's recommended to save this with pickle to avoid recomputing since it's an expensive operation\n" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "id": "89611ba9-a242-4b87-9b8f-25670d838fc6", "metadata": {}, "outputs": [ @@ -186,35 +126,23 @@ "output_type": "stream", "text": [ "Convert a single patient\n", - "Convert batches\n" + "Convert batches\n", + "Creating batches 8\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "Map (num_proc=4): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 144/144 [00:00<00:00, 261.72 examples/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Creating batches 7\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Generating train split: 7 examples [00:00, 12.06 examples/s]\n", - "Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:00<00:00, 50.63 examples/s]\n" + "Generating train split: 8 examples [00:00, 26.46 examples/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ + "Convert batches to pytorch\n", + "Done\n", "Creating batches 1\n" ] }, @@ -223,22 +151,7 @@ "output_type": "stream", "text": [ "Setting num_proc from 4 back to 1 for the train split to disable multiprocessing as it only contains one shard.\n", - "Generating train split: 1 examples [00:00, 57.97 examples/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Convert batches to pytorch\n", - "Done\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" + "Generating train split: 1 examples [00:00, 172.15 examples/s]\n" ] } ], @@ -250,23 +163,30 @@ "\n", "processor = femr.models.processor.FEMRBatchProcessor(tokenizer, motor_task)\n", "\n", + "example_patient_id = list(train_database)[0]\n", + "example_patient = train_database[example_patient_id]\n", + "\n", "# We can do this one patient at a time\n", "print(\"Convert a single patient\")\n", - "example_batch = processor.collate([processor.convert_patient(train_dataset['train'][0], tensor_type='pt')])\n", + "example_batch = processor.collate([processor.convert_patient(example_patient, tensor_type='pt')])\n", "\n", "print(\"Convert batches\")\n", "# But generally we want to convert entire datasets\n", - "train_batches = processor.convert_dataset(train_dataset, tokens_per_batch=32, num_proc=4)\n", + "train_batches = processor.convert_dataset(train_database, tokens_per_batch=32, num_proc=4)\n", "\n", "print(\"Convert batches to pytorch\")\n", "# Convert our batches to pytorch tensors\n", "train_batches.set_format(\"pt\")\n", - "print(\"Done\")" + "print(\"Done\")\n", + "\n", + "val_batches = processor.convert_dataset(val_database, tokens_per_batch=32, num_proc=4)\n", + "# Convert our batches to pytorch tensors\n", + "val_batches.set_format(\"pt\")" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "id": "f654a46c-5aa7-465c-b6c5-73d8ba26ed67", "metadata": {}, "outputs": [ @@ -274,10 +194,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/torch/cuda/__init__.py:628: UserWarning: Can't initialize NVML\n", - " warnings.warn(\"Can't initialize NVML\")\n", - "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n", - "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\n", + "/home/ethanid/health_research/venv/lib/python3.12/site-packages/transformers/training_args.py:1494: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n", " warnings.warn(\n", "Could not estimate the number of tokens of the input, floating-point operations will not be computed\n" ] @@ -288,8 +205,8 @@ "\n", "
\n", " \n", - " \n", - " [700/700 00:10, Epoch 100/100]\n", + " \n", + " [800/800 00:04, Epoch 100/100]\n", "
\n", " \n", " \n", @@ -302,178 +219,203 @@ " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
200.8554000.5069420.8301000.523818
400.8711000.5069980.8104000.523785
600.8269000.5070560.8098000.523750
800.8567000.5071160.8292000.523720
1000.8562000.5071810.8270000.523700
1200.8298000.5072510.8106000.523686
1400.8597000.5073210.7980000.523673
1600.8516000.5073930.8383000.523665
1800.8525000.5074670.8271000.523665
2000.8684000.5075400.8080000.523664
2200.8508000.5076170.8185000.523664
2400.8359000.5076960.8154000.523672
2600.8500000.5077680.8390000.523685
2800.8315000.5078410.7938000.523696
3000.8607000.5079150.8156000.523708
3200.8460000.5079880.8163000.523721
3400.8268000.5080550.8248000.523741
3600.8306000.5081230.8061000.523758
3800.8847000.5081880.8365000.523773
4000.8239000.5082480.7936000.523792
4200.8562000.5083090.7827000.523814
4400.8484000.5083600.8466000.523835
4600.8559000.5084130.8131000.523853
4800.8495000.5084580.8155000.523872
5000.8312000.5085020.8460000.523890
5200.8483000.5085420.7819000.523907
5400.8587000.5085770.8029000.523925
5600.8292000.5086080.8245000.523942
5800.8585000.5086360.8035000.523959
6000.8258000.5086590.8234000.523972
6200.8782000.5086770.8040000.523985
6400.8395000.5086920.8226000.523996
6600.8130000.5087030.7942000.524006
6800.8548000.5087090.8320000.524016
7000.8473000.5087110.8303000.524024
7200.7957000.524031
7400.8029000.524037
7600.8230000.524040
7800.7907000.524042
8000.8350000.524043

" @@ -529,8 +471,8 @@ "trainer = transformers.Trainer(\n", " model=model,\n", " data_collator=processor.collate,\n", - " train_dataset=train_batches['train'],\n", - " eval_dataset=train_batches['test'],\n", + " train_dataset=train_batches,\n", + " eval_dataset=val_batches,\n", " args=trainer_config,\n", ")\n", "\n", @@ -556,7 +498,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/tutorials/7_MOTOR Featurization And Modeling.ipynb b/tutorials/7_MOTOR Featurization And Modeling.ipynb index 4a56512..64100c0 100644 --- a/tutorials/7_MOTOR Featurization And Modeling.ipynb +++ b/tutorials/7_MOTOR Featurization And Modeling.ipynb @@ -254,7 +254,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/tutorials/input/meds/data/patients.parquet b/tutorials/input/meds/data/patients.parquet deleted file mode 100644 index 7036eaf..0000000 Binary files a/tutorials/input/meds/data/patients.parquet and /dev/null differ diff --git a/tutorials/input/meds/ontology.pkl b/tutorials/input/meds/ontology.pkl deleted file mode 100644 index ed89da9..0000000 Binary files a/tutorials/input/meds/ontology.pkl and /dev/null differ diff --git a/tutorials/input/meds_reader/code/data b/tutorials/input/meds_reader/code/data new file mode 100644 index 0000000..1a80704 Binary files /dev/null and b/tutorials/input/meds_reader/code/data differ diff --git a/tutorials/input/meds_reader/code/dictionary b/tutorials/input/meds_reader/code/dictionary new file mode 100644 index 0000000..4b1e01c Binary files /dev/null and b/tutorials/input/meds_reader/code/dictionary differ diff --git a/tutorials/input/meds_reader/code/zdict b/tutorials/input/meds_reader/code/zdict new file mode 100644 index 0000000..11a0a4b Binary files /dev/null and b/tutorials/input/meds_reader/code/zdict differ diff --git a/tutorials/input/meds_reader/datetime_value/data b/tutorials/input/meds_reader/datetime_value/data new file mode 100644 index 0000000..ec689d3 Binary files /dev/null and b/tutorials/input/meds_reader/datetime_value/data differ diff --git a/tutorials/input/meds_reader/datetime_value/zdict b/tutorials/input/meds_reader/datetime_value/zdict new file mode 100644 index 0000000..7718a83 Binary files /dev/null and b/tutorials/input/meds_reader/datetime_value/zdict differ diff --git a/tutorials/input/meds_reader/length b/tutorials/input/meds_reader/length new file mode 100644 index 0000000..94c7400 Binary files /dev/null and b/tutorials/input/meds_reader/length differ diff --git a/tutorials/input/meds/metadata.json b/tutorials/input/meds_reader/metadata.json similarity index 100% rename from tutorials/input/meds/metadata.json rename to tutorials/input/meds_reader/metadata.json diff --git a/tutorials/input/meds_reader/numeric_value/data b/tutorials/input/meds_reader/numeric_value/data new file mode 100644 index 0000000..ec689d3 Binary files /dev/null and b/tutorials/input/meds_reader/numeric_value/data differ diff --git a/tutorials/input/meds_reader/numeric_value/zdict b/tutorials/input/meds_reader/numeric_value/zdict new file mode 100644 index 0000000..7718a83 Binary files /dev/null and b/tutorials/input/meds_reader/numeric_value/zdict differ diff --git a/tutorials/input/meds_reader/patient_id b/tutorials/input/meds_reader/patient_id new file mode 100644 index 0000000..54def8b Binary files /dev/null and b/tutorials/input/meds_reader/patient_id differ diff --git a/tutorials/input/meds_reader/properties b/tutorials/input/meds_reader/properties new file mode 100644 index 0000000..5cb0a0e Binary files /dev/null and b/tutorials/input/meds_reader/properties differ diff --git a/tutorials/input/meds_reader/text_value/data b/tutorials/input/meds_reader/text_value/data new file mode 100644 index 0000000..ea3a50e Binary files /dev/null and b/tutorials/input/meds_reader/text_value/data differ diff --git a/tutorials/input/meds_reader/text_value/dictionary b/tutorials/input/meds_reader/text_value/dictionary new file mode 100644 index 0000000..e69de29 diff --git a/tutorials/input/meds_reader/text_value/zdict b/tutorials/input/meds_reader/text_value/zdict new file mode 100644 index 0000000..e8abfcb Binary files /dev/null and b/tutorials/input/meds_reader/text_value/zdict differ diff --git a/tutorials/input/meds_reader/time b/tutorials/input/meds_reader/time new file mode 100644 index 0000000..695edb7 Binary files /dev/null and b/tutorials/input/meds_reader/time differ diff --git a/tutorials/input/ontology.pkl b/tutorials/input/ontology.pkl new file mode 100644 index 0000000..49aa11c Binary files /dev/null and b/tutorials/input/ontology.pkl differ diff --git a/tutorials/synthetic_data_generation/generate_patients.py b/tutorials/synthetic_data_generation/generate_patients.py index b30bf36..46bb2e4 100644 --- a/tutorials/synthetic_data_generation/generate_patients.py +++ b/tutorials/synthetic_data_generation/generate_patients.py @@ -7,87 +7,91 @@ import jsonschema import meds +import meds_reader import pyarrow import pyarrow.parquet import femr.ontology import femr.transforms -parser = argparse.ArgumentParser(prog="generate_patients", description="Create synthetic data") -parser.add_argument("athena", type=str) -parser.add_argument("destination", type=str) -args = parser.parse_args() - -random.seed(4533) - - -def get_random_patient(patient_id): - epoch = datetime.datetime(1990, 1, 1) - birth = epoch + datetime.timedelta(days=random.randint(100, 1000)) - current_date = birth - - gender = "Gender/" + random.choice(["F", "M"]) - race = "Race/" + random.choice(["White", "Non-White"]) - - patient = { - "patient_id": patient_id, - "events": [ - { - "time": birth, - "measurements": [ - {"code": meds.birth_code}, - {"code": gender}, - {"code": race}, - ], - }, - ], - } - code_cats = ["ICD9CM", "RxNorm"] - for code in range(random.randint(1, 10 + (20 if gender == "Gender/F" else 0))): - code_cat = random.choice(code_cats) - if code_cat == "RxNorm": - code = str(random.randint(0, 10000)) - else: - code = str(random.randint(0, 10000)) - if len(code) > 3: - code = code[:3] + "." + code[3:] - current_date = current_date + datetime.timedelta(days=random.randint(1, 100)) - code = code_cat + "/" + code - patient.events.append({"time": current_date, "measurements": [{"code": code}]}) +if __name__ == "__main__": + + parser = argparse.ArgumentParser(prog="generate_patients", description="Create synthetic data") + parser.add_argument("athena", type=str) + parser.add_argument("destination", type=str) + args = parser.parse_args() + + random.seed(4533) + + def get_random_patient(patient_id): + epoch = datetime.datetime(1990, 1, 1) + birth = epoch + datetime.timedelta(days=random.randint(100, 1000)) + current_date = birth + + gender = "Gender/" + random.choice(["F", "M"]) + race = "Race/" + random.choice(["White", "Non-White"]) - return patient + patient = { + "patient_id": patient_id, + "events": [], + } + birth_codes = [meds.birth_code, gender, race] -patients = [] -for i in range(200): - patients.append(get_random_patient(i)) + for birth_code in birth_codes: + patient["events"].append({"time": birth, "code": birth_code}) -patient_schema = meds_reader.Patient_schema() + code_cats = ["ICD9CM", "RxNorm"] + for code in range(random.randint(1, 10 + (20 if gender == "Gender/F" else 0))): + code_cat = random.choice(code_cats) + if code_cat == "RxNorm": + code = str(random.randint(0, 10000)) + else: + code = str(random.randint(0, 10000)) + if len(code) > 3: + code = code[:3] + "." + code[3:] + current_date = current_date + datetime.timedelta(days=random.randint(1, 100)) + code = code_cat + "/" + code + patient["events"].append({"time": current_date, "code": code}) -patient_table = pyarrow.Table.from_pylist(patients, patient_schema) + return patient -os.makedirs(os.path.join(args.destination, "data"), exist_ok=True) + patients = [] + for i in range(200): + patients.append(get_random_patient(i)) -pyarrow.parquet.write_table(patient_table, os.path.join(args.destination, "data", "patients.parquet")) + patient_schema = meds.schema.patient_schema() + + patient_table = pyarrow.Table.from_pylist(patients, patient_schema) + + os.makedirs(os.path.join(args.destination, "data"), exist_ok=True) + + pyarrow.parquet.write_table(patient_table, os.path.join(args.destination, "data", "patients.parquet")) + + metadata = { + "dataset_name": "femr synthetic datata", + "dataset_version": "1", + "etl_name": "synthetic data", + "etl_version": "1", + "code_metadata": {}, + } -metadata = { - "dataset_name": "femr synthetic datata", - "dataset_version": "1", - "etl_name": "synthetic data", - "etl_version": "1", - "code_metadata": {}, -} + jsonschema.validate(instance=metadata, schema=meds.dataset_metadata) -jsonschema.validate(instance=metadata, schema=meds.dataset_metadata) + with open(os.path.join(args.destination, "metadata.json"), "w") as f: + json.dump(metadata, f) -with open(os.path.join(args.destination, "metadata.json"), "w") as f: - json.dump(metadata, f) + print("Converting") + os.system(f"convert_to_meds_reader {args.destination} {args.destination}_meds") -dataset = datasets.Dataset.from_parquet(os.path.join(args.destination, "data", "*")) + print("Opening database") -ontology = femr.ontology.Ontology(args.athena) + with meds_reader.PatientDatabase(args.destination + "_meds", num_threads=6) as database: + print("Creating ontology") + ontology = femr.ontology.Ontology(args.athena) -ontology.prune_to_dataset(dataset, remove_ontologies=("SNOMED")) + print("Pruning ontology") + ontology.prune_to_dataset(database, remove_ontologies=("SNOMED")) -with open(os.path.join(args.destination, "ontology.pkl"), "wb") as f: - pickle.dump(ontology, f) + with open(os.path.join(args.destination, "ontology.pkl"), "wb") as f: + pickle.dump(ontology, f)