from fairseq.models import register_model, register_model_architecture
from fairseq.models.bart import BARTModel, mbart_base_architecture
from fairseq.models.bart.melody_modules import (
    AlignmentEncoder,
    LengthEncoder,
    SimpleDecoder,
    GroupingDecoder,
    TransformerMelodyDecoder,
    MelodyAlignmentEncoder,
    AlignmentAttentionDecoder)
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
import os
import glob
import re


def load_ckpt(cur_model, ckpt_base_dir, prefix_in_ckpt='model', force=True, strict=True):
    if os.path.isfile(ckpt_base_dir):
        base_dir = os.path.dirname(ckpt_base_dir)
        checkpoint_path = [ckpt_base_dir]
    else:
        base_dir = ckpt_base_dir
        checkpoint_path = sorted(glob.glob(f'{base_dir}/checkpoint*.pt'), key=
        lambda x: int(re.findall(f'{base_dir}/checkpoint(\d+).pt', x)[0]))
    if len(checkpoint_path) > 0:
        checkpoint_path = checkpoint_path[-1]
        state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
        # state_dict = {k[len(prefix_in_ckpt) + 1:]: v for k, v in state_dict.items()
        #               if k.startswith(f'{prefix_in_ckpt}.')}
        if not strict:
            cur_model_state_dict = cur_model.state_dict()
            unmatched_keys = []
            for key, param in state_dict.items():
                if key in cur_model_state_dict:
                    new_param = cur_model_state_dict[key]
                    if new_param.shape != param.shape:
                        unmatched_keys.append(key)
                        print("| Unmatched keys: ", key, new_param.shape, param.shape)
            for key in unmatched_keys:
                del state_dict[key]
        cur_model.load_state_dict(state_dict, strict=strict)
        print(f"| load '{prefix_in_ckpt}' from '{checkpoint_path}'.")
    else:
        e_msg = f"| ckpt not found in {base_dir}."
        if force:
            assert False, e_msg
        else:
            print(e_msg)

decoder_arch = {'grouping': GroupingDecoder,
                'attention': AlignmentAttentionDecoder,
                'simple': SimpleDecoder}

@register_model("mbart_with_melody")
class mbart_with_melody(BARTModel):
    def __init__(self, args, encoder, decoder):
        super(mbart_with_melody, self).__init__(args, encoder, decoder)
        portion_embedding = nn.Embedding(args.portion_type_num + 1, args.decoder_embed_dim)
        self.alignment_decoder = decoder_arch[args.alignment_decoder_type](args, portion_embedding)

        if args.length_control_type == 'explicit':
            self.length_encoder = LengthEncoder(args)
        else:
            self.alignment_encoder = AlignmentEncoder(args, portion_embedding)
        self.note_embeddings = nn.Embedding(args.note_num + 1, args.encoder_embed_dim)
        self.dur_embeddings = nn.Embedding(args.dur_type_num + 1, args.encoder_embed_dim)
        self.melody_alignment_encoder = MelodyAlignmentEncoder()
        self._num_updates = 0
        if hasattr(args, 'without_portion'):
            self.without_portion = args.without_portion
        else:
            self.without_portion = False
        self.args = args


    @classmethod
    def build_model(cls, args, task):
        model = super(mbart_with_melody, cls).build_model(args, task)
        if args.pretrained_mt_ckpt_dir is not None:
            load_ckpt(model, args.pretrained_mt_ckpt_dir, strict=False)
        return model

    @classmethod
    def build_decoder(cls, args, tgt_dict, embed_tokens):
        return TransformerMelodyDecoder(
            args,
            tgt_dict,
            embed_tokens,
            no_encoder_attn=getattr(args, "no_cross_attention", False),
        )

    @staticmethod
    def add_args(parser):
        BARTModel.add_args(parser)
        parser.add_argument(
            "--pretrained-mt-ckpt-dir",
            default=None,
            type=str,
            help="path to the directory that contains pretrained lyrics translation checkpoint",)
        parser.add_argument(
            "--note-num",
            default=0,
            type=int,
            help="numbers of midi notes to support",
        )
        parser.add_argument(
            "--dur-type-num",
            default=8,
            type=int,
            help="numbers of midi notes durations types",
        )
        parser.add_argument(
            "--portion-type-num",
            default=25,
            type=int,
            help="numbers of portion division",
        )
        parser.add_argument(
            "--act-stop-epsilon",
            default=0.1,
            type=float,
            help="ACT halting hyper-param epsilon",
        )
        parser.add_argument(
            "--alignment-decoder-type",
            default='grouping',
            type=str,
            help="type of alignment prediction NN",
        )
        parser.add_argument(
            "--grouping-arch",
            default='act',
            type=str,
            help="type of grouping prediction way",
        )
        parser.add_argument(
            "--melody-embed-scale",
            default=0.5,
            type=float,
            help="weight for melody embeddings when added to token embeddings",
        )
        parser.add_argument(
            "--length-control-type",
            default='explicit',
            type=str,
            help="type of translation length control",
        )
        parser.add_argument(
            "--hidden-size",
            default=512,
            type=int,
            help="size of representation vector",
        )
        parser.add_argument(
            "--kernel-size",
            default=3,
            type=int,
            help="kenerl size of 1D kernel",
        )
        parser.add_argument(
            "--predictor-layers",
            default=3,
            type=int,
            help="layer number of predictor Conv layers",
        )
        parser.add_argument(
            "--predictor-dropout",
            default=0.5,
            type=float,
        )
        parser.add_argument(
            "--predictor-padding",
            default='SAME',
            type=str,
            help="padding type of Conv of predictor",
        )
        parser.add_argument(
            "--without-portion",
            action='store_true',
            help="whether to ablate portion embedding",
        )

    def set_num_updates(self, num_updates):
        super().set_num_updates(num_updates)
        self._num_updates = num_updates


    def forward(
            self,
            src_tokens,
            src_lengths,
            prev_output_tokens,
            notes,
            durs,
            src_alignments,
            src_alignments_masks,
            tgt_alignments,
            tgt_alignments_masks,
            features_only: bool = False,
            classification_head_name: Optional[str] = None,
            token_embeddings: Optional[torch.Tensor] = None,
            return_all_hiddens: bool = True,
            alignment_layer: Optional[int] = None,
            alignment_heads: Optional[int] = None,
        ):
        '''
        src_tokens: [B,T]
        src_lengths:
        prev_output_tokens,
        notes:[B,T]
        durs: [B,T]
        src_alignments: [B,T]
        tgt_alignments: [B,T]
        '''
        if classification_head_name is not None:
            features_only = True
        encoder_out = self.encoder(
            src_tokens,
            src_lengths=src_lengths,
            token_embeddings=token_embeddings,
            return_all_hiddens=return_all_hiddens
        )
        # For a BOS at the beginning
        notes_emb = self.note_embeddings(F.pad(notes, (1, 0)))
        durs_emb = self.dur_embeddings(F.pad(durs, (1, 0)))
        melody_position_emb = self.encoder.embed_positions(F.pad(notes, (1, 0)))
        # make <BOS> align to 1 note (0 pitch, 0 dur)
        if not isinstance(self.alignment_decoder, AlignmentAttentionDecoder):
            tgt_melody_cond = self.melody_alignment_encoder(F.pad(tgt_alignments, (1, 0), value=1), notes_emb + durs_emb + melody_position_emb, F.pad(tgt_alignments_masks, (1, 0), value=1))
        #[B, H]
        extra = None
        alignment_decoder_output = None
        if hasattr(self, "length_encoder"):
            length_control_hidden = self.length_encoder(prev_output_tokens, src_alignments, tgt_melody_cond)
            encoder_out = torch.stack([length_control_hidden, encoder_out], dim=1)
            x, extra = self.decoder(
                prev_output_tokens,
                encoder_out=encoder_out,
                features_only=features_only,
                alignment_layer=alignment_layer,
                alignment_heads=alignment_heads,
                src_lengths=src_lengths,
                return_all_hiddens=return_all_hiddens,
            )
            alignment_decoder_output = self.alignment_decoder(x, src_alignments, tgt_alignments, tgt_melody_cond)
            extra.update({'length_control_hidden': length_control_hidden})
        elif hasattr(self, "alignment_encoder"):
            if self.without_portion:
                prev_output_alignment_embeddings = torch.zeros_like(F.pad(tgt_alignments, (1, 0))).unsqueeze(-1).repeat(1, 1, self.args.decoder_embed_dim)
            else:
                prev_output_alignment_embeddings = self.alignment_encoder(src_alignments, F.pad(tgt_alignments, (1, 0)), F.pad(tgt_alignments_masks, (1, 0)))
            x, extra = self.decoder(
                prev_output_tokens,
                prev_output_alignment_embeddings,
                encoder_out=encoder_out,
                features_only=features_only,
                alignment_layer=alignment_layer,
                alignment_heads=alignment_heads,
                src_lengths=src_lengths,
                return_all_hiddens=return_all_hiddens,
            )
        #[B, Ttgt]
            if isinstance(self.alignment_decoder, AlignmentAttentionDecoder):
                alignment_decoder_output = self.alignment_decoder(extra['decoder_hidden_output'], notes_emb + durs_emb + melody_position_emb)
            else:
                alignment_decoder_output, alignment_extra = self.alignment_decoder(extra['decoder_hidden_output'], src_alignments, src_alignments_masks, F.pad(tgt_alignments, (1, 0)), F.pad(tgt_alignments_masks, (1, 0), value=1), prev_output_alignment_embeddings)
                extra.update(alignment_extra)
        return x, alignment_decoder_output, extra

    def decoder_infer(self,
                      prev_output_tokens,
                      notes,
                      durs,
                      src_lengths,
                      src_alignments,
                      src_alignments_masks,
                      tgt_alignments,
                      tgt_alignments_masks,
                      incremental_states=None,
                      encoder_out=None,
                      last_hx=None,
                      token_embeddings: Optional[torch.Tensor] = None,
                      return_all_hiddens: bool = True,
                      classification_head_name: Optional[str] = None,
                      alignment_layer: Optional[int] = None,
                      alignment_heads: Optional[int] = None,
                      ):
        # if classification_head_name is not None:
        #     features_only = True

        notes_emb = self.note_embeddings(F.pad(notes, (1, 0)))
        durs_emb = self.dur_embeddings(F.pad(durs, (1, 0)))
        melody_position_emb = self.encoder.embed_positions(F.pad(notes, (1, 0)))
        # make <BOS> align to 1 note (0 pitch, 0 dur)
        if not isinstance(self.alignment_decoder, AlignmentAttentionDecoder):
            # print('decoder infer align', tgt_alignments)
            tgt_melody_cond = self.melody_alignment_encoder(F.pad(tgt_alignments, (1, 0), value=1) if tgt_alignments is not None else torch.ones_like(prev_output_tokens),
                                                        notes_emb + durs_emb + melody_position_emb,
                                                        F.pad(tgt_alignments_masks, (1, 0), value=1) if tgt_alignments_masks is not None else torch.ones_like(prev_output_tokens))
        # [B, H]
        extra = None
        alignment_decoder_output = None
        if hasattr(self, "length_encoder"):
            length_control_hidden = self.length_encoder(prev_output_tokens, src_alignments, tgt_melody_cond)
            encoder_out = torch.stack([length_control_hidden, encoder_out], dim=1)
            x, extra = self.decoder(
                prev_output_tokens,
                encoder_out=encoder_out,
                incremental_state=incremental_states,
                alignment_layer=alignment_layer,
                alignment_heads=alignment_heads,
                src_lengths=src_lengths,
                return_all_hiddens=return_all_hiddens,
            )
            alignment_decoder_output = self.alignment_decoder(x, src_alignments, src_alignments_masks, tgt_alignments if tgt_alignments is not None else torch.ones_like(prev_output_tokens), tgt_melody_cond)
        elif hasattr(self, "alignment_encoder"):
            if self.without_portion:
                prev_output_alignment_embeddings = torch.zeros_like(F.pad(tgt_alignments, (1, 0))).unsqueeze(-1).repeat(1, 1, self.args.decoder_embed_dim) if tgt_alignments is not None else torch.zeros_like(prev_output_tokens).unsqueeze(-1).repeat(1, 1, self.args.decoder_embed_dim)
            else:
                prev_output_alignment_embeddings = self.alignment_encoder(src_alignments, F.pad(tgt_alignments, (1, 0)) if tgt_alignments is not None else torch.zeros_like(prev_output_tokens).long(),
                                                                      F.pad(tgt_alignments_masks, (1, 0)) if tgt_alignments_masks is not None else torch.zeros_like(prev_output_tokens).long())
            x, extra = self.decoder(
                prev_output_tokens,
                prev_output_alignment_embeddings,
                encoder_out=encoder_out,
                incremental_state=incremental_states,
                alignment_layer=alignment_layer,
                alignment_heads=alignment_heads,
                src_lengths=src_lengths,
                return_all_hiddens=return_all_hiddens,
            )
        # [B, Ttgt]
            if isinstance(self.alignment_decoder, AlignmentAttentionDecoder):
                alignment_decoder_output = self.alignment_decoder(extra['decoder_hidden_output'], notes_emb + durs_emb + melody_position_emb)
            else:
                if last_hx is None:
                    alignment_decoder_output, alignment_extra = self.alignment_decoder(extra['decoder_hidden_output'],
                                                                                       src_alignments, src_alignments_masks,
                                                                                       F.pad(tgt_alignments, (1, 0)) if tgt_alignments is not None else torch.zeros_like(prev_output_tokens).long(),
                                                                                       F.pad(tgt_alignments_masks, (1, 0)) if tgt_alignments_masks is not None else torch.zeros_like(prev_output_tokens).long(),
                                                                                       prev_output_alignment_embeddings,
                                                                                       infer=True)
                else:
                    alignment_decoder_output, alignment_extra = self.alignment_decoder(extra['decoder_hidden_output'],
                                                                                       src_alignments,
                                                                                       src_alignments_masks,
                                                                                       F.pad(tgt_alignments, (1, 0)) if tgt_alignments is not None else torch.zeros_like(
                                                                                           prev_output_tokens).long(),
                                                                                       F.pad(tgt_alignments_masks, (1, 0)) if tgt_alignments_masks is not None else torch.zeros_like(
                                                                                           prev_output_tokens).long(),
                                                                                       prev_output_alignment_embeddings,
                                                                                       last_hx=last_hx,
                                                                                       infer=True)
                extra.update(alignment_extra)
        return x, alignment_decoder_output, extra

    def forward_align(self, prev_output_tokens, notes, durs):
        notes_emb = self.note_embeddings(F.pad(notes, (0, 1)))
        durs_emb = self.dur_embeddings(F.pad(durs, (0, 1)))
        melody_position_emb = self.encoder.embed_positions(F.pad(notes, (0, 1)))
        alignment_decoder_output = self.alignment_decoder(prev_output_tokens,
                                                          notes_emb + durs_emb + melody_position_emb)
        return alignment_decoder_output


@register_model_architecture("mbart_with_melody", "mbart_base_with_melody")
def mbart_base_with_melody(args):
    mbart_base_architecture(args)