Commit 213c7f8a authored by 魏博昱's avatar 魏博昱

1

parent b934062a
*.JPG
*.jpg
*.txt
datasets/
checkpoints/
output/
results/
.vscode/
log/
logs/
*.swp
*.pth
*.pyc
.idea/
*-checkpoint.py
*.ipynb_checkpoints/
masks/
resized_paris/
fakeB/
shifted/
MIT License
Copyright (c) 2018 Zhaoyi-Yan
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
\ No newline at end of file
This diff is collapsed.
#-*-coding:utf-8-*-
import os.path
import random
import torchvision.transforms as transforms
import torch
import random
from data.base_dataset import BaseDataset
from data.image_folder import make_dataset
from PIL import Image
class AlignedDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.dir_A = opt.dataroot
self.A_paths = sorted(make_dataset(self.dir_A))
if self.opt.offline_loading_mask:
self.mask_folder = self.opt.training_mask_folder if self.opt.isTrain else self.opt.testing_mask_folder
self.mask_paths = sorted(make_dataset(self.mask_folder))
assert(opt.resize_or_crop == 'resize_and_crop')
transform_list = [transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
self.transform = transforms.Compose(transform_list)
def __getitem__(self, index):
A_path = self.A_paths[index]
A = Image.open(A_path).convert('RGB')
w, h = A.size
##只切割###3
A = self.transform(A)
nw = int(w / self.opt.fineSize * 2)
nh = int(h / self.opt.fineSize * 2)
nw0 = int(w % self.opt.fineSize)
nw0 = int(h % self.opt.fineSize)
step = int(self.opt.fineSize / 2)
A_temp = torch.FloatTensor(nw * nh, 3, self.opt.fineSize, self.opt.fineSize).zero_()
for iw in range(nw):
for ih in range(nh):
if iw == nw-1 and ih == nh-1:
A_temp[iw * nh + ih, :, :, :] = A[:, w - self.opt.fineSize:w, h-self.opt.fineSize:h]
continue
if iw == nw-1 and ih != nh-1:
A_temp[iw * nh + ih, :, :, :] = A[:, w- self.opt.fineSize:w,ih * step:ih * step + self.opt.fineSize]
continue
if iw != nw-1 and ih == nh-1:
A_temp[iw * nh + ih, :, :, :] = A[:, iw * step:iw * step+self.opt.fineSize, h-self.opt.fineSize:h]
continue
A_temp[iw * nh + ih, :, :, :] = A[:, iw * step:iw * step+self.opt.fineSize, ih * step:ih * step+self.opt.fineSize]
A = A_temp
###end####
"重置大小,切割图像 bg"
'''
if w < h:
ht_1 = self.opt.loadSize * h // w
wd_1 = self.opt.loadSize
A = A.resize((wd_1, ht_1), Image.BICUBIC)
else:
wd_1 = self.opt.loadSize * w // h
ht_1 = self.opt.loadSize
A = A.resize((wd_1, ht_1), Image.BICUBIC)
A = self.transform(A)
h = A.size(1)
w = A.size(2)
w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1))
h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1))
A = A[:, h_offset:h_offset + self.opt.fineSize,
w_offset:w_offset + self.opt.fineSize]
'''
"重置大小,切割图像 end"
if (not self.opt.no_flip) and random.random() < 0.5:
A = torch.flip(A, [2])
# let B directly equals to A
B = A.clone()
A_flip = torch.flip(A, [2])
B_flip = A_flip.clone()
# Just zero the mask is fine if not offline_loading_mask.
mask = A.clone().zero_()
if self.opt.offline_loading_mask:
if self.opt.isTrain:
mask = Image.open(self.mask_paths[random.randint(0, len(self.mask_paths)-1)])
else:
mask = Image.open(self.mask_paths[index % len(self.mask_paths)])
mask = mask.resize((self.opt.fineSize, self.opt.fineSize), Image.NEAREST)
mask = transforms.ToTensor()(mask)
return {'A': A, 'B': B, 'A_F': A_flip, 'B_F': B_flip, 'M': mask,
'A_paths': A_path, 'im_size': [w, h]}
def __len__(self):
return len(self.A_paths)
def name(self):
return 'AlignedDataset'
#-*-coding:utf-8-*-
import os.path
import random
import torchvision.transforms as transforms
import torch
from data.base_dataset import BaseDataset
from data.image_folder import make_dataset
from PIL import Image
class AlignedDatasetResized(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
self.dir_A = opt.dataroot # More Flexible for users
self.A_paths = sorted(make_dataset(self.dir_A))
assert(opt.resize_or_crop == 'resize_and_crop')
transform_list = [transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
self.transform = transforms.Compose(transform_list)
def __getitem__(self, index):
A_path = self.A_paths[index]
A = Image.open(A_path).convert('RGB')
A = A.resize ((self.opt.fineSize, self.opt.fineSize), Image.BICUBIC)
A = self.transform(A)
#if (not self.opt.no_flip) and random.random() < 0.5:
# idx = [i for i in range(A.size(2) - 1, -1, -1)] # size(2)-1, size(2)-2, ... , 0
# idx = torch.LongTensor(idx)
# A = A.index_select(2, idx)
# let B directly equals A
B = A.clone()
return {'A': A, 'B': B,
'A_paths': A_path}
def __len__(self):
return len(self.A_paths)
def name(self):
return 'AlignedDatasetResized'
class BaseDataLoader():
def __init__(self):
pass
def initialize(self, opt):
self.opt = opt
pass
def load_data():
return None
import torch.utils.data as data
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return 'BaseDataset'
def initialize(self, opt):
pass
#-*-coding:utf-8-*-
import torch.utils.data
from data.base_data_loader import BaseDataLoader
def CreateDataset(opt):
dataset = None
if opt.dataset_mode == 'aligned':
from data.aligned_dataset import AlignedDataset
dataset = AlignedDataset()
elif opt.dataset_mode == 'aligned_resized':
from data.aligned_dataset_resized import AlignedDatasetResized
dataset = AlignedDatasetResized()
elif opt.dataset_mode == 'single':
from data.single_dataset import SingleDataset
dataset = SingleDataset()
else:
raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads))
def load_data(self):
return self
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
def __iter__(self):
for i, data in enumerate(self.dataloader):
if i*self.opt.batchSize >= self.opt.max_dataset_size:
break
yield data
\ No newline at end of file
def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
print(data_loader.name())
data_loader.initialize(opt)
return data_loader
###############################################################################
# Code from
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
# Modified the original code so that it also loads images from the current
# directory as well as the subdirectories
###############################################################################
import torch.utils.data as data
from PIL import Image
import os
import os.path
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)
import os.path
import torchvision.transforms as transforms
from data.base_dataset import BaseDataset
from data.image_folder import make_dataset
from PIL import Image
class SingleDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
self.dir_A = os.path.join(opt.dataroot)
# make_dataset returns paths of all images in one folder
self.A_paths = make_dataset(self.dir_A)
self.A_paths = sorted(self.A_paths)
transform_list = []
if opt.resize_or_crop == 'resize_and_crop':
transform_list.append(transforms.Scale(opt.loadSize))
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.RandomHorizontalFlip())
if opt.resize_or_crop != 'no_resize':
transform_list.append(transforms.RandomCrop(opt.fineSize))
# Make it between [-1, 1], beacuse [(0-0.5)/0.5, (1-0.5)/0.5]
transform_list += [transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
self.transform = transforms.Compose(transform_list)
def __getitem__(self, index):
A_path = self.A_paths[index]
A_img = Image.open(A_path).convert('RGB')
A = self.transform(A_img)
if self.opt.which_direction == 'BtoA':
input_nc = self.opt.output_nc
else:
input_nc = self.opt.input_nc
return {'A': A, 'A_paths': A_path}
def __len__(self):
return len(self.A_paths)
def name(self):
return 'SingleImageDataset'
# face model (Trained on CelebaHQ-256, the first 2k images are for testing, the rest are for training.)
wget -c https://drive.google.com/open?id=1qvsWHVO9iXpEAPtwyRB25mklTmD0jgPV
# face random mask model
wget -c https://drive.google.com/open?id=1Pz9gkm2VYaEK3qMXnszJufsvRqbcXrjS
# paris random mask model
wget -c https://drive.google.com/open?id=14MzixaqYUdJNL5xGdVhSKI9jOfvGdr3M
# paris center mask model
wget -c https://drive.google.com/open?id=1nDkCdsqUdiEXfSjZ_P915gWeZELK0fo_
import torch
# import numpy as np
from options.train_options import TrainOptions
import util.util as util
import os
from PIL import Image
import glob
mask_folder = 'masks/testing_masks'
test_folder = './datasets/Paris/test'
util.mkdir(mask_folder)
opt = TrainOptions().parse()
f = glob.glob(test_folder+'/*.png')
print(f)
for fl in f:
mask = torch.zeros(opt.fineSize, opt.fineSize)
if opt.mask_sub_type == 'fractal':
assert 1==2, "It is broken now..."
mask = util.create_walking_mask() # create an initial random mask.
elif opt.mask_sub_type == 'rect':
mask, rand_t, rand_l = util.create_rand_mask(opt)
elif opt.mask_sub_type == 'island':
mask = util.wrapper_gmask(opt)
print('Generating mask for test image: '+os.path.basename(fl))
util.save_image(mask.squeeze().numpy()*255, os.path.join(mask_folder, os.path.splitext(os.path.basename(fl))[0]+'_mask.png'))
def create_model(opt):
model = None
print(opt.model)
if opt.model == 'shiftnet':
assert (opt.dataset_mode == 'aligned' or opt.dataset_mode == 'aligned_resized')
from models.shift_net.shiftnet_model import ShiftNetModel
model = ShiftNetModel()
'''
elif opt.model == 'res_shiftnet':
assert (opt.dataset_mode == 'aligned' or opt.dataset_mode == 'aligned_resized')
from models.res_shift_net.shiftnet_model import ResShiftNetModel
model = ResShiftNetModel()
elif opt.model == 'patch_soft_shiftnet':
assert (opt.dataset_mode == 'aligned' or opt.dataset_mode == 'aligned_resized')
from models.patch_soft_shift.patch_soft_shiftnet_model import PatchSoftShiftNetModel
model = PatchSoftShiftNetModel()
elif opt.model == 'res_patch_soft_shiftnet':
assert (opt.dataset_mode == 'aligned' or opt.dataset_mode == 'aligned_resized')
from models.res_patch_soft_shift.res_patch_soft_shiftnet_model import ResPatchSoftShiftNetModel
model = ResPatchSoftShiftNetModel()
else:
raise ValueError("Model [%s] not recognized." % opt.model)
'''
model.initialize(opt)
print("model [%s] was created" % (model.name()))
return model
from .discrimators import *
from .losses import *
from .modules import *
from .shift_unet import *
from .unet import *
\ No newline at end of file
This diff is collapsed.
import functools
import torch.nn as nn
from .denset_net import *
from .modules import *
################################### This is for D ###################################
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, use_spectral_norm=True):
super(NLayerDiscriminator, self).__init__()
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = 1
sequence = [
spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), use_spectral_norm),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=use_bias), use_spectral_norm),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw, bias=use_bias), use_spectral_norm),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw), use_spectral_norm)]
if use_sigmoid:
sequence += [nn.Sigmoid()]
self.model = nn.Sequential(*sequence)
def forward(self, input):
return self.model(input)
# Defines a densetnet inspired discriminator (Should improve its ability to create stronger representation)
class DenseNetDiscrimator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, use_spectral_norm=True):
super(DenseNetDiscrimator, self).__init__()
self.model = densenet121(pretrained=True, use_spectral_norm=use_spectral_norm)
self.use_sigmoid = use_sigmoid
if self.use_sigmoid:
self.sigmoid = nn.Sigmoid()
def forward(self, input):
if self.use_sigmoid:
return self.sigmoid(self.model(input))
else:
return self.model(input)
import torch
import torch.nn as nn
import numpy as np
# Defines the GAN loss which uses either LSGAN or the regular GAN.
# When LSGAN is used, it is basically same as MSELoss,
# but it abstracts away the need to create the target label tensor
# that has the same size as the input
class GANLoss(nn.Module):
def __init__(self, gan_type='wgan_gp', target_real_label=1.0, target_fake_label=0.0):
super(GANLoss, self).__init__()
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
self.gan_type = gan_type
if gan_type == 'wgan_gp':
self.loss = nn.MSELoss()
elif gan_type == 'lsgan':
self.loss = nn.MSELoss()
elif gan_type == 'vanilla':
self.loss = nn.BCELoss()
#######################################################################
### Relativistic GAN - https://github.com/AlexiaJM/RelativisticGAN ###
#######################################################################
# When Using `BCEWithLogitsLoss()`, remove the sigmoid layer in D.
elif gan_type == 're_s_gan':
self.loss = nn.BCEWithLogitsLoss()
elif gan_type == 're_avg_gan':
self.loss = nn.BCEWithLogitsLoss()
else:
raise ValueError("GAN type [%s] not recognized." % gan_type)
def get_target_tensor(self, prediction, target_is_real):
if target_is_real:
target_tensor = self.real_label
else:
target_tensor = self.fake_label
return target_tensor.expand_as(prediction)
def __call__(self, prediction, target_is_real):
if self.gan_type == 'wgan_gp':
if target_is_real:
loss = -prediction.mean()
else:
loss = prediction.mean()
else:
target_tensor = self.get_target_tensor(prediction, target_is_real)
loss = self.loss(prediction, target_tensor)
return loss
################# Discounting loss #########################
######################################################
class Discounted_L1(nn.Module):
def __init__(self, opt):
super(Discounted_L1, self).__init__()
# Register discounting template as a buffer
self.register_buffer('discounting_mask', torch.tensor(spatial_discounting_mask(opt.fineSize//2 - opt.overlap * 2, opt.fineSize//2 - opt.overlap * 2, 0.9, opt.discounting)))
self.L1 = nn.L1Loss()
def forward(self, input, target):
self._assert_no_grad(target)
input_tmp = input * self.discounting_mask
target_tmp = target * self.discounting_mask
return self.L1(input_tmp, target_tmp)
def _assert_no_grad(self, variable):
assert not variable.requires_grad, \
"nn criterions don't compute the gradient w.r.t. targets - please " \
"mark these variables as volatile or not requiring gradients"
def spatial_discounting_mask(mask_width, mask_height, discounting_gamma, discounting=1):
"""Generate spatial discounting mask constant.
Spatial discounting mask is first introduced in publication:
Generative Image Inpainting with Contextual Attention, Yu et al.
Returns:
tf.Tensor: spatial discounting mask
"""
gamma = discounting_gamma
shape = [1, 1, mask_width, mask_height]
if discounting:
print('Use spatial discounting l1 loss.')
mask_values = np.ones((mask_width, mask_height), dtype='float32')
for i in range(mask_width):
for j in range(mask_height):
mask_values[i, j] = max(
gamma**min(i, mask_width-i),
gamma**min(j, mask_height-j))
mask_values = np.expand_dims(mask_values, 0)
mask_values = np.expand_dims(mask_values, 1)
mask_values = mask_values
else:
mask_values = np.ones(shape, dtype='float32')
return mask_values
class TVLoss(nn.Module):
def __init__(self, tv_loss_weight=1):
super(TVLoss, self).__init__()
self.tv_loss_weight = tv_loss_weight
def forward(self, x):
bz, _, h, w = x.size()
count_h = self._tensor_size(x[:, :, 1:, :])
count_w = self._tensor_size(x[:, :, :, 1:])
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h - 1, :]), 2).sum()
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w - 1]), 2).sum()
return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / bz
@staticmethod
def _tensor_size(t):
return t.size(1) * t.size(2) * t.size(3)
\ No newline at end of file
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Parameter
class Self_Attn (nn.Module):
""" Self attention Layer"""
'''
https://github.com/heykeetae/Self-Attention-GAN/blob/master/sagan_models.py
'''
def __init__(self, in_dim, activation, with_attention=False):
super (Self_Attn, self).__init__ ()
self.chanel_in = in_dim
self.activation = activation
self.with_attention = with_attention
self.query_conv = nn.Conv2d (in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
self.key_conv = nn.Conv2d (in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
self.value_conv = nn.Conv2d (in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = nn.Parameter (torch.zeros (1))
self.softmax = nn.Softmax (dim=-1) #
def forward(self, x):
"""
inputs :
x : input feature maps( B X C X W X H)
returns :
out : self attention value + input feature
attention: B X N X N (N is Width*Height)
"""
m_batchsize, C, width, height = x.size ()
proj_query = self.query_conv (x).view (m_batchsize, -1, width * height).permute (0, 2, 1) # B X CX(N)
proj_key = self.key_conv (x).view (m_batchsize, -1, width * height) # B X C x (*W*H)
energy = torch.bmm (proj_query, proj_key) # transpose check
attention = self.softmax (energy) # BX (N) X (N)
proj_value = self.value_conv (x).view (m_batchsize, -1, width * height) # B X C X N
out = torch.bmm (proj_value, attention.permute (0, 2, 1))
out = out.view (m_batchsize, C, width, height)
out = self.gamma * out + x
if self.with_attention:
return out, attention
else:
return out
def l2normalize(v, eps=1e-12):
return v / (v.norm() + eps)
def spectral_norm(module, mode=True):
if mode:
return nn.utils.spectral_norm(module)
return module
class SwitchNorm2d(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.9, using_moving_average=True, using_bn=True,
last_gamma=False):
super(SwitchNorm2d, self).__init__()
self.eps = eps
self.momentum = momentum
self.using_moving_average = using_moving_average
self.using_bn = using_bn
self.last_gamma = last_gamma
self.weight = nn.Parameter(torch.ones(1, num_features, 1, 1))
self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1))
if self.using_bn:
self.mean_weight = nn.Parameter(torch.ones(3))
self.var_weight = nn.Parameter(torch.ones(3))
else:
self.mean_weight = nn.Parameter(torch.ones(2))
self.var_weight = nn.Parameter(torch.ones(2))
if self.using_bn:
self.register_buffer('running_mean', torch.zeros(1, num_features, 1))
self.register_buffer('running_var', torch.zeros(1, num_features, 1))
self.reset_parameters()
def reset_parameters(self):
if self.using_bn:
self.running_mean.zero_()
self.running_var.zero_()
if self.last_gamma:
self.weight.data.fill_(0)
else:
self.weight.data.fill_(1)
self.bias.data.zero_()
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
def forward(self, x):
self._check_input_dim(x)
N, C, H, W = x.size()
x = x.view(N, C, -1)
mean_in = x.mean(-1, keepdim=True)
var_in = x.var(-1, keepdim=True)
mean_ln = mean_in.mean(1, keepdim=True)
temp = var_in + mean_in ** 2
var_ln = temp.mean(1, keepdim=True) - mean_ln ** 2
if self.using_bn:
if self.training:
mean_bn = mean_in.mean(0, keepdim=True)
var_bn = temp.mean(0, keepdim=True) - mean_bn ** 2
if self.using_moving_average:
self.running_mean.mul_(self.momentum)
self.running_mean.add_((1 - self.momentum) * mean_bn.data)
self.running_var.mul_(self.momentum)
self.running_var.add_((1 - self.momentum) * var_bn.data)
else:
self.running_mean.add_(mean_bn.data)
self.running_var.add_(mean_bn.data ** 2 + var_bn.data)
else:
mean_bn = torch.autograd.Variable(self.running_mean)
var_bn = torch.autograd.Variable(self.running_var)
softmax = nn.Softmax(0)
mean_weight = softmax(self.mean_weight)
var_weight = softmax(self.var_weight)
if self.using_bn:
mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln + mean_weight[2] * mean_bn
var = var_weight[0] * var_in + var_weight[1] * var_ln + var_weight[2] * var_bn
else:
mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln
var = var_weight[0] * var_in + var_weight[1] * var_ln
x = (x-mean) / (var+self.eps).sqrt()
x = x.view(N, C, H, W)
return x * self.weight + self.bias
class PartialConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(PartialConv).__init__()
self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
stride, padding, dilation, groups, bias)
self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size,
stride, padding, dilation, groups, False)
#self.input_conv.apply(weights_init('kaiming'))
torch.nn.init.constant_(self.mask_conv.weight, 1.0)
# mask is not updated
for param in self.mask_conv.parameters():
param.requires_grad = False
def forward(self, input, mask):
output = self.input_conv(input * mask)
if self.input_conv.bias is not None:
output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as(
output)
else:
output_bias = torch.zeros_like(output)
with torch.no_grad():
output_mask = self.mask_conv(mask)
no_update_holes = output_mask == 0
mask_sum = output_mask.masked_fill_(no_update_holes, 1.0)
output_pre = (output - output_bias) / mask_sum + output_bias
output = output_pre.masked_fill_(no_update_holes, 0.0)
new_mask = torch.ones_like(output)
new_mask = new_mask.masked_fill_(no_update_holes, 0.0)
return output, new_mask
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_bias):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_bias)
def build_conv_block(self, dim, padding_type, norm_layer, use_bias):
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
nn.ReLU(True)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
This diff is collapsed.
import torch
import torch.nn as nn
import torch.nn.functional as F
from .modules import spectral_norm
# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
norm_layer=nn.BatchNorm2d, use_spectral_norm=False):
super(UnetGenerator, self).__init__()
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, use_spectral_norm=use_spectral_norm)
for i in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
self.model = unet_block
def forward(self, input):
return self.model(input)
# construct network from the inside to the outside.
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_spectral_norm=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if input_nc is None:
input_nc = outer_nc
downconv = spectral_norm(nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1), use_spectral_norm)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
# Different position only has differences in `upconv`
# for the outermost, the special is `tanh`
if outermost:
upconv = spectral_norm(nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1), use_spectral_norm)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
# for the innermost, the special is `inner_nc` instead of `inner_nc*2`
elif innermost:
upconv = spectral_norm(nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1), use_spectral_norm)
down = [downrelu, downconv] # for the innermost, no submodule, and delete the bn
up = [uprelu, upconv, upnorm]
model = down + up
# else, the normal
else:
upconv = spectral_norm(nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1), use_spectral_norm)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost: # if it is the outermost, directly pass the input in.
return self.model(x)
else:
x_latter = self.model(x)
_, _, h, w = x.size()
if h != x_latter.size(2) or w != x_latter.size(3):
x_latter = F.interpolate(x_latter, (h, w), mode='bilinear')
return torch.cat([x_latter, x], 1) # cat in the C channel
# It is an easy type of UNet, intead of constructing UNet with UnetSkipConnectionBlocks.
# In this way, every thing is much clear and more flexible for extension.
class EasyUnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64,
norm_layer=nn.BatchNorm2d, use_spectral_norm=False):
super(EasyUnetGenerator, self).__init__()
# Encoder layers
self.e1_c = spectral_norm(nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.e2_c = spectral_norm(nn.Conv2d(ngf, ngf*2, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.e2_norm = norm_layer(ngf*2)
self.e3_c = spectral_norm(nn.Conv2d(ngf*2, ngf*4, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.e3_norm = norm_layer(ngf*4)
self.e4_c = spectral_norm(nn.Conv2d(ngf*4, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.e4_norm = norm_layer(ngf*8)
self.e5_c = spectral_norm(nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.e5_norm = norm_layer(ngf*8)
self.e6_c = spectral_norm(nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.e6_norm = norm_layer(ngf*8)
self.e7_c = spectral_norm(nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.e7_norm = norm_layer(ngf*8)
self.e8_c = spectral_norm(nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
# Deocder layers
self.d1_c = spectral_norm(nn.ConvTranspose2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.d1_norm = norm_layer(ngf*8)
self.d2_c = spectral_norm(nn.ConvTranspose2d(ngf*8*2 , ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.d2_norm = norm_layer(ngf*8)
self.d3_c = spectral_norm(nn.ConvTranspose2d(ngf*8*2, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.d3_norm = norm_layer(ngf*8)
self.d4_c = spectral_norm(nn.ConvTranspose2d(ngf*8*2, ngf*8, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.d4_norm = norm_layer(ngf*8)
self.d5_c = spectral_norm(nn.ConvTranspose2d(ngf*8*2, ngf*4, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.d5_norm = norm_layer(ngf*4)
self.d6_c = spectral_norm(nn.ConvTranspose2d(ngf*4*2, ngf*2, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.d6_norm = norm_layer(ngf*2)
self.d7_c = spectral_norm(nn.ConvTranspose2d(ngf*2*2, ngf, kernel_size=4, stride=2, padding=1), use_spectral_norm)
self.d7_norm = norm_layer(ngf)
self.d8_c = spectral_norm(nn.ConvTranspose2d(ngf*2, output_nc, kernel_size=4, stride=2, padding=1), use_spectral_norm)
# In this case, we have very flexible unet construction mode.
def forward(self, input):
# Encoder
# No norm on the first layer
e1 = self.e1_c(input)
e2 = self.e2_norm(self.e2_c(F.leaky_relu_(e1, negative_slope=0.2)))
e3 = self.e3_norm(self.e3_c(F.leaky_relu_(e2, negative_slope=0.2)))
e4 = self.e4_norm(self.e4_c(F.leaky_relu_(e3, negative_slope=0.2)))
e5 = self.e5_norm(self.e5_c(F.leaky_relu_(e4, negative_slope=0.2)))
e6 = self.e6_norm(self.e6_c(F.leaky_relu_(e5, negative_slope=0.2)))
e7 = self.e7_norm(self.e7_c(F.leaky_relu_(e6, negative_slope=0.2)))
# No norm on the inner_most layer
e8 = self.e8_c(F.leaky_relu_(e7, negative_slope=0.2))
# Decoder
d1 = self.d1_norm(self.d1_c(F.relu_(e8)))
d2 = self.d2_norm(self.d2_c(F.relu_(torch.cat([d1, e7], dim=1))))
d3 = self.d3_norm(self.d3_c(F.relu_(torch.cat([d2, e6], dim=1))))
d4 = self.d4_norm(self.d4_c(F.relu_(torch.cat([d3, e5], dim=1))))
d5 = self.d5_norm(self.d5_c(F.relu_(torch.cat([d4, e4], dim=1))))
d6 = self.d6_norm(self.d6_c(F.relu_(torch.cat([d5, e3], dim=1))))
d7 = self.d7_norm(self.d7_c(F.relu_(torch.cat([d6, e2], dim=1))))
# No norm on the last layer
d8 = self.d8_c(F.relu_(torch.cat([d7, e1], 1)))
d8 = torch.tanh(d8)
return d8
#-*-coding:utf-8-*-
from torch.nn import init
from torch.optim import lr_scheduler
from torchvision import models
from .modules import *
###############################################################################
# Functions
###############################################################################
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=True, track_running_stats=False)
elif norm_type == 'switchable':
norm_layer = functools.partial(SwitchNorm2d)
elif norm_type == 'none':
norm_layer = None
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def get_scheduler(optimizer, opt):
if opt.lr_policy == 'lambda':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
elif opt.lr_policy == 'plateau':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
elif opt.lr_policy == 'cosine':
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
return scheduler
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids)
init_weights(net, init_type, gain=init_gain)
return net
# Note: Adding SN to G tends to give inferior results. Need more checking.
def define_G(input_nc, output_nc, ngf, which_model_netG, opt, mask_global, norm='batch', use_spectral_norm=False, init_type='normal', gpu_ids=[], init_gain=0.02):
netG = None
norm_layer = get_norm_layer(norm_type=norm)
innerCos_list = []
shift_list = []
print('input_nc {}'.format(input_nc))
print('output_nc {}'.format(output_nc))
print('which_model_netG {}'.format(which_model_netG))
# Here we need to initlize an artificial mask_global to construct the init model.
# When training, we need to set mask for special layers(mostly for Shift layers) first.
# If mask is fixed during training, we only need to set mask for these layers once,
# else we need to set the masks each iteration, generating new random masks and mask the input
# as well as setting masks for these special layers.
print('[CREATED] MODEL')
if which_model_netG == 'unet_256':
netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
elif which_model_netG == 'easy_unet_256':
netG = EasyUnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
elif which_model_netG == 'face_unet_shift_triple':
netG = FaceUnetGenerator(input_nc, output_nc, innerCos_list, shift_list, mask_global, opt, \
ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
elif which_model_netG == 'unet_shift_triple':
netG = UnetGeneratorShiftTriple(input_nc, output_nc, 8, opt, innerCos_list, shift_list, mask_global, \
ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
elif which_model_netG == 'res_unet_shift_triple':
netG = ResUnetGeneratorShiftTriple(input_nc, output_nc, 8, opt, innerCos_list, shift_list, mask_global, \
ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
elif which_model_netG == 'patch_soft_unet_shift_triple':
netG = PatchSoftUnetGeneratorShiftTriple(input_nc, output_nc, 8, opt, innerCos_list, shift_list, mask_global, \
ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
elif which_model_netG == 'res_patch_soft_unet_shift_triple':
netG = ResPatchSoftUnetGeneratorShiftTriple(input_nc, output_nc, 8, opt, innerCos_list, shift_list, mask_global, \
ngf, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
print('[CREATED] MODEL')
print('Constraint in netG:')
print(innerCos_list)
print('Shift in netG:')
print(shift_list)
print('NetG:')
print(netG)
return init_net(netG, init_type, init_gain, gpu_ids), innerCos_list, shift_list
def define_D(input_nc, ndf, which_model_netD,
n_layers_D=3, norm='batch', use_sigmoid=False, use_spectral_norm=False, init_type='normal', gpu_ids=[], init_gain=0.02):
netD = None
norm_layer = get_norm_layer(norm_type=norm)
if which_model_netD == 'basic':
netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, use_spectral_norm=use_spectral_norm)
elif which_model_netD == 'n_layers':
netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, use_spectral_norm=use_spectral_norm)
elif which_model_netD == 'densenet':
netD = DenseNetDiscrimator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, use_spectral_norm=use_spectral_norm)
else:
print('Discriminator model name [%s] is not recognized' %
which_model_netD)
print('NetD:')
print(netD)
return init_net(netD, init_type, init_gain, gpu_ids)
import torch.nn as nn
import torch
import util.util as util
from .innerPatchSoftShiftTripleModule import InnerPatchSoftShiftTripleModule
# TODO: Make it compatible for show_flow.
#
class InnerPatchSoftShiftTriple(nn.Module):
def __init__(self, shift_sz=1, stride=1, mask_thred=1, triple_weight=1, fuse=True, layer_to_last=3):
super(InnerPatchSoftShiftTriple, self).__init__()
self.shift_sz = shift_sz
self.stride = stride
self.mask_thred = mask_thred
self.triple_weight = triple_weight
self.show_flow = False # default false. Do not change it to be true, it is computation-heavy.
self.flow_srcs = None # Indicating the flow src(pixles in non-masked region that will shift into the masked region)
self.fuse = fuse
self.layer_to_last = layer_to_last
self.softShift = InnerPatchSoftShiftTripleModule()
def set_mask(self, mask_global):
mask = util.cal_feat_mask(mask_global, self.layer_to_last)
self.mask = mask
return self.mask
# If mask changes, then need to set cal_fix_flag true each iteration.
def forward(self, input):
_, self.c, self.h, self.w = input.size()
# Just pass self.mask in, instead of self.flag.
final_out = self.softShift(input, self.stride, self.triple_weight, self.mask, self.mask_thred, self.shift_sz, self.show_flow, self.fuse)
if self.show_flow:
self.flow_srcs = self.softShift.get_flow_src()
return final_out
def get_flow(self):
return self.flow_srcs
def set_flow_true(self):
self.show_flow = True
def set_flow_false(self):
self.show_flow = False
def __repr__(self):
return self.__class__.__name__+ '(' \
+ ' ,triple_weight ' + str(self.triple_weight) + ')'
from util.NonparametricShift import Modified_NonparametricShift
from torch.nn import functional as F
import torch.nn as nn
import torch
import util.util as util
class InnerPatchSoftShiftTripleModule(nn.Module):
def forward(self, input, stride, triple_w, mask, mask_thred, shift_sz, show_flow, fuse=True):
assert input.dim() == 4, "Input Dim has to be 4"
assert mask.dim() == 4, "Mask Dim has to be 4"
self.triple_w = triple_w
self.mask = mask
self.mask_thred = mask_thred
self.show_flow = show_flow
self.bz, self.c, self.h, self.w = input.size()
self.Tensor = torch.cuda.FloatTensor if torch.cuda.is_available else torch.FloatTensor
self.ind_lst = self.Tensor(self.bz, self.h * self.w, self.h * self.w).zero_()
# former and latter are all tensors
former_all = input.narrow(1, 0, self.c//2) ### decoder feature
latter_all = input.narrow(1, self.c//2, self.c//2) ### encoder feature
shift_masked_all = torch.Tensor(former_all.size()).type_as(former_all) # addition feature
self.mask = self.mask.to(input)
# extract patches from latter.
latter_all_pad = F.pad(latter_all, [shift_sz//2, shift_sz//2, shift_sz//2, shift_sz//2], 'constant', 0)
latter_all_windows = latter_all_pad.unfold(2, shift_sz, stride).unfold(3, shift_sz, stride)
latter_all_windows = latter_all_windows.contiguous().view(self.bz, -1, self.c//2, shift_sz, shift_sz)
# Extract patches from mask
# Mention: mask here must be 1*1*H*W
m_pad = F.pad(self.mask, (shift_sz//2, shift_sz//2, shift_sz//2, shift_sz//2), 'constant', 0)
m = m_pad.unfold(2, shift_sz, stride).unfold(3, shift_sz, stride)
m = m.contiguous().view(self.bz, 1, -1, shift_sz, shift_sz)
# It implements the similar functionality as `cal_flag_given_mask_thred`.
# However, it differs what `mm` means.
# Here mm: the masked reigon is filled with 0, nonmasked region is filled with 1.
# While mm in `cal_flag_given_mask_thred`, it is opposite.
m = torch.mean(torch.mean(m, dim=3, keepdim=True), dim=4, keepdim=True)
mm = m.le(self.mask_thred/(1.*shift_sz**2)).float() # bz*1*(32*32)*1*1
fuse_weight = torch.eye(shift_sz).view(1, 1, shift_sz, shift_sz).type_as(input)
self.shift_offsets = []
for idx in range(self.bz):
mm_cur = mm[idx]
# latter_win = latter_all_windows.narrow(0, idx, 1)[0]
latter_win = latter_all_windows.narrow(0, idx, 1)[0]
former = former_all.narrow(0, idx, 1)
# normalize latter for each patch.
latter_den = torch.sqrt(torch.einsum("bcij,bcij->b", [latter_win, latter_win]))
latter_den = torch.max(latter_den, self.Tensor([1e-4]))
latter_win_normed = latter_win/latter_den.view(-1, 1, 1, 1)
y_i = F.conv2d(former, latter_win_normed, stride=1, padding=shift_sz//2)
# conv implementation for fuse scores to encourage large patches
if fuse:
y_i = y_i.view(1, 1, self.h*self.w, self.h*self.w) # make all of depth of spatial resolution.
y_i = F.conv2d(y_i, fuse_weight, stride=1, padding=1)
y_i = y_i.contiguous().view(1, self.h, self.w, self.h, self.w)
y_i = y_i.permute(0, 2, 1, 4, 3)
y_i = y_i.contiguous().view(1, 1, self.h*self.w, self.h*self.w)
y_i = F.conv2d(y_i, fuse_weight, stride=1, padding=1)
y_i = y_i.contiguous().view(1, self.w, self.h, self.w, self.h)
y_i = y_i.permute(0, 2, 1, 4, 3)
y_i = y_i.contiguous().view(1, self.h*self.w, self.h, self.w) # 1*(32*32)*32*32
# firstly, wash away the masked reigon.
# multiply `mm` means (:, index_masked, :, :) will be 0.
y_i = y_i * mm_cur
# Then apply softmax to the nonmasked region.
cosine = F.softmax(y_i*10, dim=1)
# Finally, dummy parameters of masked reigon are filtered out.
cosine = cosine * mm_cur
# paste
shift_i = F.conv_transpose2d(cosine, latter_win, stride=1, padding=shift_sz//2)/9.
shift_masked_all[idx] = shift_i
# Addition: show shift map
# TODO: fix me.
# cosine here is a full size of 32*32, not only the masked region in `shift_net`,
# which results in non-direct reusing the code.
# torch.set_printoptions(threshold=2015)
# if self.show_flow:
# _, indexes = torch.max(cosine, dim=1)
# # calculate self.flag from self.m
# self.flag = (1 - mm).view(-1)
# torch.set_printoptions(threshold=1025)
# print(self.flag)
# non_mask_indexes = (self.flag == 0.).nonzero()
# non_mask_indexes = non_mask_indexes[indexes]
# print('ll')
# print(non_mask_indexes.size())
# print(non_mask_indexes)
# # Here non_mask_index is too large, should be 192.
# shift_offset = torch.stack([non_mask_indexes.squeeze() // self.w, non_mask_indexes.squeeze() % self.w], dim=-1)
# print(shift_offset.size())
# self.shift_offsets.append(shift_offset)
# print('cc')
# if self.show_flow:
# # Note: Here we assume that each mask is the same for the same batch image.
# self.shift_offsets = torch.cat(self.shift_offsets, dim=0).float() # make it cudaFloatTensor
# # Assume mask is the same for each image in a batch.
# mask_nums = self.shift_offsets.size(0)//self.bz
# self.flow_srcs = torch.zeros(self.bz, 3, self.h, self.w).type_as(input)
# for idx in range(self.bz):
# shift_offset = self.shift_offsets.narrow(0, idx*mask_nums, mask_nums)
# # reconstruct the original shift_map.
# shift_offsets_map = torch.zeros(1, self.h, self.w, 2).type_as(input)
# print(shift_offsets_map.size())
# print(shift_offset.unsqueeze(0).size())
# print(shift_offsets_map[:, (self.flag == 1).nonzero().squeeze() // self.w, (self.flag == 1).nonzero().squeeze() % self.w, :].size())
# shift_offsets_map[:, (self.flag == 1).nonzero().squeeze() // self.w, (self.flag == 1).nonzero().squeeze() % self.w, :] = \
# shift_offset.unsqueeze(0)
# # It is indicating the pixels(non-masked) that will shift the the masked region.
# flow_src = util.highlight_flow(shift_offsets_map, self.flag.unsqueeze(0))
# self.flow_srcs[idx] = flow_src
return torch.cat((former_all, latter_all, shift_masked_all), 1)
def get_flow_src(self):
return self.flow_srcs
from models.shift_net.shiftnet_model import ShiftNetModel
class PatchSoftShiftNetModel(ShiftNetModel):
def name(self):
return 'PatchSoftShiftNetModel'
import torch.nn as nn
import torch
import util.util as util
from models.patch_soft_shift.innerPatchSoftShiftTripleModule import InnerPatchSoftShiftTripleModule
# TODO: Make it compatible for show_flow.
#
class InnerResPatchSoftShiftTriple(nn.Module):
def __init__(self, inner_nc, shift_sz=1, stride=1, mask_thred=1, triple_weight=1, fuse=True, layer_to_last=3):
super(InnerResPatchSoftShiftTriple, self).__init__()
self.shift_sz = shift_sz
self.stride = stride
self.mask_thred = mask_thred
self.triple_weight = triple_weight
self.show_flow = False # default false. Do not change it to be true, it is computation-heavy.
self.flow_srcs = None # Indicating the flow src(pixles in non-masked region that will shift into the masked region)
self.fuse = fuse
self.layer_to_last = layer_to_last
self.softShift = InnerPatchSoftShiftTripleModule()
# Additional for ResShift.
self.inner_nc = inner_nc
self.res_net = nn.Sequential(
nn.Conv2d(inner_nc*2, inner_nc, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm2d(inner_nc),
nn.ReLU(True),
nn.Conv2d(inner_nc, inner_nc, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm2d(inner_nc)
)
def set_mask(self, mask_global):
mask = util.cal_feat_mask(mask_global, self.layer_to_last)
self.mask = mask
return self.mask
# If mask changes, then need to set cal_fix_flag true each iteration.
def forward(self, input):
_, self.c, self.h, self.w = input.size()
# Just pass self.mask in, instead of self.flag.
# Try to making it faster by avoiding `cal_flag_given_mask_thread`.
shift_out = self.softShift(input, self.stride, self.triple_weight, self.mask, self.mask_thred, self.shift_sz, self.show_flow, self.fuse)
c_out = shift_out.size(1)
# get F_c, F_s, F_shift
F_c = shift_out.narrow(1, 0, c_out//3)
F_s = shift_out.narrow(1, c_out//3, c_out//3)
F_shift = shift_out.narrow(1, c_out*2//3, c_out//3)
F_fuse = F_c * F_shift
F_com = torch.cat([F_c, F_fuse], dim=1)
res_out = self.res_net(F_com)
F_c = F_c + res_out
final_out = torch.cat([F_c, F_s], dim=1)
if self.show_flow:
self.flow_srcs = self.softShift.get_flow_src()
return final_out
def get_flow(self):
return self.flow_srcs
def set_flow_true(self):
self.show_flow = True
def set_flow_false(self):
self.show_flow = False
def __repr__(self):
return self.__class__.__name__+ '(' \
+ ' ,triple_weight ' + str(self.triple_weight) + ')'
from models.shift_net.shiftnet_model import ShiftNetModel
class ResPatchSoftShiftNetModel(ShiftNetModel):
def name(self):
return 'ResPatchSoftShiftNetModel'
import torch.nn as nn
import torch
import util.util as util
from models.shift_net.InnerShiftTripleFunction import InnerShiftTripleFunction
class InnerResShiftTriple(nn.Module):
def __init__(self, inner_nc, shift_sz=1, stride=1, mask_thred=1, triple_weight=1, layer_to_last=3):
super(InnerResShiftTriple, self).__init__()
self.shift_sz = shift_sz
self.stride = stride
self.mask_thred = mask_thred
self.triple_weight = triple_weight
self.show_flow = False # default false. Do not change it to be true, it is computation-heavy.
self.flow_srcs = None # Indicating the flow src(pixles in non-masked region that will shift into the masked region)
self.layer_to_last = layer_to_last
# Additional for ResShift.
self.inner_nc = inner_nc
self.res_net = nn.Sequential(
nn.Conv2d(inner_nc*2, inner_nc, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm2d(inner_nc),
nn.ReLU(True),
nn.Conv2d(inner_nc, inner_nc, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm2d(inner_nc)
)
def set_mask(self, mask_global):
mask = util.cal_feat_mask(mask_global, self.layer_to_last)
self.mask = mask.squeeze()
return self.mask
# If mask changes, then need to set cal_fix_flag true each iteration.
def forward(self, input):
#print(input.shape)
_, self.c, self.h, self.w = input.size()
self.flag = util.cal_flag_given_mask_thred(self.mask, self.shift_sz, self.stride, self.mask_thred)
shift_out = InnerShiftTripleFunction.apply(input, self.shift_sz, self.stride, self.triple_weight, self.flag, self.show_flow)
c_out = shift_out.size(1)
# get F_c, F_s, F_shift
F_c = shift_out.narrow(1, 0, c_out//3)
F_s = shift_out.narrow(1, c_out//3, c_out//3)
F_shift = shift_out.narrow(1, c_out*2//3, c_out//3)
F_fuse = F_c * F_shift
F_com = torch.cat([F_c, F_fuse], dim=1)
res_out = self.res_net(F_com)
F_c = F_c + res_out
final_out = torch.cat([F_c, F_s], dim=1)
if self.show_flow:
self.flow_srcs = InnerShiftTripleFunction.get_flow_src()
return final_out
def get_flow(self):
return self.flow_srcs
def set_flow_true(self):
self.show_flow = True
def set_flow_false(self):
self.show_flow = False
def __repr__(self):
return self.__class__.__name__+ '(' \
+ ' ,triple_weight ' + str(self.triple_weight) + ')'
from models.shift_net.shiftnet_model import ShiftNetModel
class ResShiftNetModel(ShiftNetModel):
def name(self):
return 'ResShiftNetModel'
\ No newline at end of file
import torch.nn as nn
import torch
import torch.nn.functional as F
import util.util as util
from .InnerCosFunction import InnerCosFunction
class InnerCos(nn.Module):
def __init__(self, crit='MSE', strength=1, skip=0, layer_to_last=3, device='gpu'):
super(InnerCos, self).__init__()
self.crit = crit
self.criterion = torch.nn.MSELoss() if self.crit == 'MSE' else torch.nn.L1Loss()
self.strength = strength
# To define whether this layer is skipped.
self.skip = skip
self.layer_to_last = layer_to_last
self.device = device
# Init a dummy value is fine.
self.target = torch.tensor(1.0)
self.bz = 0
self.c = 0
self.cur_mask = torch.tensor(0)
self.output = torch.tensor(0)
def set_mask(self, mask_global):
mask_all = util.cal_feat_mask(mask_global, self.layer_to_last)
self.mask_all = mask_all.float()
def _split_mask(self, cur_bsize):
# get the visible indexes of gpus and assign correct mask to set of images
cur_device = torch.cuda.current_device()
self.cur_mask = self.mask_all[cur_device*cur_bsize:(cur_device+1)*cur_bsize, :, :, :]
def forward(self, in_data):
self.bz = in_data.size(0)
self.c = in_data.size(1)
self.cur_mask = self.mask_all
self.cur_mask = self.cur_mask.to(in_data)
# if not self.skip:
# # It works like this:
# # Each iteration contains 2 forward passes, In the first forward pass, we input a GT image, just to get the target.
# # In the second forward pass, we input the corresponding corrupted image, then back-propagate the network, the guidance loss works as expected.
# self.output = InnerCosFunction.apply(in_data, self.criterion, self.strength, self.target, self.cur_mask)
# self.target = in_data.narrow(1, self.c // 2, self.c // 2).detach() # the latter part
# else:
self.output = in_data
return self.output
def __repr__(self):
skip_str = 'True' if not self.skip else 'False'
return self.__class__.__name__+ '(' \
+ 'skip: ' + skip_str \
+ 'layer ' + str(self.layer_to_last) + ' to last' \
+ ' ,strength: ' + str(self.strength) + ')'
import torch
import torch.nn as nn
class InnerCosFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, criterion, strength, target, mask):
ctx.c = input.size(1)
ctx.strength = strength
ctx.criterion = criterion
if len(target.size()) == 0: # For the first iteration.
target = target.expand_as(input.narrow(1, ctx.c // 2, ctx.c // 2)).type_as(input)
ctx.save_for_backward(input, target, mask)
return input
@staticmethod
def backward(ctx, grad_output):
with torch.enable_grad():
input, target, mask = ctx.saved_tensors
former = input.narrow(1, 0, ctx.c//2)
former_in_mask = torch.mul(former, mask)
if former_in_mask.size() != target.size(): # For the last iteration of one epoch
target = target.narrow(0, 0, 1).expand_as(former_in_mask).type_as(former_in_mask)
former_in_mask_clone = former_in_mask.clone().detach().requires_grad_(True)
ctx.loss = ctx.criterion(former_in_mask_clone, target) * ctx.strength
ctx.loss.backward()
grad_output[:,0:ctx.c//2, :,:] += former_in_mask_clone.grad
return grad_output, None, None, None, None
\ No newline at end of file
import torch.nn as nn
import torch
import util.util as util
from testFun import InnerShiftTripleFunction
class InnerShiftTriple(nn.Module):
def __init__(self, shift_sz=1, stride=1, mask_thred=1, triple_weight=1, layer_to_last=3, device='gpu'):
super(InnerShiftTriple, self).__init__()
self.shift_sz = torch.tensor(shift_sz)
self.stride = torch.tensor(stride)
self.mask_thred = torch.tensor(mask_thred)
self.triple_weight = triple_weight
self.layer_to_last = layer_to_last
self.device = device
self.show_flow = False # default false. Do not change it to be true, it is computation-heavy.
self.flow_srcs = None # Indicating the flow src(pixles in non-masked region that will shift into the masked region)
self.bz = 0
self.c = 0
self.h = 0
self.w = 0
self.cur_mask = torch.tensor(0)
self.flag = torch.tensor(0)
def set_mask(self, mask_global):
self.mask_all = util.cal_feat_mask(mask_global, self.layer_to_last)
def _split_mask(self, cur_bsize):
# get the visible indexes of gpus and assign correct mask to set of images
cur_device = torch.cuda.current_device()
self.cur_mask = self.mask_all[cur_device*cur_bsize:(cur_device+1)*cur_bsize, :, :, :]
# If mask changes, then need to set cal_fix_flag true each iteration.
def forward(self, input):
self.bz = input.size(0)
self.c = input.size(1)
self.h = input.size(2)
self.w = input.size(3)
self.cur_mask = self.mask_all
self.flag = util.cal_flag_given_mask_thred(self.cur_mask, self.shift_sz, self.stride, self.mask_thred)
# final_out = InnerShiftTripleFunction.apply(input, self.shift_sz, self.stride, self.triple_weight, self.flag, self.show_flow)
# if self.show_flow:
# self.flow_srcs = InnerShiftTripleFunction.get_flow_src()
final_out = InnerShiftTripleFunction(input, self.shift_sz, self.stride, torch.tensor(self.triple_weight), self.flag, torch.tensor(self.show_flow))
return final_out
def get_flow(self):
return self.flow_srcs
def set_flow_true(self):
self.show_flow = True
def set_flow_false(self):
self.show_flow = False
def __repr__(self):
return self.__class__.__name__+ '(' \
+ ' ,triple_weight ' + str(self.triple_weight) + ')'
import numpy as np
from util.NonparametricShift import Modified_NonparametricShift, Batch_NonShift
import torch
import util.util as util
import time
class InnerShiftTripleFunction(torch.autograd.Function):
ctx = None
@staticmethod
def forward(ctx, input, shift_sz, stride, triple_w, flag, show_flow):
InnerShiftTripleFunction.ctx = ctx
assert input.dim() == 4, "Input Dim has to be 4"
ctx.triple_w = triple_w
ctx.flag = flag
ctx.show_flow = show_flow
ctx.bz, c_real, ctx.h, ctx.w = input.size()
c = c_real
ctx.ind_lst = torch.Tensor(ctx.bz, ctx.h * ctx.w, ctx.h * ctx.w).zero_().to(input)
# former and latter are all tensors
former_all = input.narrow(1, 0, c//2) ### decoder feature
latter_all = input.narrow(1, c//2, c//2) ### encoder feature
shift_masked_all = torch.Tensor(former_all.size()).type_as(former_all).zero_() # addition feature
ctx.flag = ctx.flag.to(input).long()
# None batch version
bNonparm = Batch_NonShift()
ctx.shift_offsets = []
# batch version
cosine, latter_windows, i_2, i_3, i_1 = bNonparm.cosine_similarity(former_all.clone(), latter_all.clone(), 1, stride, flag)
_, indexes = torch.max(cosine, dim=2)
mask_indexes = (flag==1).nonzero(as_tuple=False)[:, 1].view(ctx.bz, -1)
non_mask_indexes = (flag==0).nonzero(as_tuple=False)[:, 1].view(ctx.bz, -1).gather(1, indexes)
idx_b = torch.arange(ctx.bz).long().unsqueeze(1).expand(ctx.bz, mask_indexes.size(1))
# set the elemnets of indexed by [mask_indexes, non_mask_indexes] to 1.
# It is a batch version
ctx.ind_lst[(idx_b, mask_indexes, non_mask_indexes)] = 1
shift_masked_all = bNonparm._paste(latter_windows, ctx.ind_lst, i_2, i_3, i_1)
# --- Non-batch version ----
#for idx in range(ctx.bz):
# flag_cur = ctx.flag[idx]
# latter = latter_all.narrow(0, idx, 1) ### encoder feature
# former = former_all.narrow(0, idx, 1) ### decoder feature
# #GET COSINE, RESHAPED LATTER AND ITS INDEXES
# cosine, latter_windows, i_2, i_3, i_1 = Nonparm.cosine_similarity(former.clone().squeeze(), latter.clone().squeeze(), 1, stride, flag_cur)
# ## GET INDEXES THAT MAXIMIZE COSINE SIMILARITY
# _, indexes = torch.max(cosine, dim=1)
# # SET TRANSITION MATRIX
# mask_indexes = (flag_cur == 1).nonzero()
# non_mask_indexes = (flag_cur == 0).nonzero()[indexes]
# ctx.ind_lst[idx][mask_indexes, non_mask_indexes] = 1
# # GET FINAL SHIFT FEATURE
# shift_masked_all[idx] = Nonparm._paste(latter_windows, ctx.ind_lst[idx], i_2, i_3, i_1)
# if ctx.show_flow:
# shift_offset = torch.stack([non_mask_indexes.squeeze() // ctx.w, non_mask_indexes.squeeze() % ctx.w], dim=-1)
# ctx.shift_offsets.append(shift_offset)
if ctx.show_flow:
assert 1==2, "I do not want maintance the functionality of `show flow`... ^_^"
ctx.shift_offsets = torch.cat(ctx.shift_offsets, dim=0).float() # make it cudaFloatTensor
# Assume mask is the same for each image in a batch.
mask_nums = ctx.shift_offsets.size(0)//ctx.bz
ctx.flow_srcs = torch.zeros(ctx.bz, 3, ctx.h, ctx.w).type_as(input)
for idx in range(ctx.bz):
shift_offset = ctx.shift_offsets.narrow(0, idx*mask_nums, mask_nums)
# reconstruct the original shift_map.
shift_offsets_map = torch.zeros(1, ctx.h, ctx.w, 2).type_as(input)
shift_offsets_map[:, (flag_cur == 1).nonzero(as_tuple=False).squeeze() // ctx.w, (flag_cur == 1).nonzero(as_tuple=False).squeeze() % ctx.w, :] = \
shift_offset.unsqueeze(0)
# It is indicating the pixels(non-masked) that will shift the the masked region.
flow_src = util.highlight_flow(shift_offsets_map, flag_cur.unsqueeze(0))
ctx.flow_srcs[idx] = flow_src
return torch.cat((former_all, latter_all, shift_masked_all), 1)
@staticmethod
def get_flow_src():
return InnerShiftTripleFunction.ctx.flow_srcs
@staticmethod
def backward(ctx, grad_output):
ind_lst = ctx.ind_lst
c = grad_output.size(1)
# # the former and the latter are keep original. Only the thrid part is shifted.
# C: content, pixels in masked region of the former part.
# S: style, pixels in the non-masked region of the latter part.
# N: the shifted feature, the new feature that will be used as the third part of features maps.
# W_mat: ind_lst[idx], shift matrix.
# Note: **only the masked region in N has values**.
# The gradient of shift feature should be added back to the latter part(to be precise: S).
# `ind_lst[idx][i,j] = 1` means that the i_th pixel will **be replaced** by j_th pixel in the forward.
# When applying `S mm W_mat`, then S will be transfer to N.
# (pixels in non-masked region of the latter part will be shift to the masked region in the third part.)
# However, we need to transfer back the gradient of the third part to S.
# This means the graident in S will **`be replaced`(to be precise, enhanced)** by N.
grad_former_all = grad_output[:, 0:c//3, :, :]
grad_latter_all = grad_output[:, c//3: c*2//3, :, :].clone()
grad_shifted_all = grad_output[:, c*2//3:c, :, :].clone()
W_mat_t = ind_lst.permute(0, 2, 1).contiguous()
grad = grad_shifted_all.view(ctx.bz, c//3, -1).permute(0, 2, 1)
grad_shifted_weighted = torch.bmm(W_mat_t, grad)
grad_shifted_weighted = grad_shifted_weighted.permute(0, 2, 1).contiguous().view(ctx.bz, c//3, ctx.h, ctx.w)
grad_latter_all = torch.add(grad_latter_all, grad_shifted_weighted.mul(ctx.triple_w))
# ----- 'Non_batch version here' --------------------
# for idx in range(ctx.bz):
# # So we need to transpose `W_mat`
# W_mat_t = ind_lst[idx].t()
# grad = grad_shifted_all[idx].view(c//3, -1).t()
# grad_shifted_weighted = torch.mm(W_mat_t, grad)
# # Then transpose it back
# grad_shifted_weighted = grad_shifted_weighted.t().contiguous().view(1, c//3, ctx.h, ctx.w)
# grad_latter_all[idx] = torch.add(grad_latter_all[idx], grad_shifted_weighted.mul(ctx.triple_w))
# note the input channel and the output channel are all c, as no mask input for now.
grad_input = torch.cat([grad_former_all, grad_latter_all], 1)
return grad_input, None, None, None, None, None, None
import os
import torch
from collections import OrderedDict
from torchsummary import summary
class BaseModel():
def name(self):
return 'BaseModel'
def initialize(self, opt):
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
if opt.resize_or_crop != 'scale_width':
torch.backends.cudnn.benchmark = True
self.loss_names = []
self.model_names = []
self.visual_names = []
self.image_paths = []
def set_input(self, input):
self.input = input
def forward(self):
pass
# used in test time, wrapping `forward` in no_grad() so we don't save
# intermediate steps for backprop
def test(self):
with torch.no_grad():
self.forward()
# get image paths
def get_image_paths(self):
return self.image_paths
def optimize_parameters(self):
pass
# update learning rate (called once every epoch)
def update_learning_rate(self):
for scheduler in self.schedulers:
scheduler.step()
lr = self.optimizers[0].param_groups[0]['lr']
print('learning rate = %.7f' % lr)
# return visualization images. train.py will display these images, and save the images to a html
def get_current_visuals(self):
visual_ret = OrderedDict()
for name in self.visual_names:
if isinstance(name, str):
visual_ret[name] = getattr(self, name)
return visual_ret
# return traning losses/errors. train.py will print out these errors as debugging information
def get_current_losses(self):
errors_ret = OrderedDict()
for name in self.loss_names:
if isinstance(name, str):
# float(...) works for both scalar tensor and float number
errors_ret[name] = float(getattr(self, 'loss_' + name))
return errors_ret
# save models to the disk
def save_networks(self, which_epoch):
for name in self.model_names:
if isinstance(name, str):
save_filename = '%s_net_%s.pth' % (which_epoch, name)
save_path = os.path.join(self.save_dir, save_filename)
net = getattr(self, 'net' + name)
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
torch.save(net.module.cpu().state_dict(), save_path)
net.cuda(self.gpu_ids[0])
############################################
if name == 'G':
save_filename_pt_st = '%s_net_%s_st.pt' % (which_epoch, name)
###unet256 成功#############
# example = torch.rand(1, 4, 256, 256)
# example = example.cuda()
# net.eval()
# traced_script_module = torch.jit.trace(net.module.cuda().eval(), example)
# traced_script_module.save(save_filename_pt_st)
example1 = torch.ones(1, 4, 64, 64).cuda()
example2 = torch.ones(2, 4, 64, 64).cuda()
# result1 = net.module(example)
net.eval()
# traced_script_module = torch.jit.trace(net.module.cuda().eval(), example)
traced_script_module = torch.jit.script(net.module)
result1 = traced_script_module(example1)
result2 = traced_script_module(example2)
print(result1)
print(result2)
# torch.jit.export_opnames(traced_script_module)
traced_script_module.save(save_filename_pt_st)
else:
torch.save(net.cpu().state_dict(), save_path)
############################################
if name == 'G':
save_filename_pt_st = '%s_net_%s_st.pt' % (which_epoch, name)
# torch.save(net.cpu(), save_filename_pt_st)
# example = torch.rand(1, 4, 256, 256)
# example = example.cuda()
# traced_script_module = torch.jit.trace(net.module, (example, example))
# traced_script_module.save(save_filename_pt_st)
example = torch.zeros(1, 4, 128, 128)
example = example
net.eval()
traced_script_module = torch.jit.trace(net.cpu(), example)
traced_script_module.save(save_filename_pt_st)
'''
model = torch.load(save_path)
net.load_state_dict(model)
net.eval()
example = torch.rand(1, 4, 256, 256).cuda() # 生成一个随机输入维度的输入
traced_script_module = torch.jit.trace(net, example)
traced_script_module.save(save_filename_pt_st)
model = torch.load(save_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summary(model, input_size=(1, 4, 256, 256))
model = model.to(device)
traced_script_module = torch.jit.trace(model, torch.ones(1, 4, 256, 256).to(device))
traced_script_module.save(save_filename_pt_st)
'''
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
else:
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
# load models from the disk
def load_networks(self, which_epoch):
for name in self.model_names:
if isinstance(name, str):
load_filename = '%s_net_%s.pth' % (which_epoch, name)
load_path = os.path.join(self.save_dir, load_filename)
net = getattr(self, 'net' + name)
if isinstance(net, torch.nn.DataParallel):
net = net.module
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
state_dict = torch.load(load_path, map_location=str(self.device))
#state_dict = torch.load(load_path)
# patch InstanceNorm checkpoints prior to 0.4
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
net.load_state_dict(state_dict)
# print network information
def print_networks(self, verbose):
print('---------- Networks initialized -------------')
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
num_params = 0
for param in net.parameters():
num_params += param.numel()
if verbose:
print(net)
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
print('-----------------------------------------------')
def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
\ No newline at end of file
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment