Skip to content

Commit

Permalink
Bug fix in OPFDataset parsing (#9623)
Browse files Browse the repository at this point in the history
In the split of train, test and val, some numbers are hard-coded. This
had to be changed due to introduction of ``num_groups`` that limits the
number of samples that are downloaded.

I completely missed this and forgot to commit this change into the PR.
  • Loading branch information
kaarthiksundar authored Aug 23, 2024
1 parent 9250cba commit d2b175a
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions torch_geometric/datasets/opf.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,11 @@ def process(self) -> None:
data = self.pre_transform(data)

i = int(name.split('.')[0].split('_')[1])
if i < 270_000:
train_limit = int(15_000 * self.num_groups * 0.9)
val_limit = train_limit + int(15_000 * self.num_groups * 0.05)
if i < train_limit:
train_data_list.append(data)
elif i < 285_000:
elif i < val_limit:
val_data_list.append(data)
else:
test_data_list.append(data)
Expand Down

0 comments on commit d2b175a

Please sign in to comment.