Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the example of super_resolution #2885

Merged
merged 19 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions examples/super_resolution/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Super-Resolution using an efficient sub-pixel convolutional neural network

ported from [pytorch-examples](https://github.com/pytorch/examples/tree/main/super_resolution)

This example illustrates how to use the efficient sub-pixel convolution layer described in ["Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network" - Shi et al.](https://arxiv.org/abs/1609.05158) for increasing spatial resolution within your network for tasks such as superresolution.

```
usage: main.py [-h] --upscale_factor UPSCALE_FACTOR [--batch_size BATCHSIZE]
[--test_batch_size TESTBATCHSIZE] [--n_epochs NEPOCHS] [--lr LR]
[--cuda] [--threads THREADS] [--seed SEED]

PyTorch Super Res Example

optional arguments:
-h, --help show this help message and exit
--upscale_factor super resolution upscale factor
--batch_size training batch size
--test_batch_size testing batch size
--n_epochs number of epochs to train for
--lr Learning Rate. Default=0.01
--cuda use cuda
--mps enable GPU on macOS
--threads number of threads for data loader to use Default=4
--seed random seed to use. Default=123
```

This example trains a super-resolution network on the [Caltech101 dataset](https://pytorch.org/vision/main/generated/torchvision.datasets.Caltech101.html). A snapshot of the model after every epoch with filename `model_epoch_<epoch_number>.pth`

## Example Usage:

### Train

`python main.py --upscale_factor 3 --batch_size 4 --test_batch_size 100 --n_epochs 30 --lr 0.001`

### Super Resolve

`python super_resolve.py --input_image <in>.jpg --model model_epoch_500.pth --output_filename out.png`
148 changes: 148 additions & 0 deletions examples/super_resolution/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import argparse

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from model import Net
from torch.utils.data import DataLoader
from torchvision.transforms.functional import center_crop, resize, to_tensor

from ignite.engine import Engine, Events
from ignite.metrics import PSNR

# Training settings
parser = argparse.ArgumentParser(description="PyTorch Super Res Example")
parser.add_argument("--upscale_factor", type=int, required=True, help="super resolution upscale factor")
parser.add_argument("--batch_size", type=int, default=64, help="training batch size")
parser.add_argument("--test_batch_size", type=int, default=10, help="testing batch size")
parser.add_argument("--n_epochs", type=int, default=2, help="number of epochs to train for")
parser.add_argument("--lr", type=float, default=0.01, help="Learning Rate. Default=0.01")
parser.add_argument("--cuda", action="store_true", help="use cuda?")
parser.add_argument("--mps", action="store_true", default=False, help="enables macOS GPU training")
parser.add_argument("--threads", type=int, default=4, help="number of threads for data loader to use")
parser.add_argument("--seed", type=int, default=123, help="random seed to use. Default=123")
opt = parser.parse_args()

print(opt)

if opt.cuda and not torch.cuda.is_available():
raise Exception("No GPU found, please run without --cuda")
if not opt.mps and torch.backends.mps.is_available():
raise Exception("Found mps device, please run with --mps to enable macOS GPU")

torch.manual_seed(opt.seed)
use_mps = opt.mps and torch.backends.mps.is_available()

if opt.cuda:
device = torch.device("cuda")
elif use_mps:
device = torch.device("mps")
else:
device = torch.device("cpu")

print("===> Loading datasets")


class SRDataset(torch.utils.data.Dataset):
def __init__(self, dataset, scale_factor, crop_size=256):
self.dataset = dataset
self.scale_factor = scale_factor
self.crop_size = crop_size

def __getitem__(self, index):
image, _ = self.dataset[index]
img = image.convert("YCbCr")
hr_image, _, _ = img.split()
hr_image = center_crop(hr_image, self.crop_size)
lr_image = hr_image.copy()
if self.scale_factor != 1:
size = self.crop_size // self.scale_factor
lr_image = resize(lr_image, [size, size])
hr_image = to_tensor(hr_image)
lr_image = to_tensor(lr_image)
return lr_image, hr_image

def __len__(self):
return len(self.dataset)


trainset = torchvision.datasets.Caltech101(root="./data", download=True)
testset = torchvision.datasets.Caltech101(root="./data", download=False)

trainset_sr = SRDataset(trainset, scale_factor=opt.upscale_factor)
testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor)

training_data_loader = DataLoader(dataset=trainset_sr, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True)
testing_data_loader = DataLoader(dataset=testset_sr, num_workers=opt.threads, batch_size=opt.test_batch_size)

print("===> Building model")
model = Net(upscale_factor=opt.upscale_factor).to(device)
criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=opt.lr)


def train_step(engine, batch):
model.train()
input, target = batch[0].to(device), batch[1].to(device)
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

optimizer.zero_grad()
loss = criterion(model(input), target)
loss.backward()
optimizer.step()

return loss.item()


def validation_step(engine, batch):
model.eval()
with torch.no_grad():
x, y = batch[0].to(device), batch[1].to(device)
y_pred = model(x)

return y_pred, y


trainer = Engine(train_step)
evaluator = Engine(validation_step)
psnr = PSNR(data_range=1)
psnr.attach(evaluator, "psnr")
validate_every = 1
log_interval = 100


@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
def log_training_loss(engine):
print(
"===> Epoch[{}]({}/{}): Loss: {:.4f}".format(
engine.state.epoch, engine.state.iteration, len(training_data_loader), engine.state.output
)
)


@trainer.on(Events.EPOCH_COMPLETED(every=validate_every))
def log_validation():
evaluator.run(testing_data_loader)
metrics = evaluator.state.metrics
print(f"Epoch: {trainer.state.epoch}, Avg. PSNR: {metrics['psnr']} dB")


@trainer.on(Events.EPOCH_COMPLETED)
def log_epoch_time():
print(f"Epoch {trainer.state.epoch}, Time Taken : {trainer.state.times['EPOCH_COMPLETED']}")


@trainer.on(Events.COMPLETED)
def log_total_time():
print(f"Total Time: {trainer.state.times['COMPLETED']}")


@trainer.on(Events.EPOCH_COMPLETED)
def checkpoint():
model_out_path = "model_epoch_{}.pth".format(trainer.state.epoch)
torch.save(model, model_out_path)
print("Checkpoint saved to {}".format(model_out_path))


trainer.run(training_data_loader, opt.n_epochs)
29 changes: 29 additions & 0 deletions examples/super_resolution/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch.nn as nn
import torch.nn.init as init


class Net(nn.Module):
def __init__(self, upscale_factor):
super(Net, self).__init__()

self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

self._initialize_weights()

def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
x = self.pixel_shuffle(self.conv4(x))
return x

def _initialize_weights(self):
init.orthogonal_(self.conv1.weight, init.calculate_gain("relu"))
init.orthogonal_(self.conv2.weight, init.calculate_gain("relu"))
init.orthogonal_(self.conv3.weight, init.calculate_gain("relu"))
init.orthogonal_(self.conv4.weight)
41 changes: 41 additions & 0 deletions examples/super_resolution/super_resolve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import argparse

import numpy as np
import torch
from PIL import Image
from torchvision.transforms.functional import to_tensor

# Training settings
parser = argparse.ArgumentParser(description="PyTorch Super Res Example")
parser.add_argument("--input_image", type=str, required=True, help="input image to use")
parser.add_argument("--model", type=str, required=True, help="model file to use")
parser.add_argument("--output_filename", type=str, help="where to save the output image")
parser.add_argument("--cuda", action="store_true", help="use cuda")
opt = parser.parse_args()

print(opt)
img = Image.open(opt.input_image).convert("YCbCr")
y, cb, cr = img.split()

model = torch.load(opt.model)
input = to_tensor(y).view(1, -1, y.size[1], y.size[0])

if opt.cuda:
model = model.cuda()
input = input.cuda()

model.eval()
with torch.no_grad():
out = model(input)
out = out.cpu()
out_img_y = out[0].detach().numpy()
out_img_y *= 255.0
out_img_y = out_img_y.clip(0, 255)
out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode="L")

out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC)
out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
out_img = Image.merge("YCbCr", [out_img_y, out_img_cb, out_img_cr]).convert("RGB")

out_img.save(opt.output_filename)
print("output image saved to ", opt.output_filename)