Skip to content

Commit

Permalink
Addressing comments and fixing CI
Browse files Browse the repository at this point in the history
  • Loading branch information
mtauraso committed Oct 25, 2024
1 parent 8ee1d9b commit 352481a
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 19 deletions.
9 changes: 1 addition & 8 deletions src/fibad/data_sets/hsc_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __init__(

else:
# If we're splitting a normal hscdataset we generate a single mask with the appropriate values
self.mask = np.zeros(len(data), dtype=np.bool)
self.mask = np.zeros(len(data), dtype=bool)
self._flip_mask_values(length, "false_to_true")

self.indexes = np.nonzero(self.mask)[0]
Expand Down Expand Up @@ -239,13 +239,6 @@ def __len__(self) -> int:
def __getitem__(self, idx: int) -> torch.Tensor:
return self.data[self.indexes[idx]]

def ids(self):
for obj_id, index in zip(self.data.ids(), range(len(self.data))):
if self.mask[index] is True:
yield obj_id
else:
continue


class HSCDataSetContainer(Dataset):
def __init__(self, config):
Expand Down
22 changes: 11 additions & 11 deletions src/fibad/fibad_default_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,23 +98,15 @@ num_workers = 2
# The semantics are borrowed from scikit-learn's train-test-split, and HF Dataset's train-test-split function
# It is an error for these values to add to more than 1.0 as ratios or the size of the dataset if expressed
# as integers.
#
# test_size: Size of the test split
# If `float`, should be between `0.0` and `1.0` and represent the proportion of the dataset to include in the
# test split.
# If `int`, represents the absolute number of test samples.
# If `false`, the value is set to the complement of the train size.
# If `train_size` is also `false`, it will be set to `0.25`.
test_size = 0.6

# train_size: Size of the train split
# If `float`, should be between `0.0` and `1.0` and represent the proportion of the dataset to include in the
# train split.
# If `int`, represents the absolute number of train samples.
# If `false`, the value is automatically set to the complement of the test size.
train_size = 0.2
train_size = 0.6

# validation_size: Size of the validation split
# validate_size: Size of the validation split
# If `float`, should be between `0.0` and `1.0` and represent the proportion of the dataset to include in the
# train split.
# If `int`, represents the absolute number of train samples.
Expand All @@ -123,7 +115,15 @@ train_size = 0.2
# If `false`, and only one of the other sizes is defined, no validate split is created
validate_size = 0.2

# Random number to seed with for generating a random split. False means the data will be seeded from
# test_size: Size of the test split
# If `float`, should be between `0.0` and `1.0` and represent the proportion of the dataset to include in the
# test split.
# If `int`, represents the absolute number of test samples.
# If `false`, the value is set to the complement of the train size.
# If `train_size` is also `false`, it will be set to `0.25`.
test_size = 0.6

# Number to seed with for generating a random split. False means the data will be seeded from
# a system source at runtime.
seed = false

Expand Down
55 changes: 55 additions & 0 deletions tests/fibad/test_hsc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,35 @@ def test_split_no_validate_no_test():
a = HSCDataSet(config, split="validate")


def test_split_no_validate_no_train():
"""Test splitting when validate and train are overridden"""
test_files = generate_files(num_objects=100, num_filters=3, shape=(100, 100))
with FakeFitsFS(test_files):
config = mkconfig(filters=["HSC-G", "HSC-R", "HSC-I"], validate_size=False, train_size=False)

a = HSCDataSet(config, split="test")
assert len(a) == 60

a = HSCDataSet(config, split="train")
assert len(a) == 40

with pytest.raises(RuntimeError):
a = HSCDataSet(config, split="validate")


def test_split_invalid_ratio():
"""Test that split RuntimeErrors when provided with an invalid ratio"""
test_files = generate_files(num_objects=100, num_filters=3, shape=(100, 100))
with FakeFitsFS(test_files):
config = mkconfig(filters=["HSC-G", "HSC-R", "HSC-I"], validate_size=False, train_size=1.1)
with pytest.raises(RuntimeError):
HSCDataSet(config, split=None)

config = mkconfig(filters=["HSC-G", "HSC-R", "HSC-I"], validate_size=False, train_size=-0.1)
with pytest.raises(RuntimeError):
HSCDataSet(config, split=None)


def test_split_no_splits_configured():
"""Test splitting when all splits are overriden, and nothing is specified."""
test_files = generate_files(num_objects=100, num_filters=3, shape=(100, 100))
Expand Down Expand Up @@ -448,6 +477,19 @@ def test_split_values_configured_no_validate():
assert len(a) == 22


def test_split_invalid_configured():
"""Test that split RuntimeErrors when provided with an invalid datapoint count"""
test_files = generate_files(num_objects=100, num_filters=3, shape=(100, 100))
with FakeFitsFS(test_files):
config = mkconfig(filters=["HSC-G", "HSC-R", "HSC-I"], validate_size=False, train_size=120)
with pytest.raises(RuntimeError):
HSCDataSet(config, split=None)

config = mkconfig(filters=["HSC-G", "HSC-R", "HSC-I"], validate_size=False, train_size=-10)
with pytest.raises(RuntimeError):
HSCDataSet(config, split=None)


def test_split_values_rng():
"""Generate twice with the same RNG seed, verify same values are selected."""
test_files = generate_files(num_objects=100, num_filters=3, shape=(100, 100))
Expand Down Expand Up @@ -509,3 +551,16 @@ def test_split_and():
and_mask = np.logical_and(test_split.mask, train_split.mask)

assert all([a == b for a, b in zip(and_split.mask, and_mask)])


def test_split_and_conflicting_datasets():
"""Generate two splits from different data sets, and them together. Verify this RuntimeErrors"""
test_files = generate_files(num_objects=100, num_filters=3, shape=(100, 100))
with FakeFitsFS(test_files):
config = mkconfig(filters=["HSC-G", "HSC-R", "HSC-I"], validate_size=False, test_size=False)

a = HSCDataSet(config, split="test")
b = HSCDataSet(config, split="test")

with pytest.raises(RuntimeError):
a.current_split.logical_and(b.current_split)

0 comments on commit 352481a

Please sign in to comment.