-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
52 lines (42 loc) · 1.35 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#!/usr/bin/env python3
# Author: Armit
# Create Time: 2023/12/11
from PIL import Image
from PIL.Image import Image as PILImage
from pathlib import Path
from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module
import numpy as np
from numpy import ndarray
from numpy.typing import NDArray
torch.set_float32_matmul_precision('medium')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
BASE_PATH = Path(__file__).parent
IMG_PATH = BASE_PATH / 'img'
LOG_PATH = BASE_PATH / 'log'
npimg_u8 = NDArray[np.uint8]
npimg_f32 = NDArray[np.float32]
npimg = Union[npimg_u8, npimg_f32]
exists = lambda x: x is not None
pos_or_none = lambda x: x if x > 0 else None
def load_pil(fp:Path) -> PILImage:
return Image.open(fp).convert('RGB')
def pil_to_npimg(img:PILImage, dtype=np.float32) -> npimg:
assert dtype in [np.uint8, np.float32, 'uint8', 'float32']
im = np.asarray(img, dtype=np.uint8)
if dtype in [np.float32, 'float32']:
im = im.astype(np.float32) / 255.0
return im
def npimg_to_pil(im:npimg) -> PILImage:
assert isinstance(im, ndarray)
assert len(im.shape) == 3 and im.shape[-1] == 3
if im.dtype == np.float32:
assert 0.0 <= im.min() and im.max() <= 1.0
im = np.asarray(im * 255, dtype=np.uint8)
else:
assert im.dtype == np.uint8
return Image.fromarray(im)