"""
@Time    : 2020/6/30 11:23
@Author  : Qin Dian
@Manual  : liver tumor dataset
"""
from torch.utils.data import Dataset
import numpy as np
import os
import torch
import utils.data_utils as du
from glob import glob


LOWER_BOUND = -60
UPPER_BOUND = 140


class SliceDataset(Dataset):
    def __init__(self, load_path, index_path, task='3tumors', series_index=None):
        super(SliceDataset, self).__init__()
        self.load_path = load_path
        self.series_index = series_index
        self.task = task
        if series_index is not None:
            self.tumor_slices = np.load(os.path.join(index_path, f'tumor_slices_s{series_index}.npy'))

    def __len__(self):
        return len(self.tumor_slices)

    def __getitem__(self, item):
        # Data loading
        f_name = self.tumor_slices[item]
        npz = np.load(os.path.join(self.load_path, f_name), allow_pickle=True)
        ct = npz.get('ct')
        mask = npz.get('mask') >> 1

        # Preprocess
        if self.task == '3tumors':
            mask[(mask == 6) | (mask == 2) | (mask == 3)] = 1
            mask[mask > 3] -= 2

            if self.series_index is not None:
                ct = ct[self.series_index]
                mask = mask[self.series_index]

        ct = du.window_standardize(ct, LOWER_BOUND, UPPER_BOUND)
        ct = du.to_3channel_replica(ct)
        mask = du.to_3channel_replica(mask)

        # To tensor & cut to 384
        ct = torch.from_numpy(du.cut_384(ct.copy())).float()
        # self replication to 3 channel for adapting RGB channel
        mask = torch.from_numpy(du.cut_384(mask.copy())).float()

        return ct, mask


class CTDataset(Dataset):
    def __init__(self, load_path, task='3tumors', series_index=None):
        self.files = glob(os.path.join(load_path, 'ct*.npz'))
        self.series_index = series_index
        self.task = task

    def __len__(self):
        return len(self.files)

    def __getitem__(self, item):
        file = np.load(self.files[item], allow_pickle=True)
        ct = file.get('arrays', file.get('arr_0'))
        ct = ct.reshape(3, -1, 512, 512)
        if self.series_index is not None:
            ct = ct[self.series_index]
        ct = du.window_standardize(ct, LOWER_BOUND, UPPER_BOUND)
        ct = du.cut_384(ct)
        ct = du.to_3channel_replica(ct)

        ct = torch.from_numpy(ct.copy()).float()
        return ct


# class Images_Dataset_folder(torch.utils.data.Dataset):
#     """Class for getting individual transformations and data
#     Args:
#         images_dir = path of input images
#         labels_dir = path of labeled images
#         transformI = Input Images transformation (default: None)
#         transformM = Input Labels transformation (default: None)
#     Output:
#         tx = Transformed images
#         lx = Transformed labels
#     """
#
#     def __init__(self, images_dir, labels_dir, transformI=None, transformM=None):
#         self.images = sorted(os.listdir(images_dir))
#         self.labels = sorted(os.listdir(labels_dir))
#         self.images_dir = images_dir
#         self.labels_dir = labels_dir
#         self.transformI = transformI
#         self.transformM = transformM
#
#         if self.transformI:
#             self.tx = self.transformI
#         else:
#             self.tx = torchvision.transforms.Compose([
#                 #  torchvision.transforms.Resize((128,128)),
#                 torchvision.transforms.CenterCrop(96),
#                 torchvision.transforms.RandomRotation((-10, 10)),
#                 # torchvision.transforms.RandomHorizontalFlip(),
#                 torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
#                 torchvision.transforms.ToTensor(),
#                 torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
#             ])
#
#         if self.transformM:
#             self.lx = self.transformM
#         else:
#             self.lx = torchvision.transforms.Compose([
#                 #  torchvision.transforms.Resize((128,128)),
#                 torchvision.transforms.CenterCrop(96),
#                 torchvision.transforms.RandomRotation((-10, 10)),
#                 torchvision.transforms.Grayscale(),
#                 torchvision.transforms.ToTensor(),
#                 # torchvision.transforms.Lambda(lambda x: torch.cat([x, 1 - x], dim=0))
#             ])
#
#     def __len__(self):
#
#         return len(self.images)
#
#     def __getitem__(self, i):
#         i1 = Image.open(self.images_dir + self.images[i])
#         l1 = Image.open(self.labels_dir + self.labels[i])
#
#         seed = np.random.randint(0, 2 ** 32)  # make a seed with numpy generator
#
#         # apply this seed to img tranfsorms
#         random.seed(seed)
#         torch.manual_seed(seed)
#         img = self.tx(i1)
#
#         # apply this seed to target/label tranfsorms
#         random.seed(seed)
#         torch.manual_seed(seed)
#         label = self.lx(l1)
#         return img, label

















