Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LiteFlowNet with cupy instead of FlowNet2 with compiled modules #63

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,21 @@
<img src="https://github.com/nbei/Deep-Flow-Guided-Video-Inpainting/blob/master/gif/captain.gif" width="860"/>

## Install & Requirements
The code has been tested on pytorch=0.4.0 and python3.6. Please refer to `requirements.txt` for detailed information.
The code has been tested on pytorch=1.3.0 and python3.6. Please refer to `requirements.txt` for detailed information.

Alternatively, you can run it with the provided [Docker image](docker/README.md).

**To Install python packages**
```
pip install -r requirements.txt
```
**To Install flownet2 modules**
```
bash install_scripts.sh
```
**To Install liteflownet modules**

The correlation layer for LiteFlowNet is implemented in CUDA using CuPy. Install it using `pip install cupy` or install one of the provided binaries (listed [here](https://docs-cupy.chainer.org/en/stable/install.html#install-cupy)).
## Componets
There exist three components in this repo:
* Video Inpainting Tool: DFVI
* Extract Flow: FlowNet2(modified by [Nvidia official version](https://github.com/NVIDIA/flownet2-pytorch/tree/python36-PyTorch0.4))
* Extract Flow: LiteFlowNet([Pytorch version](https://github.com/sniklaus/pytorch-liteflownet) reimplemented from [LiteFlowNet](https://github.com/twhui/LiteFlowNet))
* Image Inpainting(reimplemented from [Deepfillv1](https://github.com/JiahuiYu/generative_inpainting))

## Usage
Expand All @@ -28,19 +27,19 @@ and the mask of each frame should be put into `xxx/video_name/masks`.
And please download the resources of the demo and model weights from [here](https://drive.google.com/drive/folders/1a2FrHIQGExJTHXxSIibZOGMukNrypr_g?usp=sharing).
An example demo containing frames and masks has been put into the demo and running the following command will get the result:
```
python tools/video_inpaint.py --frame_dir ./demo/frames --MASK_ROOT ./demo/masks --img_size 512 832 --FlowNet2 --DFC --ResNet101 --Propagation
python tools/video_inpaint.py --frame_dir ./demo/frames --MASK_ROOT ./demo/masks --img_size 512 832 --LiteFlowNet --DFC --ResNet101 --Propagation
```
<img src="https://github.com/nbei/Deep-Flow-Guided-Video-Inpainting/blob/master/gif/flamingo.gif" width="850"/>

We provide the original model weight used in our movie demo which use ResNet101 as backbone and other related weights pls download from [here](https://drive.google.com/drive/folders/1a2FrHIQGExJTHXxSIibZOGMukNrypr_g?usp=sharing).
We provide the original model weight used in our movie demo which use ResNet101 as backbone and other related weights pls download from [here](https://drive.google.com/drive/folders/1a2FrHIQGExJTHXxSIibZOGMukNrypr_g?usp=sharing). Weights for LiteFlowNet are hosted by [sniklaus](https://github.com/sniklaus): [default](http://content.sniklaus.com/github/pytorch-liteflownet/network-default.pytorch), [kitti](http://content.sniklaus.com/github/pytorch-liteflownet/network-kitti.pytorch), [sintel](http://content.sniklaus.com/github/pytorch-liteflownet/network-sintel.pytorch).
Please refer to [tools](https://github.com/nbei/Deep-Flow-Guided-Video-Inpainting/tree/master/tools) for detailed use and training settings.

* For fixed region inpainting, we provide the model weights of refined stages in DAVIS. Please download the lady-running resources [link](https://drive.google.com/drive/folders/1GHV1g1IkpGa2qhRnZE2Fv30RXrbHPH0O?usp=sharing) and
model weights [link](https://drive.google.com/drive/folders/1zIamN-DzvknZLf5QAGCfvWs7a6qUqaaC?usp=sharing). The following command can help you to get the result:
```
CUDA_VISIBLE_DEVICES=0 python tools/video_inpaint.py --frame_dir ./demo/lady-running/frames \
--MASK_ROOT ./demo/lady-running/mask_bbox.png \
--img_size 448 896 --DFC --FlowNet2 --Propagation \
--img_size 448 896 --DFC --LiteFlowNet --Propagation \
--PRETRAINED_MODEL_1 ./pretrained_models/resnet50_stage1.pth \
--PRETRAINED_MODEL_2 ./pretrained_models/DAVIS_model/davis_stage2.pth \
--PRETRAINED_MODEL_3 ./pretrained_models/DAVIS_model/davis_stage3.pth \
Expand All @@ -51,7 +50,7 @@ You can just change the **th_warp** param for getting better results in your vid

* To extract flow for videos:
```
python tools/infer_flownet2.py --frame_dir xxx/video_name/frames
python tools/infer_liteflownet.py --frame_dir xxx/video_name/frames
```

* To use the Deepfillv1-Pytorch model for image inpainting,
Expand Down
21 changes: 10 additions & 11 deletions dataset/FlowInfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import cv2
import numpy as np
import torch.utils.data
from PIL import Image


class FlowInfer(torch.utils.data.Dataset):
Expand Down Expand Up @@ -34,22 +35,20 @@ def __len__(self):
return len(self.frame1_list)

def __getitem__(self, idx):
frame1 = cv2.imread(self.frame1_list[idx])
frame2 = cv2.imread(self.frame2_list[idx])
frame1 = np.array(self._img_tf(Image.open(self.frame1_list[idx])))/255
frame2 = np.array(self._img_tf(Image.open(self.frame2_list[idx])))/255

output_path = self.output_list[idx]

if self.isRGB:
frame1 = frame1[:, :, ::-1]
frame2 = frame2[:, :, ::-1]
output_path = self.output_list[idx]

frame1 = self._img_tf(frame1)
frame2 = self._img_tf(frame2)

frame1_tensor = torch.from_numpy(frame1).permute(2, 0, 1).contiguous().float()
frame2_tensor = torch.from_numpy(frame2).permute(2, 0, 1).contiguous().float()
frame1_tensor = torch.from_numpy(frame1.transpose(2, 0, 1).copy()).contiguous().float()
frame2_tensor = torch.from_numpy(frame2.transpose(2, 0, 1).copy()).contiguous().float()

return frame1_tensor, frame2_tensor, output_path

def _img_tf(self, img):
img = cv2.resize(img, (self.size[1], self.size[0]))

return img
#img = cv2.resize(img, (self.size[1], self.size[0]))
return img.resize((self.size[0], self.size[1]), Image.BILINEAR)
4 changes: 2 additions & 2 deletions dataset/FlowInitial.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __getitem__(self, idx):
tmp_flow = cvb.read_flow(flow_dir[i])
if self.config.get_mask:
tmp_mask = cv2.imread(mask_dir[i],
cv2.IMREAD_UNCHANGED)
cv2.IMREAD_COLOR)
tmp_mask = self._mask_tf(tmp_mask)
else:
if self.config.FIX_MASK:
Expand All @@ -89,7 +89,7 @@ def __getitem__(self, idx):

if self.config.INITIAL_HOLE:
tmp_flow_resized = cv2.resize(tmp_flow, (self.size[1] // 2, self.size[0] // 2))
tmp_mask_resized = cv2.resize(tmp_mask, (self.size[1] // 2, self.size[0] // 2), cv2.INTER_NEAREST)
tmp_mask_resized = cv2.resize(tmp_mask, (self.size[1] // 2, self.size[0] // 2), interpolation=cv2.INTER_NEAREST)
tmp_flow_masked_small = tmp_flow_resized
tmp_flow_masked_small[:, :, 0] = rf.regionfill(tmp_flow_resized[:, :, 0], tmp_mask_resized)
tmp_flow_masked_small[:, :, 1] = rf.regionfill(tmp_flow_resized[:, :, 1], tmp_mask_resized)
Expand Down
4 changes: 2 additions & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ scipy==1.2.1 \
six==1.12.0 \
tensorboardX==1.8 \
terminaltables==3.1.0 \
torch==0.4.0 \
torchvision==0.2.1 \
torch==1.3.0 \
torchvision==0.4.0 \
tqdm==4.32.1 \
urllib3==1.25.3

Expand Down
8 changes: 0 additions & 8 deletions install_scripts.sh

This file was deleted.

55 changes: 31 additions & 24 deletions models/DeepFill_Models/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def forward(self, f, b, mask=None, ksize=3, stride=1,
kernel = 2*self.rate
raw_w = self.extract_patches(b, kernel=kernel, stride=self.rate)
raw_w = raw_w.permute(0, 2, 3, 4, 5, 1)
raw_w = raw_w.contiguous().view(raw_int_bs[0], raw_int_bs[2] / self.rate, raw_int_bs[3] / self.rate, -1)
raw_w = raw_w.contiguous().view(raw_int_bs[0], raw_int_bs[2] // self.rate, raw_int_bs[3] // self.rate, -1)
raw_w = raw_w.contiguous().view(raw_int_bs[0], -1, kernel, kernel, raw_int_bs[1])
raw_w = raw_w.permute(0, 1, 4, 2, 3)

Expand All @@ -268,7 +268,7 @@ def forward(self, f, b, mask=None, ksize=3, stride=1,
int_bs = list(b.size())
w = self.extract_patches(b)
w = w.permute(0, 2, 3, 4, 5, 1)
w = w.contiguous().view(raw_int_bs[0], raw_int_bs[2] / self.rate, raw_int_bs[3] / self.rate, -1)
w = w.contiguous().view(raw_int_bs[0], raw_int_bs[2] // self.rate, raw_int_bs[3] // self.rate, -1)
w = w.contiguous().view(raw_int_bs[0], -1, ksize, ksize, raw_int_bs[1])
w = w.permute(0, 1, 4, 2, 3)
# process mask
Expand All @@ -282,7 +282,7 @@ def forward(self, f, b, mask=None, ksize=3, stride=1,
m = self.extract_patches(mask)

m = m.permute(0, 2, 3, 4, 5, 1)
m = m.contiguous().view(raw_int_bs[0], raw_int_bs[2]/self.rate, raw_int_bs[3]/self.rate, -1)
m = m.contiguous().view(raw_int_bs[0], raw_int_bs[2]//self.rate, raw_int_bs[3]//self.rate, -1)
m = m.contiguous().view(raw_int_bs[0], -1, ksize, ksize, 1)
m = m.permute(0, 4, 1, 2, 3)

Expand Down Expand Up @@ -398,28 +398,35 @@ def reduce_sum(x):
x = reduce_sum(x)
return torch.sqrt(x)



# Pytorch>=0.4.1, when use the old down_sample func, I found some problems in the results
# to check the reason
def down_sample(x, size=None, scale_factor=None, mode='nearest', device=None):
# define size if user has specified scale_factor
if size is None: size = (int(scale_factor*x.size(2)), int(scale_factor*x.size(3)))
# create coordinates
# size_origin = [x.size[2], x.size[3]]
h = torch.arange(0, size[0]) / (size[0]) * 2 - 1
w = torch.arange(0, size[1]) / (size[1]) * 2 - 1
# create grid
grid =torch.zeros(size[0],size[1],2)
grid[:,:,0] = w.unsqueeze(0).repeat(size[0],1)
grid[:,:,1] = h.unsqueeze(0).repeat(size[1],1).transpose(0,1)
# expand to match batch size
grid = grid.unsqueeze(0).repeat(x.size(0),1,1,1)
if x.is_cuda:
if device:
grid = Variable(grid).cuda(device)
else:
grid = Variable(grid).cuda()
# do sampling

return F.grid_sample(x, grid, mode=mode)
res = F.interpolate(x, scale_factor=scale_factor, mode=mode)
return res


# def down_sample(x, size=None, scale_factor=None, mode='nearest', device=None):
# # define size if user has specified scale_factor
# if size is None: size = (int(scale_factor*x.size(2)), int(scale_factor*x.size(3)))
# # create coordinates
# # size_origin = [x.size[2], x.size[3]]
# h = torch.arange(0, size[0]) / (size[0]) * 2 - 1
# w = torch.arange(0, size[1]) / (size[1]) * 2 - 1
# # create grid
# grid =torch.zeros(size[0],size[1],2)
# grid[:,:,0] = w.unsqueeze(0).repeat(size[0],1)
# grid[:,:,1] = h.unsqueeze(0).repeat(size[1],1).transpose(0,1)
# # expand to match batch size
# grid = grid.unsqueeze(0).repeat(x.size(0),1,1,1)
# if x.is_cuda:
# if device:
# grid = Variable(grid).cuda(device)
# else:
# grid = Variable(grid).cuda()
# # do sampling

# return F.grid_sample(x, grid, mode=mode)


def to_var(x, volatile=False, device=None):
Expand Down
Loading