"""
@Time    : 2020/6/2 14:55
@Author  : Qin Dian
@Manual  : Deep Snake Network
"""
import torch.nn as nn
# from networks.backbones.dla import DLASeg
# from utils import data_utils, net_utils
import torch


# class Network(nn.Module):
#     def __init__(self, num_layers, heads, head_conv=256, down_ratio=1):
#         super(Network, self).__init__()
#
#         self.dla = DLASeg('dla{}'.format(num_layers), heads,
#                           pretrained=True,
#                           down_ratio=down_ratio,
#                           final_kernel=1,
#                           last_level=5,
#                           head_conv=head_conv)
#
#     @staticmethod
#     def decode_detection(output, h, w):
#         ct_hm = output['ct_hm']
#         wh = output['hw']
#         ct, detection = net_utils.extract_top_k_ctp_dt(torch.sigmoid(ct_hm), wh)
#         detection[..., :4] = data_utils.clip_to_image(detection[..., :4], h, w)
#         output.update({'ct': ct, 'detection': detection})
#         return ct, detection
#
#     @staticmethod
#     def use_gt_detection(output, batch):
#         _, _, height, width = output['ct_hm'].size()
#         ct_01 = batch['ct_01'].byte()
#
#         ct_ind = batch['ct_ind'][ct_01]
#         xs, ys = ct_ind % width, ct_ind // width
#         xs, ys = xs[:, None].float(), ys[:, None].float()
#         ct = torch.cat([xs, ys], dim=1)
#
#         wh = batch['hw'][ct_01]
#         bboxes = torch.cat([xs - wh[..., 0:1] / 2,
#                             ys - wh[..., 1:2] / 2,
#                             xs + wh[..., 0:1] / 2,
#                             ys + wh[..., 1:2] / 2], dim=1)
#         score = torch.ones([len(bboxes)]).to(bboxes)[:, None]
#         ct_cls = batch['ct_cls'][ct_01].float()[:, None]
#         detection = torch.cat([bboxes, score, ct_cls], dim=1)
#
#         output['ct'] = ct[None]
#         output['detection'] = detection[None]
#
#         return output
#
#     def forward(self, x, batch=None):
#         # output 为目标检测特征
#         output, cnn_feature = self.dla(x)
#         with torch.no_grad():
#             # 处理目标检测特征为中心点和轮廓
#             self.decode_detection(output, cnn_feature.size(2), cnn_feature.size(3))
#         return output


class conv_block(nn.Module):
    """
    Convolution Block
    """

    def __init__(self, in_ch, out_ch):
        super(conv_block, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.conv(x)
        return x


class up_conv(nn.Module):
    """
    Up Convolution Block
    """
    def __init__(self, in_ch, out_ch):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x


class U_Net(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """
    def __init__(self, in_ch=3, out_ch=1):
        super(U_Net, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0])
        self.Conv2 = conv_block(filters[0], filters[1])
        self.Conv3 = conv_block(filters[1], filters[2])
        self.Conv4 = conv_block(filters[2], filters[3])
        self.Conv5 = conv_block(filters[3], filters[4])

        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = conv_block(filters[4], filters[3])

        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = conv_block(filters[3], filters[2])

        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = conv_block(filters[2], filters[1])

        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = conv_block(filters[1], filters[0])

        self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)

    # self.active = torch.nn.Sigmoid()

    def forward(self, x):
        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)
        d5 = torch.cat((e4, d5), dim=1)

        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        out = self.Conv(d2)

        # d1 = self.active(out)

        return out


if __name__ == '__main__':
    net = Network(num_layers=34, heads={'ct_hm': 7, 'hw': 2})
    net.eval()
    net.cuda()
    tmp_ct = torch.randn((10, 3, 512, 512)).cuda()
    out = net(tmp_ct)
    print(out['ct_hm'].size())