-
Notifications
You must be signed in to change notification settings - Fork 0
/
net.py
31 lines (28 loc) · 881 Bytes
/
net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from math import sqrt
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
from utils import *
import os
from loss import *
# from TridentUNet.TridentUNet import *
from DNANet.DNANet import *
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
class Net(nn.Module):
def __init__(self, model_name, mode):
super(Net, self).__init__()
self.model_name = model_name
self.cal_loss = SoftIoULoss()
if model_name == 'DNANet':
if mode == 'train':
self.model = DNANet(mode='train')
else:
self.model = DNANet(mode='test')
else:
self.model = DNANet(mode='test')
def forward(self, img, str1=None,str2=None):
return self.model(img)
def loss(self, pred, gt_mask):
loss = self.cal_loss(pred, gt_mask)
return loss