"""
@Time    : 2020/7/3 11:13
@Author  : Qin Dian
@Manual  : 
"""
import torch
from torch import nn
import torch.nn.functional as f


class FocalLoss(nn.Module):
    '''nn.Module warpper for focal loss'''
    def __init__(self):
        super(FocalLoss, self).__init__()
        self.neg_loss = _neg_loss

    def forward(self, out, target):
        return self.neg_loss(out, target)


class IndL1Loss1d(nn.Module):
    def __init__(self, task='l1'):
        super(IndL1Loss1d, self).__init__()
        if task == 'l1':
            self.loss = f.l1_loss
        elif task == 'smooth_l1':
            self.loss = f.smooth_l1_loss

    def forward(self, output, target):
        output = output * target.bool()
        loss = self.loss(output, target, reduction='sum')
        loss = loss / (target.bool().sum() + 1e-4)
        return loss


def _neg_loss(pred, gt):
    """
    Modified focal loss. Exactly the same as CornerNet.
        Runs faster and costs a little bit more memory
        Arguments:
            pred (batch x c x h x w)
            gt (batch x c x h x w)
    """
    pos_inds = gt.eq(1).float()
    neg_inds = gt.lt(1).float()

    neg_weights = torch.pow(1 - gt, 4)

    loss = 0

    pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
    neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds

    num_pos = pos_inds.float().sum()
    pos_loss = pos_loss.sum()
    neg_loss = neg_loss.sum()

    if num_pos == 0:
        loss = loss - neg_loss
    else:
        loss = loss - (pos_loss + neg_loss) / num_pos
    return loss
