diff --git a/docs/source/notebooks/tutorial.ipynb b/docs/source/notebooks/tutorial.ipynb deleted file mode 100644 index 67feef4..0000000 --- a/docs/source/notebooks/tutorial.ipynb +++ /dev/null @@ -1,261 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Code Example with Synthetic Data" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Set-up" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Imports\n", - "\n", - "First, let's import ACES! Three modules - `config`, `predicates`, and `query` - are required to execute an end-to-end cohort extraction. `omegaconf` is also required to express our data config parameters in order to load our `EventStream` dataset. Other imports are only needed visualization!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import json\n", - "from pathlib import Path\n", - "\n", - "import yaml\n", - "from bigtree import print_tree\n", - "from EventStream.data.dataset_polars import Dataset\n", - "from IPython.display import display\n", - "from omegaconf import DictConfig\n", - "\n", - "from aces import config, predicates, query" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Directories\n", - "\n", - "Next, let's specify our paths and directories. In this tutorial, we will extract a cohort for a typical in-hospital mortality prediction task from the ESGPT synthetic sample dataset. The task configuration file and sample data are both shipped with the repository in [sample_configs/](https://github.com/justin13601/ACES/tree/main/sample_configs) and [sample_data/](https://github.com/justin13601/ACES/tree/main/sample_data) folders in the project root, respectively." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "config_path = \"../../../sample_configs/inhospital_mortality.yaml\"\n", - "data_path = \"../../../sample_data/esgpt_sample\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Configuration File" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The task configuration file is the core configuration language that ACES uses to extract cohorts. Details about this configuration language is available in [Configuration Language](https://eventstreamaces.readthedocs.io/en/latest/configuration.html). In brief, the configuration file contains `predicates`, `trigger`, and `windows` sections. \n", - "\n", - "The `predicates` section is used to define dataset-specific concepts that are needed for the task. In our case of binary mortality prediction, we are interested in extracting a cohort of patients that have been admitted into the hospital and who were subsequently discharged or died. As such `admission`, `discharge`, `death`, and `discharge_or_death` would be handy predicates.\n", - "\n", - "We'd also like to make a prediction of mortality for each admission. Hence, a reasonable `trigger` event would be an `admission` predicate.\n", - "\n", - "Suppose in our task, we'd like to set a constraint that the admission must have been more than 48 hours long. Additionally, for our prediction inputs, we'd like to use all information in the patient record up until 24 hours after admission, which must contain at least 5 event records (as we'd want to ensure there is sufficient input data). These clauses are captured in the `windows` section where each window is defined relative to another." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with open(config_path, \"r\") as stream:\n", - " data_loaded = yaml.safe_load(stream)\n", - " print(json.dumps(data_loaded, indent=4))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can see that the `input` window begins at `null` (start of the patient record) and ends 24 hours after `trigger` (`admission`). A `gap` window is defined for 24 hours after the end of the `input` window, constraining the admission to be longer than 48 hours at minimum. Finally, a `target` window is specified from the end of the `gap` window to either the next `discharge` or `death` event (ie., `discharge_or_death`). This would allow us to extract a binary label for each patient in our cohort to be used in the prediction task (ie., field `label` in the `target` window, which will extract `0`: discharged, `1`: died). Additionally, an `index_timestamp` field is set as the `end` of the `input` window to denote when a prediction is made (ie., at the end of the `input` window when all input data is fed into the model), and can be used to index extraction results." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We now load our configuration file by passing its path (`str`) into `config.TaskExtractorConfig.load()`. This parses the configuration file for each of the three key sections indicated above and prepares ACES for extraction based on our defined constraints (inclusion/exclusion criteria for each window)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "cfg = config.TaskExtractorConfig.load(config_path=config_path)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Task Tree\n", - "\n", - "With the configuration file loaded and parsed, we can access a visualization of a tree structure that is representative of our task of interest. As seen, the tree nodes are `start` and `end` time points of the windows that were defined in the configuration file, and the tree edges express the relationships between these windows. ACES will traverse this tree and recursively compute aggregated predicate counts for each subtree. This would allow us to filter our dataset to valid realizations of this task tree, which would make up our task cohort." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tree = cfg.window_tree\n", - "print_tree(tree)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Data" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This tutorial uses synthetic data of 100 patients stored in the ESGPT standard. For more information about this data, please refer to the [ESGPT Documentation](https://eventstreamml.readthedocs.io/en/latest/_collections/local_tutorial_notebook.html).\n", - "\n", - "We first load the dataset by passing the path (`Path` object) to the directory containing the ESGPT dataset into `EventStream`. This configures a `ESD` object, allowing us to access the relevant dataframes. While ESGPT contains a wealth of other functionality, we are particularly interested in the loading of `events_df` and the `dynamic_measurements_df`.\n", - "\n", - "`events_df` consists of unique (`subject_id`, `timestamp`) pairs mapped to an unique `event_id`. For each `event_id`, the `event_type` column contains `&` delimited sequences, such as `ADMISSION`, `DEATH`, `LAB`, and `VITALS`, etc., specifying the type of event(s) that occurred that `event_id`.\n", - "\n", - "`dynamic_measurements_df` consists of various values in the electronic health record, and can be linked to the `events_df` table via the `event_id`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ESD = Dataset.load(Path(data_path))\n", - "events_df = ESD.events_df\n", - "dynamic_measurements_df = ESD.dynamic_measurements_df\n", - "\n", - "display(events_df)\n", - "display(dynamic_measurements_df)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Predicate Columns\n", - "\n", - "The next step in our cohort extraction is the generation of predicate columns. Our defined dataset-agnostic windows (ie., complex task logic) are linked to dataset-specific predicates (ie., dataset observations and concepts), which facilitates the sharing of tasks across datasets. As such, the predicates dataframe is the foundational unit on which ACES acts upon.\n", - "\n", - "A predicate column is simply a column containing numerical counts (often just `0`'s and `1`'s), representing the number of times a given predicate (concept) occurs at a given timestamp for a given patient.\n", - "\n", - "In the case of ESGPT, ACES support the automatic generation of these predicate columns from the configuration file. However, some fields need to be provided via a `DictConfig` object. These include the path to the directory of the ESGPT dataset (`str`) and the data standard (which is `esgpt` in this case).\n", - "\n", - "Given this data configuration, we then call `predicates.get_predicates_df()` to generate the relevant predicate columns for our task. Due to the nature of the specified predicates, the resulting dataframe simply contains the unique (`subject_id`, `timestamp`) pairs and binary columns for each predicate. An additional predicate `_ANY_EVENT` is also generated - this will be used to enforce our constraint of the number of events in the `input` window. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "data_config = DictConfig({\"path\": data_path, \"standard\": \"esgpt\"})\n", - "\n", - "predicates_df = predicates.get_predicates_df(cfg=cfg, data_config=data_config)\n", - "display(predicates_df)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## End-to-End Query" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Finally, with our task configuration object and the computed predicates dataframe, we can call `query.query()` to execute the extraction of our cohort.\n", - "\n", - "Each row of the resulting dataframe is a valid realization of our task tree. Hence, each instance can be included in our cohort used for the prediction of in-hospital mortality as defined in our task configuration file. The output contains:\n", - "\n", - "- `subject_id`: subject IDs of our cohort (since we'd like to treat individual admissions as separate samples, there will be duplicate subject IDs)\n", - "- `index_timestamp`: timestamp of when a prediction is made, which coincides with the `end` timestamp of the `input` window (as specified in our task configuration)\n", - "- `label`: binary label of mortality, which is derived from the `death` predicate of the `target` window (as specified in our task configuration)\n", - "- `trigger`: timestamp of the `trigger` event, which is the `admission` predicate (as specified in our task configuration)\n", - "\n", - "Additionally, it also includes a column for each node of our task tree in a pre-order traversal order. Each column contains a `pl.Struct` object containing the name of the node, the start and end times of the window it represents, and the counts of all defined predicates in that window." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "df_result = query.query(cfg=cfg, predicates_df=predicates_df)\n", - "display(df_result)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "... and that's a wrap! We have used ACES to perform an end-to-end extraction on a ESGPT dataset for a cohort that can be used to predict in-hospital mortality. Similar pipelines can be made for other tasks, as well as using the MEDS data standard. You may also pre-compute predicate columns and use the `direct` flag when loading in `.csv` or `.parquet` data files. More information about this is available in [Predicates DataFrame](https://eventstreamaces.readthedocs.io/en/latest/notebooks/predicates.html).\n", - "\n", - "As always, please don't hesitate to reach out should you have any questions about ACES!\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "esgpt", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/sample_data/meds_sample/held_out/0.parquet b/sample_data/meds_sample/held_out/0.parquet index 5c71d98..e7a31e5 100644 Binary files a/sample_data/meds_sample/held_out/0.parquet and b/sample_data/meds_sample/held_out/0.parquet differ diff --git a/sample_data/meds_sample/sample_shard.parquet b/sample_data/meds_sample/sample_shard.parquet index 5c71d98..b69f923 100644 Binary files a/sample_data/meds_sample/sample_shard.parquet and b/sample_data/meds_sample/sample_shard.parquet differ diff --git a/sample_data/meds_sample/train/0.parquet b/sample_data/meds_sample/train/0.parquet index 2f90ac3..d3f6f2b 100644 Binary files a/sample_data/meds_sample/train/0.parquet and b/sample_data/meds_sample/train/0.parquet differ diff --git a/sample_data/meds_sample/train/1.parquet b/sample_data/meds_sample/train/1.parquet index 98e4ee7..dd79221 100644 Binary files a/sample_data/meds_sample/train/1.parquet and b/sample_data/meds_sample/train/1.parquet differ diff --git a/sample_data/meds_sample/tuning/0.parquet b/sample_data/meds_sample/tuning/0.parquet new file mode 100644 index 0000000..ad3f7ca Binary files /dev/null and b/sample_data/meds_sample/tuning/0.parquet differ