# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math

from fairseq import metrics, utils
from fairseq.criterions import register_criterion
import torch
import torch.nn as nn
import torch.nn.functional as F
from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion, label_smoothed_nll_loss


@register_criterion("label_smoothed_cross_entropy_with_alignment")
class LabelSmoothedCrossEntropyCriterionWithAlignment(
    LabelSmoothedCrossEntropyCriterion
):
    def __init__(self, task, sentence_avg, label_smoothing,
                 ignore_prefix_size=0,
                 report_accuracy=False,
                 alignment_lambda=0.0,
                 multi_alignment_weight=1.0,
                 distill_length_token_lambda=0.0,
                 act_reg_type=None,
                 alignment_decoder_type=None,
                 grouping_arch=None,):
        super(LabelSmoothedCrossEntropyCriterionWithAlignment, self).__init__(task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy)
        self.alignment_lambda = alignment_lambda
        self.alignment_weight = multi_alignment_weight
        self.distill_lambda = distill_length_token_lambda
        self.act_reg_type = act_reg_type
        self.alignment_decoder_type = alignment_decoder_type
        self.grouping_arch = grouping_arch


    @staticmethod
    def add_args(parser):
        """Add criterion-specific arguments to the parser."""
        LabelSmoothedCrossEntropyCriterion.add_args(parser)
        parser.add_argument(
            "--alignment-lambda",
            default=1.0,
            type=float,
            help="lambda for the alignment loss",
        )
        parser.add_argument(
            "--multi-alignment-weight",
            default=1.0,
            type=float,
            help="weight for the alignment loss",
        )
        parser.add_argument(
            "--distill-length-token-lambda",
            default=0.0,
            type=float,
            help="weight for the distill-length-token loss",
        )
        parser.add_argument(
            "--act-reg-type",
            default='gaussian', # or geometric
            type=str,
            help="ACT regularization distribution type",
        )

    def forward(self, model, sample, reduce=True, infer=False):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample["net_input"])
        loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
        sample_size = (
            sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
        )
        logging_output = {
            "nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data,
            "ntokens": sample["ntokens"],
            "nsentences": sample["target"].size(0),
            "sample_size": sample_size,
        }

        # Compute alignment loss only for training set and non dummy batches.
        if "target_alignments" in sample.keys() and sample["target_alignments"] is not None:

            if self.alignment_decoder_type == 'grouping':

                if self.grouping_arch == 'act':
                    gp_loss = self.compute_act_loss(sample, net_output, reduce=reduce)
                else:
                    gp_loss, gp_nll_loss = self.compute_grouping_loss(sample, net_output, reduce=reduce)

                logging_output["gp_loss"] = utils.item(gp_loss.data) if reduce else gp_loss.data
                loss += self.alignment_lambda * gp_loss

            elif self.alignment_decoder_type == 'simple':
                alignment_loss, alignment_nll_loss = self.compute_simple_alignment_loss(model, sample, net_output, reduce=reduce)
                logging_output["simple_alignment_loss"] = utils.item(alignment_loss.data) if reduce else alignment_loss.data
                loss += self.alignment_lambda * alignment_loss
            elif self.alignment_decoder_type == 'attention':
                alignment_loss = self.compute_attn_alignment_loss(sample, (None, {'attn': net_output[2]['alignment_attn']}))
                logging_output["attn_alignment_loss"] = utils.item(alignment_loss.data) if reduce else alignment_loss.data
                loss += self.alignment_lambda * alignment_loss

        if self.distill_lambda > 0:
            distill_loss = self.distill_with_length_token(net_output, reduce=reduce)
            logging_output["distill_loss"] = utils.item(distill_loss.data) if reduce else distill_loss.data
            loss += self.distill_lambda * distill_loss

        logging_output['loss'] = utils.item(loss.data) if reduce else loss.data

        return loss, sample_size, logging_output


    def compute_simple_alignment_loss(self, model, sample, net_output, reduce=True):
        tgt_alignments = F.pad(sample["target_alignments"], (0, 1))
        alignment_classification_prob, tgt_alignments = self.get_lprobs_and_target(model, [net_output[1]], {"target": tgt_alignments})
        # unclamped_alignments = F.pad(sample['net_input']['tgt_alignments'], (0, 1)).view(-1, 1).squeeze(-1)
        # clamped_index = (tgt_alignments == unclamped_alignments) * tgt_alignments
        step_weight = torch.ones_like(tgt_alignments).masked_fill(tgt_alignments > 1, self.alignment_weight).float()
        step_weight = step_weight.unsqueeze(-1).repeat(1, 1, alignment_classification_prob.shape[1])

        if len(alignment_classification_prob.shape) == 2:
            step_weight = step_weight.squeeze(0)

        alignment_classification_prob = alignment_classification_prob * step_weight
        loss, nll_loss = label_smoothed_nll_loss(alignment_classification_prob,
                                                 tgt_alignments,
                                                 self.eps,
                                                 reduce=reduce)
        return loss, nll_loss

    def compute_attn_alignment_loss(self, sample, net_output):
        attn_prob = net_output[1]["attn"][0]
        bsz, tgt_sz, src_sz = attn_prob.shape
        attn = attn_prob.view(bsz * tgt_sz, src_sz)

        align = sample["alignments"]
        align_weights = sample["align_weights"].float()

        if len(align) > 0:
            # Alignment loss computation. align (shape [:, 2]) contains the src-tgt index pairs corresponding to
            # the alignments. align_weights (shape [:]) contains the 1 / frequency of a tgt index for normalizing.
            loss = -(
                (attn[align[:, 1][:, None], align[:, 0][:, None]]).log()
                * align_weights[:, None]
            ).sum()
        else:
            return None

        return loss


    def compute_grouping_loss(self, sample, net_output, reduce=True):
        tgt_alignments = F.pad(sample["target_alignments"], (0, 1))
        tgt_grouping = torch.zeros_like(sample['net_input']['notes']).long()
        # grouping_masks = sample['net_input']['notes'] != 0



        temp_cumsum = torch.cumsum(tgt_alignments, dim=-1)
        ones_index = temp_cumsum.nonzero()
        ones_index[:, -1] = temp_cumsum[ones_index[:, 0], ones_index[:, 1]]

        tgt_grouping = tgt_grouping.masked_fill(sample['net_input']['notes'] != 0, 1.0)
        tgt_grouping[ones_index[:, 0], ones_index[:, 1] - 1] = 2.0
        step_probs = net_output[2]['ponder_cost']
        if tgt_grouping.shape[1] != step_probs.shape[1]:
            for i in range(sample['net_input']['notes'].shape[0]):
                if (sample['net_input']['notes'][i] != 0).sum() != sample['net_input']['tgt_alignments'][i].sum():
                    print(sample['net_input']['notes'][i])
                    print(sample['net_input']['tgt_alignments'][i])
        step_probs = F.log_softmax(step_probs, dim=-1)

        grouping_loss, grouping_nll_loss = label_smoothed_nll_loss(step_probs,
                                                                   tgt_grouping,
                                                                   self.eps,
                                                                   ignore_index=0,
                                                                   reduce=False,
                                                                  )
        step_weight = torch.ones_like(tgt_grouping).masked_fill(tgt_grouping == 1.0, self.alignment_weight).float().unsqueeze(-1)

        grouping_loss = grouping_loss * step_weight
        grouping_nll_loss = grouping_nll_loss * step_weight
        if reduce:
            grouping_loss = grouping_loss.sum()
            grouping_nll_loss = grouping_nll_loss.sum()
        return grouping_loss, grouping_nll_loss


    def compute_act_loss(self, sample, net_output, reduce=True):
        tgt_alignments = F.pad(sample["target_alignments"], (0, 1))
        pred_alignments = net_output[1]
        ponder_dict = net_output[2]
        tgt_alignments_masks = sample["target_masks"]

        # cnt = 0
        # for i in range(pred_alignments.shape[0]):
        #     if tgt_alignments[i][0] > 1:
        #         cnt += 1
        #         print('multi-note begin')
        #         print(tgt_alignments[i])
        #         print(pred_alignments[i])
        #
        # print('multi-note begin ratio:', cnt / pred_alignments.shape[0])
        # print('ponder_cost: ', ponder_dict['ponder_cost'] * tgt_alignments_masks)
        # print('pred align: ', pred_alignments * tgt_alignments_masks)
        # print('tgt align', tgt_alignments)

        # ponder_dict['ponder_cost'] = torch.abs(ponder_dict['ponder_cost'] * tgt_alignments_masks)

        step_ponder_loss = nn.functional.smooth_l1_loss(pred_alignments * tgt_alignments_masks.float() + ponder_dict['ponder_cost'], tgt_alignments.float(), reduction='none') / tgt_alignments_masks.sum().float()
        step_weight = torch.zeros_like(tgt_alignments).masked_fill(tgt_alignments > 1, self.alignment_weight).float()
        step_ponder_loss = (step_ponder_loss * step_weight).sum() / step_weight.sum()
        sum_step_ponder_loss = nn.functional.l1_loss(torch.sum(pred_alignments * tgt_alignments_masks.float(), dim=-1) + torch.sum(ponder_dict['ponder_cost'], dim=-1), torch.sum(tgt_alignments, dim=-1).float(), reduction='sum')
        # if self.act_reg_type == 'gaussian':
        #     kl_loss = torch.distributions.kl_divergence()
        # if self.act_reg_type == 'geometric':
        #     kl_loss =
        if reduce:
            step_ponder_loss = step_ponder_loss.sum()
            sum_step_ponder_loss = sum_step_ponder_loss.sum()
        return step_ponder_loss + sum_step_ponder_loss
        # return ponder_loss + kl_loss


    def distill_with_length_token(self, net_output, reduce=False):
        length_control_hidden = net_output[2]
        # TODO
        length_token_emb = net_output[2]
        loss = nn.MSELoss(length_token_emb, length_control_hidden)
        return loss

    @staticmethod
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
        nll_loss_sum = utils.item(
            sum(log.get("nll_loss", 0) for log in logging_outputs)
        )
        simple_alignment_loss_sum = utils.item(
            sum(log.get("simple_alignment_loss", 0) for log in logging_outputs)
        )
        gp_loss_sum = utils.item(
            sum(log.get("gp_loss", 0) for log in logging_outputs)
        )
        distill_loss_sum = utils.item(
            sum(log.get("distill_loss", 0) for log in logging_outputs)
        )
        ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
        sample_size = utils.item(
            sum(log.get("sample_size", 0) for log in logging_outputs)
        )

        metrics.log_scalar(
            "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
        )
        metrics.log_scalar(
            "nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
        )
        if simple_alignment_loss_sum != 0:
            metrics.log_scalar(
                "simple_alignment_loss",
                simple_alignment_loss_sum / sample_size / math.log(2),
                sample_size,
                round=3,
            )
        if gp_loss_sum != 0:
            metrics.log_scalar(
                "gp_loss",
                gp_loss_sum / sample_size / math.log(2),
                sample_size,
                round=3,
            )
        if distill_loss_sum != 0:
            metrics.log_scalar(
                "distill_loss",
                distill_loss_sum / sample_size / math.log(2),
                sample_size,
                round=3,
            )
        metrics.log_derived(
            "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
        )


    @staticmethod
    def logging_outputs_can_be_summed() -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        # some named instance logged
        return True
