Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the example of super_resolution #2885

Merged
merged 19 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions examples/super_resolution/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Super-Resolution using an efficient sub-pixel convolutional neural network

ported from [pytorch-examples](https://github.com/pytorch/examples/tree/main/super_resolution)

This example illustrates how to use the efficient sub-pixel convolution layer described in ["Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network" - Shi et al.](https://arxiv.org/abs/1609.05158) for increasing spatial resolution within your network for tasks such as superresolution.

```
usage: main.py [-h] --upscale_factor UPSCALE_FACTOR [--batchSize BATCHSIZE]
[--testBatchSize TESTBATCHSIZE] [--nEpochs NEPOCHS] [--lr LR]
[--cuda] [--threads THREADS] [--seed SEED]
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

PyTorch Super Res Example

optional arguments:
-h, --help show this help message and exit
--upscale_factor super resolution upscale factor
--batch_size training batch size
--test_batch_size testing batch size
--n_epochs number of epochs to train for
--lr Learning Rate. Default=0.01
--cuda use cuda
--mps enable GPU on macOS
--threads number of threads for data loader to use Default=4
--seed random seed to use. Default=123
```

This example trains a super-resolution network on the [Cifar10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html). A snapshot of the model after every epoch with filename `model_epoch_<epoch_number>.pth`

## Example Usage:

### Train

`python main.py --upscale_factor 3 --batch_size 4 --test_batch_size 100 --n_epochs 30 --lr 0.001`

### Super Resolve

`python super_resolve.py --input_image dataset/BSDS300/images/test/16077.jpg --model model_epoch_500.pth --output_filename out.png`
183 changes: 183 additions & 0 deletions examples/super_resolution/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import argparse

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from model import Net
from torch.utils.data import DataLoader
from torchvision.transforms import CenterCrop, Compose, Resize, ToTensor

from ignite.engine import Engine, Events
from ignite.metrics import PSNR

# Training settings
parser = argparse.ArgumentParser(description="PyTorch Super Res Example")
parser.add_argument("--upscale_factor", type=int, required=True, help="super resolution upscale factor")
parser.add_argument("--batch_size", type=int, default=64, help="training batch size")
parser.add_argument("--test_batch_size", type=int, default=10, help="testing batch size")
parser.add_argument("--n_epochs", type=int, default=2, help="number of epochs to train for")
parser.add_argument("--lr", type=float, default=0.01, help="Learning Rate. Default=0.01")
parser.add_argument("--cuda", action="store_true", help="use cuda?")
parser.add_argument("--mps", action="store_true", default=False, help="enables macOS GPU training")
parser.add_argument("--threads", type=int, default=4, help="number of threads for data loader to use")
parser.add_argument("--seed", type=int, default=123, help="random seed to use. Default=123")
opt = parser.parse_args()

print(opt)

if opt.cuda and not torch.cuda.is_available():
raise Exception("No GPU found, please run without --cuda")
if not opt.mps and torch.backends.mps.is_available():
raise Exception("Found mps device, please run with --mps to enable macOS GPU")

torch.manual_seed(opt.seed)
use_mps = opt.mps and torch.backends.mps.is_available()

if opt.cuda:
device = torch.device("cuda")
elif use_mps:
device = torch.device("mps")
else:
device = torch.device("cpu")

print("===> Loading datasets")


class SRDataset(torch.utils.data.Dataset):
def __init__(self, dataset, scale_factor, input_transform=None, target_transform=None):
self.dataset = dataset
self.input_transform = input_transform
self.target_transform = target_transform

def __getitem__(self, index):
image, _ = self.dataset[index]
img = image.convert("YCbCr")
lr_image, _, _ = img.split()

hr_image = lr_image.copy()
if self.input_transform:
lr_image = self.input_transform(lr_image)
if self.target_transform:
hr_image = self.target_transform(hr_image)
return lr_image, hr_image

def __len__(self):
return len(self.dataset)


def calculate_valid_crop_size(crop_size, upscale_factor):
return crop_size - (crop_size % upscale_factor)


def input_transform(crop_size, upscale_factor):
return Compose(
[
CenterCrop(crop_size),
Resize(crop_size // upscale_factor),
ToTensor(),
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
]
)


def target_transform(crop_size):
return Compose(
[
CenterCrop(crop_size),
ToTensor(),
]
)


crop_size = calculate_valid_crop_size(256, opt.upscale_factor)

trainset = torchvision.datasets.Caltech101(root="./data", download=True)
testset = torchvision.datasets.Caltech101(root="./data", download=False)

trainset_sr = SRDataset(
trainset,
scale_factor=opt.upscale_factor,
input_transform=input_transform(crop_size, opt.upscale_factor),
target_transform=target_transform(crop_size),
)
testset_sr = SRDataset(
testset,
scale_factor=opt.upscale_factor,
input_transform=input_transform(crop_size, opt.upscale_factor),
target_transform=target_transform(crop_size),
)

training_data_loader = DataLoader(dataset=trainset_sr, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True)
testing_data_loader = DataLoader(
dataset=testset_sr, num_workers=opt.threads, batch_size=opt.test_batch_size, shuffle=False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dataset=testset_sr, num_workers=opt.threads, batch_size=opt.test_batch_size, shuffle=False
dataset=testset_sr, num_workers=opt.threads, batch_size=opt.test_batch_size

)

print("===> Building model")
model = Net(upscale_factor=opt.upscale_factor).to(device)
criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=opt.lr)


def train_step(engine, batch):
input, target = batch[0].to(device), batch[1].to(device)
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

optimizer.zero_grad()
loss = criterion(model(input), target)
loss.backward()
optimizer.step()

return loss.item()


def validation_step(engine, batch):
model.eval()
with torch.no_grad():
x, y = batch[0].to(device), batch[1].to(device)
y_pred = model(x)

return y_pred, y


trainer = Engine(train_step)
evaluator = Engine(validation_step)
psnr = PSNR(data_range=1)
psnr.attach(evaluator, "psnr")
validate_every = 1
log_interval = 100


@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
def log_training_loss(engine):
print(
"===> Epoch[{}]({}/{}): Loss: {:.4f}".format(
engine.state.epoch, engine.state.iteration, len(training_data_loader), engine.state.output
)
)


@trainer.on(Events.EPOCH_COMPLETED(every=validate_every))
def log_validation():
evaluator.run(testing_data_loader)
metrics = evaluator.state.metrics
print(f"Epoch: {trainer.state.epoch}, Avg. PSNR: {metrics['psnr']} dB")


@trainer.on(Events.EPOCH_COMPLETED)
def log_epoch_time():
print(f"Epoch {trainer.state.epoch}, Time Taken : {trainer.state.times['EPOCH_COMPLETED']}")


@trainer.on(Events.COMPLETED)
def log_total_time():
print(f"Total Time: {trainer.state.times['COMPLETED']}")


@trainer.on(Events.EPOCH_COMPLETED)
def checkpoint():
model_out_path = "model_epoch_{}.pth".format(trainer.state.epoch)
torch.save(model, model_out_path)
print("Checkpoint saved to {}".format(model_out_path))


trainer.run(training_data_loader, opt.n_epochs)
29 changes: 29 additions & 0 deletions examples/super_resolution/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch.nn as nn
import torch.nn.init as init


class Net(nn.Module):
def __init__(self, upscale_factor):
super(Net, self).__init__()

self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

self._initialize_weights()

def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
x = self.pixel_shuffle(self.conv4(x))
return x

def _initialize_weights(self):
init.orthogonal_(self.conv1.weight, init.calculate_gain("relu"))
init.orthogonal_(self.conv2.weight, init.calculate_gain("relu"))
init.orthogonal_(self.conv3.weight, init.calculate_gain("relu"))
init.orthogonal_(self.conv4.weight)
40 changes: 40 additions & 0 deletions examples/super_resolution/super_resolve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import argparse

import numpy as np
import torch
from PIL import Image
from torchvision.transforms import ToTensor

# Training settings
parser = argparse.ArgumentParser(description="PyTorch Super Res Example")
parser.add_argument("--input_image", type=str, required=True, help="input image to use")
parser.add_argument("--model", type=str, required=True, help="model file to use")
parser.add_argument("--output_filename", type=str, help="where to save the output image")
parser.add_argument("--cuda", action="store_true", help="use cuda")
opt = parser.parse_args()

print(opt)
img = Image.open(opt.input_image).convert("YCbCr")
y, cb, cr = img.split()

model = torch.load(opt.model)
img_to_tensor = ToTensor()
input = img_to_tensor(y).view(1, -1, y.size[1], y.size[0])
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

if opt.cuda:
model = model.cuda()
input = input.cuda()

out = model(input)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
out = model(input)
model.eval()
with torch.no_grad():
out = model(input)

out = out.cpu()
out_img_y = out[0].detach().numpy()
out_img_y *= 255.0
out_img_y = out_img_y.clip(0, 255)
out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode="L")

out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC)
out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
out_img = Image.merge("YCbCr", [out_img_y, out_img_cb, out_img_cr]).convert("RGB")

out_img.save(opt.output_filename)
print("output image saved to ", opt.output_filename)