From 5c568fee046c92bfa73bc9e6d0ea13b166d6adea Mon Sep 17 00:00:00 2001 From: Scott Fleming Date: Sat, 27 Apr 2024 14:30:03 -0400 Subject: [PATCH] Add generic wrapper for dataset splits with user-provided function --- src/femr/models/tokenizer.py | 3 ++- src/femr/splits.py | 29 ++++++++++++++++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/femr/models/tokenizer.py b/src/femr/models/tokenizer.py index 26922dc..9853668 100644 --- a/src/femr/models/tokenizer.py +++ b/src/femr/models/tokenizer.py @@ -24,6 +24,7 @@ def train_tokenizer( num_numeric: int = 1000, ontology: Optional[femr.ontology.Ontology] = None, num_proc: int = 1, + batch_size: int = 1_000, ) -> FEMRTokenizer: """Train a FEMR tokenizer from the given dataset""" statistics = femr.hf_utils.aggregate_over_dataset( @@ -33,7 +34,7 @@ def train_tokenizer( ), agg_statistics, num_proc=num_proc, - batch_size=1_000, + batch_size=batch_size, ) return FEMRTokenizer( convert_statistics_to_msgpack(statistics, vocab_size, is_hierarchical, num_numeric, ontology), ontology diff --git a/src/femr/splits.py b/src/femr/splits.py index 897e2f3..2938d6c 100644 --- a/src/femr/splits.py +++ b/src/femr/splits.py @@ -4,7 +4,7 @@ import dataclasses import hashlib import struct -from typing import List +from typing import Callable, List import datasets @@ -71,3 +71,30 @@ def generate_hash_split(patient_ids: List[int], seed: int, frac_test: float = 0. train_patient_ids.append(patient_id) return PatientSplit(train_patient_ids=train_patient_ids, test_patient_ids=test_patient_ids) + + +def generate_split(patient_ids: List[int], is_test_set_fn: Callable[[int], bool]) -> PatientSplit: + """Generates a patient split based on a user-defined function. + + This function categorizes each patient ID as either 'test' or 'train' based on + the user-defined function's return value. + + Args: + patient_ids (List[int]): A list of patient IDs. + is_test_set_fn (Callable[[int], bool]): A function that takes a patient ID + and returns True if it belongs to the test set, otherwise False. + + Returns: + PatientSplit: A dataclass instance containing lists of train and test patient IDs. + + """ + train_patient_ids: List[int] = [] + test_patient_ids: List[int] = [] + + for patient_id in patient_ids: + if is_test_set_fn(patient_id): + test_patient_ids.append(patient_id) + else: + train_patient_ids.append(patient_id) + + return PatientSplit(train_patient_ids=train_patient_ids, test_patient_ids=test_patient_ids)