# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

import numpy
import numpy as np
import torch
from fairseq.data import FairseqDataset, data_utils

logger = logging.getLogger(__name__)


def collate(
        samples,
        pad_idx,
        bos_idx,
        left_pad_source=True,
        left_pad_target=False,
        input_feeding=True,
        pad_to_length=None,
        pad_to_multiple=1,
):
    if len(samples) == 0:
        return {}

    def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None, pad_idx=pad_idx):
        return data_utils.collate_tokens(
            [s[key] for s in samples],
            pad_idx,
            eos_idx=bos_idx,
            left_pad=left_pad,
            move_eos_to_beginning=move_eos_to_beginning,
            pad_to_length=pad_to_length,
            pad_to_multiple=pad_to_multiple,
        )

    id = torch.LongTensor([s["id"] for s in samples])
    src_tokens = merge(
        "source",
        left_pad=left_pad_source,
        pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
    )
    # sort by descending source length
    src_lengths = torch.LongTensor(
        [s["source"].ne(pad_idx).long().sum() for s in samples]
    )

    #适配Teacher-forcing训练
    src_alignments_length = [s["source_alignment"].shape[0] for s in samples]
    src_alignments = merge("source_alignment",
                           left_pad=False,
                           pad_to_length=pad_to_length["source_alignment"] if pad_to_length is not None else None,
                           pad_idx=0)

    src_lengths, sort_order = src_lengths.sort(descending=True)
    id = id.index_select(0, sort_order)
    src_tokens = src_tokens.index_select(0, sort_order)

    src_alignments_masks = src_alignments > 0
    # for i in range(len(src_alignments_length)):
    #     src_alignments_masks[i, :src_alignments_length[i]] = 1
    src_alignments = src_alignments.index_select(0, sort_order)
    src_alignments_masks = src_alignments_masks.index_select(0, sort_order)

    prev_output_tokens = None
    target = None
    if samples[0].get("target", None) is not None:
        target = merge(
            "target",
            left_pad=left_pad_target,
            pad_to_length=pad_to_length["target"]
            if pad_to_length is not None
            else None,
        )
        target = target.index_select(0, sort_order)
        tgt_lengths = torch.LongTensor(
            [s["target"].ne(pad_idx).long().sum() for s in samples]
        ).index_select(0, sort_order)
        ntokens = tgt_lengths.sum().item()

        if samples[0].get("prev_output_tokens", None) is not None:
            prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target)
        elif input_feeding:
            # we create a shifted version of targets for feeding the
            # previous output token(s) into the next decoder step
            prev_output_tokens = merge(
                "target",
                left_pad=left_pad_target,
                move_eos_to_beginning=True,
                pad_to_length=pad_to_length["target"]
                if pad_to_length is not None
                else None,
            )
    else:
        ntokens = src_lengths.sum().item()

    #适配Teacher-forcing训练

    tgt_alignments = merge(
            "target_alignment",
            left_pad=False,
            pad_to_length=pad_to_length["target_alignment"]
            if pad_to_length is not None
            else None,
            pad_idx=0,
        )

    # temp_tgt_alignments = merge(
    #         "target_alignment",
    #         left_pad=False,
    #         pad_to_length=pad_to_length["target_alignment"]
    #         if pad_to_length is not None
    #         else None,
    #         pad_idx=-1,
    #     )
    tgt_alignments_masks = tgt_alignments > 0
    # for i in range(len(tgt_alignments_length)):
    #     tgt_alignments_masks[i, :tgt_alignments_length[i]] = 1

    tgt_alignments = tgt_alignments.index_select(0, sort_order)
    tgt_alignments_masks = tgt_alignments_masks.index_select(0, sort_order)

    notes = merge("notes",
                  left_pad=False,
                  pad_to_length=pad_to_length["notes"] if pad_to_length is not None else None,
                  pad_idx=0,
                  )
    notes = notes.index_select(0, sort_order)
    durs = merge("durs",
                 left_pad=False,
                 pad_to_length=pad_to_length["durs"] if pad_to_length is not None else None,
                 pad_idx = 0,
                 )
    durs = durs.index_select(0, sort_order)

    # def get_ones_index(_alignments):
    #     nonzero_alignments = _alignments.nonzero()
    #     _dim0 = []
    #     _dim1 = []
    #     _dim2 = []
    #     start_dim1 = 0
    #     end_dim1 = 0
    #     start_dim2 = 0
    #     end_dim2 = 0
    #     for i in range(nonzero_alignments.shape[0]):
    #         if i == 0 or nonzero_alignments[i][0] != nonzero_alignments[i - 1][0]:
    #             num_dim0 = _alignments[nonzero_alignments[i][0]][nonzero_alignments[i][1]] * (nonzero_alignments[i][1] + 1) * (nonzero_alignments[i][1] + 1)
    #             start_dim1 = 0
    #             end_dim1 = nonzero_alignments[i][1] + 1
    #             start_dim2 = 0
    #             end_dim2 = _alignments[nonzero_alignments[i][0]][nonzero_alignments[i][1]] + 1
    #         else:
    #             num_dim0 = _alignments[nonzero_alignments[i][0]][nonzero_alignments[i][1]] * (nonzero_alignments[i][1] - nonzero_alignments[i - 1][1]) * (nonzero_alignments[i][1] - nonzero_alignments[i - 1][1])
    #             start_dim1 = end_dim1
    #             end_dim1 = nonzero_alignments[i][1] + 1
    #             start_dim2 = end_dim2
    #             end_dim2 = end_dim2 + _alignments[nonzero_alignments[i][0]][nonzero_alignments[i][1]] + 1
    #         _dim0.append(torch.LongTensor([nonzero_alignments[i][0]] * num_dim0))
    #         # [1,2,3,1,2,3,1,2,3]
    #         _dim1.append(torch.arange(start_dim1, end_dim1).repeat(end_dim1 - start_dim1))
    #         # [1,1,1,2,2,2,3,3,3]
    #         _dim2.append(torch.arange(start_dim2, end_dim2).repeat(end_dim2 - start_dim2, 1).transpose(0, 1).reshape(1, -1).squeeze(0))
    #     _dim0 = torch.cat(_dim0)
    #     _dim1 = torch.cat(_dim1)
    #     _dim2 = torch.cat(_dim2)
    #     return (_dim0, _dim1, _dim2)
    # B, Ty = tgt_alignments.shape
    # Tm = torch.cumsum(tgt_alignments, dim=-1).max()
    #
    # attn_alignments = torch.zeros(B, Ty, Tm)
    # attn_alignments[get_ones_index(tgt_alignments)] = 1

    target_masks = merge("target_masks",
                  left_pad=False,
                  pad_to_length=pad_to_length["target_masks"] if pad_to_length is not None else None,
                  pad_idx=0,
                  )

    batch = {
        "id": id,
        "nsentences": len(samples),
        "ntokens": ntokens,
        "net_input": {"src_tokens": src_tokens, "src_lengths": src_lengths,
                      "src_alignments": src_alignments,
                      "tgt_alignments": tgt_alignments,
                      "src_alignments_masks": src_alignments_masks,
                      "tgt_alignments_masks": tgt_alignments_masks,
                      "notes": notes,
                      "durs": durs,
                      },
        "target": target[:, :],
        "source": merge("source",
                        left_pad=False,
                        pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
                       ).index_select(0, sort_order),
        "target_masks": target_masks.index_select(0, sort_order),
        "target_alignments": tgt_alignments,
        # "alignments": ,
        # "align_weights": torch.ones_like(_sample["alignments"])
    }


    if prev_output_tokens is not None:
        prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
        # batch["net_input"]["prev_output_tokens"] = prev_output_tokens[:, :-1]
        batch["net_input"]["prev_output_tokens"] = prev_output_tokens

    return batch


class MultiLanguagePairWithMelodyDataset(FairseqDataset):
    """
    A pair of torch.utils.data.Datasets.

    Args:
        src (torch.utils.data.Dataset): source dataset to wrap
        src_sizes (List[int]): source sentence lengths
        src_dict (~fairseq.data.Dictionary): source vocabulary
        tgt (torch.utils.data.Dataset, optional): target dataset to wrap
        tgt_sizes (List[int], optional): target sentence lengths
        tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
        left_pad_source (bool, optional): pad source tensors on the left side
            (default: True).
        left_pad_target (bool, optional): pad target tensors on the left side
            (default: False).
        shuffle (bool, optional): shuffle dataset elements before batching
            (default: True).
        input_feeding (bool, optional): create a shifted version of the targets
            to be passed into the model for teacher forcing (default: True).
        remove_eos_from_source (bool, optional): if set, removes eos from end
            of source if it's present (default: False).
        append_eos_to_target (bool, optional): if set, appends eos to end of
            target if it's absent (default: False).
        align_dataset (torch.utils.data.Dataset, optional): dataset
            containing alignments.
        constraints (Tensor, optional): 2d tensor with a concatenated, zero-
            delimited list of constraints for each sentence.
        append_bos (bool, optional): if set, appends bos to the beginning of
            source/target sentence.
        num_buckets (int, optional): if set to a value greater than 0, then
            batches will be bucketed into the given number of batch shapes.
        src_lang_id (int, optional): source language ID, if set, the collated batch
            will contain a field 'src_lang_id' in 'net_input' which indicates the
            source language of the samples.
        tgt_lang_id (int, optional): target language ID, if set, the collated batch
            will contain a field 'tgt_lang_id' which indicates the target language
             of the samples.
    """

    def __init__(
            self,
            split,
            src,
            src_align_dataset,
            src_sizes,
            src_dict,
            tgt=None,
            tgt_align_dataset=None,
            tgt_sizes=None,
            tgt_dict=None,
            melody=None,
            left_pad_source=True,
            left_pad_target=False,
            shuffle=True,
            input_feeding=True,
            remove_eos_from_source=False,
            append_eos_to_target=False,
            constraints=None,
            append_bos=False,
            eos=None,
            bos=None,
            num_buckets=0,
            src_lang_id=None,
            tgt_lang_id=None,
            pad_to_multiple=1,
            max_delta_note=15,
    ):
        if tgt_dict is not None:
            assert src_dict.pad() == tgt_dict.pad()
            assert src_dict.eos() == tgt_dict.eos()
            assert src_dict.unk() == tgt_dict.unk()
        if tgt is not None:
            assert len(src) == len(tgt), "Source and target must contain the same number of examples"
        if src_align_dataset is not None and tgt_align_dataset is not None:
            assert len(src_align_dataset) == len(tgt_align_dataset), "Source and target alignments must contain the same number of examples"
        self.src = src
        self.tgt = tgt
        self.src_sizes = np.array(src_sizes)
        self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
        self.split = split
        self.src_dict = src_dict
        self.tgt_dict = tgt_dict
        self.src_align_dataset = src_align_dataset
        self.tgt_align_dataset = tgt_align_dataset
        self.melody = melody
        self.left_pad_source = left_pad_source
        self.left_pad_target = left_pad_target
        self.shuffle = shuffle
        self.input_feeding = input_feeding
        self.remove_eos_from_source = remove_eos_from_source
        self.append_eos_to_target = append_eos_to_target
        if self.src_align_dataset is not None and self.tgt_align_dataset is not None:
            assert (self.tgt_sizes is not None), "Both source and target needed when alignments are provided"
        self.constraints = constraints
        self.append_bos = append_bos
        self.eos = eos if eos is not None else src_dict.eos()
        self.bos = bos if bos is not None else src_dict.bos()
        self.src_lang_id = src_lang_id
        self.tgt_lang_id = tgt_lang_id
        self.max_delta_note = max_delta_note
        if num_buckets > 0:
            from fairseq.data import BucketPadLengthDataset

            self.src = BucketPadLengthDataset(
                self.src,
                sizes=self.src_sizes,
                num_buckets=num_buckets,
                pad_idx=self.src_dict.pad(),
                left_pad=self.left_pad_source,
            )
            self.src_sizes = self.src.sizes
            logger.info("bucketing source lengths: {}".format(list(self.src.buckets)))
            if self.tgt is not None:
                self.tgt = BucketPadLengthDataset(
                    self.tgt,
                    sizes=self.tgt_sizes,
                    num_buckets=num_buckets,
                    pad_idx=self.tgt_dict.pad(),
                    left_pad=self.left_pad_target,
                )
                self.tgt_sizes = self.tgt.sizes
                logger.info(
                    "bucketing target lengths: {}".format(list(self.tgt.buckets))
                )

            # determine bucket sizes using self.num_tokens, which will return
            # the padded lengths (thanks to BucketPadLengthDataset)
            num_tokens = np.vectorize(self.num_tokens, otypes=[np.compat.long])
            self.bucketed_num_tokens = num_tokens(np.arange(len(self.src)))
            self.buckets = [
                (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens)
            ]
        else:
            self.buckets = None
        self.pad_to_multiple = pad_to_multiple

    def get_batch_shapes(self):
        return self.buckets

    def __getitem__(self, index):
        tgt_item = self.tgt[index] if self.tgt is not None else None
        src_item = self.src[index]
        melody = self.melody[index]
        src_alignment = self.src_align_dataset[index]
        trg_alignment = self.tgt_align_dataset[index]
        # Append EOS to end of tgt sentence if it does not have an EOS and remove
        # EOS from end of src sentence if it exists. This is useful when we use
        # use existing datasets for opposite directions i.e., when we want to
        # use tgt_dataset as src_dataset and vice versa
        if self.append_eos_to_target:
            eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
            if self.tgt and self.tgt[index][-1] != eos:
                tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])

        if self.append_bos:
            bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
            if self.tgt and self.tgt[index][0] != bos:
                tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])

            bos = self.src_dict.bos()
            if self.src[index][0] != bos:
                src_item = torch.cat([torch.LongTensor([bos]), self.src[index]])

        if self.remove_eos_from_source:
            eos = self.src_dict.eos()
            if self.src[index][-1] == eos:
                src_item = self.src[index][:-1]

        example = {
            "id": index,
            "source": src_item,
            "target": tgt_item,
            "notes": torch.LongTensor(melody['notes']),
            "durs": torch.LongTensor(melody['durs']),
            "is_slur": melody['is_slur'],
            "source_alignment": torch.LongTensor(src_alignment),
            "target_alignment": torch.LongTensor(trg_alignment),
            "target_masks": tgt_item != self.tgt_dict.pad()
        }
        if self.constraints is not None:
            example["constraints"] = self.constraints[index]
        return example

    def __len__(self):
        return len(self.src)

    def collater(self, samples, pad_to_length=None):
        """Merge a list of samples to form a mini-batch.

        Args:
            samples (List[dict]): samples to collate
            pad_to_length (dict, optional): a dictionary of
                {'source': source_pad_to_length, 'target': target_pad_to_length}
                to indicate the max length to pad to in source and target respectively.

        Returns:
            dict: a mini-batch with the following keys:

                - `id` (LongTensor): example IDs in the original input order
                - `ntokens` (int): total number of tokens in the batch
                - `net_input` (dict): the input to the Model, containing keys:

                  - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
                    the source sentence of shape `(bsz, src_len)`. Padding will
                    appear on the left if *left_pad_source* is ``True``.
                  - `src_lengths` (LongTensor): 1D Tensor of the unpadded
                    lengths of each source sentence of shape `(bsz)`
                  - `prev_output_tokens` (LongTensor): a padded 2D Tensor of
                    tokens in the target sentence, shifted right by one
                    position for teacher forcing, of shape `(bsz, tgt_len)`.
                    This key will not be present if *input_feeding* is
                    ``False``.  Padding will appear on the left if
                    *left_pad_target* is ``True``.
                  - `src_lang_id` (LongTensor): a long Tensor which contains source
                    language IDs of each sample in the batch

                - `target` (LongTensor): a padded 2D Tensor of tokens in the
                  target sentence of shape `(bsz, tgt_len)`. Padding will appear
                  on the left if *left_pad_target* is ``True``.
                - `tgt_lang_id` (LongTensor): a long Tensor which contains target language
                   IDs of each sample in the batch
        """
        res = collate(
            samples,
            pad_idx=self.src_dict.pad(),
            bos_idx=self.bos,
            left_pad_source=self.left_pad_source,
            left_pad_target=self.left_pad_target,
            input_feeding=self.input_feeding,
            pad_to_length=pad_to_length,
            pad_to_multiple=self.pad_to_multiple,
        )

        if 'target_alignments' in res:
            res['target_alignments'] = torch.clamp(res['target_alignments'], min=0, max=self.max_delta_note)
        if self.src_lang_id is not None or self.tgt_lang_id is not None:
            src_tokens = res["net_input"]["src_tokens"]
            bsz = src_tokens.size(0)
            if self.src_lang_id is not None:
                res["net_input"]["src_lang_id"] = (
                    torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens)
                )
            if self.tgt_lang_id is not None:
                res["tgt_lang_id"] = (
                    torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens)
                )
        return res

    def num_tokens(self, index):
        """Return the number of tokens in a sample. This value is used to
        enforce ``--max-tokens`` during batching."""
        return max(
            self.src_sizes[index],
            self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
        )

    def num_tokens_vec(self, indices):
        """Return the number of tokens for a set of positions defined by indices.
        This value is used to enforce ``--max-tokens`` during batching."""
        sizes = self.src_sizes[indices]
        if self.tgt_sizes is not None:
            sizes = np.maximum(sizes, self.tgt_sizes[indices])
        return sizes

    def size(self, index):
        """Return an example's size as a float or tuple. This value is used when
        filtering a dataset with ``--max-positions``."""
        return (
            self.src_sizes[index],
            self.tgt_sizes[index] if self.tgt_sizes is not None else 0,
        )

    def ordered_indices(self):
        """Return an ordered list of indices. Batches will be constructed based
        on this order."""
        if self.shuffle:
            indices = np.random.permutation(len(self)).astype(np.int64)
        else:
            indices = np.arange(len(self), dtype=np.int64)
        if self.buckets is None:
            # sort by target length, then source length
            if self.tgt_sizes is not None:
                indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")]
            return indices[np.argsort(self.src_sizes[indices], kind="mergesort")]
        else:
            # sort by bucketed_num_tokens, which is:
            #   max(padded_src_len, padded_tgt_len)
            return indices[
                np.argsort(self.bucketed_num_tokens[indices], kind="mergesort")
            ]

    @property
    def supports_prefetch(self):
        return getattr(self.src, "supports_prefetch", False) and (
                getattr(self.tgt, "supports_prefetch", False) or self.tgt is None
        )

    @property
    def sizes(self):
        return self.src_sizes

    def prefetch(self, indices):
        self.src.prefetch(indices)
        if self.tgt is not None:
            self.tgt.prefetch(indices)
        if self.align_dataset is not None:
            self.align_dataset.prefetch(indices)
