-
Notifications
You must be signed in to change notification settings - Fork 0
/
augmentations.py
64 lines (52 loc) · 2.11 KB
/
augmentations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from torchvision import transforms
# Prediction preprocess
preprocess_test = transforms.Compose([
transforms.Resize((256, 256)),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
preprocess_train = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(degrees=30),
transforms.RandomApply([transforms.GaussianBlur(kernel_size=(3, 3))], p=0.3),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def get_transforms(resize=(256, 256), single_channel=False, training=True, prediction=False):
# Create list of transformations
transfom_list = list()
# Transform to pil
if prediction:
transfom_list.append(transforms.ToPILImage())
# Resize
transfom_list.append(transforms.Resize(resize))
# Transformations for training
if training:
transfom_list.append(transforms.RandomHorizontalFlip())
transfom_list.append(transforms.RandomVerticalFlip())
transfom_list.append(transforms.RandomRotation(degrees=30))
transfom_list.append(transforms.RandomApply([transforms.GaussianBlur(kernel_size=(3, 3))], p=0.3))
# Grayscale
if single_channel:
transfom_list.append(transforms.Grayscale(num_output_channels=1))
else:
transfom_list.append(transforms.Grayscale(num_output_channels=3))
# To tensor
transfom_list.append(transforms.ToTensor())
# Normalize
if single_channel:
transfom_list.append(transforms.Normalize(mean=[0.485], std=[0.229]))
else:
transfom_list.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
# Compose transformations
TRANSFORM = transforms.Compose(transfom_list)
return TRANSFORM