-
Notifications
You must be signed in to change notification settings - Fork 22
/
dataset.py
49 lines (35 loc) · 1.46 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
from torch.utils import data
from PIL import Image
class LabeledDataset(data.Dataset):
def __init__(self, data_root, path_to_txt_file, transform):
self.data_root = data_root
with open(path_to_txt_file, 'r') as f:
self.file_list = f.readlines()
self.file_list = [row.rstrip() for row in self.file_list]
self.transform = transform
def __getitem__(self, idx):
image_path = os.path.join(self.data_root, self.file_list[idx].split()[0])
img = Image.open(image_path).convert('RGB')
target = int(self.file_list[idx].split()[1])
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.file_list)
class PoisonGenerationDataset(data.Dataset):
def __init__(self, data_root, path_to_txt_file, transform):
self.data_root = data_root
with open(path_to_txt_file, 'r') as f:
self.file_list = f.readlines()
self.file_list = [row.rstrip() for row in self.file_list]
self.transform = transform
def __getitem__(self, idx):
image_path = os.path.join(self.data_root, self.file_list[idx])
img = Image.open(image_path).convert('RGB')
# target = self.file_list[idx].split()[1]
if self.transform is not None:
img = self.transform(img)
return img, image_path
def __len__(self):
return len(self.file_list)