import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from fairseq.models.transformer import TransformerDecoder
from typing import Optional, Dict, Any, List
from torch import Tensor
import os
from fairseq.modules import(
    AdaptiveSoftmax,
    FairseqDropout,
    LayerDropModuleList,
    LayerNorm,
    TransformerDecoderLayer,
    TransformerEncoderLayer,
    GradMultiply
)
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"


class LayerNorm(nn.LayerNorm):
    """Layer normalization module.
    :param int nout: output dim size
    :param int dim: dimension to be normalized
    """

    def __init__(self, nout, dim=-1):
        """Construct an LayerNorm object."""
        super(LayerNorm, self).__init__(nout, eps=1e-12)
        self.dim = dim

    def forward(self, x):
        """Apply layer normalization.
        :param torch.Tensor x: input tensor
        :return: layer normalized tensor
        :rtype torch.Tensor
        """
        if self.dim == -1:
            return super(LayerNorm, self).forward(x)
        return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)

class TransformerMelodyDecoder(TransformerDecoder):

    # def __init__(self, args, tgt_dict, embed_tokens, no_encoder_attn=False):
    #     super(TransformerMelodyDecoder, self).__init__(args, tgt_dict, embed_tokens, no_encoder_attn)
    #
    def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
        super(TransformerMelodyDecoder, self).__init__(args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn)
        self.melody_embed_scale = args.melody_embed_scale


    def forward(
        self,
        prev_output_tokens,
        prev_output_alignments,
        encoder_out: Optional[Dict[str, List[Tensor]]] = None,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        features_only: bool = False,
        full_context_alignment: bool = False,
        alignment_layer: Optional[int] = None,
        alignment_heads: Optional[int] = None,
        src_lengths: Optional[Any] = None,
        return_all_hiddens: bool = False,
    ):
        """
        Args:
            prev_output_tokens (LongTensor): previous decoder outputs of shape
                `(batch, tgt_len)`, for teacher forcing
            prev_output_alignments (Embeddings): previous alignment outputs of shape
                `(batch, tgt_len, hidden)`, for teacher forcing
            encoder_out (optional): output from the encoder, used for
                encoder-side attention
            incremental_state (dict): dictionary used for storing state during
                :ref:`Incremental decoding`
            features_only (bool, optional): only return features without
                applying output layer (default: False).
            full_context_alignment (bool, optional): don't apply
                auto-regressive mask to self-attention (default: False).

        Returns:
            tuple:
                - the decoder's output of shape `(batch, tgt_len, vocab)`
                - a dictionary with any model-specific outputs
        """
        x, extra = self.extract_features(
            prev_output_tokens,
            prev_output_alignments,
            encoder_out=encoder_out,
            incremental_state=incremental_state,
            full_context_alignment=full_context_alignment,
            alignment_layer=alignment_layer,
            alignment_heads=alignment_heads,
        )
        if not features_only:
            extra['decoder_hidden_output'] = x
            x = self.output_layer(x)

        return x, extra

    def extract_features(
        self,
        prev_output_tokens,
        prev_output_alignments,
        encoder_out: Optional[Dict[str, List[Tensor]]],
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        full_context_alignment: bool = False,
        alignment_layer: Optional[int] = None,
        alignment_heads: Optional[int] = None,
    ):
        return self.extract_features_scriptable(
            prev_output_tokens,
            prev_output_alignments,
            encoder_out,
            incremental_state,
            full_context_alignment,
            alignment_layer,
            alignment_heads,
        )

    def extract_features_scriptable(
        self,
        prev_output_tokens,
        prev_output_alignments,
        encoder_out: Optional[Dict[str, List[Tensor]]],
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        full_context_alignment: bool = False,
        alignment_layer: Optional[int] = None,
        alignment_heads: Optional[int] = None,
    ):
        """
        Similar to *forward* but only return features.

        Includes several features from "Jointly Learning to Align and
        Translate with Transformer Models" (Garg et al., EMNLP 2019).

        Args:
            full_context_alignment (bool, optional): don't apply
                auto-regressive mask to self-attention (default: False).
            alignment_layer (int, optional): return mean alignment over
                heads at this layer (default: last layer).
            alignment_heads (int, optional): only average alignment over
                this many heads (default: all heads).

        Returns:
            tuple:
                - the decoder's features of shape `(batch, tgt_len, embed_dim)`
                - a dictionary with any model-specific outputs
        """
        if alignment_layer is None:
            alignment_layer = self.num_layers - 1

        # embed positions
        positions = None
        if self.embed_positions is not None:
            positions = self.embed_positions(
                prev_output_tokens, incremental_state=incremental_state
            )

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
            prev_output_alignments = prev_output_alignments[:, -1:, :]
            if positions is not None:
                positions = positions[:, -1:]

        # embed tokens and positions

        prev_output_alignments = GradMultiply.apply(prev_output_alignments, 10.0)

        x = self.embed_scale * self.embed_tokens(prev_output_tokens) + self.melody_embed_scale * prev_output_alignments
        if self.quant_noise is not None:
            x = self.quant_noise(x)

        if self.project_in_dim is not None:
            x = self.project_in_dim(x)

        if positions is not None:
            x += positions

        if self.layernorm_embedding is not None:
            x = self.layernorm_embedding(x)

        x = self.dropout_module(x)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        self_attn_padding_mask: Optional[Tensor] = None
        if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
            self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)

        # decoder layers
        attn: Optional[Tensor] = None
        inner_states: List[Optional[Tensor]] = [x]
        for idx, layer in enumerate(self.layers):
            if incremental_state is None and not full_context_alignment:
                self_attn_mask = self.buffered_future_mask(x)
            else:
                self_attn_mask = None

            x, layer_attn, _ = layer(
                x,
                encoder_out["encoder_out"][0]
                if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0)
                else None,
                encoder_out["encoder_padding_mask"][0]
                if (
                    encoder_out is not None
                    and len(encoder_out["encoder_padding_mask"]) > 0
                )
                else None,
                incremental_state,
                self_attn_mask=self_attn_mask,
                self_attn_padding_mask=self_attn_padding_mask,
                need_attn=bool((idx == alignment_layer)),
                need_head_weights=bool((idx == alignment_layer)),
            )
            inner_states.append(x)
            if layer_attn is not None and idx == alignment_layer:
                attn = layer_attn.float().to(x)

        if attn is not None:
            if alignment_heads is not None:
                attn = attn[:alignment_heads]

            # average probabilities over heads
            attn = attn.mean(dim=0)

        if self.layer_norm is not None:
            x = self.layer_norm(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if self.project_out_dim is not None:
            x = self.project_out_dim(x)
        x = GradMultiply.apply(x, 0.1)
        return x, {"attn": [attn], "inner_states": inner_states}
'''
prev_output_token:<BOS>, x1, x2, x3, x4
                      1,...............
pitch                    0
dur                      0
pos                      0

decoder_hidden:   x1, x2, x3, x4, <EOS>
                  x1,               0 
                  x2_alignment,     0
'''

class AlignmentAttentionDecoder(TransformerDecoder):
    def __init__(self, args, **extra_kwargs):
        self.args = args
        self.register_buffer("version", torch.Tensor([3]))

        self.dropout_module = FairseqDropout(
            args.dropout, module_name=self.__class__.__name__
        )
        self.decoder_layerdrop = args.decoder_layerdrop

        if getattr(args, "layernorm_embedding", False):
            self.layernorm_embedding = LayerNorm(args.embed_dim)
        else:
            self.layernorm_embedding = None

        self.cross_self_attention = getattr(args, "cross_self_attention", False)

        if self.decoder_layerdrop > 0.0:
            self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
        else:
            self.layers = nn.ModuleList([])
        self.layers.extend(
            [
                self.build_decoder_layer(args, False)
                for _ in range(args.alignment_decoder_layers)
            ]
        )
        self.num_layers = len(self.layers)

        if args.decoder_normalize_before and not getattr(
            args, "no_decoder_final_norm", False
        ):
            self.layer_norm = LayerNorm(args.embed_dim)
        else:
            self.layer_norm = None


    def forward(self, lyrics_cond,
                lyrics_padding_mask,
                melody_cond,
                incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
                full_context_alignment: bool = False,
                alignment_layer: Optional[int] = None,
                alignment_heads: Optional[int] = None,):

        x = melody_cond
        encoder_out = {"encoder_out": lyrics_cond,
                       "encoder_padding_mask": lyrics_padding_mask,}
        self_attn_padding_mask: Optional[Tensor] = None

        attn: Optional[Tensor] = None
        inner_states: List[Optional[Tensor]] = [x]

        for idx, layer in enumerate(self.layers):
            if incremental_state is None and not full_context_alignment:
                self_attn_mask = self.buffered_future_mask(x)
            else:
                self_attn_mask = None

            x, layer_attn, _ = layer(
                x,
                encoder_out["encoder_out"][0]
                if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0)
                else None,
                encoder_out["encoder_padding_mask"][0]
                if (
                    encoder_out is not None
                    and len(encoder_out["encoder_padding_mask"]) > 0
                )
                else None,
                incremental_state,
                self_attn_mask=self_attn_mask,
                self_attn_padding_mask=self_attn_padding_mask,
                need_attn=bool((idx == alignment_layer)),
                need_head_weights=bool((idx == alignment_layer)),
            )
            inner_states.append(x)
            if layer_attn is not None and idx == alignment_layer:
                attn = layer_attn.float().to(x)

        if attn is not None:
            if alignment_heads is not None:
                attn = attn[:alignment_heads]

            # average probabilities over heads
            attn = attn.mean(dim=0)

        return attn


class MelodyAlignmentEncoder(nn.Module):
    def __init__(self):
        super(MelodyAlignmentEncoder, self).__init__()
        self.pooling = nn.AdaptiveAvgPool1d(output_size=1)

    def forward(self, alignments, melody_cond, alignments_mask):
        '''
        alignments: [B, T]
        melody_cond: [B, Tm, H]

        '''

        def get_ones_index(alignments):
            cum_alignments = torch.cumsum(alignments, dim=-1)
            nonzero_alignments = alignments.nonzero()
            nonzero_dim1 = torch.sum(alignments > 0, dim=-1)
            nonzero_dim1 = torch.cat([torch.arange(nonzero_dim1[i]) for i in range(nonzero_dim1.shape[0])])
            nonzero_cum = cum_alignments[torch.where(alignments > 0)]
            nonzero_num = nonzero_cum - alignments[torch.where(alignments > 0)]
            index_dim0 = torch.cat([torch.LongTensor([nonzero_alignments[i][0]] * alignments[nonzero_alignments[i][0]][nonzero_alignments[i][1]]) for i in range(nonzero_alignments.shape[0])], dim=0)
            index_dim1 = torch.cat([torch.LongTensor([nonzero_dim1[i]] * alignments[nonzero_alignments[i][0]][nonzero_alignments[i][1]]) for i in range(nonzero_dim1.shape[0])], dim=0)
            index_dim2 = torch.cat([torch.arange(nonzero_num[i], nonzero_cum[i]) for i in range(nonzero_num.shape[0])], dim=0)
            assert index_dim0.shape[0] == index_dim1.shape[0] and index_dim1.shape[0] == index_dim2.shape[0], (nonzero_dim1, alignments)
            return (index_dim0, index_dim1, index_dim2)


        scattered_pos = torch.cumsum(alignments.flip(dims=[-1]) != 0, dim=-1)

        scattered_pos = torch.max(scattered_pos, dim=-1).values.unsqueeze(-1) - scattered_pos.flip(dims=[-1])

        # pooled_melody_cond = []
        pooling_mat = torch.zeros((melody_cond.shape[0], torch.max(scattered_pos) + 1, melody_cond.shape[1])).to(alignments.device)

        # temp = get_ones_index(alignments)
        # index_dim0 = temp[0]
        # index_dim1 = temp[1]
        # index_dim2 = temp[2]
        # for i in range(index_dim0.shape[0]):
        #     print(index_dim0[i], index_dim1[i], index_dim2[i])
        #     print(pooling_mat[index_dim0[i]][index_dim1[i]][index_dim2[i]])

        ones_dim0, ones_dim1, ones_dim2 = get_ones_index(alignments)
        ones_dim2 = torch.clamp(ones_dim2, min=0, max=pooling_mat.shape[-1] - 1)
        pooling_mat[(ones_dim0, ones_dim1, ones_dim2)] = 1.0

        norm_ = torch.sum(pooling_mat, dim=-1)
        norm_ = norm_.masked_fill(norm_ == 0, 1)
        pooling_mat = pooling_mat / norm_.unsqueeze(-1)

        # for i in range(alignments.shape[0]): #N

            # temp_melody_cond = []
            # temp_pos = 0
            # for j in range(alignments.shape[1]):
            #     if alignments[i][j] == 0:
            #         continue
            #     if alignments[i][j] == 1:
            #         temp_melody_cond.append(melody_cond[i, temp_pos, :].unsqueeze(0))
            #     else:
            #         temp_melody_cond.append(self.pooling(melody_cond[i, temp_pos: temp_pos + alignments[i][j], :].transpose(0, 1)).transpose(0, 1))
            #     temp_pos = temp_pos + alignments[i][j]
            # temp_melody_cond = torch.concat(temp_melody_cond, dim=0)
            # temp_melody_cond = F.pad(temp_melody_cond, (0, 0, 0, alignments.shape[1] - temp_melody_cond.shape[0]))
            # pooled_melody_cond.append(temp_melody_cond.unsqueeze(0))
        # pooled_melody_cond = torch.concat(pooled_melody_cond, dim=0)

        pooled_melody_cond = torch.einsum('btc,btm->bmc', melody_cond, pooling_mat.transpose(1, 2))
        return torch.gather(input=pooled_melody_cond, dim=1, index=scattered_pos.unsqueeze(-1).repeat(1, 1, melody_cond.shape[-1])) * alignments_mask.unsqueeze(-1)

# Bad Rom@@ man@@ ce
# 1   0     0     2  0 3 1 1
#       durs =0
#       pitch=0
# durs_emb + notes_emb
# 2 3
# mean_pool(Sigma[dur_emb + notes_emb])

class LengthEncoder(nn.Module):

    def __init__(self, args):
        super(LengthEncoder, self).__init__()
        self.pooling = nn.AdaptiveAvgPool1d()
        self.conv = torch.nn.ModuleList()
        for idx in range(args.predictor_layers):
            self.conv += [nn.Sequential(
                          nn.ConstantPad1d((args.kernel_size - 1, 0), 0),
                          nn.Conv1d(args.decoder_embed_dim, args.decoder_embed_dim, args.kernel_size, stride=1, padding=0),
                          nn.ReLU(),
                          LayerNorm(args.decoder_embed_dim, dim=1),
                          nn.Dropout(args.predictor_dropout)
            )]
        self.linear = nn.Linear(args.decoder_embed_dim, v)

    def forward(self, x, src_alignments, melody_cond):
        '''
        x: [B, T, H]
        src_alignments: [B, Tsrc],
        melody_cond: [B, Nn, H],
        '''
        #Teacher Forcing
        cond = melody_cond

        for conv_layer in self.conv:
            cond = conv_layer(cond)
        # [B, T, H] -> [B, H]
        cond = self.pooling(cond)
        cond = self.linear(cond)
        # [B, H] length control hidden
        return cond


class AlignmentEncoder(nn.Module):
    def __init__(self, args, portion_embedding):
        super(AlignmentEncoder, self).__init__()
        self.portion_type = args.portion_type_num
        self.portion_embedding = portion_embedding
        self.conv = torch.nn.ModuleList()
        for idx in range(args.predictor_layers):
            self.conv += [nn.Sequential(
                          nn.ConstantPad1d((args.kernel_size - 1, 0), 0),
                          nn.Conv1d(args.decoder_embed_dim, args.encoder_embed_dim, args.kernel_size, stride=1, padding=0),
                          nn.ReLU(),
                          LayerNorm(args.decoder_embed_dim, dim=1),
                          nn.Dropout(args.predictor_dropout)
                        )]

    def forward(self, src_alignments, tgt_alignments, tgt_alignments_masks, infer=False):
        '''
        src_alignments: [B, Tsrc],
        tgt_alignments: [B, Ttgt],
        melody_cond: [B, Nn, H],
        '''
        if infer:
            pass
        else:
            #Teacher Forcing
            # cond = tgt_melody_cond
            # Pad a 0 to the alignment for <BOS>
            tgt_aligned_notes_deltas = torch.cumsum(tgt_alignments, dim=-1)
            total_notes = torch.sum(src_alignments, dim=-1).unsqueeze(-1)
            tgt_remained_portion = (total_notes - tgt_aligned_notes_deltas) / total_notes * self.portion_type

            tgt_remained_portion = (tgt_remained_portion.round() * tgt_alignments_masks).long()
            tgt_remained_portion = torch.clamp(tgt_remained_portion, 0, self.portion_type)
            tgt_remained_portion = self.portion_embedding(tgt_remained_portion).transpose(1, 2)
            for conv_layer in self.conv:
                tgt_remained_portion = conv_layer(tgt_remained_portion)
            # cond += tgt_remained_portion.transpose(1, 2)
        return tgt_remained_portion.transpose(1, 2)

# 1 1 1 2 1 1
# 0 1 2 [3,4] 5 6
# 1 [0 1] [0 1] [1 1] 1 1
# 0    1     2   3 4  5 6

class ACTEncoder(nn.Module):
    def __init__(self, args, portion_embedding):
        super(ACTEncoder, self).__init__()
        self.portion_type = args.portion_type_num
        self.portion_embedding = portion_embedding
        self.src_allocate_encoder = torch.nn.ModuleList()
        for idx in range(args.predictor_layers):
            self.src_allocate_encoder += [nn.Sequential(
                nn.ConstantPad1d((args.kernel_size - 1, 0), 0),
                nn.Conv1d(args.decoder_embed_dim, args.decoder_embed_dim, args.kernel_size, stride=1, padding=0),
                nn.ReLU(),
                LayerNorm(args.decoder_embed_dim, dim=1),
                nn.Dropout(args.predictor_dropout)
            )]
        self.src_allocate_encoder += [nn.AdaptiveAvgPool1d(output_size=1)]

    def forward(self, x, src_alignments, src_alignments_masks, prev_output_alignment_embeddings, infer=False):

        if infer:
            pass
        else:
            #Teacher Forcing
            x_mask = x != 0
            cond =  (torch.cumsum(src_alignments, dim=-1) / torch.sum(src_alignments, dim=-1).unsqueeze(-1) * self.portion_type).round().long()
            # [B, T]
            cond = cond * src_alignments_masks
            cond = self.portion_embedding(cond)

            cond = cond.transpose(1, 2)  #[B, 1, T]
            for conv_layer in self.src_allocate_encoder:
                cond = conv_layer(cond)
            cond = x + prev_output_alignment_embeddings + cond.transpose(1, 2).repeat(1, x.shape[1], 1) * x_mask
        return cond   #[B, T, H]


class SimpleDecoder(nn.Module):
    def __init__(self, args, portion_embedding):
        super(SimpleDecoder, self).__init__()
        self.cond_encoder = ACTEncoder(args, portion_embedding)
        self.delta_note_classifier = nn.Linear(args.decoder_embed_dim, args.max_delta_note + 1)

    def forward(self, x, src_alignments, src_alignments_masks, tgt_alignments, tgt_alignments_masks, prev_output_alignment_embeddings, infer=False):
        # [x1, x2, x3, x4, EOS]

        cond = self.cond_encoder(x, src_alignments, src_alignments_masks, prev_output_alignment_embeddings, infer=False)

        if infer:
            pred_alignments = self.delta_note_classifier(cond[:, -1:, :])
        else:
            pred_alignments = self.delta_note_classifier(cond)
        return pred_alignments, {}



class ACT_unit(nn.Module):
    def __init__(self, args, portion_embedding):
        super(ACT_unit, self).__init__()
        self.stop_epsilon = args.act_stop_epsilon
        self.max_ponder = args.max_delta_note
        self.unit_encoder = torch.nn.ModuleList([nn.Linear(args.decoder_embed_dim * 3, args.decoder_embed_dim),
                                                 nn.Mish(),
                                                 nn.Linear(args.decoder_embed_dim, args.decoder_embed_dim)
                                                 ])
        self.portion_type = args.portion_type_num
        self.step_embedding = portion_embedding
        self.ponder_linear = nn.Linear(args.decoder_embed_dim, 1)
        if hasattr(args, 'without_portion'):
            self.without_portion = args.without_portion
        else:
            self.without_portion = False

    def unit_encode(self, cond, hx, remained_delta, step):

        cond = torch.cat((cond, hx, self.step_embedding(remained_delta.round().long())), dim=-1)
        for layer in self.unit_encoder:
            cond = layer(cond)
        cond = cond + self.step_embedding((step / self.max_ponder * self.portion_type).round().long())
        return cond

    def forward(self, x, act_cond, remained_delta, total_notes, hx=None):
        '''
        x: [B, 1, H],
        act_cond: [B, 1, H],
        tgt_alignment: [B, 1]
        '''
        def bool_to_idx(idx):
            return idx.nonzero().squeeze(1)

        B, H = x.size()
        n_selector = x.data.new(B).byte() #N(t)

        if hx is None:
            hx = Variable(x.new(B, H).zero_())
            hx = hx.fill_(1)


        accum_alpha = Variable(x.data.new(B).zero_())
        accum_hx = Variable(x.data.new(B, H).zero_())

        n_selector = n_selector.fill_(1)

        step_ponder_cost = Variable(x.data.new(B).zero_())
        step_count = Variable(x.data.new(B).zero_())
        _remained_delta = remained_delta.clone()
        _remained_delta = torch.clamp(_remained_delta, min=0)
        for act_step in range(self.max_ponder):
            idx = bool_to_idx(n_selector)

            step_count[idx] += 1.0
            hx[idx] = self.unit_encode(act_cond[idx], hx[idx], _remained_delta[idx] / total_notes[idx] * self.portion_type, step_count[idx])
            alpha = F.sigmoid(self.ponder_linear(hx[idx]).squeeze(1))
            accum_alpha[idx] += alpha
            p = alpha - (accum_alpha[idx] - 1).clamp(min=0)
            accum_hx[idx] += p.unsqueeze(1) * hx[idx]

            step_ponder_cost[idx] = 1.0 - accum_alpha[idx]

            _remained_delta[idx] -= 1
            _remained_delta[idx] = torch.clamp(_remained_delta[idx], min=0)

            n_selector = (accum_alpha < 1 - self.stop_epsilon).data

            if not n_selector.any():
                break

        hx = accum_hx / step_count.clone().unsqueeze(1)
        #step_ponder_cost: -\sigma_{0}^{n}i \alpha_i
        return hx, step_count - 1, step_ponder_cost



class Grouping_unit(ACT_unit):

    def __init__(self, args, portion_embedding):
        super(Grouping_unit, self).__init__(args, portion_embedding)
        del self.stop_epsilon
        self.ponder_linear = nn.Linear(args.decoder_embed_dim, 3)

    def forward(self, x, act_cond, remained_delta, total_notes, hx=None, max_steps=None):
        '''
        x: [B, 1, H],
        act_cond: [B, 1, H],
        tgt_alignment: [B, 1]
        '''
        def bool_to_idx(idx):
            return idx.nonzero().squeeze(1)

        B, H = x.size()
        n_selector = x.data.new(B).zero_().bool() #N(t)
        if hx is None:
            hx = Variable(x.new(B, H).zero_())

        n_selector = n_selector.fill_(1)

        step_count = Variable(x.data.new(B).zero_())
        step_prob = Variable(x.data.new(self.max_ponder + 1, B, 3).zero_())

        remained_delta = torch.clamp(remained_delta, min=0)
        _remained_delta = remained_delta.clone()

        for act_step in range(self.max_ponder + 1):
            idx = bool_to_idx(n_selector)

            hx[idx] = self.unit_encode(act_cond[idx], hx[idx], _remained_delta[idx] / total_notes[idx] * self.portion_type, step_count[idx])

            p = F.softmax(self.ponder_linear(hx[idx]).squeeze(1), dim=-1)

            step_count[idx] += 1.0
            _remained_delta[idx] -= 1
            _remained_delta[idx] = torch.clamp(_remained_delta[idx], min=0)

            if max_steps is None:
                n_selector[idx] = (p.argmax(dim=-1) == 1).data
            else:
                n_selector = step_count < max_steps

            step_count[idx] -= (p.argmax(dim=-1) == 0).long().data

            step_prob[act_step, idx, :] = p

            if not n_selector.any():
                break

        #step_ponder_cost: -\sigma_{0}^{n}i \alpha_i
        return hx, step_count, step_prob.transpose(0, 1)


grouping_decoder_arch = {'act' : ACT_unit,
                         'grouping': Grouping_unit}


class GroupingDecoder(nn.Module):
    def __init__(self, args, portion_embedding):
        super(GroupingDecoder, self).__init__()
        self.arch = args.grouping_arch
        self.act_unit = grouping_decoder_arch[args.grouping_arch](args, portion_embedding)
        self.act_encoder = ACTEncoder(args, portion_embedding)
        if hasattr(args, 'without_portion'):
            self.without_portion = args.without_portion
        else:
            self.without_portion = False
        # if self.arch == 'grouping':
        #     self.grouping_layers = torch.nn.ModuleList([nn.Linear(args.decoder_embed_dim * 2, args.decoder_embed_dim),
        #                                                 nn.Mish(),
        #                                                 nn.Linear(args.decoder_embed_dim, args.decoder_embed_dim)
        #                                                 ])
        #
        #     self.simple_group_classifier = nn.Linear(args.decoder_embed_dim, 2)

    def forward(self, x, src_alignments, src_alignments_masks, tgt_alignments, tgt_alignments_masks, prev_output_alignment_embeddings, last_hx=None, infer=False):
        '''
        x: [B, T, H],
        src_alignment: [B, Tsrc]
        tgt_alignments: [B, Ttgt],
        melody_cond: [B, Nn, H],
        '''
        B, T, H= x.shape
        if self.without_portion:
            cond = x
        else:
            cond = self.act_encoder(x, src_alignments, src_alignments_masks, prev_output_alignment_embeddings, infer=False)

        total_notes = torch.sum(src_alignments, dim=-1).unsqueeze(-1)
        tgt_alignments_cumsum = torch.cumsum(tgt_alignments, dim=-1)
        remained_delta = (total_notes - tgt_alignments_cumsum) * tgt_alignments_masks
        remained_delta = torch.clamp(remained_delta, min=0)
        if infer:
            pred_alignments = torch.zeros([B, 1]).to(x.device)

            if self.arch == 'act':
                ponder_cost = torch.zeros([B, 1]).to(x.device)
                temp_hx, pred_alignments[:, 0], ponder_cost[:, 0] = self.act_unit(x[:, -1], cond[:, -1], remained_delta[:, -1], total_notes.squeeze(-1), hx=None)
            # [B, 1]
            elif self.arch == 'grouping':
                temp_hx, pred_alignments[:, 0], step_probs = self.act_unit(x[:, -1], cond[:, -1], remained_delta[:, -1], total_notes.squeeze(-1), hx=None)
                # cond = torch.cat((cond[:, -1], self.act_encoder.portion_embedding((remained_delta[:, -1] / total_notes * self.act_unit.portion_type).round().long())), dim=-1)
                # for layer in self.grouping_layers:
                #     cond = layer(cond)
                # step_probs = self.simple_group_classifier(cond)

            return pred_alignments, {'ponder_cost': ponder_cost if self.arch == 'act' else step_probs, 'hx': temp_hx if self.arch == 'act' else None}
        else:
            pred_alignments = torch.zeros_like(tgt_alignments)
            temp_hx = None
            if self.arch == 'act':
                ponder_cost = torch.zeros_like(tgt_alignments, dtype=torch.float)
            elif self.arch == 'grouping':
                step_probs = torch.zeros([B, tgt_alignments_cumsum.max(), 3]).to(tgt_alignments_cumsum.device)
            #used portion and left notes

            for length_idx in range(T):
                if self.arch == 'act':
                    temp_hx, pred_alignments[:, length_idx], ponder_cost[:, length_idx] = self.act_unit(x[:, length_idx], cond[:, length_idx], remained_delta[:, length_idx], total_notes.squeeze(-1), hx=None)
            # [B, T]
                #simple
                # elif self.arch == 'grouping':
                    # cond = torch.cat((cond, self.act_encoder.portion_embedding((remained_delta / total_notes * self.act_unit.portion_type).round().long())), dim=-1)
                    # for layer in self.grouping_layers:
                    #     cond = layer(cond)
                    # step_probs = self.simple_group_classifier(cond)
                # act-like
                elif self.arch == 'grouping':
                    temp_hx, pred_alignments[:, length_idx], temp_step_probs = self.act_unit(x[:, length_idx], cond[:, length_idx], remained_delta[:, length_idx], total_notes.squeeze(-1), hx=temp_hx, max_steps=tgt_alignments[:, length_idx])
                    for i in range(B):
                        if length_idx == 0:
                            step_probs[i, :tgt_alignments_cumsum[i, length_idx]] = temp_step_probs[i, : tgt_alignments[i, length_idx]]
                        else:
                            step_probs[i, tgt_alignments_cumsum[i, length_idx - 1]: tgt_alignments_cumsum[i, length_idx]] = temp_step_probs[i, : tgt_alignments[i, length_idx]]

        return pred_alignments, {'ponder_cost': ponder_cost if self.arch == 'act' else step_probs}






