"""
@Time    : 2019/12/10 14:29
@Author  : Qin Dian
@Manual  : 
"""
import numpy as np
import os
import argparse
from glob import glob
from multiprocessing.dummy import Pool
from skimage import measure


parser = argparse.ArgumentParser(description='Slice Maker')
parser.add_argument('--in_path', type=str, default=r'D:\A_UNET\in')
parser.add_argument('--out_path', type=str, default=r'D:\A_UNET\out')
parser.add_argument('--generate_tumor_index', type=bool, default=True, help='是否生成肿瘤指示列表文件')
parser.add_argument('--series_index', type=int, default=0, help='哪一期的肿瘤指示列表')
parser.add_argument('--list_path', type=str, default=r'D:\A_UNET')
parser.add_argument('--process_num', type=int, default=20)

args = parser.parse_args()


def main():
    if not os.path.exists(args.out_path):
        os.mkdir(args.out_path)

    paths = glob(os.path.join(args.in_path, "ct*.npz"))
    pool = Pool(args.process_num)
    result = pool.map(make_slice, paths)
    if args.generate_tumor_index:
        tumor_slices = []
        for i in result:
            tumor_slices += i
        np.save(os.path.join(args.list_path, f'tumor_slices_s{args.series_index}.npy'), tumor_slices)


def make_slice(path):
    result = []

    ct = np.load(path, allow_pickle=True)
    names = os.path.split(path)
    mask = np.load(os.path.join(names[0], names[1][3:]), allow_pickle=True)
    ct = ct.get('arrays', ct.get('arr_0'))
    ct = ct.reshape(3, -1, 512, 512)
    mask = mask.get('arrays', mask.get('arr_0'))
    mask = mask.reshape(3, -1, 512, 512)
    case = names[1][3:].split('.')[0]
    for i in range(ct.shape[1]):
        ct_slice = ct[:, i, ...]
        mask_slice = mask[:, i, ...]
        if np.any(mask_slice > 1):
            # format of center: {0:[[x,y,h,w,c], ...], 1:[[x,y,h,w,c], ...], 2:[[x,y,h,w,c]}
            # h:max_x-min_x, w: max_y-min_y,
            # 即三期各期的中心点list
            centers = {}
            for series_i in range(3):
                t_mask = mask_slice[series_i] >> 1
                tumors = measure.label((t_mask > 0).astype(np.uint8), connectivity=1)
                regions = measure.regionprops(tumors)
                s_centers = []
                for region in regions:
                    min_x, min_y, max_x, max_y = region.bbox
                    s_centers.append([(min_x+max_x)//2, (min_y+max_y)//2, max_x-min_x, max_y-min_y,
                                      t_mask[(min_x+max_x)//2,  (min_y+max_y)//2]])
                centers[series_i] = s_centers
            np.savez_compressed(f'{args.out_path}/{case}_{i}.npz', ct=ct_slice, mask=mask_slice, centers=centers)
            # 只记录有tumor的slice
            if args.generate_tumor_index and np.any(mask_slice[args.series_index] > 1):
                print(f'{case}_{i}.npz')
                result.append(f'{case}_{i}.npz')
        else:
            np.savez_compressed(f'{args.out_path}/{case}_{i}.npz', ct=ct_slice, mask=mask_slice)

    print(f'complete making slices of {case}')
    return result


if __name__ == '__main__':
    main()
