-
Notifications
You must be signed in to change notification settings - Fork 200
/
celeba.py
executable file
·143 lines (117 loc) · 4.2 KB
/
celeba.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# -*- coding: utf-8 -*-
# @Date : 10/6/19
# @Author : Xinyu Gong ([email protected])
# @Link : None
# @Version : 0.0
from functools import partial
import torch
import os
import PIL
from torchvision.datasets.vision import VisionDataset
from torchvision.datasets.utils import download_file_from_google_drive, check_integrity, verify_str_arg
from torch.utils.data import Dataset
import glob
class CelebA(Dataset):
""" pyTorch Dataset wrapper for the generic flat directory images dataset """
def __setup_files(self):
"""
private helper for setting up the files_list
:return: files => list of paths of files
"""
file_names = os.listdir(self.data_dir)
files = [] # initialize to empty list
for file_name in file_names:
possible_file = os.path.join(self.data_dir, file_name)
if os.path.isfile(possible_file):
files.append(possible_file)
# return the files list
return files
def __init__(self, root, transform=None):
"""
constructor for the class
:param data_dir: path to the directory containing the data
:param transform: transforms to be applied to the images
"""
# define the state of the object
self.data_dir = root
self.transform = transform
# setup the files for reading
self.files = self.__setup_files()
def __len__(self):
"""
compute the length of the dataset
:return: len => length of dataset
"""
return len(self.files)
def __getitem__(self, idx):
"""
obtain the image (read and transform)
:param idx: index of the file required
:return: img => image array
"""
from PIL import Image
# read the image:
img_name = self.files[idx]
if img_name[-4:] == ".npy":
img = np.load(img_name)
img = Image.fromarray(img.squeeze(0).transpose(1, 2, 0))
else:
img = Image.open(img_name)
# apply the transforms on the image
if self.transform is not None:
img = self.transform(img)
# return the image:
return img, img
class FFHQ(Dataset):
""" pyTorch Dataset wrapper for the generic flat directory images dataset """
def __setup_files(self):
"""
private helper for setting up the files_list
:return: files => list of paths of files
"""
file_names = glob.glob(os.path.join(self.data_dir, "./*/*.png")) + \
glob.glob(os.path.join(self.data_dir, "./*.jpg")) + \
[y for x in os.walk(self.data_dir) for y in glob.glob(os.path.join(x[0], "*.webp"))]
files = [] # initialize to empty list
for file_name in file_names:
possible_file = os.path.join(self.data_dir, file_name)
if os.path.isfile(possible_file):
files.append(possible_file)
# return the files list
return files
def __init__(self, root, transform=None):
"""
constructor for the class
:param data_dir: path to the directory containing the data
:param transform: transforms to be applied to the images
"""
# define the state of the object
self.data_dir = root
self.transform = transform
# setup the files for reading
self.files = self.__setup_files()
def __len__(self):
"""
compute the length of the dataset
:return: len => length of dataset
"""
return len(self.files)
def __getitem__(self, idx):
"""
obtain the image (read and transform)
:param idx: index of the file required
:return: img => image array
"""
from PIL import Image
# read the image:
img_name = self.files[idx]
if img_name[-4:] == ".npy":
img = np.load(img_name)
img = Image.fromarray(img.squeeze(0).transpose(1, 2, 0))
else:
img = Image.open(img_name)
# apply the transforms on the image
if self.transform is not None:
img = self.transform(img)
# return the image:
return img, img