Skip to content

Commit

Permalink
Add generic wrapper for dataset splits with user-provided function
Browse files Browse the repository at this point in the history
  • Loading branch information
scottfleming committed Apr 27, 2024
1 parent 0df7121 commit 5c568fe
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/femr/models/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def train_tokenizer(
num_numeric: int = 1000,
ontology: Optional[femr.ontology.Ontology] = None,
num_proc: int = 1,
batch_size: int = 1_000,
) -> FEMRTokenizer:
"""Train a FEMR tokenizer from the given dataset"""
statistics = femr.hf_utils.aggregate_over_dataset(
Expand All @@ -33,7 +34,7 @@ def train_tokenizer(
),
agg_statistics,
num_proc=num_proc,
batch_size=1_000,
batch_size=batch_size,
)
return FEMRTokenizer(
convert_statistics_to_msgpack(statistics, vocab_size, is_hierarchical, num_numeric, ontology), ontology
Expand Down
29 changes: 28 additions & 1 deletion src/femr/splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import dataclasses
import hashlib
import struct
from typing import List
from typing import Callable, List

import datasets

Expand Down Expand Up @@ -71,3 +71,30 @@ def generate_hash_split(patient_ids: List[int], seed: int, frac_test: float = 0.
train_patient_ids.append(patient_id)

return PatientSplit(train_patient_ids=train_patient_ids, test_patient_ids=test_patient_ids)


def generate_split(patient_ids: List[int], is_test_set_fn: Callable[[int], bool]) -> PatientSplit:
"""Generates a patient split based on a user-defined function.
This function categorizes each patient ID as either 'test' or 'train' based on
the user-defined function's return value.
Args:
patient_ids (List[int]): A list of patient IDs.
is_test_set_fn (Callable[[int], bool]): A function that takes a patient ID
and returns True if it belongs to the test set, otherwise False.
Returns:
PatientSplit: A dataclass instance containing lists of train and test patient IDs.
"""
train_patient_ids: List[int] = []
test_patient_ids: List[int] = []

for patient_id in patient_ids:
if is_test_set_fn(patient_id):
test_patient_ids.append(patient_id)
else:
train_patient_ids.append(patient_id)

return PatientSplit(train_patient_ids=train_patient_ids, test_patient_ids=test_patient_ids)

0 comments on commit 5c568fe

Please sign in to comment.