Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Splits for HSCDataSet #105

Merged
merged 3 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/fibad/data_sets/example_cifar_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,17 @@
FIBAD config with a transformation that works well for example code.
"""

def __init__(self, config):
def __init__(self, config, split: str):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
super().__init__(root=config["general"]["data_dir"], train=True, download=True, transform=transform)

if split not in ["train", "test"]:
RuntimeError("CIFAR10 dataset only supports 'train' and 'test' splits.")

Check warning on line 20 in src/fibad/data_sets/example_cifar_data_set.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/example_cifar_data_set.py#L19-L20

Added lines #L19 - L20 were not covered by tests

train = split == "train"

Check warning on line 22 in src/fibad/data_sets/example_cifar_data_set.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/example_cifar_data_set.py#L22

Added line #L22 was not covered by tests

super().__init__(root=config["general"]["data_dir"], train=train, download=True, transform=transform)

Check warning on line 24 in src/fibad/data_sets/example_cifar_data_set.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/example_cifar_data_set.py#L24

Added line #L24 was not covered by tests

def shape(self):
return (3, 32, 32)
227 changes: 225 additions & 2 deletions src/fibad/data_sets/hsc_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import logging
import re
from copy import copy, deepcopy
from pathlib import Path
from typing import Optional, Union
from typing import Literal, Optional, Union

import numpy as np
import torch
Expand All @@ -18,6 +19,228 @@

@fibad_data_set
class HSCDataSet(Dataset):
"""Interface object to allow simple access to splits on a corpus of HSC data files

f/s operations and management are handled in HSCDatSetContainer
splits on the dataset and their generation are handled by HSCDataSetSplit

"""

def __init__(self, config, split: Union[str, None]):
# initialize the filesystem references
self.container = HSCDataSetContainer(config)

# initalize our splits from configuration
self._create_splits(config)

# Set the split to what was requested.
self._set_split(split)

def _create_splits(self, config):
seed = config["prepare"]["seed"] if config["prepare"]["seed"] else None

# Init the splits based on config values
train_size = config["prepare"]["train_size"] if config["prepare"]["train_size"] else None
test_size = config["prepare"]["test_size"] if config["prepare"]["test_size"] else None
validate_size = config["prepare"]["validate_size"] if config["prepare"]["validate_size"] else None

# Convert all values specified as counts into ratios of the underlying container
if isinstance(train_size, int):
train_size = train_size / len(self.container)
if isinstance(test_size, int):
test_size = test_size / len(self.container)
if isinstance(validate_size, int):
validate_size = validate_size / len(self.container)

# Fill in any values not provided
if test_size is None:
if train_size is None:
train_size = 0.25
test_size = 1.0 - train_size
elif train_size is None:
train_size = 1.0 - test_size
elif validate_size is None:
validate_size = 1.0 - (train_size + test_size)

# Generate splits
self.splits = {}
self.splits["test"] = HSCDataSetSplit(self.container, test_size, seed=seed)
rest = copy(self.splits["test"]).complement()
self.splits["train"] = HSCDataSetSplit(rest, train_size, seed=seed)

# Validate is only generated if it is provided, or if both test and train are provided.
if validate_size:
rest = rest.logical_and(copy(self.splits["train"]).complement())
self.splits["validate"] = HSCDataSetSplit(rest, validate_size, seed=seed)

logger.info("HSC Data Set Splits loaded are:")
for key, value in self.splits.items():
logger.info(f"{key} split contains {len(value)} items")

def _set_split(self, split: Union[str, None] = None):
self.current_split = self.splits.get(split, self.container)

if split is not None and self.current_split == self.container:
splits = list(self.splits.keys())
raise RuntimeError(f"Split {split} does not exist. valid split names are {splits}")

def shape(self) -> tuple[int, int, int]:
return self.container.shape()

def __getitem__(self, idx: int) -> torch.Tensor:
return self.current_split[idx]

Check warning on line 91 in src/fibad/data_sets/hsc_data_set.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/hsc_data_set.py#L91

Added line #L91 was not covered by tests

def __len__(self) -> int:
return len(self.current_split)


class HSCDataSetSplit(Dataset):
Copy link
Collaborator Author

@mtauraso mtauraso Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class had more functionality in the first draft of this, but it now looks a lot like a glorified numpy masked array.

I ended up writing it this way to ensure there's only ever one HSCDatasetContainer I expect that object to have bookeeping for ~10M files at runtime, and I essentially never want a copy made of that object.

If anyone has suggestions to preserve object lifetimes while rewriting this to be shorter with numpy.ma I'm interested.

def __init__(
self,
data: Union["HSCDataSetContainer", "HSCDataSetSplit"],
ratio: float,
seed: Union[int, None] = None,
):
"""
This class represents a split of an HSCDataset.

It should only get created by passing in an existing HSCDataSetContainer (or HSCDataSetSplit)
and splitting it according to the train_test_split like parameters. When you split a split,
all splits end up referring to the same uderlying HSCDataSetContainer object.

Each encodes a subset of the underlying HSCDataSetContainer by keeping a list of boolean values.

Parameters
----------
data : Union[HSCDataSetContainer, "HSCDataSetSplit"]
The underlying HSCDataSet or split to operate on. Creating a split from an existing split ends up
referring to a subset of the data selected by the original split, but the new object only refers
to an underlying HSCDataSet object, not any other split object.
ratio : float
Ratio of the underlying data source to use for this split. This is expressed as a fraction of the
HSCDataSetContainer even when an HSCDataSetSplit is passed.
seed : Union[int, None] , optional
The seed value to provide to the random number generator, or None if you would like to use system
entropy to generate a seed. None by default.
shuffle : bool, optional
Whether to shuffle the order of the underlying data when accessing the split object, by default
True
"""
self.rng = np.random.default_rng(seed)

if ratio > 1.0 or ratio < 0.0:
msg = f"Split provided for HSCDatSetSplit as a ratio is {ratio}, which is not between 0.0 and 1.0"
raise RuntimeError(msg)

self.data = data.data if isinstance(data, HSCDataSetSplit) else data

# The length of this split once constructed
length = int(np.round(len(self.data) * ratio))

if isinstance(data, HSCDataSetSplit):
# If we're splitting a split we need to modify the existing mask of the prior split
# Namely we switch some true values to false to more of the underlying dataset
split = data
self.mask = copy(split.mask)
remove_count = len(split) - length
self._flip_mask_values(remove_count, "true_to_false")

else:
# If we're splitting a normal hscdataset we generate a single mask with the appropriate values
self.mask = np.zeros(len(data), dtype=bool)
self._flip_mask_values(length, "false_to_true")

self.indexes = np.nonzero(self.mask)[0]

def _flip_mask_values(self, num: int, mode: Literal["false_to_true", "true_to_false"]):
"""
Private helper to flips some values of self.mask. The direction to flip is controlled by the
mode parameter. Either the function randomly finds `num` true values to flip to false, or `num` false
values to flip to true.

This function is used during object construction to create a set number of randomly selected true
values in the mask.

Parameters
----------
num : int
The number of values to flip
mode : Literal[&quot;false_to_true&quot;, &quot;true_to_false&quot;]
The mode to work in, either flipping True values false or the reverse

Raises
------
RuntimeError
It is a RuntimeError to try to flip more values than the mask has of that type.

"""
mask_tmp = np.logical_not(self.mask) if mode == "false_to_true" else self.mask
target_val = mode == "false_to_true"
target_indexes = np.nonzero(mask_tmp)[0]

if num > len(target_indexes):
msg_mode = mode.replace("_", " ")
num_tgt = len(target_indexes)
msg = f"Cannot flip {num} values {msg_mode} when only {num_tgt} {target_val} values exist in mask"
raise RuntimeError(msg)

Check warning on line 185 in src/fibad/data_sets/hsc_data_set.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/hsc_data_set.py#L182-L185

Added lines #L182 - L185 were not covered by tests

change_indexes = self.rng.permutation(target_indexes)[:num]
for i in change_indexes:
self.mask[i] = target_val

def complement(self) -> "HSCDataSetSplit":
"""Mutates the split by inverting it with respect to the underlying dataset.

e.g. if you have an underlying dataset with 5 members, and indexes 1,2, and 4 are part of this split
The compliment would be a dataset selecting indexes 0 and 3.
"""
self.mask = np.logical_not(self.mask)
self.indexes = np.nonzero(self.mask)[0]
return self

def logical_and(self, obj: "HSCDataSetSplit") -> "HSCDataSetSplit":
"""Takes the logical and of this object and the passed in object. self is modified, the passed in
object is not

If the self object selects indicies 1,2 and 4 and the passed in object selects indicies 2, 4, and 0
the self object would be modified to select indicies 2, and 4 only.

It is a RuntimeError to and two split objects that do not reference the same underlying HSCDataSet

Parameters
----------
obj : HSCDataSetSplit
The object to and with
"""
if self.data != obj.data:
msg = "Tried to take logical and of two HSCDataSetSplits with different HSCDataSet objects"
raise RuntimeError(msg)

self.mask = np.logical_and(self.mask, obj.mask)
self.indexes = np.nonzero(self.mask)[0]
return self

def __copy__(self) -> "HSCDataSetSplit":
# Create a HSCDataSetSplit with no data selected, but the same data source as self
copy_object = HSCDataSetSplit(self.data, 0.0)

# Copy mask and indexes over
copy_object.mask = self.mask.copy()
copy_object.indexes = self.indexes.copy()

# Copy RNG state over.
copy_object.rng = deepcopy(self.rng)

return copy_object

def __len__(self) -> int:
return len(self.indexes)

def __getitem__(self, idx: int) -> torch.Tensor:
return self.data[self.indexes[idx]]

Check warning on line 240 in src/fibad/data_sets/hsc_data_set.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_sets/hsc_data_set.py#L240

Added line #L240 was not covered by tests


class HSCDataSetContainer(Dataset):
def __init__(self, config):
# TODO: What will be a reasonable set of tranformations?
# For now tanh all the values so they end up in [-1,1]
Expand Down Expand Up @@ -280,7 +503,7 @@
return self._object_id_to_tensor(object_id)

def __contains__(self, object_id: str) -> bool:
"""Allows you to do `object_id in dataset` queries
"""Allows you to do `object_id in dataset` queries. Used by testing code.

Parameters
----------
Expand Down
42 changes: 39 additions & 3 deletions src/fibad/fibad_default_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,16 @@ mask = false
# e.g. "user_package.submodule.ExternalModel" or "ExampleAutoencoder"
name = "ExampleAutoencoder"

weights_filepath = "example_model.pth"
epochs = 10

base_channel_size = 32
latent_dim = 64

[train]
weights_filepath = "example_model.pth"
epochs = 10
# Set this to the path of a checkpoint file to resume, or continue training,
# from a checkpoint. Otherwise, set to false to start from scratch.
resume = false
split = "train"

[data_set]
# Name of the built-in data loader to use or the libpath to an external data loader
Expand Down Expand Up @@ -92,6 +93,41 @@ batch_size = 32
shuffle = true
num_workers = 2

[prepare]
# How to split the data between training and eval sets.
# The semantics are borrowed from scikit-learn's train-test-split, and HF Dataset's train-test-split function
# It is an error for these values to add to more than 1.0 as ratios or the size of the dataset if expressed
# as integers.

# train_size: Size of the train split
# If `float`, should be between `0.0` and `1.0` and represent the proportion of the dataset to include in the
# train split.
# If `int`, represents the absolute number of train samples.
# If `false`, the value is automatically set to the complement of the test size.
train_size = 0.6

# validate_size: Size of the validation split
# If `float`, should be between `0.0` and `1.0` and represent the proportion of the dataset to include in the
# train split.
# If `int`, represents the absolute number of train samples.
# If `false`, and both train_size and test_size are defined, the value is automatically set to the complement
# of the other two sizes summed.
# If `false`, and only one of the other sizes is defined, no validate split is created
validate_size = 0.2

# test_size: Size of the test split
# If `float`, should be between `0.0` and `1.0` and represent the proportion of the dataset to include in the
# test split.
# If `int`, represents the absolute number of test samples.
# If `false`, the value is set to the complement of the train size.
# If `train_size` is also `false`, it will be set to `0.25`.
test_size = 0.6

# Number to seed with for generating a random split. False means the data will be seeded from
# a system source at runtime.
seed = false

[predict]
model_weights_file = false
batch_size = 32
split = "test"
3 changes: 2 additions & 1 deletion src/fibad/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
The parsed config file as a nested dict
"""

model, data_set = setup_model_and_dataset(config)
model, data_set = setup_model_and_dataset(config, split=config["predict"]["split"])
logger.info(f"data set has length {len(data_set)}")

Check warning on line 23 in src/fibad/predict.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/predict.py#L22-L23

Added lines #L22 - L23 were not covered by tests
data_loader = dist_data_loader(data_set, config)

# Create a results directory and dump our config there
Expand Down
10 changes: 6 additions & 4 deletions src/fibad/pytorch_ignite.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
logger = logging.getLogger(__name__)


def setup_model_and_dataset(config: ConfigDict) -> tuple:
def setup_model_and_dataset(config: ConfigDict, split: str) -> tuple:

Check warning on line 20 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L20

Added line #L20 was not covered by tests
"""
Construct the dataset and the model according to configuration.

Expand All @@ -27,6 +27,8 @@
----------
config : ConfigDict
The entire runtime config
split : str
The name of the split we want to use from the data set.

Returns
-------
Expand All @@ -35,7 +37,7 @@
"""
# Fetch data loader class specified in config and create an instance of it
data_set_cls = fetch_data_set_class(config)
data_set = data_set_cls(config)
data_set = data_set_cls(config, split)

Check warning on line 40 in src/fibad/pytorch_ignite.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/pytorch_ignite.py#L40

Added line #L40 was not covered by tests

# Fetch model class specified in config and create an instance of it
model_cls = fetch_model_class(config)
Expand Down Expand Up @@ -210,8 +212,8 @@
greater_or_equal=True,
)

if config["model"]["resume"]:
prev_checkpoint = torch.load(config["model"]["resume"], map_location=device)
if config["train"]["resume"]:
prev_checkpoint = torch.load(config["train"]["resume"], map_location=device)
Checkpoint.load_objects(to_load=to_save, checkpoint=prev_checkpoint)

@trainer.on(Events.STARTED)
Expand Down
6 changes: 3 additions & 3 deletions src/fibad/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
results_dir = create_results_dir(config, "train")
log_runtime_config(config, results_dir)

model, data_set = setup_model_and_dataset(config)
model, data_set = setup_model_and_dataset(config, split=config["train"]["split"])

Check warning on line 22 in src/fibad/train.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/train.py#L22

Added line #L22 was not covered by tests
data_loader = dist_data_loader(data_set, config)

# Create trainer, a pytorch-ignite `Engine` object
trainer = create_trainer(model, config, results_dir)

# Run the training process
trainer.run(data_loader, max_epochs=config["model"]["epochs"])
trainer.run(data_loader, max_epochs=config["train"]["epochs"])

Check warning on line 29 in src/fibad/train.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/train.py#L29

Added line #L29 was not covered by tests

# Save the trained model
model.save(results_dir / config["model"]["weights_filepath"])
model.save(results_dir / config["train"]["weights_filepath"])

Check warning on line 32 in src/fibad/train.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/train.py#L32

Added line #L32 was not covered by tests

logger.info("Finished Training")
Loading