Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruohan Gao authored Oct 12, 2020
1 parent d83fd5b commit bd4f4fd
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ def get_coseparation_loss(output, opt, loss_coseparation):
opt = TrainOptions().parse()
opt.device = torch.device("cuda")

if opt.with_additional_scene_image:
opt.number_of_classes = opt.number_of_classes + 1

#construct data loader
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
Expand Down Expand Up @@ -249,8 +252,6 @@ def get_coseparation_loss(output, opt, loss_coseparation):
input_nc=opt.unet_input_nc,
output_nc=opt.unet_output_nc,
weights=opt.weights_unet)
if opt.with_additional_scene_image:
opt.number_of_classes = opt.number_of_classes + 1
net_classifier = builder.build_classifier(
pool_type=opt.classifier_pool,
num_of_classes=opt.number_of_classes,
Expand Down

0 comments on commit bd4f4fd

Please sign in to comment.