diff --git a/torch_geometric/datasets/opf.py b/torch_geometric/datasets/opf.py index fb91f677e207..d6a47a9670aa 100644 --- a/torch_geometric/datasets/opf.py +++ b/torch_geometric/datasets/opf.py @@ -205,9 +205,11 @@ def process(self) -> None: data = self.pre_transform(data) i = int(name.split('.')[0].split('_')[1]) - if i < 270_000: + train_limit = int(15_000 * self.num_groups * 0.9) + val_limit = train_limit + int(15_000 * self.num_groups * 0.05) + if i < train_limit: train_data_list.append(data) - elif i < 285_000: + elif i < val_limit: val_data_list.append(data) else: test_data_list.append(data)