|
import torch |
|
from collections import OrderedDict |
|
from torch.nn import utils, functional as F |
|
from torch.optim import Adam, SGD |
|
from torch.autograd import Variable |
|
from torch.backends import cudnn |
|
from model import build_model, weights_init |
|
import scipy.misc as sm |
|
import numpy as np |
|
import os |
|
import torchvision.utils as vutils |
|
import cv2 |
|
import torch.nn.functional as F |
|
import math |
|
import time |
|
import sys |
|
import PIL.Image |
|
import scipy.io |
|
import os |
|
import logging |
|
EPSILON = 1e-8 |
|
p = OrderedDict() |
|
|
|
from dataset import get_loader |
|
base_model_cfg = 'resnet' |
|
p['lr_bone'] = 5e-5 |
|
p['lr_branch'] = 0.025 |
|
p['wd'] = 0.0005 |
|
p['momentum'] = 0.90 |
|
lr_decay_epoch = [15, 24] |
|
nAveGrad = 10 |
|
showEvery = 50 |
|
tmp_path = 'tmp_see' |
|
|
|
|
|
class Solver(object): |
|
def __init__(self, train_loader, test_loader, config, save_fold=None): |
|
self.train_loader = train_loader |
|
self.test_loader = test_loader |
|
self.config = config |
|
self.save_fold = save_fold |
|
self.mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1, 1) / 255. |
|
|
|
if config.visdom: |
|
self.visual = Viz_visdom("trueUnify", 1) |
|
self.build_model() |
|
if self.config.pre_trained: self.net.load_state_dict(torch.load(self.config.pre_trained)) |
|
if config.mode == 'train': |
|
self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w') |
|
else: |
|
print('Loading pre-trained model from %s...' % self.config.model) |
|
self.net_bone.load_state_dict(torch.load(self.config.model)) |
|
self.net_bone.eval() |
|
|
|
def print_network(self, model, name): |
|
num_params = 0 |
|
for p in model.parameters(): |
|
num_params += p.numel() |
|
print(name) |
|
print(model) |
|
print("The number of parameters: {}".format(num_params)) |
|
|
|
def get_params(self, base_lr): |
|
ml = [] |
|
for name, module in self.net_bone.named_children(): |
|
print(name) |
|
if name == 'loss_weight': |
|
ml.append({'params': module.parameters(), 'lr': p['lr_branch']}) |
|
else: |
|
ml.append({'params': module.parameters()}) |
|
return ml |
|
|
|
|
|
def build_model(self): |
|
self.net_bone = build_model(base_model_cfg) |
|
if self.config.cuda: |
|
self.net_bone = self.net_bone.cuda() |
|
|
|
self.net_bone.eval() |
|
self.net_bone.apply(weights_init) |
|
if self.config.mode == 'train': |
|
if self.config.load_bone == '': |
|
if base_model_cfg == 'vgg': |
|
self.net_bone.base.load_pretrained_model(torch.load(self.config.vgg)) |
|
elif base_model_cfg == 'resnet': |
|
self.net_bone.base.load_state_dict(torch.load(self.config.resnet)) |
|
if self.config.load_bone != '': self.net_bone.load_state_dict(torch.load(self.config.load_bone)) |
|
|
|
self.lr_bone = p['lr_bone'] |
|
self.lr_branch = p['lr_branch'] |
|
self.optimizer_bone = Adam(filter(lambda p: p.requires_grad, self.net_bone.parameters()), lr=self.lr_bone, weight_decay=p['wd']) |
|
|
|
self.print_network(self.net_bone, 'trueUnify bone part') |
|
|
|
|
|
def update_lr(self, rate): |
|
for param_group in self.optimizer.param_groups: |
|
param_group['lr'] = param_group['lr'] * rate |
|
|
|
|
|
def test(self, test_mode=0): |
|
EPSILON = 1e-8 |
|
img_num = len(self.test_loader) |
|
time_t = 0.0 |
|
name_t = 'EGNet_ResNet50/' |
|
|
|
if not os.path.exists(os.path.join(self.save_fold, name_t)): |
|
os.mkdir(os.path.join(self.save_fold, name_t)) |
|
for i, data_batch in enumerate(self.test_loader): |
|
self.config.test_fold = self.save_fold |
|
print(self.config.test_fold) |
|
images_, name, im_size = data_batch['image'], data_batch['name'][0], np.asarray(data_batch['size']) |
|
|
|
with torch.no_grad(): |
|
|
|
images = Variable(images_) |
|
if self.config.cuda: |
|
images = images.cuda() |
|
print(images.size()) |
|
time_start = time.time() |
|
up_edge, up_sal, up_sal_f = self.net_bone(images) |
|
torch.cuda.synchronize() |
|
time_end = time.time() |
|
print(time_end - time_start) |
|
time_t = time_t + time_end - time_start |
|
pred = np.squeeze(torch.sigmoid(up_sal_f[-1]).cpu().data.numpy()) |
|
multi_fuse = 255 * pred |
|
|
|
|
|
|
|
cv2.imwrite(os.path.join(self.config.test_fold,name_t, name[:-4] + '.png'), multi_fuse) |
|
|
|
print("--- %s seconds ---" % (time_t)) |
|
print('Test Done!') |
|
|
|
|
|
|
|
def train(self): |
|
iter_num = len(self.train_loader.dataset) // self.config.batch_size |
|
aveGrad = 0 |
|
F_v = 0 |
|
if not os.path.exists(tmp_path): |
|
os.mkdir(tmp_path) |
|
for epoch in range(self.config.epoch): |
|
r_edge_loss, r_sal_loss, r_sum_loss= 0,0,0 |
|
self.net_bone.zero_grad() |
|
for i, data_batch in enumerate(self.train_loader): |
|
sal_image, sal_label, sal_edge = data_batch['sal_image'], data_batch['sal_label'], data_batch['sal_edge'] |
|
if sal_image.size()[2:] != sal_label.size()[2:]: |
|
print("Skip this batch") |
|
continue |
|
sal_image, sal_label, sal_edge = Variable(sal_image), Variable(sal_label), Variable(sal_edge) |
|
if self.config.cuda: |
|
sal_image, sal_label, sal_edge = sal_image.cuda(), sal_label.cuda(), sal_edge.cuda() |
|
|
|
up_edge, up_sal, up_sal_f = self.net_bone(sal_image) |
|
|
|
edge_loss = [] |
|
for ix in up_edge: |
|
edge_loss.append(bce2d_new(ix, sal_edge, reduction='sum')) |
|
edge_loss = sum(edge_loss) / (nAveGrad * self.config.batch_size) |
|
r_edge_loss += edge_loss.data |
|
|
|
sal_loss1= [] |
|
sal_loss2 = [] |
|
for ix in up_sal: |
|
sal_loss1.append(F.binary_cross_entropy_with_logits(ix, sal_label, reduction='sum')) |
|
|
|
for ix in up_sal_f: |
|
sal_loss2.append(F.binary_cross_entropy_with_logits(ix, sal_label, reduction='sum')) |
|
sal_loss = (sum(sal_loss1) + sum(sal_loss2)) / (nAveGrad * self.config.batch_size) |
|
|
|
r_sal_loss += sal_loss.data |
|
loss = sal_loss + edge_loss |
|
r_sum_loss += loss.data |
|
loss.backward() |
|
aveGrad += 1 |
|
|
|
if aveGrad % nAveGrad == 0: |
|
|
|
self.optimizer_bone.step() |
|
self.optimizer_bone.zero_grad() |
|
aveGrad = 0 |
|
|
|
|
|
if i % showEvery == 0: |
|
|
|
print('epoch: [%2d/%2d], iter: [%5d/%5d] || Edge : %10.4f || Sal : %10.4f || Sum : %10.4f' % ( |
|
epoch, self.config.epoch, i, iter_num, r_edge_loss*(nAveGrad * self.config.batch_size)/showEvery, |
|
r_sal_loss*(nAveGrad * self.config.batch_size)/showEvery, |
|
r_sum_loss*(nAveGrad * self.config.batch_size)/showEvery)) |
|
|
|
print('Learning rate: ' + str(self.lr_bone)) |
|
r_edge_loss, r_sal_loss, r_sum_loss= 0,0,0 |
|
|
|
if i % 200 == 0: |
|
|
|
vutils.save_image(torch.sigmoid(up_sal_f[-1].data), tmp_path+'/iter%d-sal-0.jpg' % i, normalize=True, padding = 0) |
|
|
|
vutils.save_image(sal_image.data, tmp_path+'/iter%d-sal-data.jpg' % i, padding = 0) |
|
vutils.save_image(sal_label.data, tmp_path+'/iter%d-sal-target.jpg' % i, padding = 0) |
|
|
|
if (epoch + 1) % self.config.epoch_save == 0: |
|
torch.save(self.net_bone.state_dict(), '%s/models/epoch_%d_bone.pth' % (self.config.save_fold, epoch + 1)) |
|
|
|
if epoch in lr_decay_epoch: |
|
self.lr_bone = self.lr_bone * 0.1 |
|
self.optimizer_bone = Adam(filter(lambda p: p.requires_grad, self.net_bone.parameters()), lr=self.lr_bone, weight_decay=p['wd']) |
|
|
|
|
|
torch.save(self.net_bone.state_dict(), '%s/models/final_bone.pth' % self.config.save_fold) |
|
|
|
def bce2d_new(input, target, reduction=None): |
|
assert(input.size() == target.size()) |
|
pos = torch.eq(target, 1).float() |
|
neg = torch.eq(target, 0).float() |
|
|
|
|
|
num_pos = torch.sum(pos) |
|
num_neg = torch.sum(neg) |
|
num_total = num_pos + num_neg |
|
|
|
alpha = num_neg / num_total |
|
beta = 1.1 * num_pos / num_total |
|
|
|
|
|
weights = alpha * pos + beta * neg |
|
|
|
return F.binary_cross_entropy_with_logits(input, target, weights, reduction=reduction) |
|
|
|
|