diff --git a/models/cycle_gan_sty2_model.py b/models/cycle_gan_sty2_model.py index 166d5f06d1b..be6bc18346a 100644 --- a/models/cycle_gan_sty2_model.py +++ b/models/cycle_gan_sty2_model.py @@ -17,6 +17,8 @@ from .modules import loss +import torch.nn.functional as nnf + class CycleGANSty2Model(BaseModel): @staticmethod @@ -247,6 +249,7 @@ def __init__(self, opt): self.niter=0 self.mean_path_length_A = 0 self.mean_path_length_B = 0 + def set_input(self, input): AtoB = self.opt.direction == 'AtoB' @@ -264,9 +267,9 @@ def forward(self): #self.netDecoderG_B.eval() if self.rec_noise > 0.0: self.fake_B_noisy1 = self.gaussian(self.fake_B, self.rec_noise) - self.z_rec_A, self.n_rec_A = self.netG_B(self.fake_B_noisy1) + self.z_rec_A, self.n_rec_A = self.netG_B(nnf.interpolate(self.fake_B_noisy1, size=(self.opt.crop_size), mode='bicubic', align_corners=False)) else: - self.z_rec_A, self.n_rec_A = self.netG_B(self.fake_B) + self.z_rec_A, self.n_rec_A = self.netG_B(nnf.interpolate(self.fake_B, size=(self.opt.crop_size), mode='bicubic', align_corners=False)) self.rec_A = self.netDecoderG_B(self.z_rec_A,input_is_latent=True,truncation=self.truncation,truncation_latent=self.mean_latent_B, randomize_noise=False, noise=self.n_rec_A)[0] self.z_fake_A, self.n_fake_A = self.netG_B(self.real_B) @@ -274,9 +277,9 @@ def forward(self): if self.rec_noise > 0.0: self.fake_A_noisy1 = self.gaussian(self.fake_A, self.rec_noise) - self.z_rec_B, self.n_rec_B = self.netG_A(self.fake_A_noisy1) + self.z_rec_B, self.n_rec_B = self.netG_A(nnf.interpolate(self.fake_A_noisy1, size=(self.opt.crop_size), mode='bicubic', align_corners=False)) else: - self.z_rec_B, self.n_rec_B = self.netG_A(self.fake_A) + self.z_rec_B, self.n_rec_B = self.netG_A(nnf.interpolate(self.fake_A, size=(self.opt.crop_size), mode='bicubic', align_corners=False)) self.rec_B = self.netDecoderG_A(self.z_rec_B,input_is_latent=True,truncation=self.truncation,truncation_latent=self.mean_latent_A, randomize_noise=False, noise=self.n_rec_B)[0] def backward_G(self): @@ -291,27 +294,27 @@ def backward_G(self): self.z_idt_A, self.n_idt_A = self.netG_A(self.real_B) self.idt_A = self.netDecoderG_A(self.z_idt_A,input_is_latent=True,truncation=self.truncation,truncation_latent=self.mean_latent_A,randomize_noise=False,noise=self.n_idt_A)[0] - self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt + self.loss_idt_A = self.criterionIdt(nnf.interpolate(self.idt_A, size=(self.opt.crop_size), mode='bicubic', align_corners=False), self.real_B) * lambda_B * lambda_idt if self.percept_loss: - self.loss_idt_A += self.criterionIdt2(self.idt_A, self.real_B) * lambda_B * lambda_idt + self.loss_idt_A += self.criterionIdt2(nnf.interpolate(self.idt_A, size=(self.opt.crop_size), mode='bicubic', align_corners=False), self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. self.z_idt_B, self.n_idt_B = self.netG_B(self.real_A) self.idt_B = self.netDecoderG_B(self.z_idt_B,input_is_latent=True,truncation=self.truncation,truncation_latent=self.mean_latent_B,randomize_noise=False,noise=self.n_idt_B)[0] - self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt + self.loss_idt_B = self.criterionIdt(nnf.interpolate(self.idt_B, size=(self.opt.crop_size), mode='bicubic', align_corners=False), self.real_A) * lambda_A * lambda_idt if self.percept_loss: - self.loss_idt_B += self.criterionIdt2(self.idt_B, self.real_A) * lambda_A * lambda_idt + self.loss_idt_B += self.criterionIdt2(nnf.interpolate(self.idt_B, size=(self.opt.crop_size), mode='bicubic', align_corners=False), self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # Forward cycle loss - self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A + self.loss_cycle_A = self.criterionCycle(nnf.interpolate(self.rec_A, size=(self.opt.crop_size), mode='bicubic', align_corners=False), self.real_A) * lambda_A if self.percept_loss: - self.loss_cycle_A += self.criterionCycle2(self.rec_A, self.real_A) * lambda_A + self.loss_cycle_A += self.criterionCycle2(nnf.interpolate(self.rec_A, size=(self.opt.crop_size), mode='bicubic', align_corners=False), self.real_A) * lambda_A # Backward cycle loss - self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B + self.loss_cycle_B = self.criterionCycle(nnf.interpolate(self.rec_B, size=(self.opt.crop_size), mode='bicubic', align_corners=False), self.real_B) * lambda_B if self.percept_loss: - self.loss_cycle_B += self.criterionCycle2(self.rec_B, self.real_B) * lambda_B + self.loss_cycle_B += self.criterionCycle2(nnf.interpolate(self.rec_B, size=(self.opt.crop_size), mode='bicubic', align_corners=False), self.real_B) * lambda_B # combined loss standard cyclegan self.loss_G = self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B @@ -322,7 +325,7 @@ def backward_G(self): compute_g_regularize = False #A - self.fake_pred_g_loss_A = self.netDiscriminatorDecoderG_A(self.fake_A) + self.fake_pred_g_loss_A = self.netDiscriminatorDecoderG_A(nnf.interpolate(self.fake_A, size=(self.opt.crop_size), mode='bicubic', align_corners=False)) self.loss_g_nonsaturating_A = self.g_nonsaturating_loss(self.fake_pred_g_loss_A) if compute_g_regularize: @@ -342,7 +345,7 @@ def backward_G(self): self.loss_weighted_path_A = 0#*self.loss_weighted_path_A #B - self.fake_pred_g_loss_B = self.netDiscriminatorDecoderG_B(self.fake_B) + self.fake_pred_g_loss_B = self.netDiscriminatorDecoderG_B(nnf.interpolate(self.fake_B, size=(self.opt.crop_size), mode='bicubic', align_corners=False)) self.loss_g_nonsaturating_B = self.g_nonsaturating_loss(self.fake_pred_g_loss_B) if compute_g_regularize: @@ -417,11 +420,11 @@ def backward_G(self): def backward_discriminator_decoder(self): real_pred_A = self.netDiscriminatorDecoderG_A(self.real_A) - fake_pred_A = self.netDiscriminatorDecoderG_A(self.fake_A_pool.query(self.fake_A)) + fake_pred_A = self.netDiscriminatorDecoderG_A(nnf.interpolate(self.fake_A_pool.query(self.fake_A), size=(self.opt.crop_size), mode='bicubic', align_corners=False)) self.loss_d_dec_A = self.d_logistic_loss(real_pred_A,fake_pred_A).unsqueeze(0) real_pred_B = self.netDiscriminatorDecoderG_B(self.real_B) - fake_pred_B = self.netDiscriminatorDecoderG_B(self.fake_B_pool.query(self.fake_B)) + fake_pred_B = self.netDiscriminatorDecoderG_B(nnf.interpolate(self.fake_B_pool.query(self.fake_B), size=(self.opt.crop_size), mode='bicubic', align_corners=False)) self.loss_d_dec_B = self.d_logistic_loss(real_pred_B,fake_pred_B).unsqueeze(0) self.loss_d_dec = self.loss_d_dec_A + self.loss_d_dec_B diff --git a/util/visualizer.py b/util/visualizer.py index b7785149772..f54afadd239 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -12,6 +12,7 @@ else: VisdomExceptionBase = ConnectionError +import torch.nn.functional as nnf def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): """Save images to the disk. @@ -84,6 +85,8 @@ def __init__(self, opt): now = time.strftime("%c") log_file.write('================ Training Loss (%s) ================\n' % now) + self.crop_size=opt.crop_size + def reset(self): """Reset the self.saved status""" self.saved = False @@ -127,7 +130,7 @@ def display_current_results(self, visuals, epoch, save_result,params=[]): param_html_row = '' for label, image in visuals.items(): - image_numpy = util.tensor2im(image) + image_numpy = util.tensor2im(nnf.interpolate(image, size=(self.crop_size), mode='bicubic', align_corners=False)) label_html_row += '%s' % label images.append(image_numpy.transpose([2, 0, 1])) idx += 1