"""
@Time    : 2020/6/2 14:55
@Author  : Qin Dian
@Manual  :
"""

import os
import torch
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
import argparse
from networks.deep_snake import U_Net
from datasets.liver import SliceDataset, CTDataset
import utils.net_utils as nu
import supports.loss_functions as lf
import torch.nn.functional as F


pl.seed_everything(123)
parser = argparse.ArgumentParser('deep snake')
parser.add_argument('--data_path', type=str, default=r'D:\A_UNET\out')
parser.add_argument('--tumor_index_path', type=str, default=r'D:\A_UNET')
parser.add_argument('--checkpoint_path', type=str, default=r'D:\A_UNET\out\checkpoints')
parser.add_argument('--test_path', type=str, default='/data/ts4.0')
parser.add_argument('--series_index', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=2)
parser.add_argument('--max_epochs', type=int, default=50)
parser.add_argument('--gpu', type=list, default=[0])


def dice_loss(prediction, target):
    """Calculating the dice loss
    Args:
        prediction = predicted image
        target = Targeted image
    Output:
        dice_loss"""

    smooth = 1.0

    i_flat = prediction.view(-1)
    t_flat = target.view(-1)

    intersection = (i_flat * t_flat).sum()

    return 1 - ((2. * intersection + smooth) / (i_flat.sum() + t_flat.sum() + smooth))


def calc_loss(prediction, target, bce_weight=0.5):
    """Calculating the loss and metrics
    Args:
        prediction = predicted image
        target = Targeted image
        metrics = Metrics printed
        bce_weight = 0.5 (default)
    Output:
        loss : dice loss of the epoch """
    bce = F.binary_cross_entropy_with_logits(prediction, target)
    prediction = F.sigmoid(prediction)
    dice = dice_loss(prediction, target)

    loss = bce * bce_weight + dice * (1 - bce_weight)

    return loss

class DeepSnake(pl.LightningModule):
    def __init__(self, params):
        super(DeepSnake, self).__init__()
        self.params = params
        self.net = U_Net(in_ch=3, out_ch=3)

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_idx):
        ct, mask = batch
        output = self.forward(ct)
        lossT = calc_loss(output, mask)     # Dice_loss Used
        loss = lossT.item() * ct.size(0)

        logs = {'train_loss': loss}

        return {'loss': loss, 'log': logs}

    def test_step(self, batch, batch_idx):
        x = batch[0]
        num = int(x.size()[0])
        for i in range(num//8 + 1):
            if i < num//8:
                out = self(x[i*8:i*8+8])
                out = nu.clipped_sigmoid(out['ct_hm'])
            else:
                out = self(x[-8:])
                out = nu.clipped_sigmoid(out['ct_hm'])

    def train_dataloader(self):
        return DataLoader(SliceDataset(load_path=self.params.data_path, index_path=self.params.tumor_index_path,
                                       series_index=self.params.series_index), batch_size=self.params.batch_size,
                          num_workers=8, pin_memory=True, shuffle=True)

    def test_dataloader(self):
        return DataLoader(CTDataset(load_path='/data/ts4.0', series_index=2),
                          batch_size=1, num_workers=2, pin_memory=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4, betas=(0.5, 0.999))


def main():
    args = parser.parse_args()
    model = DeepSnake(args)

    # checkpoint
    checkpoint_callback = ModelCheckpoint(
        filepath=os.path.join(args.checkpoint_path, 'checkpoint_{epoch}')
        # save_last=True  # always saves the model at the end of the epoch
    )

    # trainer = Trainer.from_argparse_args(args, gpus=args.gpu, amp_level='O2', precision=16)
    trainer = Trainer.from_argparse_args(args, gpus=args.gpu, checkpoint_callback=checkpoint_callback)
    trainer.fit(model)


def test():
    args = parser.parse_args()
    model = DeepSnake.load_from_checkpoint(
        '/data/checkpoints/deep_snake/last.ckpt',
    )

    trainer = Trainer(gpus=[0])
    trainer.test(model)


if __name__ == '__main__':
    main()
    # test()