"""
@Time    : 2020/6/17 14:19
@Author  : Qin Dian
@Manual  : 
"""
import torch
import torch.nn.functional as F


def clipped_sigmoid(x):
    y = torch.clamp(x.sigmoid(), min=1e-4, max=1 - 1e-4)
    return y


def nms(heat, kernel=3):
    """
    Make finding top k center points more convenient by using max pooling
    """
    pad = (kernel - 1) // 2
    hmax = F.max_pool2d(
        heat, (kernel, kernel), stride=1, padding=pad)
    keep = (hmax == heat).type(torch.FloatTensor).to(heat.device)
    return heat * keep


def gather_feat(feat, ind, mask=None):
    dim = feat.size(2)
    ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
    feat = feat.gather(1, ind)
    if mask is not None:
        mask = mask.unsqueeze(2).expand_as(feat)
        feat = feat[mask]
        feat = feat.view(-1, dim)
    return feat


def topk(scores, k=40):
    batch, cat, height, width = scores.size()

    topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), k)

    topk_inds = topk_inds % (height * width)
    topk_ys = (topk_inds / width).int().float()
    topk_xs = (topk_inds % width).int().float()

    topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), k)
    topk_clses = (topk_ind / k).int()
    topk_inds = gather_feat(
        topk_inds.view(batch, -1, 1), topk_ind).view(batch, k)
    topk_ys = gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, k)
    topk_xs = gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, k)

    return topk_score, topk_inds, topk_clses, topk_ys, topk_xs


def transpose_and_gather_feat(feat, ind):
    feat = feat.permute(0, 2, 3, 1).contiguous()
    feat = feat.view(feat.size(0), -1, feat.size(3))
    feat = gather_feat(feat, ind)
    return feat


def extract_top_k_ctp_dt(ct_hm, wh, reg=None, k=100):
    batch, cat, height, width = ct_hm.size()
    ct_hm = nms(ct_hm)

    scores, inds, clses, ys, xs = topk(ct_hm, k=k)
    wh = transpose_and_gather_feat(wh, inds)
    wh = wh.view(batch, k, 2)

    if reg is not None:
        reg = transpose_and_gather_feat(reg, inds)
        reg = reg.view(batch, k, 2)
        xs = xs.view(batch, k, 1) + reg[:, :, 0:1]
        ys = ys.view(batch, k, 1) + reg[:, :, 1:2]
    else:
        xs = xs.view(batch, k, 1)
        ys = ys.view(batch, k, 1)

    clses = clses.view(batch, k, 1).float()
    scores = scores.view(batch, k, 1)
    ct = torch.cat([xs, ys], dim=2)
    bboxes = torch.cat([xs - wh[..., 0:1] / 2,
                        ys - wh[..., 1:2] / 2,
                        xs + wh[..., 0:1] / 2,
                        ys + wh[..., 1:2] / 2], dim=2)
    detection = torch.cat([bboxes, scores, clses], dim=2)

    return ct, detection
