Commit 7fc4ad53 authored by 魏博昱's avatar 魏博昱

first commit

parent 466003cf
Download datasets from the google drive links and place them in this directory. Your directory tree should look like this
`GoPro` <br/>
  `├──`[train](https://drive.google.com/drive/folders/1AsgIP9_X0bg0olu2-1N6karm2x15cJWE?usp=sharing) <br/>
  `└──`[test](https://drive.google.com/drive/folders/1a2qKfXWpNuTGOm2-Jex8kfNSzYJLbqkf?usp=sharing)
`HIDE` <br/>
  `└──`[test](https://drive.google.com/drive/folders/1nRsTXj4iTUkTvBhTcGg8cySK8nd3vlhK?usp=sharing)
`RealBlur_J` <br/>
  `└──`[test](https://drive.google.com/drive/folders/1KYtzeKCiDRX9DSvC-upHrCqvC4sPAiJ1?usp=sharing)
`RealBlur_R` <br/>
  `└──`[test](https://drive.google.com/drive/folders/1EwDoajf5nStPIAcU4s9rdc8SPzfm3tW1?usp=sharing)
This diff is collapsed.
## Training
- Download the [Datasets](Datasets/README.md)
- Train the model with default arguments by running
```
python train.py
```
## Evaluation
### Download the [model](https://drive.google.com/file/d/1QwQUVbk6YVOJViCsOKYNykCsdJSVGRtb/view?usp=sharing) and place it in ./pretrained_models/
#### Testing on GoPro dataset
- Download [images](https://drive.google.com/drive/folders/1a2qKfXWpNuTGOm2-Jex8kfNSzYJLbqkf?usp=sharing) of GoPro and place them in `./Datasets/GoPro/test/`
- Run
```
python test.py --dataset GoPro
```
#### Testing on HIDE dataset
- Download [images](https://drive.google.com/drive/folders/1nRsTXj4iTUkTvBhTcGg8cySK8nd3vlhK?usp=sharing) of HIDE and place them in `./Datasets/HIDE/test/`
- Run
```
python test.py --dataset HIDE
```
#### Testing on RealBlur-J dataset
- Download [images](https://drive.google.com/drive/folders/1KYtzeKCiDRX9DSvC-upHrCqvC4sPAiJ1?usp=sharing) of RealBlur-J and place them in `./Datasets/RealBlur_J/test/`
- Run
```
python test.py --dataset RealBlur_J
```
#### Testing on RealBlur-R dataset
- Download [images](https://drive.google.com/drive/folders/1EwDoajf5nStPIAcU4s9rdc8SPzfm3tW1?usp=sharing) of RealBlur-R and place them in `./Datasets/RealBlur_R/test/`
- Run
```
python test.py --dataset RealBlur_R
```
#### To reproduce PSNR/SSIM scores of the paper on GoPro and HIDE datasets, run this MATLAB script
```
evaluate_GOPRO_HIDE.m
```
#### To reproduce PSNR/SSIM scores of the paper on RealBlur dataset, run
```
evaluate_RealBlur.m
```
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jul 23 14:35:48 2019
@author: aditya
"""
r"""This module provides package-wide configuration management."""
from typing import Any, List
from yacs.config import CfgNode as CN
class Config(object):
r"""
A collection of all the required configuration parameters. This class is a nested dict-like
structure, with nested keys accessible as attributes. It contains sensible default values for
all the parameters, which may be overriden by (first) through a YAML file and (second) through
a list of attributes and values.
Extended Summary
----------------
This class definition contains default values corresponding to ``joint_training`` phase, as it
is the final training phase and uses almost all the configuration parameters. Modification of
any parameter after instantiating this class is not possible, so you must override required
parameter values in either through ``config_yaml`` file or ``config_override`` list.
Parameters
----------
config_yaml: str
Path to a YAML file containing configuration parameters to override.
config_override: List[Any], optional (default= [])
A list of sequential attributes and values of parameters to override. This happens after
overriding from YAML file.
Examples
--------
Let a YAML file named "config.yaml" specify these parameters to override::
ALPHA: 1000.0
BETA: 0.5
>>> _C = Config("config.yaml", ["OPTIM.BATCH_SIZE", 2048, "BETA", 0.7])
>>> _C.ALPHA # default: 100.0
1000.0
>>> _C.BATCH_SIZE # default: 256
2048
>>> _C.BETA # default: 0.1
0.7
Attributes
----------
"""
def __init__(self, config_yaml: str, config_override: List[Any] = []):
self._C = CN()
self._C.GPU = [0]
self._C.VERBOSE = False
self._C.MODEL = CN()
self._C.MODEL.MODE = 'global'
self._C.MODEL.SESSION = 'ps128_bs1'
self._C.OPTIM = CN()
self._C.OPTIM.BATCH_SIZE = 1
self._C.OPTIM.NUM_EPOCHS = 100
self._C.OPTIM.NEPOCH_DECAY = [100]
self._C.OPTIM.LR_INITIAL = 0.0002
self._C.OPTIM.LR_MIN = 0.0002
self._C.OPTIM.BETA1 = 0.5
self._C.TRAINING = CN()
self._C.TRAINING.VAL_AFTER_EVERY = 3
self._C.TRAINING.RESUME = False
self._C.TRAINING.SAVE_IMAGES = False
self._C.TRAINING.TRAIN_DIR = 'images_dir/train'
self._C.TRAINING.VAL_DIR = 'images_dir/val'
self._C.TRAINING.SAVE_DIR = 'checkpoints'
self._C.TRAINING.TRAIN_PS = 64
self._C.TRAINING.VAL_PS = 64
# Override parameter values from YAML file first, then from override list.
self._C.merge_from_file(config_yaml)
self._C.merge_from_list(config_override)
# Make an instantiated object of this class immutable.
self._C.freeze()
def dump(self, file_path: str):
r"""Save config at the specified file path.
Parameters
----------
file_path: str
(YAML) path to save config at.
"""
self._C.dump(stream=open(file_path, "w"))
def __getattr__(self, attr: str):
return self._C.__getattr__(attr)
def __repr__(self):
return self._C.__repr__()
import os
from dataset_RGB import DataLoaderTrain, DataLoaderVal, DataLoaderTest
def get_training_data(rgb_dir, img_options):
assert os.path.exists(rgb_dir)
return DataLoaderTrain(rgb_dir, img_options)
def get_validation_data(rgb_dir, img_options):
assert os.path.exists(rgb_dir)
return DataLoaderVal(rgb_dir, img_options)
def get_test_data(rgb_dir, img_options):
assert os.path.exists(rgb_dir)
return DataLoaderTest(rgb_dir, img_options)
import os
import numpy as np
from torch.utils.data import Dataset
import torch
from PIL import Image
import torchvision.transforms.functional as TF
from pdb import set_trace as stx
import random
def is_image_file(filename):
return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
class DataLoaderTrain(Dataset):
def __init__(self, rgb_dir, img_options=None):
super(DataLoaderTrain, self).__init__()
inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))
self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)]
self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)]
self.img_options = img_options
self.sizex = len(self.tar_filenames) # get the size of target
self.ps = self.img_options['patch_size']
def __len__(self):
return self.sizex
def __getitem__(self, index):
index_ = index % self.sizex
ps = self.ps
inp_path = self.inp_filenames[index_]
tar_path = self.tar_filenames[index_]
inp_img = Image.open(inp_path)
tar_img = Image.open(tar_path)
w,h = tar_img.size
padw = ps-w if w<ps else 0
padh = ps-h if h<ps else 0
# Reflect Pad in case image is smaller than patch_size
if padw!=0 or padh!=0:
inp_img = TF.pad(inp_img, (0,0,padw,padh), padding_mode='reflect')
tar_img = TF.pad(tar_img, (0,0,padw,padh), padding_mode='reflect')
aug = random.randint(0, 2)
if aug == 1:
inp_img = TF.adjust_gamma(inp_img, 1)
tar_img = TF.adjust_gamma(tar_img, 1)
aug = random.randint(0, 2)
if aug == 1:
sat_factor = 1 + (0.2 - 0.4*np.random.rand())
inp_img = TF.adjust_saturation(inp_img, sat_factor)
tar_img = TF.adjust_saturation(tar_img, sat_factor)
inp_img = TF.to_tensor(inp_img)
tar_img = TF.to_tensor(tar_img)
hh, ww = tar_img.shape[1], tar_img.shape[2]
rr = random.randint(0, hh-ps)
cc = random.randint(0, ww-ps)
aug = random.randint(0, 8)
# Crop patch
inp_img = inp_img[:, rr:rr+ps, cc:cc+ps]
tar_img = tar_img[:, rr:rr+ps, cc:cc+ps]
# Data Augmentations
if aug==1:
inp_img = inp_img.flip(1)
tar_img = tar_img.flip(1)
elif aug==2:
inp_img = inp_img.flip(2)
tar_img = tar_img.flip(2)
elif aug==3:
inp_img = torch.rot90(inp_img,dims=(1,2))
tar_img = torch.rot90(tar_img,dims=(1,2))
elif aug==4:
inp_img = torch.rot90(inp_img,dims=(1,2), k=2)
tar_img = torch.rot90(tar_img,dims=(1,2), k=2)
elif aug==5:
inp_img = torch.rot90(inp_img,dims=(1,2), k=3)
tar_img = torch.rot90(tar_img,dims=(1,2), k=3)
elif aug==6:
inp_img = torch.rot90(inp_img.flip(1),dims=(1,2))
tar_img = torch.rot90(tar_img.flip(1),dims=(1,2))
elif aug==7:
inp_img = torch.rot90(inp_img.flip(2),dims=(1,2))
tar_img = torch.rot90(tar_img.flip(2),dims=(1,2))
filename = os.path.splitext(os.path.split(tar_path)[-1])[0]
return tar_img, inp_img, filename
class DataLoaderVal(Dataset):
def __init__(self, rgb_dir, img_options=None, rgb_dir2=None):
super(DataLoaderVal, self).__init__()
inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))
self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)]
self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)]
self.img_options = img_options
self.sizex = len(self.tar_filenames) # get the size of target
self.ps = self.img_options['patch_size']
def __len__(self):
return self.sizex
def __getitem__(self, index):
index_ = index % self.sizex
ps = self.ps
inp_path = self.inp_filenames[index_]
tar_path = self.tar_filenames[index_]
inp_img = Image.open(inp_path)
tar_img = Image.open(tar_path)
# Validate on center crop
if self.ps is not None:
inp_img = TF.center_crop(inp_img, (ps,ps))
tar_img = TF.center_crop(tar_img, (ps,ps))
inp_img = TF.to_tensor(inp_img)
tar_img = TF.to_tensor(tar_img)
filename = os.path.splitext(os.path.split(tar_path)[-1])[0]
return tar_img, inp_img, filename
class DataLoaderTest(Dataset):
def __init__(self, inp_dir, img_options):
super(DataLoaderTest, self).__init__()
inp_files = sorted(os.listdir(inp_dir))
self.inp_filenames = [os.path.join(inp_dir, x) for x in inp_files if is_image_file(x)]
self.inp_size = len(self.inp_filenames)
self.img_options = img_options
def __len__(self):
return self.inp_size
def __getitem__(self, index):
path_inp = self.inp_filenames[index]
filename = os.path.splitext(os.path.split(path_inp)[-1])[0]
inp = Image.open(path_inp)
inp = TF.to_tensor(inp)
return inp, filename
%% Multi-Stage Progressive Image Restoration
%% Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Shao
%% https://arxiv.org/abs/2102.02808
close all;clear all;
% datasets = {'GoPro'};
datasets = {'GoPro', 'HIDE'};
num_set = length(datasets);
for idx_set = 1:num_set
file_path = strcat('./results/', datasets{idx_set}, '/');
gt_path = strcat('./Datasets/' datasets{idx_set}, '/test/target/');
path_list = [dir(strcat(file_path,'*.jpg')); dir(strcat(file_path,'*.png'))];
gt_list = [dir(strcat(gt_path,'*.jpg')); dir(strcat(gt_path,'*.png'))];
img_num = length(path_list);
total_psnr = 0;
total_ssim = 0;
if img_num > 0
for j = 1:img_num
image_name = path_list(j).name;
gt_name = gt_list(j).name;
input = imread(strcat(file_path,image_name));
gt = imread(strcat(gt_path, gt_name));
ssim_val = ssim(input, gt);
psnr_val = psnr(input, gt);
total_ssim = total_ssim + ssim_val;
total_psnr = total_psnr + psnr_val;
end
end
qm_psnr = total_psnr / img_num;
qm_ssim = total_ssim / img_num;
fprintf('For %s dataset PSNR: %f SSIM: %f\n', datasets{idx_set}, qm_psnr, qm_ssim);
end
## Multi-Stage Progressive Image Restoration
## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Shao
## https://arxiv.org/abs/2102.02808
import os
import numpy as np
from glob import glob
from natsort import natsorted
from skimage import io
import cv2
from skimage.metrics import structural_similarity
from tqdm import tqdm
import concurrent.futures
def image_align(deblurred, gt):
# this function is based on kohler evaluation code
z = deblurred
c = np.ones_like(z)
x = gt
zs = (np.sum(x * z) / np.sum(z * z)) * z # simple intensity matching
warp_mode = cv2.MOTION_HOMOGRAPHY
warp_matrix = np.eye(3, 3, dtype=np.float32)
# Specify the number of iterations.
number_of_iterations = 100
termination_eps = 0
criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
number_of_iterations, termination_eps)
# Run the ECC algorithm. The results are stored in warp_matrix.
(cc, warp_matrix) = cv2.findTransformECC(cv2.cvtColor(x, cv2.COLOR_RGB2GRAY), cv2.cvtColor(zs, cv2.COLOR_RGB2GRAY), warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=5)
target_shape = x.shape
shift = warp_matrix
zr = cv2.warpPerspective(
zs,
warp_matrix,
(target_shape[1], target_shape[0]),
flags=cv2.INTER_CUBIC+ cv2.WARP_INVERSE_MAP,
borderMode=cv2.BORDER_REFLECT)
cr = cv2.warpPerspective(
np.ones_like(zs, dtype='float32'),
warp_matrix,
(target_shape[1], target_shape[0]),
flags=cv2.INTER_NEAREST+ cv2.WARP_INVERSE_MAP,
borderMode=cv2.BORDER_CONSTANT,
borderValue=0)
zr = zr * cr
xr = x * cr
return zr, xr, cr, shift
def compute_psnr(image_true, image_test, image_mask, data_range=None):
# this function is based on skimage.metrics.peak_signal_noise_ratio
err = np.sum((image_true - image_test) ** 2, dtype=np.float64) / np.sum(image_mask)
return 10 * np.log10((data_range ** 2) / err)
def compute_ssim(tar_img, prd_img, cr1):
ssim_pre, ssim_map = structural_similarity(tar_img, prd_img, multichannel=True, gaussian_weights=True, use_sample_covariance=False, data_range = 1.0, full=True)
ssim_map = ssim_map * cr1
r = int(3.5 * 1.5 + 0.5) # radius as in ndimage
win_size = 2 * r + 1
pad = (win_size - 1) // 2
ssim = ssim_map[pad:-pad,pad:-pad,:]
crop_cr1 = cr1[pad:-pad,pad:-pad,:]
ssim = ssim.sum(axis=0).sum(axis=0)/crop_cr1.sum(axis=0).sum(axis=0)
ssim = np.mean(ssim)
return ssim
def proc(filename):
tar,prd = filename
tar_img = io.imread(tar)
prd_img = io.imread(prd)
tar_img = tar_img.astype(np.float32)/255.0
prd_img = prd_img.astype(np.float32)/255.0
prd_img, tar_img, cr1, shift = image_align(prd_img, tar_img)
PSNR = compute_psnr(tar_img, prd_img, cr1, data_range=1)
SSIM = compute_ssim(tar_img, prd_img, cr1)
return (PSNR,SSIM)
datasets = ['RealBlur_J', 'RealBlur_R']
for dataset in datasets:
file_path = os.path.join('results' , dataset)
gt_path = os.path.join('Datasets', dataset, 'test', 'target')
path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.jpg')))
gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.jpg')))
assert len(path_list) != 0, "Predicted files not found"
assert len(gt_list) != 0, "Target files not found"
psnr, ssim = [], []
img_files =[(i, j) for i,j in zip(gt_list,path_list)]
with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)):
psnr.append(PSNR_SSIM[0])
ssim.append(PSNR_SSIM[1])
avg_psnr = sum(psnr)/len(psnr)
avg_ssim = sum(ssim)/len(ssim)
print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim))
import torch
import torch.nn as nn
import torch.nn.functional as F
class CharbonnierLoss(nn.Module):
"""Charbonnier Loss (L1)"""
def __init__(self, eps=1e-3):
super(CharbonnierLoss, self).__init__()
self.eps = eps
def forward(self, x, y):
diff = x - y
# loss = torch.sum(torch.sqrt(diff * diff + self.eps))
loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
return loss
class EdgeLoss(nn.Module):
def __init__(self):
super(EdgeLoss, self).__init__()
k = torch.Tensor([[.05, .25, .4, .25, .05]])
self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1)
if torch.cuda.is_available():
self.kernel = self.kernel.cuda()
self.loss = CharbonnierLoss()
def conv_gauss(self, img):
n_channels, _, kw, kh = self.kernel.shape
img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
return F.conv2d(img, self.kernel, groups=n_channels)
def laplacian_kernel(self, current):
filtered = self.conv_gauss(current) # filter
down = filtered[:,:,::2,::2] # downsample
new_filter = torch.zeros_like(filtered)
new_filter[:,:,::2,::2] = down*4 # upsample
filtered = self.conv_gauss(new_filter) # filter
diff = current - filtered
return diff
def forward(self, x, y):
loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
return loss
pre-trained deblurring model is available [here](https://drive.google.com/file/d/1QwQUVbk6YVOJViCsOKYNykCsdJSVGRtb/view?usp=sharing)
\ No newline at end of file
"""
## Multi-Stage Progressive Image Restoration
## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Shao
## https://arxiv.org/abs/2102.02808
"""
import numpy as np
import os
import argparse
from tqdm import tqdm
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from data_RGB import get_test_data
from MPRNet import MPRNet
from skimage import img_as_ubyte
from pdb import set_trace as stx
parser = argparse.ArgumentParser(description='Image Deblurring using MPRNet')
parser.add_argument('--input_dir', default='./Datasets/', type=str, help='Directory of validation images')
parser.add_argument('--result_dir', default='./results/', type=str, help='Directory for results')
parser.add_argument('--weights', default='./pretrained_models/model_deblurring.pth', type=str, help='Path to weights')
parser.add_argument('--dataset', default='GoPro', type=str, help='Test Dataset') # ['GoPro', 'HIDE', 'RealBlur_J', 'RealBlur_R']
parser.add_argument('--gpus', default='0', type=str, help='CUDA_VISIBLE_DEVICES')
args = parser.parse_args()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
model_restoration = MPRNet()
utils.load_checkpoint(model_restoration,args.weights)
print("===>Testing using weights: ",args.weights)
model_restoration.cuda()
model_restoration = nn.DataParallel(model_restoration)
model_restoration.eval()
dataset = args.dataset
rgb_dir_test = os.path.join(args.input_dir, dataset, 'test', 'input')
test_dataset = get_test_data(rgb_dir_test, img_options={})
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False, pin_memory=True)
result_dir = os.path.join(args.result_dir, dataset)
utils.mkdir(result_dir)
with torch.no_grad():
for ii, data_test in enumerate(tqdm(test_loader), 0):
torch.cuda.ipc_collect()
torch.cuda.empty_cache()
input_ = data_test[0].cuda()
filenames = data_test[1]
# Padding in case images are not multiples of 8
if dataset == 'RealBlur_J' or dataset == 'RealBlur_R':
factor = 8
h,w = input_.shape[2], input_.shape[3]
H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor
padh = H-h if h%factor!=0 else 0
padw = W-w if w%factor!=0 else 0
input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
restored = model_restoration(input_)
restored = torch.clamp(restored[0],0,1)
# Unpad images to original dimensions
if dataset == 'RealBlur_J' or dataset == 'RealBlur_R':
restored = restored[:,:,:h,:w]
restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
for batch in range(len(restored)):
restored_img = img_as_ubyte(restored[batch])
utils.save_img((os.path.join(result_dir, filenames[batch]+'.png')), restored_img)
import os
from config import Config
opt = Config('training.yml')
gpus = ','.join([str(i) for i in opt.GPU])
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpus
import torch
torch.backends.cudnn.benchmark = True
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import random
import time
import numpy as np
import utils
from data_RGB import get_training_data, get_validation_data
from MPRNet import MPRNet
import losses
from warmup_scheduler import GradualWarmupScheduler
from tqdm import tqdm
from pdb import set_trace as stx
######### Set Seeds ###########
random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234)
torch.cuda.manual_seed_all(1234)
start_epoch = 1
mode = opt.MODEL.MODE
session = opt.MODEL.SESSION
result_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'results', session)
model_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'models', session)
utils.mkdir(result_dir)
utils.mkdir(model_dir)
train_dir = opt.TRAINING.TRAIN_DIR
val_dir = opt.TRAINING.VAL_DIR
######### Model ###########
model_restoration = MPRNet()
model_restoration.cuda()
device_ids = [i for i in range(torch.cuda.device_count())]
if torch.cuda.device_count() > 1:
print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n")
new_lr = opt.OPTIM.LR_INITIAL
optimizer = optim.Adam(model_restoration.parameters(), lr=new_lr, betas=(0.9, 0.999),eps=1e-8)
######### Scheduler ###########
warmup_epochs = 3
scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.OPTIM.NUM_EPOCHS-warmup_epochs, eta_min=opt.OPTIM.LR_MIN)
scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
scheduler.step()
######### Resume ###########
if opt.TRAINING.RESUME:
path_chk_rest = utils.get_last_path(model_dir, '_latest.pth')
utils.load_checkpoint(model_restoration,path_chk_rest)
start_epoch = utils.load_start_epoch(path_chk_rest) + 1
utils.load_optim(optimizer, path_chk_rest)
for i in range(1, start_epoch):
scheduler.step()
new_lr = scheduler.get_lr()[0]
print('------------------------------------------------------------------------------')
print("==> Resuming Training with learning rate:", new_lr)
print('------------------------------------------------------------------------------')
if len(device_ids)>1:
model_restoration = nn.DataParallel(model_restoration, device_ids = device_ids)
######### Loss ###########
criterion_char = losses.CharbonnierLoss()
criterion_edge = losses.EdgeLoss()
######### DataLoaders ###########
train_dataset = get_training_data(train_dir, {'patch_size':opt.TRAINING.TRAIN_PS})
train_loader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=16, drop_last=False, pin_memory=True)
val_dataset = get_validation_data(val_dir, {'patch_size':opt.TRAINING.VAL_PS})
val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=False, num_workers=8, drop_last=False, pin_memory=True)
print('===> Start Epoch {} End Epoch {}'.format(start_epoch,opt.OPTIM.NUM_EPOCHS + 1))
print('===> Loading datasets')
best_psnr = 0
best_epoch = 0
for epoch in range(start_epoch, opt.OPTIM.NUM_EPOCHS + 1):
epoch_start_time = time.time()
epoch_loss = 0
train_id = 1
model_restoration.train()
for i, data in enumerate(tqdm(train_loader), 0):
# zero_grad
for param in model_restoration.parameters():
param.grad = None
target = data[0].cuda()
input_ = data[1].cuda()
restored = model_restoration(input_)
# Compute loss at each stage
loss_char = np.sum([criterion_char(restored[j],target) for j in range(len(restored))])
loss_edge = np.sum([criterion_edge(restored[j],target) for j in range(len(restored))])
loss = (loss_char) + (0.05*loss_edge)
loss.backward()
optimizer.step()
epoch_loss +=loss.item()
#### Evaluation ####
if epoch%opt.TRAINING.VAL_AFTER_EVERY == 0:
model_restoration.eval()
psnr_val_rgb = []
for ii, data_val in enumerate((val_loader), 0):
target = data_val[0].cuda()
input_ = data_val[1].cuda()
with torch.no_grad():
restored = model_restoration(input_)
restored = restored[0]
for res,tar in zip(restored,target):
psnr_val_rgb.append(utils.torchPSNR(res, tar))
psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item()
if psnr_val_rgb > best_psnr:
best_psnr = psnr_val_rgb
best_epoch = epoch
torch.save({'epoch': epoch,
'state_dict': model_restoration.state_dict(),
'optimizer' : optimizer.state_dict()
}, os.path.join(model_dir,"model_best.pth"))
print("[epoch %d PSNR: %.4f --- best_epoch %d Best_PSNR %.4f]" % (epoch, psnr_val_rgb, best_epoch, best_psnr))
torch.save({'epoch': epoch,
'state_dict': model_restoration.state_dict(),
'optimizer' : optimizer.state_dict()
}, os.path.join(model_dir,f"model_epoch_{epoch}.pth"))
scheduler.step()
print("------------------------------------------------------------------")
print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time, epoch_loss, scheduler.get_lr()[0]))
print("------------------------------------------------------------------")
torch.save({'epoch': epoch,
'state_dict': model_restoration.state_dict(),
'optimizer' : optimizer.state_dict()
}, os.path.join(model_dir,"model_latest.pth"))
###############
##
####
GPU: [0,1,2,3]
VERBOSE: True
MODEL:
MODE: 'Deblurring'
SESSION: 'MPRNet'
# Optimization arguments.
OPTIM:
BATCH_SIZE: 16
NUM_EPOCHS: 3000
# NEPOCH_DECAY: [10]
LR_INITIAL: 2e-4
LR_MIN: 1e-6
# BETA1: 0.9
TRAINING:
VAL_AFTER_EVERY: 20
RESUME: False
TRAIN_PS: 256
VAL_PS: 256
TRAIN_DIR: './Datasets/GoPro/train' # path to training data
VAL_DIR: './Datasets/GoPro/test' # path to validation data
SAVE_DIR: './checkpoints' # path to save models and images
# SAVE_IMAGES: False
from .dir_utils import *
from .image_utils import *
from .model_utils import *
from .dataset_utils import *
import torch
class MixUp_AUG:
def __init__(self):
self.dist = torch.distributions.beta.Beta(torch.tensor([0.6]), torch.tensor([0.6]))
def aug(self, rgb_gt, rgb_noisy):
bs = rgb_gt.size(0)
indices = torch.randperm(bs)
rgb_gt2 = rgb_gt[indices]
rgb_noisy2 = rgb_noisy[indices]
lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda()
rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2
rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2
return rgb_gt, rgb_noisy
\ No newline at end of file
import os
from natsort import natsorted
from glob import glob
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def get_last_path(path, session):
x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1]
return x
\ No newline at end of file
import torch
import numpy as np
import cv2
def torchPSNR(tar_img, prd_img):
imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1)
rmse = (imdff**2).mean().sqrt()
ps = 20*torch.log10(1/rmse)
return ps
def save_img(filepath, img):
cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
def numpyPSNR(tar_img, prd_img):
imdff = np.float32(prd_img) - np.float32(tar_img)
rmse = np.sqrt(np.mean(imdff**2))
ps = 20*np.log10(255/rmse)
return ps
import torch
import os
from collections import OrderedDict
def freeze(model):
for p in model.parameters():
p.requires_grad=False
def unfreeze(model):
for p in model.parameters():
p.requires_grad=True
def is_frozen(model):
x = [p.requires_grad for p in model.parameters()]
return not all(x)
def save_checkpoint(model_dir, state, session):
epoch = state['epoch']
model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session))
torch.save(state, model_out_path)
def load_checkpoint(model, weights):
checkpoint = torch.load(weights)
try:
model.load_state_dict(checkpoint["state_dict"])
except:
state_dict = checkpoint["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
def load_checkpoint_multigpu(model, weights):
checkpoint = torch.load(weights)
state_dict = checkpoint["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
def load_start_epoch(weights):
checkpoint = torch.load(weights)
epoch = checkpoint["epoch"]
return epoch
def load_optim(optimizer, weights):
checkpoint = torch.load(weights)
optimizer.load_state_dict(checkpoint['optimizer'])
# for p in optimizer.param_groups: lr = p['lr']
# return lr
Download datasets from the provided links and place them in this directory. Your directory structure should look something like this
`SIDD` <br/>
  `├──`[train](https://www.eecs.yorku.ca/~kamel/sidd/dataset.php) <br/>
  `├──`[val](https://drive.google.com/drive/folders/1S44fHXaVxAYW3KLNxK41NYCnyX9S79su?usp=sharing) <br/>
  `└──`[test](https://www.eecs.yorku.ca/~kamel/sidd/benchmark.php) <br/>
      `├──ValidationNoisyBlocksSrgb.mat` <br/>
      `└──ValidationGtBlocksSrgb.mat`
`DND` <br/>
  `└──`[test](https://noise.visinf.tu-darmstadt.de/downloads/) <br/>
      `├──info.mat` <br/>
      `└──images_srgb` <br/>
            `├──0001.mat` <br/>
            `├──0002.mat` <br/>
            `├── ... ` <br/>
            `└──0050.mat`
This diff is collapsed.
## Training
- Download the [Datasets](Datasets/README.md)
- Generate image patches from full-resolution training images of SIDD dataset
```
python generate_patches_SIDD.py --ps 256 --num_patches 300 --num_cores 10
```
- Train the model with default arguments by running
```
python train.py
```
## Evaluation
- Download the [model](https://drive.google.com/file/d/1LODPt9kYmxwU98g96UrRA0_Eh5HYcsRw/view?usp=sharing) and place it in `./pretrained_models/`
#### Testing on SIDD dataset
- Download SIDD Validation Data and Ground Truth from [here](https://www.eecs.yorku.ca/~kamel/sidd/benchmark.php) and place them in `./Datasets/SIDD/test/`
- Run
```
python test_SIDD.py --save_images
```
#### Testing on DND dataset
- Download DND Benchmark Data from [here](https://noise.visinf.tu-darmstadt.de/downloads/) and place it in `./Datasets/DND/test/`
- Run
```
python test_DND.py --save_images
```
#### To reproduce PSNR/SSIM scores of the paper, run MATLAB script
```
evaluate_SIDD.m
```
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jul 23 14:35:48 2019
@author: aditya
"""
r"""This module provides package-wide configuration management."""
from typing import Any, List
from yacs.config import CfgNode as CN
class Config(object):
r"""
A collection of all the required configuration parameters. This class is a nested dict-like
structure, with nested keys accessible as attributes. It contains sensible default values for
all the parameters, which may be overriden by (first) through a YAML file and (second) through
a list of attributes and values.
Extended Summary
----------------
This class definition contains default values corresponding to ``joint_training`` phase, as it
is the final training phase and uses almost all the configuration parameters. Modification of
any parameter after instantiating this class is not possible, so you must override required
parameter values in either through ``config_yaml`` file or ``config_override`` list.
Parameters
----------
config_yaml: str
Path to a YAML file containing configuration parameters to override.
config_override: List[Any], optional (default= [])
A list of sequential attributes and values of parameters to override. This happens after
overriding from YAML file.
Examples
--------
Let a YAML file named "config.yaml" specify these parameters to override::
ALPHA: 1000.0
BETA: 0.5
>>> _C = Config("config.yaml", ["OPTIM.BATCH_SIZE", 2048, "BETA", 0.7])
>>> _C.ALPHA # default: 100.0
1000.0
>>> _C.BATCH_SIZE # default: 256
2048
>>> _C.BETA # default: 0.1
0.7
Attributes
----------
"""
def __init__(self, config_yaml: str, config_override: List[Any] = []):
self._C = CN()
self._C.GPU = [0]
self._C.VERBOSE = False
self._C.MODEL = CN()
self._C.MODEL.MODE = 'global'
self._C.MODEL.SESSION = 'ps128_bs1'
self._C.OPTIM = CN()
self._C.OPTIM.BATCH_SIZE = 1
self._C.OPTIM.NUM_EPOCHS = 100
self._C.OPTIM.NEPOCH_DECAY = [100]
self._C.OPTIM.LR_INITIAL = 0.0002
self._C.OPTIM.LR_MIN = 0.0002
self._C.OPTIM.BETA1 = 0.5
self._C.TRAINING = CN()
self._C.TRAINING.VAL_AFTER_EVERY = 3
self._C.TRAINING.RESUME = False
self._C.TRAINING.SAVE_IMAGES = False
self._C.TRAINING.TRAIN_DIR = 'images_dir/train'
self._C.TRAINING.VAL_DIR = 'images_dir/val'
self._C.TRAINING.SAVE_DIR = 'checkpoints'
self._C.TRAINING.TRAIN_PS = 64
self._C.TRAINING.VAL_PS = 64
# Override parameter values from YAML file first, then from override list.
self._C.merge_from_file(config_yaml)
self._C.merge_from_list(config_override)
# Make an instantiated object of this class immutable.
self._C.freeze()
def dump(self, file_path: str):
r"""Save config at the specified file path.
Parameters
----------
file_path: str
(YAML) path to save config at.
"""
self._C.dump(stream=open(file_path, "w"))
def __getattr__(self, attr: str):
return self._C.__getattr__(attr)
def __repr__(self):
return self._C.__repr__()
import os
from dataset_RGB import DataLoaderTrain, DataLoaderVal, DataLoaderTest
def get_training_data(rgb_dir, img_options):
assert os.path.exists(rgb_dir)
return DataLoaderTrain(rgb_dir, img_options)
def get_validation_data(rgb_dir, img_options):
assert os.path.exists(rgb_dir)
return DataLoaderVal(rgb_dir, img_options)
def get_test_data(rgb_dir, img_options):
assert os.path.exists(rgb_dir)
return DataLoaderTest(rgb_dir, img_options)
import os
import numpy as np
from torch.utils.data import Dataset
import torch
from PIL import Image
import torchvision.transforms.functional as TF
from pdb import set_trace as stx
import random
def is_image_file(filename):
return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
class DataLoaderTrain(Dataset):
def __init__(self, rgb_dir, img_options=None):
super(DataLoaderTrain, self).__init__()
inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))
self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)]
self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)]
self.img_options = img_options
self.sizex = len(self.tar_filenames) # get the size of target
self.ps = self.img_options['patch_size']
def __len__(self):
return self.sizex
def __getitem__(self, index):
index_ = index % self.sizex
ps = self.ps
inp_path = self.inp_filenames[index_]
tar_path = self.tar_filenames[index_]
inp_img = Image.open(inp_path)
tar_img = Image.open(tar_path)
w,h = tar_img.size
padw = ps-w if w<ps else 0
padh = ps-h if h<ps else 0
# Reflect Pad in case image is smaller than patch_size
if padw!=0 or padh!=0:
inp_img = TF.pad(inp_img, (0,0,padw,padh), padding_mode='reflect')
tar_img = TF.pad(tar_img, (0,0,padw,padh), padding_mode='reflect')
inp_img = TF.to_tensor(inp_img)
tar_img = TF.to_tensor(tar_img)
hh, ww = tar_img.shape[1], tar_img.shape[2]
rr = random.randint(0, hh-ps)
cc = random.randint(0, ww-ps)
aug = random.randint(0, 8)
# Crop patch
inp_img = inp_img[:, rr:rr+ps, cc:cc+ps]
tar_img = tar_img[:, rr:rr+ps, cc:cc+ps]
# Data Augmentations
if aug==1:
inp_img = inp_img.flip(1)
tar_img = tar_img.flip(1)
elif aug==2:
inp_img = inp_img.flip(2)
tar_img = tar_img.flip(2)
elif aug==3:
inp_img = torch.rot90(inp_img,dims=(1,2))
tar_img = torch.rot90(tar_img,dims=(1,2))
elif aug==4:
inp_img = torch.rot90(inp_img,dims=(1,2), k=2)
tar_img = torch.rot90(tar_img,dims=(1,2), k=2)
elif aug==5:
inp_img = torch.rot90(inp_img,dims=(1,2), k=3)
tar_img = torch.rot90(tar_img,dims=(1,2), k=3)
elif aug==6:
inp_img = torch.rot90(inp_img.flip(1),dims=(1,2))
tar_img = torch.rot90(tar_img.flip(1),dims=(1,2))
elif aug==7:
inp_img = torch.rot90(inp_img.flip(2),dims=(1,2))
tar_img = torch.rot90(tar_img.flip(2),dims=(1,2))
filename = os.path.splitext(os.path.split(tar_path)[-1])[0]
return tar_img, inp_img, filename
class DataLoaderVal(Dataset):
def __init__(self, rgb_dir, img_options=None, rgb_dir2=None):
super(DataLoaderVal, self).__init__()
inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))
self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)]
self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)]
self.img_options = img_options
self.sizex = len(self.tar_filenames) # get the size of target
self.ps = self.img_options['patch_size']
def __len__(self):
return self.sizex
def __getitem__(self, index):
index_ = index % self.sizex
ps = self.ps
inp_path = self.inp_filenames[index_]
tar_path = self.tar_filenames[index_]
inp_img = Image.open(inp_path)
tar_img = Image.open(tar_path)
# Validate on center crop
if self.ps is not None:
inp_img = TF.center_crop(inp_img, (ps,ps))
tar_img = TF.center_crop(tar_img, (ps,ps))
inp_img = TF.to_tensor(inp_img)
tar_img = TF.to_tensor(tar_img)
filename = os.path.splitext(os.path.split(tar_path)[-1])[0]
return tar_img, inp_img, filename
class DataLoaderTest(Dataset):
def __init__(self, inp_dir, img_options):
super(DataLoaderTest, self).__init__()
inp_files = sorted(os.listdir(inp_dir))
self.inp_filenames = [os.path.join(inp_dir, x) for x in inp_files if is_image_file(x)]
self.inp_size = len(self.inp_filenames)
self.img_options = img_options
def __len__(self):
return self.inp_size
def __getitem__(self, index):
path_inp = self.inp_filenames[index]
filename = os.path.splitext(os.path.split(path_inp)[-1])[0]
inp = Image.open(path_inp)
inp = TF.to_tensor(inp)
return inp, filename
close all;clear all;
denoised = load('Idenoised.mat');
gt = load('ValidationGtBlocksSrgb.mat');
denoised = denoised.Idenoised;
gt = gt.ValidationGtBlocksSrgb;
gt = im2single(gt);
total_psnr = 0;
total_ssim = 0;
for i = 1:40
for k = 1:32
denoised_patch = squeeze(denoised(i,k,:,:,:));
gt_patch = squeeze(gt(i,k,:,:,:));
ssim_val = ssim(denoised_patch, gt_patch);
psnr_val = psnr(denoised_patch, gt_patch);
total_ssim = total_ssim + ssim_val;
total_psnr = total_psnr + psnr_val;
end
end
qm_psnr = total_psnr / (40*32);
qm_ssim = total_ssim / (40*32);
fprintf('PSNR: %f SSIM: %f\n', qm_psnr, qm_ssim);
from glob import glob
from tqdm import tqdm
import numpy as np
import os
from natsort import natsorted
import cv2
from joblib import Parallel, delayed
import multiprocessing
import argparse
parser = argparse.ArgumentParser(description='Generate patches from Full Resolution images')
parser.add_argument('--src_dir', default='../SIDD_Medium_Srgb/Data', type=str, help='Directory for full resolution images')
parser.add_argument('--tar_dir', default='../SIDD_patches/train',type=str, help='Directory for image patches')
parser.add_argument('--ps', default=256, type=int, help='Image Patch Size')
parser.add_argument('--num_patches', default=300, type=int, help='Number of patches per image')
parser.add_argument('--num_cores', default=10, type=int, help='Number of CPU Cores')
args = parser.parse_args()
src = args.src_dir
tar = args.tar_dir
PS = args.ps
NUM_PATCHES = args.num_patches
NUM_CORES = args.num_cores
noisy_patchDir = os.path.join(tar, 'input')
clean_patchDir = os.path.join(tar, 'groundtruth')
if os.path.exists(tar):
os.system("rm -r {}".format(tar))
os.makedirs(noisy_patchDir)
os.makedirs(clean_patchDir)
#get sorted folders
files = natsorted(glob(os.path.join(src, '*', '*.PNG')))
noisy_files, clean_files = [], []
for file_ in files:
filename = os.path.split(file_)[-1]
if 'GT' in filename:
clean_files.append(file_)
if 'NOISY' in filename:
noisy_files.append(file_)
def save_files(i):
noisy_file, clean_file = noisy_files[i], clean_files[i]
noisy_img = cv2.imread(noisy_file)
clean_img = cv2.imread(clean_file)
H = noisy_img.shape[0]
W = noisy_img.shape[1]
for j in range(NUM_PATCHES):
rr = np.random.randint(0, H - PS)
cc = np.random.randint(0, W - PS)
noisy_patch = noisy_img[rr:rr + PS, cc:cc + PS, :]
clean_patch = clean_img[rr:rr + PS, cc:cc + PS, :]
cv2.imwrite(os.path.join(noisy_patchDir, '{}_{}.png'.format(i+1,j+1)), noisy_patch)
cv2.imwrite(os.path.join(clean_patchDir, '{}_{}.png'.format(i+1,j+1)), clean_patch)
Parallel(n_jobs=NUM_CORES)(delayed(save_files)(i) for i in tqdm(range(len(noisy_files))))
import torch
import torch.nn as nn
import torch.nn.functional as F
class CharbonnierLoss(nn.Module):
"""Charbonnier Loss (L1)"""
def __init__(self, eps=1e-3):
super(CharbonnierLoss, self).__init__()
self.eps = eps
def forward(self, x, y):
diff = x - y
# loss = torch.sum(torch.sqrt(diff * diff + self.eps))
loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
return loss
class EdgeLoss(nn.Module):
def __init__(self):
super(EdgeLoss, self).__init__()
k = torch.Tensor([[.05, .25, .4, .25, .05]])
self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1)
if torch.cuda.is_available():
self.kernel = self.kernel.cuda()
self.loss = CharbonnierLoss()
def conv_gauss(self, img):
n_channels, _, kw, kh = self.kernel.shape
img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
return F.conv2d(img, self.kernel, groups=n_channels)
def laplacian_kernel(self, current):
filtered = self.conv_gauss(current) # filter
down = filtered[:,:,::2,::2] # downsample
new_filter = torch.zeros_like(filtered)
new_filter[:,:,::2,::2] = down*4 # upsample
filtered = self.conv_gauss(new_filter) # filter
diff = current - filtered
return diff
def forward(self, x, y):
loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
return loss
pre-trained denoising model is available [here](https://drive.google.com/file/d/1LODPt9kYmxwU98g96UrRA0_Eh5HYcsRw/view?usp=sharing)
\ No newline at end of file
"""
## Multi-Stage Progressive Image Restoration
## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Shao
## https://arxiv.org/abs/2102.02808
"""
import numpy as np
import os
import argparse
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import utils
from MPRNet import MPRNet
from skimage import img_as_ubyte
import h5py
import scipy.io as sio
from pdb import set_trace as stx
parser = argparse.ArgumentParser(description='Image Denoising using MPRNet')
parser.add_argument('--input_dir', default='./Datasets/DND/', type=str, help='Directory of validation images')
parser.add_argument('--result_dir', default='./results/DND/test/', type=str, help='Directory for results')
parser.add_argument('--weights', default='./pretrained_models/model_denoising.pth', type=str, help='Path to weights')
parser.add_argument('--gpus', default='1', type=str, help='CUDA_VISIBLE_DEVICES')
parser.add_argument('--save_images', action='store_true', help='Save denoised images in result directory')
args = parser.parse_args()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
result_dir = os.path.join(args.result_dir, 'mat')
utils.mkdir(result_dir)
if args.save_images:
result_dir_img = os.path.join(args.result_dir, 'png')
utils.mkdir(result_dir_img)
model_restoration = MPRNet()
utils.load_checkpoint(model_restoration,args.weights)
print("===>Testing using weights: ",args.weights)
model_restoration.cuda()
model_restoration = nn.DataParallel(model_restoration)
model_restoration.eval()
israw = False
eval_version="1.0"
# Load info
infos = h5py.File(os.path.join(args.input_dir, 'info.mat'), 'r')
info = infos['info']
bb = info['boundingboxes']
# Process data
with torch.no_grad():
for i in tqdm(range(50)):
Idenoised = np.zeros((20,), dtype=np.object)
filename = '%04d.mat'%(i+1)
filepath = os.path.join(args.input_dir, 'images_srgb', filename)
img = h5py.File(filepath, 'r')
Inoisy = np.float32(np.array(img['InoisySRGB']).T)
# bounding box
ref = bb[0][i]
boxes = np.array(info[ref]).T
for k in range(20):
idx = [int(boxes[k,0]-1),int(boxes[k,2]),int(boxes[k,1]-1),int(boxes[k,3])]
noisy_patch = torch.from_numpy(Inoisy[idx[0]:idx[1],idx[2]:idx[3],:]).unsqueeze(0).permute(0,3,1,2).cuda()
restored_patch = model_restoration(noisy_patch)
restored_patch = torch.clamp(restored_patch[0],0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
Idenoised[k] = restored_patch
if args.save_images:
save_file = os.path.join(result_dir_img, '%04d_%02d.png'%(i+1,k+1))
denoised_img = img_as_ubyte(restored_patch)
utils.save_img(save_file, denoised_img)
# save denoised data
sio.savemat(os.path.join(result_dir, filename),
{"Idenoised": Idenoised,
"israw": israw,
"eval_version": eval_version},
)
"""
## Multi-Stage Progressive Image Restoration
## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Shao
## https://arxiv.org/abs/2102.02808
"""
import numpy as np
import os
import argparse
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import utils
from MPRNet import MPRNet
from skimage import img_as_ubyte
import h5py
import scipy.io as sio
from pdb import set_trace as stx
parser = argparse.ArgumentParser(description='Image Denoising using MPRNet')
parser.add_argument('--input_dir', default='./Datasets/SIDD/test/', type=str, help='Directory of validation images')
parser.add_argument('--result_dir', default='./results/SIDD/', type=str, help='Directory for results')
parser.add_argument('--weights', default='./pretrained_models/model_denoising.pth', type=str, help='Path to weights')
parser.add_argument('--gpus', default='1', type=str, help='CUDA_VISIBLE_DEVICES')
parser.add_argument('--save_images', action='store_true', help='Save denoised images in result directory')
args = parser.parse_args()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
result_dir = os.path.join(args.result_dir, 'mat')
utils.mkdir(result_dir)
if args.save_images:
result_dir_img = os.path.join(args.result_dir, 'png')
utils.mkdir(result_dir_img)
model_restoration = MPRNet()
utils.load_checkpoint(model_restoration,args.weights)
print("===>Testing using weights: ",args.weights)
model_restoration.cuda()
model_restoration = nn.DataParallel(model_restoration)
model_restoration.eval()
# Process data
filepath = os.path.join(args.input_dir, 'ValidationNoisyBlocksSrgb.mat')
img = sio.loadmat(filepath)
Inoisy = np.float32(np.array(img['ValidationNoisyBlocksSrgb']))
Inoisy /=255.
restored = np.zeros_like(Inoisy)
with torch.no_grad():
for i in tqdm(range(40)):
for k in range(32):
noisy_patch = torch.from_numpy(Inoisy[i,k,:,:,:]).unsqueeze(0).permute(0,3,1,2).cuda()
restored_patch = model_restoration(noisy_patch)
restored_patch = torch.clamp(restored_patch[0],0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0)
restored[i,k,:,:,:] = restored_patch
if args.save_images:
save_file = os.path.join(result_dir_img, '%04d_%02d.png'%(i+1,k+1))
utils.save_img(save_file, img_as_ubyte(restored_patch))
# save denoised data
sio.savemat(os.path.join(result_dir, 'Idenoised.mat'), {"Idenoised": restored,})
import os
from config import Config
opt = Config('training.yml')
gpus = ','.join([str(i) for i in opt.GPU])
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpus
import torch
torch.backends.cudnn.benchmark = True
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import random
import time
import numpy as np
import utils
from data_RGB import get_training_data, get_validation_data
from MPRNet import MPRNet
import losses
from warmup_scheduler import GradualWarmupScheduler
from tqdm import tqdm
from pdb import set_trace as stx
######### Set Seeds ###########
random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234)
torch.cuda.manual_seed_all(1234)
start_epoch = 1
mode = opt.MODEL.MODE
session = opt.MODEL.SESSION
result_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'results', session)
model_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'models', session)
utils.mkdir(result_dir)
utils.mkdir(model_dir)
train_dir = opt.TRAINING.TRAIN_DIR
val_dir = opt.TRAINING.VAL_DIR
######### Model ###########
model_restoration = MPRNet()
model_restoration.cuda()
device_ids = [i for i in range(torch.cuda.device_count())]
if torch.cuda.device_count() > 1:
print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n")
new_lr = opt.OPTIM.LR_INITIAL
optimizer = optim.Adam(model_restoration.parameters(), lr=new_lr, betas=(0.9, 0.999),eps=1e-8, weight_decay=1e-8)
######### Scheduler ###########
warmup_epochs = 3
scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.OPTIM.NUM_EPOCHS-warmup_epochs+40, eta_min=opt.OPTIM.LR_MIN)
scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
scheduler.step()
######### Resume ###########
if opt.TRAINING.RESUME:
path_chk_rest = utils.get_last_path(model_dir, '_latest.pth')
utils.load_checkpoint(model_restoration,path_chk_rest)
start_epoch = utils.load_start_epoch(path_chk_rest) + 1
utils.load_optim(optimizer, path_chk_rest)
for i in range(1, start_epoch):
scheduler.step()
new_lr = scheduler.get_lr()[0]
print('------------------------------------------------------------------------------')
print("==> Resuming Training with learning rate:", new_lr)
print('------------------------------------------------------------------------------')
if len(device_ids)>1:
model_restoration = nn.DataParallel(model_restoration, device_ids = device_ids)
######### Loss ###########
criterion = losses.CharbonnierLoss()
######### DataLoaders ###########
train_dataset = get_training_data(train_dir, {'patch_size':opt.TRAINING.TRAIN_PS})
train_loader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=16, drop_last=False, pin_memory=True)
val_dataset = get_validation_data(val_dir, {'patch_size':opt.TRAINING.VAL_PS})
val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=False, num_workers=8, drop_last=False, pin_memory=True)
print('===> Start Epoch {} End Epoch {}'.format(start_epoch,opt.OPTIM.NUM_EPOCHS + 1))
print('===> Loading datasets')
best_psnr = 0
best_epoch = 0
best_iter = 0
eval_now = len(train_loader)//3 - 1
print(f"\nEval after every {eval_now} Iterations !!!\n")
mixup = utils.MixUp_AUG()
for epoch in range(start_epoch, opt.OPTIM.NUM_EPOCHS + 1):
epoch_start_time = time.time()
epoch_loss = 0
train_id = 1
model_restoration.train()
for i, data in enumerate(tqdm(train_loader), 0):
# zero_grad
for param in model_restoration.parameters():
param.grad = None
target = data[0].cuda()
input_ = data[1].cuda()
if epoch>5:
target, input_ = mixup.aug(target, input_)
restored = model_restoration(input_)
# Compute loss at each stage
loss = np.sum([criterion(torch.clamp(restored[j],0,1),target) for j in range(len(restored))])
loss.backward()
optimizer.step()
epoch_loss +=loss.item()
#### Evaluation ####
if i%eval_now==0 and i>0 and (epoch in [1,25,45] or epoch>60):
model_restoration.eval()
psnr_val_rgb = []
for ii, data_val in enumerate((val_loader), 0):
target = data_val[0].cuda()
input_ = data_val[1].cuda()
with torch.no_grad():
restored = model_restoration(input_)
restored = restored[0]
for res,tar in zip(restored,target):
psnr_val_rgb.append(utils.torchPSNR(res, tar))
psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item()
if psnr_val_rgb > best_psnr:
best_psnr = psnr_val_rgb
best_epoch = epoch
best_iter = i
torch.save({'epoch': epoch,
'state_dict': model_restoration.state_dict(),
'optimizer' : optimizer.state_dict()
}, os.path.join(model_dir,"model_best.pth"))
print("[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]" % (epoch, i, psnr_val_rgb, best_epoch, best_iter, best_psnr))
torch.save({'epoch': epoch,
'state_dict': model_restoration.state_dict(),
'optimizer' : optimizer.state_dict()
}, os.path.join(model_dir,f"model_epoch_{epoch}.pth"))
model_restoration.train()
scheduler.step()
print("------------------------------------------------------------------")
print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time, epoch_loss, scheduler.get_lr()[0]))
print("------------------------------------------------------------------")
torch.save({'epoch': epoch,
'state_dict': model_restoration.state_dict(),
'optimizer' : optimizer.state_dict()
}, os.path.join(model_dir,"model_latest.pth"))
###############
##
####
GPU: [0,1,2,3]
VERBOSE: True
MODEL:
MODE: 'Denoising'
SESSION: 'MPRNet'
# Optimization arguments.
OPTIM:
BATCH_SIZE: 16
NUM_EPOCHS: 80
# NEPOCH_DECAY: [10]
LR_INITIAL: 2e-4
LR_MIN: 1e-6
# BETA1: 0.9
TRAINING:
VAL_AFTER_EVERY: 1
RESUME: False
TRAIN_PS: 128
VAL_PS: 256
TRAIN_DIR: './Datasets/SIDD/train' # path to training data
VAL_DIR: './Datasets/SIDD/val' # path to validation data
SAVE_DIR: './checkpoints' # path to save models and images
# SAVE_IMAGES: False
from .dir_utils import *
from .image_utils import *
from .model_utils import *
from .dataset_utils import *
import torch
class MixUp_AUG:
def __init__(self):
self.dist = torch.distributions.beta.Beta(torch.tensor([0.6]), torch.tensor([0.6]))
def aug(self, rgb_gt, rgb_noisy):
bs = rgb_gt.size(0)
indices = torch.randperm(bs)
rgb_gt2 = rgb_gt[indices]
rgb_noisy2 = rgb_noisy[indices]
lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda()
rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2
rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2
return rgb_gt, rgb_noisy
\ No newline at end of file
import os
from natsort import natsorted
from glob import glob
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def get_last_path(path, session):
x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1]
return x
\ No newline at end of file
import torch
import numpy as np
import cv2
def torchPSNR(tar_img, prd_img):
imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1)
rmse = (imdff**2).mean().sqrt()
ps = 20*torch.log10(1/rmse)
return ps
def save_img(filepath, img):
cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
def numpyPSNR(tar_img, prd_img):
imdff = np.float32(prd_img) - np.float32(tar_img)
rmse = np.sqrt(np.mean(imdff**2))
ps = 20*np.log10(255/rmse)
return ps
import torch
import os
from collections import OrderedDict
def freeze(model):
for p in model.parameters():
p.requires_grad=False
def unfreeze(model):
for p in model.parameters():
p.requires_grad=True
def is_frozen(model):
x = [p.requires_grad for p in model.parameters()]
return not all(x)
def save_checkpoint(model_dir, state, session):
epoch = state['epoch']
model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session))
torch.save(state, model_out_path)
def load_checkpoint(model, weights):
checkpoint = torch.load(weights)
try:
model.load_state_dict(checkpoint["state_dict"])
except:
state_dict = checkpoint["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
def load_checkpoint_multigpu(model, weights):
checkpoint = torch.load(weights)
state_dict = checkpoint["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
def load_start_epoch(weights):
checkpoint = torch.load(weights)
epoch = checkpoint["epoch"]
return epoch
def load_optim(optimizer, weights):
checkpoint = torch.load(weights)
optimizer.load_state_dict(checkpoint['optimizer'])
# for p in optimizer.param_groups: lr = p['lr']
# return lr
Download datasets from the google drive links and place them in this directory. Your directory structure should look something like this
`Synthetic_Rain_Datasets` <br/>
  `├──`[train](https://drive.google.com/drive/folders/1Hnnlc5kI0v9_BtfMytC2LR5VpLAFZtVe?usp=sharing) <br/>
  `└──`[test](https://drive.google.com/drive/folders/1PDWggNh8ylevFmrjo-JEvlmqsDlWWvZs?usp=sharing) <br/>
      `├──Test100` <br/>
      `├──Rain100H` <br/>
      `├──Rain100L` <br/>
      `├──Test1200` <br/>
      `└──Test2800`
This diff is collapsed.
## Training
- Download the [Datasets](Datasets/README.md)
- Train the model with default arguments by running
```
python train.py
```
## Evaluation
1. Download the [model](https://drive.google.com/file/d/1O3WEJbcat7eTY6doXWeorAbQ1l_WmMnM/view?usp=sharing) and place it in `./pretrained_models/`
2. Download test datasets (Test100, Rain100H, Rain100L, Test1200, Test2800) from [here](https://drive.google.com/drive/folders/1PDWggNh8ylevFmrjo-JEvlmqsDlWWvZs?usp=sharing) and place them in `./Datasets/Synthetic_Rain_Datasets/test/`
3. Run
```
python test.py
```
#### To reproduce PSNR/SSIM scores of the paper, run
```
evaluate_PSNR_SSIM.m
```
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jul 23 14:35:48 2019
@author: aditya
"""
r"""This module provides package-wide configuration management."""
from typing import Any, List
from yacs.config import CfgNode as CN
class Config(object):
r"""
A collection of all the required configuration parameters. This class is a nested dict-like
structure, with nested keys accessible as attributes. It contains sensible default values for
all the parameters, which may be overriden by (first) through a YAML file and (second) through
a list of attributes and values.
Extended Summary
----------------
This class definition contains default values corresponding to ``joint_training`` phase, as it
is the final training phase and uses almost all the configuration parameters. Modification of
any parameter after instantiating this class is not possible, so you must override required
parameter values in either through ``config_yaml`` file or ``config_override`` list.
Parameters
----------
config_yaml: str
Path to a YAML file containing configuration parameters to override.
config_override: List[Any], optional (default= [])
A list of sequential attributes and values of parameters to override. This happens after
overriding from YAML file.
Examples
--------
Let a YAML file named "config.yaml" specify these parameters to override::
ALPHA: 1000.0
BETA: 0.5
>>> _C = Config("config.yaml", ["OPTIM.BATCH_SIZE", 2048, "BETA", 0.7])
>>> _C.ALPHA # default: 100.0
1000.0
>>> _C.BATCH_SIZE # default: 256
2048
>>> _C.BETA # default: 0.1
0.7
Attributes
----------
"""
def __init__(self, config_yaml: str, config_override: List[Any] = []):
self._C = CN()
self._C.GPU = [0]
self._C.VERBOSE = False
self._C.MODEL = CN()
self._C.MODEL.MODE = 'global'
self._C.MODEL.SESSION = 'ps128_bs1'
self._C.OPTIM = CN()
self._C.OPTIM.BATCH_SIZE = 1
self._C.OPTIM.NUM_EPOCHS = 100
self._C.OPTIM.NEPOCH_DECAY = [100]
self._C.OPTIM.LR_INITIAL = 0.0002
self._C.OPTIM.LR_MIN = 0.0002
self._C.OPTIM.BETA1 = 0.5
self._C.TRAINING = CN()
self._C.TRAINING.VAL_AFTER_EVERY = 3
self._C.TRAINING.RESUME = False
self._C.TRAINING.SAVE_IMAGES = False
self._C.TRAINING.TRAIN_DIR = 'images_dir/train'
self._C.TRAINING.VAL_DIR = 'images_dir/val'
self._C.TRAINING.SAVE_DIR = 'checkpoints'
self._C.TRAINING.TRAIN_PS = 64
self._C.TRAINING.VAL_PS = 64
# Override parameter values from YAML file first, then from override list.
self._C.merge_from_file(config_yaml)
self._C.merge_from_list(config_override)
# Make an instantiated object of this class immutable.
self._C.freeze()
def dump(self, file_path: str):
r"""Save config at the specified file path.
Parameters
----------
file_path: str
(YAML) path to save config at.
"""
self._C.dump(stream=open(file_path, "w"))
def __getattr__(self, attr: str):
return self._C.__getattr__(attr)
def __repr__(self):
return self._C.__repr__()
import os
from dataset_RGB import DataLoaderTrain, DataLoaderVal, DataLoaderTest
def get_training_data(rgb_dir, img_options):
assert os.path.exists(rgb_dir)
return DataLoaderTrain(rgb_dir, img_options)
def get_validation_data(rgb_dir, img_options):
assert os.path.exists(rgb_dir)
return DataLoaderVal(rgb_dir, img_options)
def get_test_data(rgb_dir, img_options):
assert os.path.exists(rgb_dir)
return DataLoaderTest(rgb_dir, img_options)
import os
from torch.utils.data import Dataset
import torch
from PIL import Image
import torchvision.transforms.functional as TF
from pdb import set_trace as stx
import random
def is_image_file(filename):
return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
class DataLoaderTrain(Dataset):
def __init__(self, rgb_dir, img_options=None):
super(DataLoaderTrain, self).__init__()
inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))
self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)]
self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)]
self.img_options = img_options
self.sizex = len(self.tar_filenames) # get the size of target
self.ps = self.img_options['patch_size']
def __len__(self):
return self.sizex
def __getitem__(self, index):
index_ = index % self.sizex
ps = self.ps
inp_path = self.inp_filenames[index_]
tar_path = self.tar_filenames[index_]
inp_img = Image.open(inp_path)
tar_img = Image.open(tar_path)
w,h = tar_img.size
padw = ps-w if w<ps else 0
padh = ps-h if h<ps else 0
# Reflect Pad in case image is smaller than patch_size
if padw!=0 or padh!=0:
inp_img = TF.pad(inp_img, (0,0,padw,padh), padding_mode='reflect')
tar_img = TF.pad(tar_img, (0,0,padw,padh), padding_mode='reflect')
inp_img = TF.to_tensor(inp_img)
tar_img = TF.to_tensor(tar_img)
hh, ww = tar_img.shape[1], tar_img.shape[2]
rr = random.randint(0, hh-ps)
cc = random.randint(0, ww-ps)
aug = random.randint(0, 8)
# Crop patch
inp_img = inp_img[:, rr:rr+ps, cc:cc+ps]
tar_img = tar_img[:, rr:rr+ps, cc:cc+ps]
# Data Augmentations
if aug==1:
inp_img = inp_img.flip(1)
tar_img = tar_img.flip(1)
elif aug==2:
inp_img = inp_img.flip(2)
tar_img = tar_img.flip(2)
elif aug==3:
inp_img = torch.rot90(inp_img,dims=(1,2))
tar_img = torch.rot90(tar_img,dims=(1,2))
elif aug==4:
inp_img = torch.rot90(inp_img,dims=(1,2), k=2)
tar_img = torch.rot90(tar_img,dims=(1,2), k=2)
elif aug==5:
inp_img = torch.rot90(inp_img,dims=(1,2), k=3)
tar_img = torch.rot90(tar_img,dims=(1,2), k=3)
elif aug==6:
inp_img = torch.rot90(inp_img.flip(1),dims=(1,2))
tar_img = torch.rot90(tar_img.flip(1),dims=(1,2))
elif aug==7:
inp_img = torch.rot90(inp_img.flip(2),dims=(1,2))
tar_img = torch.rot90(tar_img.flip(2),dims=(1,2))
filename = os.path.splitext(os.path.split(tar_path)[-1])[0]
return tar_img, inp_img, filename
class DataLoaderVal(Dataset):
def __init__(self, rgb_dir, img_options=None, rgb_dir2=None):
super(DataLoaderVal, self).__init__()
inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))
self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)]
self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)]
self.img_options = img_options
self.sizex = len(self.tar_filenames) # get the size of target
self.ps = self.img_options['patch_size']
def __len__(self):
return self.sizex
def __getitem__(self, index):
index_ = index % self.sizex
ps = self.ps
inp_path = self.inp_filenames[index_]
tar_path = self.tar_filenames[index_]
inp_img = Image.open(inp_path)
tar_img = Image.open(tar_path)
# Validate on center crop
if self.ps is not None:
inp_img = TF.center_crop(inp_img, (ps,ps))
tar_img = TF.center_crop(tar_img, (ps,ps))
inp_img = TF.to_tensor(inp_img)
tar_img = TF.to_tensor(tar_img)
filename = os.path.splitext(os.path.split(tar_path)[-1])[0]
return tar_img, inp_img, filename
class DataLoaderTest(Dataset):
def __init__(self, inp_dir, img_options):
super(DataLoaderTest, self).__init__()
inp_files = sorted(os.listdir(inp_dir))
self.inp_filenames = [os.path.join(inp_dir, x) for x in inp_files if is_image_file(x)]
self.inp_size = len(self.inp_filenames)
self.img_options = img_options
def __len__(self):
return self.inp_size
def __getitem__(self, index):
path_inp = self.inp_filenames[index]
filename = os.path.splitext(os.path.split(path_inp)[-1])[0]
inp = Image.open(path_inp)
inp = TF.to_tensor(inp)
return inp, filename
clc;close all;clear all;addpath(genpath('./'));
% datasets = {'Rain100L'};
datasets = {'Rain100L', 'Rain100H', 'Test100', 'Test1200', 'Test2800'};
num_set = length(datasets);
psnr_alldatasets = 0;
ssim_alldatasets = 0;
for idx_set = 1:num_set
file_path = strcat('./results/', datasets{idx_set}, '/');
gt_path = strcat('./Datasets/test/', datasets{idx_set}, '/target/');
path_list = [dir(strcat(file_path,'*.jpg')); dir(strcat(file_path,'*.png'))];
gt_list = [dir(strcat(gt_path,'*.jpg')); dir(strcat(gt_path,'*.png'))];
img_num = length(path_list);
total_psnr = 0;
total_ssim = 0;
if img_num > 0
for j = 1:img_num
image_name = path_list(j).name;
gt_name = gt_list(j).name;
input = imread(strcat(file_path,image_name));
gt = imread(strcat(gt_path, gt_name));
ssim_val = compute_ssim(input, gt);
psnr_val = compute_psnr(input, gt);
total_ssim = total_ssim + ssim_val;
total_psnr = total_psnr + psnr_val;
end
end
qm_psnr = total_psnr / img_num;
qm_ssim = total_ssim / img_num;
fprintf('For %s dataset PSNR: %f SSIM: %f\n', datasets{idx_set}, qm_psnr, qm_ssim);
psnr_alldatasets = psnr_alldatasets + qm_psnr;
ssim_alldatasets = ssim_alldatasets + qm_ssim;
end
fprintf('For all datasets PSNR: %f SSIM: %f\n', psnr_alldatasets/num_set, ssim_alldatasets/num_set);
function ssim_mean=compute_ssim(img1,img2)
if size(img1, 3) == 3
img1 = rgb2ycbcr(img1);
img1 = img1(:, :, 1);
end
if size(img2, 3) == 3
img2 = rgb2ycbcr(img2);
img2 = img2(:, :, 1);
end
ssim_mean = SSIM_index(img1, img2);
end
function psnr=compute_psnr(img1,img2)
if size(img1, 3) == 3
img1 = rgb2ycbcr(img1);
img1 = img1(:, :, 1);
end
if size(img2, 3) == 3
img2 = rgb2ycbcr(img2);
img2 = img2(:, :, 1);
end
imdff = double(img1) - double(img2);
imdff = imdff(:);
rmse = sqrt(mean(imdff.^2));
psnr = 20*log10(255/rmse);
end
function [mssim, ssim_map] = SSIM_index(img1, img2, K, window, L)
%========================================================================
%SSIM Index, Version 1.0
%Copyright(c) 2003 Zhou Wang
%All Rights Reserved.
%
%The author is with Howard Hughes Medical Institute, and Laboratory
%for Computational Vision at Center for Neural Science and Courant
%Institute of Mathematical Sciences, New York University.
%
%----------------------------------------------------------------------
%Permission to use, copy, or modify this software and its documentation
%for educational and research purposes only and without fee is hereby
%granted, provided that this copyright notice and the original authors'
%names appear on all copies and supporting documentation. This program
%shall not be used, rewritten, or adapted as the basis of a commercial
%software or hardware product without first obtaining permission of the
%authors. The authors make no representations about the suitability of
%this software for any purpose. It is provided "as is" without express
%or implied warranty.
%----------------------------------------------------------------------
%
%This is an implementation of the algorithm for calculating the
%Structural SIMilarity (SSIM) index between two images. Please refer
%to the following paper:
%
%Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image
%quality assessment: From error measurement to structural similarity"
%IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004.
%
%Kindly report any suggestions or corrections to zhouwang@ieee.org
%
%----------------------------------------------------------------------
%
%Input : (1) img1: the first image being compared
% (2) img2: the second image being compared
% (3) K: constants in the SSIM index formula (see the above
% reference). defualt value: K = [0.01 0.03]
% (4) window: local window for statistics (see the above
% reference). default widnow is Gaussian given by
% window = fspecial('gaussian', 11, 1.5);
% (5) L: dynamic range of the images. default: L = 255
%
%Output: (1) mssim: the mean SSIM index value between 2 images.
% If one of the images being compared is regarded as
% perfect quality, then mssim can be considered as the
% quality measure of the other image.
% If img1 = img2, then mssim = 1.
% (2) ssim_map: the SSIM index map of the test image. The map
% has a smaller size than the input images. The actual size:
% size(img1) - size(window) + 1.
%
%Default Usage:
% Given 2 test images img1 and img2, whose dynamic range is 0-255
%
% [mssim ssim_map] = ssim_index(img1, img2);
%
%Advanced Usage:
% User defined parameters. For example
%
% K = [0.05 0.05];
% window = ones(8);
% L = 100;
% [mssim ssim_map] = ssim_index(img1, img2, K, window, L);
%
%See the results:
%
% mssim %Gives the mssim value
% imshow(max(0, ssim_map).^4) %Shows the SSIM index map
%
%========================================================================
if (nargin < 2 || nargin > 5)
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
if (size(img1) ~= size(img2))
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
[M N] = size(img1);
if (nargin == 2)
if ((M < 11) || (N < 11))
ssim_index = -Inf;
ssim_map = -Inf;
return
end
window = fspecial('gaussian', 11, 1.5); %
K(1) = 0.01; % default settings
K(2) = 0.03; %
L = 255; %
end
if (nargin == 3)
if ((M < 11) || (N < 11))
ssim_index = -Inf;
ssim_map = -Inf;
return
end
window = fspecial('gaussian', 11, 1.5);
L = 255;
if (length(K) == 2)
if (K(1) < 0 || K(2) < 0)
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
else
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
end
if (nargin == 4)
[H W] = size(window);
if ((H*W) < 4 || (H > M) || (W > N))
ssim_index = -Inf;
ssim_map = -Inf;
return
end
L = 255;
if (length(K) == 2)
if (K(1) < 0 || K(2) < 0)
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
else
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
end
if (nargin == 5)
[H W] = size(window);
if ((H*W) < 4 || (H > M) || (W > N))
ssim_index = -Inf;
ssim_map = -Inf;
return
end
if (length(K) == 2)
if (K(1) < 0 || K(2) < 0)
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
else
ssim_index = -Inf;
ssim_map = -Inf;
return;
end
end
C1 = (K(1)*L)^2;
C2 = (K(2)*L)^2;
window = window/sum(sum(window));
img1 = double(img1);
img2 = double(img2);
mu1 = filter2(window, img1, 'valid');
mu2 = filter2(window, img2, 'valid');
mu1_sq = mu1.*mu1;
mu2_sq = mu2.*mu2;
mu1_mu2 = mu1.*mu2;
sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq;
sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq;
sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2;
if (C1 > 0 & C2 > 0)
ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2));
else
numerator1 = 2*mu1_mu2 + C1;
numerator2 = 2*sigma12 + C2;
denominator1 = mu1_sq + mu2_sq + C1;
denominator2 = sigma1_sq + sigma2_sq + C2;
ssim_map = ones(size(mu1));
index = (denominator1.*denominator2 > 0);
ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index));
index = (denominator1 ~= 0) & (denominator2 == 0);
ssim_map(index) = numerator1(index)./denominator1(index);
end
mssim = mean2(ssim_map);
end
import torch
import torch.nn as nn
import torch.nn.functional as F
class CharbonnierLoss(nn.Module):
"""Charbonnier Loss (L1)"""
def __init__(self, eps=1e-3):
super(CharbonnierLoss, self).__init__()
self.eps = eps
def forward(self, x, y):
diff = x - y
# loss = torch.sum(torch.sqrt(diff * diff + self.eps))
loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
return loss
class EdgeLoss(nn.Module):
def __init__(self):
super(EdgeLoss, self).__init__()
k = torch.Tensor([[.05, .25, .4, .25, .05]])
self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1)
if torch.cuda.is_available():
self.kernel = self.kernel.cuda()
self.loss = CharbonnierLoss()
def conv_gauss(self, img):
n_channels, _, kw, kh = self.kernel.shape
img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
return F.conv2d(img, self.kernel, groups=n_channels)
def laplacian_kernel(self, current):
filtered = self.conv_gauss(current) # filter
down = filtered[:,:,::2,::2] # downsample
new_filter = torch.zeros_like(filtered)
new_filter[:,:,::2,::2] = down*4 # upsample
filtered = self.conv_gauss(new_filter) # filter
diff = current - filtered
return diff
def forward(self, x, y):
loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
return loss
pre-trained deraining model is available [here](https://drive.google.com/file/d/1O3WEJbcat7eTY6doXWeorAbQ1l_WmMnM/view?usp=sharing)
\ No newline at end of file
"""
## Multi-Stage Progressive Image Restoration
## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Shao
## https://arxiv.org/abs/2102.02808
"""
import numpy as np
import os
import argparse
from tqdm import tqdm
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from data_RGB import get_test_data
from MPRNet import MPRNet
from skimage import img_as_ubyte
from pdb import set_trace as stx
parser = argparse.ArgumentParser(description='Image Deraining using MPRNet')
parser.add_argument('--input_dir', default='./Datasets/test/', type=str, help='Directory of validation images')
parser.add_argument('--result_dir', default='./results/', type=str, help='Directory for results')
parser.add_argument('--weights', default='./pretrained_models/model_deraining.pth', type=str, help='Path to weights')
parser.add_argument('--gpus', default='0', type=str, help='CUDA_VISIBLE_DEVICES')
args = parser.parse_args()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
model_restoration = MPRNet()
utils.load_checkpoint(model_restoration,args.weights)
print("===>Testing using weights: ",args.weights)
model_restoration.cuda()
model_restoration = nn.DataParallel(model_restoration)
model_restoration.eval()
datasets = ['Rain100L', 'Rain100H', 'Test100', 'Test1200', 'Test2800']
# datasets = ['Rain100L']
for dataset in datasets:
rgb_dir_test = os.path.join(args.input_dir, dataset, 'input')
test_dataset = get_test_data(rgb_dir_test, img_options={})
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False, pin_memory=True)
result_dir = os.path.join(args.result_dir, dataset)
utils.mkdir(result_dir)
with torch.no_grad():
for ii, data_test in enumerate(tqdm(test_loader), 0):
torch.cuda.ipc_collect()
torch.cuda.empty_cache()
input_ = data_test[0].cuda()
filenames = data_test[1]
restored = model_restoration(input_)
restored = torch.clamp(restored[0],0,1)
restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
for batch in range(len(restored)):
restored_img = img_as_ubyte(restored[batch])
utils.save_img((os.path.join(result_dir, filenames[batch]+'.png')), restored_img)
import os
from config import Config
opt = Config('training.yml')
gpus = ','.join([str(i) for i in opt.GPU])
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpus
import torch
torch.backends.cudnn.benchmark = True
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import random
import time
import numpy as np
import utils
from data_RGB import get_training_data, get_validation_data
from MPRNet import MPRNet
import losses
from warmup_scheduler import GradualWarmupScheduler
from tqdm import tqdm
from pdb import set_trace as stx
######### Set Seeds ###########
random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234)
torch.cuda.manual_seed_all(1234)
start_epoch = 1
mode = opt.MODEL.MODE
session = opt.MODEL.SESSION
result_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'results', session)
model_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'models', session)
utils.mkdir(result_dir)
utils.mkdir(model_dir)
train_dir = opt.TRAINING.TRAIN_DIR
val_dir = opt.TRAINING.VAL_DIR
######### Model ###########
model_restoration = MPRNet()
model_restoration.cuda()
device_ids = [i for i in range(torch.cuda.device_count())]
if torch.cuda.device_count() > 1:
print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n")
new_lr = opt.OPTIM.LR_INITIAL
optimizer = optim.Adam(model_restoration.parameters(), lr=new_lr, betas=(0.9, 0.999),eps=1e-8)
######### Scheduler ###########
warmup_epochs = 3
scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.OPTIM.NUM_EPOCHS-warmup_epochs, eta_min=opt.OPTIM.LR_MIN)
scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
scheduler.step()
######### Resume ###########
if opt.TRAINING.RESUME:
path_chk_rest = utils.get_last_path(model_dir, '_latest.pth')
utils.load_checkpoint(model_restoration,path_chk_rest)
start_epoch = utils.load_start_epoch(path_chk_rest) + 1
utils.load_optim(optimizer, path_chk_rest)
for i in range(1, start_epoch):
scheduler.step()
new_lr = scheduler.get_lr()[0]
print('------------------------------------------------------------------------------')
print("==> Resuming Training with learning rate:", new_lr)
print('------------------------------------------------------------------------------')
if len(device_ids)>1:
model_restoration = nn.DataParallel(model_restoration, device_ids = device_ids)
######### Loss ###########
criterion_char = losses.CharbonnierLoss()
criterion_edge = losses.EdgeLoss()
######### DataLoaders ###########
train_dataset = get_training_data(train_dir, {'patch_size':opt.TRAINING.TRAIN_PS})
train_loader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=16, drop_last=False, pin_memory=True)
val_dataset = get_validation_data(val_dir, {'patch_size':opt.TRAINING.VAL_PS})
val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=False, num_workers=8, drop_last=False, pin_memory=True)
print('===> Start Epoch {} End Epoch {}'.format(start_epoch,opt.OPTIM.NUM_EPOCHS + 1))
print('===> Loading datasets')
best_psnr = 0
best_epoch = 0
for epoch in range(start_epoch, opt.OPTIM.NUM_EPOCHS + 1):
epoch_start_time = time.time()
epoch_loss = 0
train_id = 1
model_restoration.train()
for i, data in enumerate(tqdm(train_loader), 0):
# zero_grad
for param in model_restoration.parameters():
param.grad = None
target = data[0].cuda()
input_ = data[1].cuda()
restored = model_restoration(input_)
# Compute loss at each stage
loss_char = np.sum([criterion_char(restored[j],target) for j in range(len(restored))])
loss_edge = np.sum([criterion_edge(restored[j],target) for j in range(len(restored))])
loss = (loss_char) + (0.05*loss_edge)
loss.backward()
optimizer.step()
epoch_loss +=loss.item()
#### Evaluation ####
if epoch%opt.TRAINING.VAL_AFTER_EVERY == 0:
model_restoration.eval()
psnr_val_rgb = []
for ii, data_val in enumerate((val_loader), 0):
target = data_val[0].cuda()
input_ = data_val[1].cuda()
with torch.no_grad():
restored = model_restoration(input_)
restored = restored[0]
for res,tar in zip(restored,target):
psnr_val_rgb.append(utils.torchPSNR(res, tar))
psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item()
if psnr_val_rgb > best_psnr:
best_psnr = psnr_val_rgb
best_epoch = epoch
torch.save({'epoch': epoch,
'state_dict': model_restoration.state_dict(),
'optimizer' : optimizer.state_dict()
}, os.path.join(model_dir,"model_best.pth"))
print("[epoch %d PSNR: %.4f --- best_epoch %d Best_PSNR %.4f]" % (epoch, psnr_val_rgb, best_epoch, best_psnr))
torch.save({'epoch': epoch,
'state_dict': model_restoration.state_dict(),
'optimizer' : optimizer.state_dict()
}, os.path.join(model_dir,f"model_epoch_{epoch}.pth"))
scheduler.step()
print("------------------------------------------------------------------")
print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.8f}".format(epoch, time.time()-epoch_start_time, epoch_loss, scheduler.get_lr()[0]))
print("------------------------------------------------------------------")
torch.save({'epoch': epoch,
'state_dict': model_restoration.state_dict(),
'optimizer' : optimizer.state_dict()
}, os.path.join(model_dir,"model_latest.pth"))
###############
##
####
GPU: [0,1,2,3]
VERBOSE: True
MODEL:
MODE: 'Deraining'
SESSION: 'MPRNet'
# Optimization arguments.
OPTIM:
BATCH_SIZE: 16
NUM_EPOCHS: 250
# NEPOCH_DECAY: [10]
LR_INITIAL: 2e-4
LR_MIN: 1e-6
# BETA1: 0.9
TRAINING:
VAL_AFTER_EVERY: 5
RESUME: False
TRAIN_PS: 256
VAL_PS: 128
TRAIN_DIR: './Datasets/train' # path to training data
VAL_DIR: './Datasets/test/Rain100L' # path to validation data
SAVE_DIR: './checkpoints' # path to save models and images
# SAVE_IMAGES: False
from .dir_utils import *
from .image_utils import *
from .model_utils import *
from .dataset_utils import *
import torch
class MixUp_AUG:
def __init__(self):
self.dist = torch.distributions.beta.Beta(torch.tensor([0.6]), torch.tensor([0.6]))
def aug(self, rgb_gt, rgb_noisy):
bs = rgb_gt.size(0)
indices = torch.randperm(bs)
rgb_gt2 = rgb_gt[indices]
rgb_noisy2 = rgb_noisy[indices]
lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda()
rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2
rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2
return rgb_gt, rgb_noisy
\ No newline at end of file
import os
from natsort import natsorted
from glob import glob
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def get_last_path(path, session):
x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1]
return x
\ No newline at end of file
import torch
import numpy as np
import cv2
def torchPSNR(tar_img, prd_img):
imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1)
rmse = (imdff**2).mean().sqrt()
ps = 20*torch.log10(1/rmse)
return ps
def save_img(filepath, img):
cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
def numpyPSNR(tar_img, prd_img):
imdff = np.float32(prd_img) - np.float32(tar_img)
rmse = np.sqrt(np.mean(imdff**2))
ps = 20*np.log10(255/rmse)
return ps
import torch
import os
from collections import OrderedDict
def freeze(model):
for p in model.parameters():
p.requires_grad=False
def unfreeze(model):
for p in model.parameters():
p.requires_grad=True
def is_frozen(model):
x = [p.requires_grad for p in model.parameters()]
return not all(x)
def save_checkpoint(model_dir, state, session):
epoch = state['epoch']
model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session))
torch.save(state, model_out_path)
def load_checkpoint(model, weights):
checkpoint = torch.load(weights)
try:
model.load_state_dict(checkpoint["state_dict"])
except:
state_dict = checkpoint["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
def load_checkpoint_multigpu(model, weights):
checkpoint = torch.load(weights)
state_dict = checkpoint["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
def load_start_epoch(weights):
checkpoint = torch.load(weights)
epoch = checkpoint["epoch"]
return epoch
def load_optim(optimizer, weights):
checkpoint = torch.load(weights)
optimizer.load_state_dict(checkpoint['optimizer'])
# for p in optimizer.param_groups: lr = p['lr']
# return lr
# MPRNet
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-stage-progressive-image-restoration/deblurring-on-gopro)](https://paperswithcode.com/sota/deblurring-on-gopro?p=multi-stage-progressive-image-restoration)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-stage-progressive-image-restoration/deblurring-on-hide-trained-on-gopro)](https://paperswithcode.com/sota/deblurring-on-hide-trained-on-gopro?p=multi-stage-progressive-image-restoration)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-stage-progressive-image-restoration/deblurring-on-realblur-r-trained-on-gopro)](https://paperswithcode.com/sota/deblurring-on-realblur-r-trained-on-gopro?p=multi-stage-progressive-image-restoration)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-stage-progressive-image-restoration/deblurring-on-realblur-j-trained-on-gopro)](https://paperswithcode.com/sota/deblurring-on-realblur-j-trained-on-gopro?p=multi-stage-progressive-image-restoration)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-stage-progressive-image-restoration/deblurring-on-realblur-r)](https://paperswithcode.com/sota/deblurring-on-realblur-r?p=multi-stage-progressive-image-restoration)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-stage-progressive-image-restoration/deblurring-on-realblur-j-1)](https://paperswithcode.com/sota/deblurring-on-realblur-j-1?p=multi-stage-progressive-image-restoration)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-stage-progressive-image-restoration/single-image-deraining-on-rain100h)](https://paperswithcode.com/sota/single-image-deraining-on-rain100h?p=multi-stage-progressive-image-restoration)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-stage-progressive-image-restoration/single-image-deraining-on-rain100l)](https://paperswithcode.com/sota/single-image-deraining-on-rain100l?p=multi-stage-progressive-image-restoration)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-stage-progressive-image-restoration/single-image-deraining-on-test100)](https://paperswithcode.com/sota/single-image-deraining-on-test100?p=multi-stage-progressive-image-restoration)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-stage-progressive-image-restoration/single-image-deraining-on-test1200)](https://paperswithcode.com/sota/single-image-deraining-on-test1200?p=multi-stage-progressive-image-restoration)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-stage-progressive-image-restoration/single-image-deraining-on-test2800)](https://paperswithcode.com/sota/single-image-deraining-on-test2800?p=multi-stage-progressive-image-restoration)
# Multi-Stage Progressive Image Restoration (CVPR 2021)
[Syed Waqas Zamir](https://scholar.google.es/citations?user=WNGPkVQAAAAJ&hl=en), [Aditya Arora](https://adityac8.github.io/), [Salman Khan](https://salman-h-khan.github.io/), [Munawar Hayat](https://scholar.google.com/citations?user=Mx8MbWYAAAAJ&hl=en), [Fahad Shahbaz Khan](https://scholar.google.es/citations?user=zvaeYnUAAAAJ&hl=en), [Ming-Hsuan Yang](https://scholar.google.com/citations?user=p9-ohHsAAAAJ&hl=en), and [Ling Shao](https://scholar.google.com/citations?user=z84rLjoAAAAJ&hl=en)
**Paper**: https://arxiv.org/abs/2102.02808
**Supplementary**: [pdf](https://drive.google.com/file/d/1mbfljawUuFUQN9V5g0Rmw1UdauJdckCu/view?usp=sharing)
> **Abstract:** *Image restoration tasks demand a complex balance between spatial details and high-level contextualized information while recovering images. In this paper, we propose a novel synergistic design that can optimally balance these competing goals. Our main proposal is a multi-stage architecture, that progressively learns restoration functions for the degraded inputs, thereby breaking down the overall recovery process into more manageable steps. Specifically, our model first learns the contextualized features using encoder-decoder architectures and later combines them with a high-resolution branch that retains local information. At each stage, we introduce a novel per-pixel adaptive design that leverages in-situ supervised attention to reweight the local features. A key ingredient in such a multi-stage architecture is the information exchange between different stages. To this end, we propose a two-faceted approach where the information is not only exchanged sequentially from early to late stages, but lateral connections between feature processing blocks also exist to avoid any loss of information. The resulting tightly interlinked multi-stage architecture, named as MPRNet, delivers strong performance gains on ten datasets across a range of tasks including image deraining, deblurring, and denoising. For example, on the Rain100L, GoPro and DND datasets, we obtain PSNR gains of 4 dB, 0.81 dB and 0.21 dB, respectively, compared to the state-of-the-art.*
## Network Architecture
<table>
<tr>
<td> <img src = "https://i.imgur.com/69c0pQv.png" width="500"> </td>
<td> <img src = "https://i.imgur.com/JJAKXOi.png" width="400"> </td>
</tr>
<tr>
<td><p align="center"><b>Overall Framework of MPRNet</b></p></td>
<td><p align="center"> <b>Supervised Attention Module (SAM)</b></p></td>
</tr>
</table>
## Installation
The model is built in PyTorch 1.1.0 and tested on Ubuntu 16.04 environment (Python3.7, CUDA9.0, cuDNN7.5).
For installing, follow these intructions
```
conda create -n pytorch1 python=3.7
conda activate pytorch1
conda install pytorch=1.1 torchvision=0.3 cudatoolkit=9.0 -c pytorch
pip install matplotlib scikit-image opencv-python yacs joblib natsort h5py tqdm
```
Install warmup scheduler
```
cd pytorch-gradual-warmup-lr; python setup.py install; cd ..
```
## Quick Run
To test the pre-trained models of [Deblurring](https://drive.google.com/file/d/1QwQUVbk6YVOJViCsOKYNykCsdJSVGRtb/view?usp=sharing), [Deraining](https://drive.google.com/file/d/1O3WEJbcat7eTY6doXWeorAbQ1l_WmMnM/view?usp=sharing), [Denoising](https://drive.google.com/file/d/1LODPt9kYmxwU98g96UrRA0_Eh5HYcsRw/view?usp=sharing) on your own images, run
```
python demo.py --task Task_Name --input_dir path_to_images --result_dir save_images_here
```
Here is an example to perform Deblurring:
```
python demo.py --task Deblurring --input_dir ./samples/input/ --result_dir ./samples/output/
```
## Training and Evaluation
Training and Testing codes for deblurring, deraining and denoising are provided in their respective directories.
## Results
Experiments are performed for different image processing tasks including, image deblurring, image deraining and image denoising.
### Image Deblurring
<table>
<tr>
<td> <img src = "https://i.imgur.com/UIwmY13.png" width="450"> </td>
<td> <img src = "https://i.imgur.com/ecSlcEo.png" width="450"> </td>
</tr>
<tr>
<td><p align="center"><b>Deblurring on Synthetic Datasets.</b></p></td>
<td><p align="center"><b>Deblurring on Real Dataset.</b></p></td>
</tr>
</table>
### Image Deraining
<img src = "https://i.imgur.com/YVXWRJT.png" width="900">
### Image Denoising
<p align="center"> <img src = "https://i.imgur.com/Wssu6Xu.png" width="450"> </p>
## Citation
If you use MPRNet, please consider citing:
@inproceedings{Zamir2021MPRNet,
title={Multi-Stage Progressive Image Restoration},
author={Syed Waqas Zamir and Aditya Arora and Salman Khan and Munawar Hayat
and Fahad Shahbaz Khan and Ming-Hsuan Yang and Ling Shao},
booktitle={CVPR},
year={2021}
}
## Contact
Should you have any question, please contact waqas.zamir@inceptioniai.org
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image
import os
from runpy import run_path
from skimage import img_as_ubyte
from collections import OrderedDict
from natsort import natsorted
from glob import glob
import cv2
import argparse
parser = argparse.ArgumentParser(description='Demo MPRNet')
parser.add_argument('--input_dir', default='./samples/input/', type=str, help='Input images')
parser.add_argument('--result_dir', default='./samples/output/', type=str, help='Directory for results')
parser.add_argument('--task', default='Denoising', type=str, help='Task to run', choices=['Deblurring', 'Denoising', 'Deraining'])
args = parser.parse_args()
def save_img(filepath, img):
cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
def load_checkpoint(model, weights):
checkpoint = torch.load(weights)
try:
model.load_state_dict(checkpoint["state_dict"])
except:
state_dict = checkpoint["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
task = args.task
inp_dir = args.input_dir
out_dir = args.result_dir
os.makedirs(out_dir, exist_ok=True)
files = natsorted(glob(os.path.join(inp_dir, '*.jpg'))
+ glob(os.path.join(inp_dir, '*.JPG'))
+ glob(os.path.join(inp_dir, '*.png'))
+ glob(os.path.join(inp_dir, '*.PNG')))
if len(files) == 0:
raise Exception(f"No files found at {inp_dir}")
# Load corresponding model architecture and weights
load_file = run_path(os.path.join(task, "MPRNet.py"))
model = load_file['MPRNet']()
model.cuda()
weights = os.path.join(task, "pretrained_models", "model_"+task.lower()+".pth")
load_checkpoint(model, weights)
model.eval()
img_multiple_of = 8
for file_ in files:
# img = Image.open(file_).convert('RGB')
img = cv2.imread(file_)
input_ = TF.to_tensor(img).unsqueeze(0).cuda()
# Pad the input if not_multiple_of 8
h,w = input_.shape[2], input_.shape[3]
H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of
padh = H-h if h%img_multiple_of!=0 else 0
padw = W-w if w%img_multiple_of!=0 else 0
input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
print(input_)
with torch.no_grad():
restored = model(input_)
restored = restored[0]
print( restored)
restored = torch.clamp(restored, 0, 1)
# Unpad the output
restored = restored[:,:,:h,:w]
restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
restored = img_as_ubyte(restored[0])
f = os.path.splitext(os.path.split(file_)[-1])[0]
save_img((os.path.join(out_dir, f+'.png')), restored)
print(f"Files saved at {out_dir}")
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import setuptools
_VERSION = '0.3'
REQUIRED_PACKAGES = [
]
DEPENDENCY_LINKS = [
]
setuptools.setup(
name='warmup_scheduler',
version=_VERSION,
description='Gradually Warm-up LR Scheduler for Pytorch',
install_requires=REQUIRED_PACKAGES,
dependency_links=DEPENDENCY_LINKS,
url='https://github.com/ildoonet/pytorch-gradual-warmup-lr',
license='MIT License',
package_dir={},
packages=setuptools.find_packages(exclude=['tests']),
)
from warmup_scheduler.scheduler import GradualWarmupScheduler
import torch
from torch.optim.lr_scheduler import StepLR, ExponentialLR
from torch.optim.sgd import SGD
from warmup_scheduler import GradualWarmupScheduler
if __name__ == '__main__':
model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
optim = SGD(model, 0.1)
# scheduler_warmup is chained with schduler_steplr
scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1)
scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr)
# this zero gradient update is needed to avoid a warning message, issue #8.
optim.zero_grad()
optim.step()
for epoch in range(1, 20):
scheduler_warmup.step(epoch)
print(epoch, optim.param_groups[0]['lr'])
optim.step() # backward pass (update network)
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
class GradualWarmupScheduler(_LRScheduler):
""" Gradually warm-up(increasing) learning rate in optimizer.
Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
Args:
optimizer (Optimizer): Wrapped optimizer.
multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
total_epoch: target learning rate is reached at total_epoch, gradually
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
"""
def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
self.multiplier = multiplier
if self.multiplier < 1.:
raise ValueError('multiplier should be greater thant or equal to 1.')
self.total_epoch = total_epoch
self.after_scheduler = after_scheduler
self.finished = False
super(GradualWarmupScheduler, self).__init__(optimizer)
def get_lr(self):
if self.last_epoch > self.total_epoch:
if self.after_scheduler:
if not self.finished:
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
self.finished = True
return self.after_scheduler.get_lr()
return [base_lr * self.multiplier for base_lr in self.base_lrs]
if self.multiplier == 1.0:
return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
else:
return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
def step_ReduceLROnPlateau(self, metrics, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
if self.last_epoch <= self.total_epoch:
warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
param_group['lr'] = lr
else:
if epoch is None:
self.after_scheduler.step(metrics, None)
else:
self.after_scheduler.step(metrics, epoch - self.total_epoch)
def step(self, epoch=None, metrics=None):
if type(self.after_scheduler) != ReduceLROnPlateau:
if self.finished and self.after_scheduler:
if epoch is None:
self.after_scheduler.step(None)
else:
self.after_scheduler.step(epoch - self.total_epoch)
else:
return super(GradualWarmupScheduler, self).step(epoch)
else:
self.step_ReduceLROnPlateau(metrics, epoch)
import argparse
import os
from runpy import run_path
from collections import OrderedDict
import cv2
import torchvision.transforms.functional as TF
import torch
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
parser = argparse.ArgumentParser(description='Demo MPRNet')
parser.add_argument('--input_dir', default='./samples/input/', type=str, help='Input images')
parser.add_argument('--result_dir', default='./samples/output/', type=str, help='Directory for results')
parser.add_argument('--task', default='Denoising', type=str, help='Task to run', choices=['Deblurring', 'Denoising', 'Deraining'])
args = parser.parse_args()
def load_checkpoint(model, weights):
checkpoint = torch.load(weights)
try:
model.load_state_dict(checkpoint["state_dict"])
except:
state_dict = checkpoint["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
# Load corresponding model architecture and weights
task = args.task
load_file = run_path(os.path.join(task, "MPRNet.py"))
model = load_file['MPRNet']()
weights = os.path.join(task, "pretrained_models", "model_"+task.lower()+".pth")
load_checkpoint(model, weights)
model.eval()
img = cv2.imread('samples/input/4.jpg')
input_ = TF.to_tensor(img).unsqueeze(0)
with torch.no_grad():
example = torch.ones((1, 3, 128, 128))
trace_model = torch.jit.trace(model, example)
r1 = model(input_)
r2 = trace_model(input_)
print(r1)
print(r2)
trace_model.save('denoise.pt')
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment