Skip to content
This repository has been archived by the owner on Oct 21, 2021. It is now read-only.

Commit

Permalink
fix: different image sizes between encoder et decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
pnsuau committed Dec 11, 2020
1 parent d792dca commit d9ae451
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 17 deletions.
35 changes: 19 additions & 16 deletions models/cycle_gan_sty2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from .modules import loss

import torch.nn.functional as nnf

class CycleGANSty2Model(BaseModel):

@staticmethod
Expand Down Expand Up @@ -247,6 +249,7 @@ def __init__(self, opt):
self.niter=0
self.mean_path_length_A = 0
self.mean_path_length_B = 0


def set_input(self, input):
AtoB = self.opt.direction == 'AtoB'
Expand All @@ -264,19 +267,19 @@ def forward(self):
#self.netDecoderG_B.eval()
if self.rec_noise > 0.0:
self.fake_B_noisy1 = self.gaussian(self.fake_B, self.rec_noise)
self.z_rec_A, self.n_rec_A = self.netG_B(self.fake_B_noisy1)
self.z_rec_A, self.n_rec_A = self.netG_B(nnf.interpolate(self.fake_B_noisy1, size=(self.opt.crop_size), mode='bicubic', align_corners=False))
else:
self.z_rec_A, self.n_rec_A = self.netG_B(self.fake_B)
self.z_rec_A, self.n_rec_A = self.netG_B(nnf.interpolate(self.fake_B, size=(self.opt.crop_size), mode='bicubic', align_corners=False))
self.rec_A = self.netDecoderG_B(self.z_rec_A,input_is_latent=True,truncation=self.truncation,truncation_latent=self.mean_latent_B, randomize_noise=False, noise=self.n_rec_A)[0]

self.z_fake_A, self.n_fake_A = self.netG_B(self.real_B)
self.fake_A,self.latent_fake_A = self.netDecoderG_B(self.z_fake_A,input_is_latent=True,truncation=self.truncation,truncation_latent=self.mean_latent_B,randomize_noise=False,return_latents=True,noise=self.n_fake_A)

if self.rec_noise > 0.0:
self.fake_A_noisy1 = self.gaussian(self.fake_A, self.rec_noise)
self.z_rec_B, self.n_rec_B = self.netG_A(self.fake_A_noisy1)
self.z_rec_B, self.n_rec_B = self.netG_A(nnf.interpolate(self.fake_A_noisy1, size=(self.opt.crop_size), mode='bicubic', align_corners=False))
else:
self.z_rec_B, self.n_rec_B = self.netG_A(self.fake_A)
self.z_rec_B, self.n_rec_B = self.netG_A(nnf.interpolate(self.fake_A, size=(self.opt.crop_size), mode='bicubic', align_corners=False))
self.rec_B = self.netDecoderG_A(self.z_rec_B,input_is_latent=True,truncation=self.truncation,truncation_latent=self.mean_latent_A, randomize_noise=False, noise=self.n_rec_B)[0]

def backward_G(self):
Expand All @@ -291,27 +294,27 @@ def backward_G(self):
self.z_idt_A, self.n_idt_A = self.netG_A(self.real_B)
self.idt_A = self.netDecoderG_A(self.z_idt_A,input_is_latent=True,truncation=self.truncation,truncation_latent=self.mean_latent_A,randomize_noise=False,noise=self.n_idt_A)[0]

self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
self.loss_idt_A = self.criterionIdt(nnf.interpolate(self.idt_A, size=(self.opt.crop_size), mode='bicubic', align_corners=False), self.real_B) * lambda_B * lambda_idt
if self.percept_loss:
self.loss_idt_A += self.criterionIdt2(self.idt_A, self.real_B) * lambda_B * lambda_idt
self.loss_idt_A += self.criterionIdt2(nnf.interpolate(self.idt_A, size=(self.opt.crop_size), mode='bicubic', align_corners=False), self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed.
self.z_idt_B, self.n_idt_B = self.netG_B(self.real_A)
self.idt_B = self.netDecoderG_B(self.z_idt_B,input_is_latent=True,truncation=self.truncation,truncation_latent=self.mean_latent_B,randomize_noise=False,noise=self.n_idt_B)[0]
self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
self.loss_idt_B = self.criterionIdt(nnf.interpolate(self.idt_B, size=(self.opt.crop_size), mode='bicubic', align_corners=False), self.real_A) * lambda_A * lambda_idt
if self.percept_loss:
self.loss_idt_B += self.criterionIdt2(self.idt_B, self.real_A) * lambda_A * lambda_idt
self.loss_idt_B += self.criterionIdt2(nnf.interpolate(self.idt_B, size=(self.opt.crop_size), mode='bicubic', align_corners=False), self.real_A) * lambda_A * lambda_idt
else:
self.loss_idt_A = 0
self.loss_idt_B = 0

# Forward cycle loss
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
self.loss_cycle_A = self.criterionCycle(nnf.interpolate(self.rec_A, size=(self.opt.crop_size), mode='bicubic', align_corners=False), self.real_A) * lambda_A
if self.percept_loss:
self.loss_cycle_A += self.criterionCycle2(self.rec_A, self.real_A) * lambda_A
self.loss_cycle_A += self.criterionCycle2(nnf.interpolate(self.rec_A, size=(self.opt.crop_size), mode='bicubic', align_corners=False), self.real_A) * lambda_A
# Backward cycle loss
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
self.loss_cycle_B = self.criterionCycle(nnf.interpolate(self.rec_B, size=(self.opt.crop_size), mode='bicubic', align_corners=False), self.real_B) * lambda_B
if self.percept_loss:
self.loss_cycle_B += self.criterionCycle2(self.rec_B, self.real_B) * lambda_B
self.loss_cycle_B += self.criterionCycle2(nnf.interpolate(self.rec_B, size=(self.opt.crop_size), mode='bicubic', align_corners=False), self.real_B) * lambda_B
# combined loss standard cyclegan
self.loss_G = self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B

Expand All @@ -322,7 +325,7 @@ def backward_G(self):
compute_g_regularize = False

#A
self.fake_pred_g_loss_A = self.netDiscriminatorDecoderG_A(self.fake_A)
self.fake_pred_g_loss_A = self.netDiscriminatorDecoderG_A(nnf.interpolate(self.fake_A, size=(self.opt.crop_size), mode='bicubic', align_corners=False))
self.loss_g_nonsaturating_A = self.g_nonsaturating_loss(self.fake_pred_g_loss_A)

if compute_g_regularize:
Expand All @@ -342,7 +345,7 @@ def backward_G(self):
self.loss_weighted_path_A = 0#*self.loss_weighted_path_A

#B
self.fake_pred_g_loss_B = self.netDiscriminatorDecoderG_B(self.fake_B)
self.fake_pred_g_loss_B = self.netDiscriminatorDecoderG_B(nnf.interpolate(self.fake_B, size=(self.opt.crop_size), mode='bicubic', align_corners=False))
self.loss_g_nonsaturating_B = self.g_nonsaturating_loss(self.fake_pred_g_loss_B)

if compute_g_regularize:
Expand Down Expand Up @@ -417,11 +420,11 @@ def backward_G(self):

def backward_discriminator_decoder(self):
real_pred_A = self.netDiscriminatorDecoderG_A(self.real_A)
fake_pred_A = self.netDiscriminatorDecoderG_A(self.fake_A_pool.query(self.fake_A))
fake_pred_A = self.netDiscriminatorDecoderG_A(nnf.interpolate(self.fake_A_pool.query(self.fake_A), size=(self.opt.crop_size), mode='bicubic', align_corners=False))
self.loss_d_dec_A = self.d_logistic_loss(real_pred_A,fake_pred_A).unsqueeze(0)

real_pred_B = self.netDiscriminatorDecoderG_B(self.real_B)
fake_pred_B = self.netDiscriminatorDecoderG_B(self.fake_B_pool.query(self.fake_B))
fake_pred_B = self.netDiscriminatorDecoderG_B(nnf.interpolate(self.fake_B_pool.query(self.fake_B), size=(self.opt.crop_size), mode='bicubic', align_corners=False))
self.loss_d_dec_B = self.d_logistic_loss(real_pred_B,fake_pred_B).unsqueeze(0)
self.loss_d_dec = self.loss_d_dec_A + self.loss_d_dec_B

Expand Down
5 changes: 4 additions & 1 deletion util/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
else:
VisdomExceptionBase = ConnectionError

import torch.nn.functional as nnf

def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
"""Save images to the disk.
Expand Down Expand Up @@ -84,6 +85,8 @@ def __init__(self, opt):
now = time.strftime("%c")
log_file.write('================ Training Loss (%s) ================\n' % now)

self.crop_size=opt.crop_size

def reset(self):
"""Reset the self.saved status"""
self.saved = False
Expand Down Expand Up @@ -127,7 +130,7 @@ def display_current_results(self, visuals, epoch, save_result,params=[]):
param_html_row = ''

for label, image in visuals.items():
image_numpy = util.tensor2im(image)
image_numpy = util.tensor2im(nnf.interpolate(image, size=(self.crop_size), mode='bicubic', align_corners=False))
label_html_row += '<td>%s</td>' % label
images.append(image_numpy.transpose([2, 0, 1]))
idx += 1
Expand Down

0 comments on commit d9ae451

Please sign in to comment.