# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
import multiprocessing
import os

import torch
import itertools

import numpy as np
from fairseq import utils
from fairseq import search
from fairseq.models.bart.musicxml_utils import render, out_to_wav

try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    try:
        from tensorboardX import SummaryWriter
    except ImportError:
        SummaryWriter = None

from ly.musicxml import create_musicxml
# import partitura
from librosa import midi_to_note
from PIL import Image


from fairseq.data import (
    indexed_dataset,
    AppendTokenDataset,
    ConcatDataset,
    StripTokenDataset,
    TruncateDataset,
    Dictionary,
    PrependTokenDataset,
    PrependDataset,
    ResamplingDataset,
    SortDataset,
    MultiLanguagePairWithMelodyDataset,
    data_utils,
)
from fairseq.tasks.denoising import DenoisingTask
from fairseq.tasks import register_task
from multiprocessing import Pool
logger = logging.getLogger(__name__)
torch.multiprocessing.set_sharing_strategy(os.getenv('TORCH_SHARE_STRATEGY', 'file_system'))
try:
    multiprocessing.set_start_method('spawn', force=True)
except RuntimeError:
   pass

def get_aligned_notes(sample, hypos, dictionary, log_src=False):

    to_actual_dur = {0:0, 1:4.0, 2:2.0, 3:1.0, 4:0.5, 5:0.25, 6: 0.125,
                    7: 3.0, 8: 1.5, 9: 0.75, 10: 0.375, 11: 0.1875,
                    12: 1.3333333333, 13:0.66666666666666, 14: 0.3333333333, 15: 0.166666666,16: 0.0833333333333,
                    17: 2.6666666666, 18: 0.09375, 19: 0.03125, 20: 0.2145833333, 21: 0.07083333333,
                    22: 0.3541666666, 23: 6.9, 24: 0.0625, 25: 0.041666666, 26: 0.2, 27:0.1,
                    28: 1.75, 29:6.0, 30: 0.25}
    # {'pitch': , 'dur':}
    src_aligned_notes = []
    tgt_aligned_notes = []

    pred_lyrics = [dictionary.string(temp_sent[0]['tokens'].cpu(), extra_symbols_to_ignore=[dictionary.pad()]).split(' ') for temp_sent in hypos]
    eos_pos = [(temp_sent[0]['tokens'] == dictionary.eos()).nonzero() for temp_sent in hypos]
    pred_alignments = [temp_sent[0]['pred_alignments'].cpu() for temp_sent in hypos]
    cnt = 0

    def tolerant_correct(alignments, num_notes):
        _sum = sum(alignments)
        if _sum > num_notes:
            while _sum > num_notes:
                flag = False
                for i in range(len(alignments)):
                    if alignments[i] > 1:
                        flag = True

                    if alignments[i] > 1 and _sum > num_notes:
                        alignments[i] -= 1
                        _sum -= 1
                    elif _sum <= num_notes:
                        break
                if not flag:
                    break

        elif _sum < num_notes:
            for i in range(len(alignments) - 1, -1, -1):
                if alignments[i] != 0:
                    alignments[i] += num_notes - _sum
                    break

    def convert_and_add_new_note(aligned_dict, notes, durs, note_index):
        if note_index >= len(notes) or notes[note_index] == 0:
            return
        aligned_dict[-1]['pitch'].append(notes[note_index])
        aligned_dict[-1]['dur'].append(durs[note_index])


    for notes, durs, lyrics, alignments in zip(sample["net_input"]['notes'], sample["net_input"]['durs'], pred_lyrics,
                                               pred_alignments):
        temp_pos = 0
        index = 0
        syllabic_cnt = 0
        notes = notes.cpu().numpy()
        durs = durs.cpu().numpy()
        durs = [to_actual_dur[dur] for dur in durs]
        temp_aligned_notes = []
        #把给eos的去掉
        alignments = alignments[:-1]
        # while index < eos_pos[cnt][-1]:

        tolerant_correct(alignments, sum(notes != 0))
        while index < min(eos_pos[cnt][-1], len(alignments)):
            if alignments[index] == 0:
                syllabic_cnt += 1
            else:
                if syllabic_cnt > 0:
                    if alignments[index] > syllabic_cnt:
                        temp_aligned_notes.append({'pitch': [], 'dur': []})
                        convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos)
                        for i in range(1, syllabic_cnt):
                            convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos + i)
                        for i in range(alignments[index] - syllabic_cnt):
                            convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos + syllabic_cnt + i)
                    else:
                        if alignments[index] > 1:
                            temp_aligned_notes.append({'pitch': [], 'dur': []})
                            convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos)

                            for i in range(alignments[index] - 1):
                                convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos + i + 1)
                        else:
                            temp_aligned_notes.append({'pitch': [], 'dur': []})
                            convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos)

                    syllabic_cnt = 0
                else:
                    temp_aligned_notes.append({'pitch': [], 'dur': []})
                    for i in range(alignments[index]):
                        convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos + i)

            temp_pos += alignments[index]
            index += 1
        cnt += 1
        temp_aligned_notes = {'pitch': [x['pitch'] for x in temp_aligned_notes], 'dur': [x['dur'] for x in temp_aligned_notes]}
        tgt_aligned_notes.append(temp_aligned_notes)
    if log_src:
        cnt = 0
        # get rid of <bos> and <eos>
        src_lyrics = [dictionary.string(temp_sent.cpu()[2:], extra_symbols_to_ignore=[dictionary.pad()]).split(' ') for temp_sent in sample["source"]]
        for notes, durs, lyrics, alignments in zip(sample["net_input"]['notes'], sample["net_input"]['durs'],
                                                   src_lyrics,
                                                   sample["net_input"]['src_alignments']):
            temp_pos = 0
            index = 0
            syllabic_cnt = 0
            notes = notes.cpu().numpy()
            durs = durs.cpu().numpy()
            durs = [to_actual_dur[dur] for dur in durs]
            temp_aligned_notes = []
            # lyrics = lyrics[2:]
            # while index < eos_pos[cnt][-1]:
            tolerant_correct(alignments, sum(notes != 0))
            while index < len(lyrics):
                if alignments[index] == 0:
                    syllabic_cnt += 1
                else:
                    if syllabic_cnt > 0:
                        if alignments[index] > syllabic_cnt:
                            temp_aligned_notes.append({'pitch': [], 'dur': []})
                            convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos)
                            for i in range(1, syllabic_cnt):
                                convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos + i)
                            for i in range(alignments[index] - syllabic_cnt):
                                convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos + syllabic_cnt + i)
                        else:
                            if alignments[index] > 1:
                                temp_aligned_notes.append({'pitch': [], 'dur': []})
                                convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos)
                                for i in range(alignments[index] - 1):
                                    convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos + i + 1)
                            else:
                                temp_aligned_notes.append({'pitch': [], 'dur': []})
                                convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos)

                        syllabic_cnt = 0
                    else:
                        temp_aligned_notes.append({'pitch': [], 'dur': []})
                        for i in range(alignments[index]):
                            convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos + i)

                temp_pos += alignments[index]
                index += 1
            cnt += 1
            temp_aligned_notes = {'pitch': [x['pitch'] for x in temp_aligned_notes],
                                  'dur': [x['dur'] for x in temp_aligned_notes]}
            src_aligned_notes.append(temp_aligned_notes)
    return src_aligned_notes, tgt_aligned_notes


def wav_process(notes, lyrics, alignment, wav_infer_ins, lang):
    from utils.hparams import set_hparams
    set_hparams(exp_name='0831_opencpop_ds1000', config='usr/configs/midi/e2e/opencpop/ds1000.yaml',
                print_hparams=False)

    torch.cuda.empty_cache()
    try:
        temp_wav = out_to_wav(lyrics, notes, infer_ins=wav_infer_ins, lang=lang)
        return temp_wav
    except (TypeError, RuntimeError, IndexError) as e:
        print(e)
        print('out to wav failed')
        print(lyrics)
        print(alignment)
        print(notes)
        print('===================')
        return None



def image_process(id, cnt, lyrics, notes, durs, alignments, eos_pos, log_dir, out_fd):

    to_actual_dur = {0:0, 1:192, 2:96, 3:48, 4:24, 5:12, 6: 6,
                    7: 144, 8: 72, 9: 36, 10: 18, 11: 9,
                    12: 64, 13:32, 14: 16, 15: 8 ,16: 4,
                    17: 128, 18: 4.5, 19: 1.5, 20: 10.3, 21: 3.4,
                    22: 1.7, 23: 6.9, 24: 3.0, 25: 2.0, 26: 9.6, 27:4.8,
                    28: 84, 29:288, 30: 24}

    to_actual_type = {0.5: '32th', 1.0: '16th', 2.0 : 'eighth', 4.0: 'quarter', 8.0: 'half', 16.0: 'whole'}

    def convert_and_add_new_note(musicxml_obj, notes, durs, note_index):
        if note_index >= len(notes) or notes[note_index] == 0:
            return
        temp_note = midi_to_note(notes[note_index])
        note_name = temp_note[0]
        note_area = int(temp_note[-1])
        musicxml_obj.new_note(note_name, note_area, to_actual_type.get(durs[note_index], 'quarter'), durs[note_index])

    temp_musicxml = create_musicxml.CreateMusicXML()
    temp_musicxml.create_part(name='Song', abbr='Sg.', midi='Piano')
    temp_musicxml.create_measure(clef=('G', 2, False), mustime=[4, 4], divs=4)
    temp_pos = 0
    index = 0
    syllabic_cnt = 0
    notes = notes
    durs = durs
    durs = [to_actual_dur[dur] / 12 for dur in durs]
    # 把给eos的去掉
    alignments = alignments[:-1]
    while index < min(eos_pos, len(lyrics), len(alignments)):
        if alignments[index] == 0:
            syllabic_cnt += 1
        else:
            if syllabic_cnt > 0:
                if alignments[index] > syllabic_cnt:
                    convert_and_add_new_note(temp_musicxml, notes, durs, temp_pos)
                    temp_musicxml.add_lyric(lyrics[index - syllabic_cnt], 'begin', nr=1)
                    for i in range(1, syllabic_cnt):
                        convert_and_add_new_note(temp_musicxml, notes, durs, temp_pos + i)
                        temp_musicxml.add_lyric(lyrics[index - syllabic_cnt + i], 'middle', nr=1)
                    for i in range(alignments[index] - syllabic_cnt):
                        convert_and_add_new_note(temp_musicxml, notes, durs, temp_pos + syllabic_cnt + i)
                        if i == 0:
                            temp_musicxml.add_lyric(lyrics[index], 'end', nr=1,
                                                    ext=True if alignments[index] - syllabic_cnt > 1 else False)
                else:
                    if alignments[index] > 1:
                        convert_and_add_new_note(temp_musicxml, notes, durs, temp_pos)
                        temp_musicxml.add_lyric(' '.join(lyrics[index - syllabic_cnt: index]), 'begin', nr=1)
                        for i in range(alignments[index] - 1):
                            convert_and_add_new_note(temp_musicxml, notes, durs, temp_pos + i + 1)
                            if i == 1:
                                temp_musicxml.add_lyric(lyrics[index], 'end', nr=1,
                                                        ext=True if alignments[index] > 2 else False)
                    else:
                        convert_and_add_new_note(temp_musicxml, notes, durs, temp_pos)
                        temp_musicxml.add_lyric(' '.join(lyrics[index - syllabic_cnt: index + 1]), 'single', nr=1)

                syllabic_cnt = 0
            else:
                for i in range(alignments[index]):
                    convert_and_add_new_note(temp_musicxml, notes, durs, temp_pos + i)
                    if i == 0:
                        temp_musicxml.add_lyric(lyrics[index], 'single', nr=1,
                                                ext=True if alignments[index] > 1 else False)

        temp_pos += alignments[index]
        index += 1
    temp_musicxml = temp_musicxml.musicxml()
    temp_musicxml_name = f'{log_dir}/tmp_musicxml_log/{cnt}.tmp.musicxml'
    temp_musicxml.write(temp_musicxml_name)
    # temp_part = partitura.load_musicxml(f'{log_dir}/tmp_musicxml_log/{cnt}.tmp.musicxml')
    temp_image = None
    if out_fd is not None:
        if not os.path.exists(temp_musicxml_name):
            temp_musicxml.write(temp_musicxml_name)
            render(temp_musicxml_name, out_fn=f'{out_fd}/pred_{id}.png')
        else:
            render(temp_musicxml_name, out_fn=f'{out_fd}/pred_{id}.png')
        temp_musicxml.write(f'{out_fd}/pred_{id}.musicxml')
        temp_image = f'{out_fd}/pred_{id}.png'
    else:
        if not os.path.exists(temp_musicxml_name):
            temp_musicxml.write(temp_musicxml_name)
            render(temp_musicxml_name, out_fn=f'{log_dir}/tmp_musicxml_log/{cnt}.tmp.png')
        else:
            render(temp_musicxml_name, out_fn=f'{log_dir}/tmp_musicxml_log/{cnt}.tmp.png')
        # temp_image = np.array(Image.open(f'{log_dir}/tmp_musicxml_log/{cnt}.tmp.png'))
        temp_image = f'{log_dir}/tmp_musicxml_log/{cnt}.tmp.png'
        # try:
        #     os.remove(f'{log_dir}/tmp_musicxml_log/{cnt}.tmp.png')
        # except OSError:
        #     pass
    # try:
    #     os.remove(temp_musicxml_name)
    # except OSError:
    #     pass
    return temp_image


def get_musicxml_image(sample, hypos, dictionary, out_fd=None, log_src=False, log_dir=None):
    pool = Pool(8)
    os.makedirs(f'{log_dir}/tmp_musicxml_log', exist_ok=True)
    if out_fd is not None:
        os.makedirs(f'{out_fd}', exist_ok=True)
        os.makedirs(os.path.join(out_fd, 'pred'), exist_ok=True)
    pred_musicxml_images = []
    src_musicxml_images = []
    pred_lyrics = [dictionary.string(temp_sent[0]['tokens'].cpu(), extra_symbols_to_ignore=[dictionary.pad()]).split(' ') for temp_sent in hypos]
    eos_pos = [(temp_sent[0]['tokens'] == dictionary.eos()).nonzero() for temp_sent in hypos]
    pred_alignments = [temp_sent[0]['pred_alignments'].cpu().numpy() for temp_sent in hypos]

    pred_futures = []
    pred_out_fd = os.path.join(out_fd, 'pred') if out_fd is not None else None
    for cnt, (notes, durs, lyrics, alignments) in enumerate(zip(sample["net_input"]['notes'], sample["net_input"]['durs'], pred_lyrics, pred_alignments)):
        id = sample['id'][cnt].detach().cpu().numpy()
        notes = notes.detach().cpu().numpy()
        durs = durs.detach().cpu().numpy()
        _eos_pos = eos_pos[cnt][-1].detach().cpu().numpy()
        pred_futures.append(pool.apply_async(image_process, args=(id, cnt, lyrics, notes, durs, alignments, _eos_pos, log_dir, pred_out_fd)))

    pool.close()
    for future in pred_futures:
        pred_musicxml_images.append(future.get())
        # infer wav后empty_cache
        torch.cuda.empty_cache()
    pool.join()
    for i, pred_musicxml_image in enumerate(pred_musicxml_images):
        pred_musicxml_images[i] = np.array(Image.open(pred_musicxml_image))

    if log_src:
        # get rid of <bos> and <eos>
        src_out_fd = os.path.join(out_fd, 'src') if out_fd is not None else None
        if out_fd is not None:
            os.makedirs(src_out_fd, exist_ok=True)
        pool = Pool(8)
        src_futures = []
        src_lyrics = [dictionary.string(temp_sent.cpu()[2:], extra_symbols_to_ignore=[dictionary.pad()]).split(' ') for temp_sent in sample["source"]]
        eos_pos = [(temp_sent[2:] == dictionary.eos()).nonzero() for temp_sent in sample["source"]]
        for cnt, (notes, durs, lyrics, alignments) in enumerate(zip(sample["net_input"]['notes'], sample["net_input"]['durs'], src_lyrics,
                                                   sample["net_input"]['src_alignments'])):
            id = sample['id'][cnt].detach().cpu().numpy()
            notes = notes.detach().cpu().numpy()
            durs = durs.detach().cpu().numpy()
            _eos_pos = eos_pos[cnt][-1].detach().cpu().numpy()
            src_futures.append(pool.apply_async(image_process, args=(id, cnt, lyrics, notes, durs, alignments, len(lyrics), log_dir, src_out_fd)))


    if log_src:
        pool.close()
        for future in src_futures:
            src_musicxml_images.append(future.get())
            # infer wav后empty_cache
            torch.cuda.empty_cache()
        pool.join()
        for i, src_musicxml_image in enumerate(src_musicxml_images):
            src_musicxml_images[i] = np.array(Image.open(src_musicxml_image))

    return src_musicxml_images, pred_musicxml_images


def load_melody(melody_data_path):
    melody = []
    with open(melody_data_path, 'r') as f:
        for line in f.readlines():
            notes, durs, is_slurs = line.split('|')
            notes = [int(x) for x in notes.split(' ') if x != '']
            durs = [int(x) for x in durs.split(' ') if x != '']
            is_slurs = [bool(x) for x in is_slurs.split(' ') if x != '']
            for i, note in enumerate(notes):
                notes[i] = min(max(note, 0), 127)
            melody.append({'notes': notes,
                           'durs': durs,
                           'is_slur': is_slurs})
    return melody

def load_align(alignment_data_path):
    alignment = []
    with open(alignment_data_path, 'r') as f:
        for line in f.readlines():
            alignment.append([int(x) for x in line.strip().split(' ') if x != '\n' and x != ''])
    return alignment


def load_multi_langpair_with_alignments_dataset(
    data_path,
    split,
    src,
    src_dict,
    tgt,
    tgt_dict,
    combine,
    dataset_impl,
    upsample_primary,
    left_pad_source,
    left_pad_target,
    max_source_positions,
    max_target_positions,
    prepend_bos=False,
    load_alignments=True,
    truncate_source=False,
    add_language_token=False,
    add_length_token=False,
    domain=None,
    common_eos=None,
    num_buckets=0,
    shuffle=True,
    pad_to_multiple=1,
    tag_num=0,
    max_delta_note=15,
):
    def split_exists(split, src, tgt, lang, data_path):
        filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang))
        return indexed_dataset.dataset_exists(filename, impl=dataset_impl)

    def replace_eos(dataset, dictionary, eos_token):
        dataset = StripTokenDataset(dataset, dictionary.eos())
        eos_index = dictionary.index("[{}]".format(eos_token))
        return AppendTokenDataset(dataset, eos_index)

    def length_getter(dataset, idx):
        item = dataset[idx]
        length_token = '[LEN{}]'.format(len(item) - 1)
        length_token_index = tgt_dict.index(length_token)
        return length_token_index


    src_datasets = []
    tgt_datasets = []
    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else "")

        # infer langcode
        if split_exists(split_k, src, tgt, src, data_path):
            prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt))
        elif split_exists(split_k, tgt, src, src, data_path):
            prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src))
        else:
            if k > 0:
                break
            else:
                raise FileNotFoundError(
                    "Dataset not found: {} ({})".format(split, data_path)
                )

        src_dataset = data_utils.load_indexed_dataset(
            prefix + src, src_dict, dataset_impl
        )
        if truncate_source:
            src_dataset = AppendTokenDataset(
                TruncateDataset(
                    StripTokenDataset(src_dataset, src_dict.eos()),
                    max_source_positions - tag_num - 1,
                ),
                src_dict.eos(),
            )
        src_datasets.append(src_dataset)

        tgt_dataset = data_utils.load_indexed_dataset(
            prefix + tgt, tgt_dict, dataset_impl
        )
        if tgt_dataset is not None:
            tgt_datasets.append(tgt_dataset)

        logger.info(
            "{} {} {}-{} {} examples".format(
                data_path, split_k, src, tgt, len(src_datasets[-1])
            )
        )

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
    else:
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    if add_length_token:
        tgt_dataset = PrependDataset(tgt_dataset, prepend_getter=length_getter)

    # add tags
    if domain is not None:
        src_dataset = PrependTokenDataset(src_dataset, src_dict.index("[{}]".format(domain)))
    if add_language_token:
        src_dataset = PrependTokenDataset(
                src_dataset, tgt_dict.index('[2{}]'.format(tgt))
            )
    

    if load_alignments:
        align_path = os.path.join(data_path, "{}.{}-{}.alignment.".format(split, src, tgt))

        src_align_dataset = load_align(align_path + src)
        tgt_align_dataset = load_align(align_path + tgt)

    melody = load_melody(os.path.join(data_path, "{}.melody".format(split)))

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    return MultiLanguagePairWithMelodyDataset(
        split,
        src_dataset,
        src_align_dataset,
        src_dataset.sizes,
        src_dict,
        tgt_dataset,
        tgt_align_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        melody,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        num_buckets=num_buckets,
        shuffle=shuffle,
        pad_to_multiple=pad_to_multiple,
        max_delta_note=max_delta_note,
    )

@register_task('xdae_multilingual_translation_with_melody')
class XDAEMultilingualTranslationWithMelodyTask(DenoisingTask):
    @staticmethod
    def add_args(parser):
        DenoisingTask.add_args(parser)
        # for pretrain
        parser.add_argument(
            "--multilang-sampling-alpha",
            type=float,
            default=1.0,
            help="smoothing alpha for sample ratios across multiple datasets",
        )
        parser.add_argument("--downsample-by-min", default=False, action="store_true",
                            help="Downsample all large dataset by the length of smallest dataset")
        parser.add_argument("--add-lang-token", default=False, action="store_true")
        parser.add_argument("--add-length-token", default=False, action="store_true")
        parser.add_argument("--with-len", default=False, action="store_true")
        parser.add_argument('--prepend-bos', default=False, action='store_true')

        parser.add_argument('--placeholder', type=int,
                            help="placeholder for more special ids such as language ids",
                            default=-1)
        parser.add_argument("--add-tgt-len-tags", type=int, default=0,
                            help="number of length tags to add")
        parser.add_argument('--word-shuffle', type=float, default=0,
                            help="Randomly shuffle input words (0 to disable)")
        parser.add_argument("--word-dropout", type=float, default=0,
                            help="Randomly dropout input words (0 to disable)")
        parser.add_argument("--word-blank", type=float, default=0,
                            help="Randomly blank input words (0 to disable)")

        parser.add_argument('--validation-inference-interval', type=int,
                            help="interval for inference validation",
                            default=5)

        parser.add_argument('--distance-reward', default=False, action='store_true')
        parser.add_argument(
            "--max-delta-note",
            default=90,
            type=int,
            help="ACT max halting steps",
        )

        parser.add_argument('--sampled-data', default=False, action='store_true')
        parser.add_argument(
            "--langs", type=str, help="language ids we are considering", default=None
        )
        parser.add_argument(
            "--no-whole-word-mask-langs",
            type=str,
            default="",
            metavar="N",
            help="languages without spacing between words dont support whole word masking",
        )
        parser.add_argument('--finetune-langs', type=str, 
                            help="language pairs to finetune',', for example, 'en-zh,zh-en'", 
                            default=None)
        parser.add_argument('--finetune-data', type=str, 
                            help="finetuning data path", 
                            default=None)
        parser.add_argument('--finetune-domain', type=str,
                            help="finetuning data domain",
                            default=None)
        parser.add_argument('--generate-with-melody',
                            help="whether to use melody alignment generator",
                            action='store_true')
        parser.add_argument('--common-eos', type=str, 
                            help="common end of sentence tag for all tasks/langs", 
                            default=None)
        parser.add_argument('--domains', type=str,
                            help="domains to pretrain ',', for example, 'LYRICS,WMT'",
                            default=None)
        parser.add_argument("--use-domain-eos", action="store_true",
                            help="use domain tag as end of sentence",
                            default=False)
        # for generation
        parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
                            help='source language')
        parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
                            help='target language')
        parser.add_argument('--load-alignments', action='store_true',
                            help='load the binarized alignments')
        parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
                            help='pad the source on the left')
        parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL',
                            help='pad the target on the left')
        parser.add_argument('--upsample-primary', default=1, type=int,
                            help='amount to upsample primary dataset')
        parser.add_argument('--truncate-source', action='store_true', default=False,
                            help='truncate source to max-source-positions')

        ## options for reporting BLEU during validation
        parser.add_argument('--eval-bleu', action='store_true',
                            help='evaluation with BLEU scores')
        #parser.add_argument('--eval-bleu-detok', type=str, default="space",
                            #help='detokenizer before computing BLEU (e.g., "moses"); '
                                 #'required if using --eval-bleu; use "space" to '
                                 #'disable detokenization; see fairseq.data.encoders '
                                 #'for other options')
        #parser.add_argument('--eval-bleu-detok-args', type=str, metavar='JSON',
                            #help='args for building the tokenizer, if needed')
        #parser.add_argument('--eval-tokenized-bleu', action='store_true', default=False,
                            #help='if setting, we compute tokenized BLEU instead of sacrebleu')
        parser.add_argument('--eval-bleu-remove-bpe', nargs='?', const='@@ ', default=None,
                            help='remove BPE before computing BLEU')
        #parser.add_argument('--eval-bleu-args', type=str, metavar='JSON',
                            #help='generation args for BLUE scoring, '
                                 #'e.g., \'{"beam": 4, "lenpen": 0.6}\'')
        parser.add_argument('--eval-bleu-print-samples', action='store_true',
                           help='print sample generations during validation')
        parser.add_argument('--eval-inference', action='store_true',
                            help='validation with inference')
        parser.add_argument('--eval-align-dist', action='store_true',
                            help='evaluate alignment distribution distance in validation')
        parser.add_argument('--eval-inference-start-step', type=int,
                            default=1000,
                            help='validation with inference start step')
        parser.add_argument('--with-backtrans-data', action='store_true',
                            help='whether to train with back-translation data')
        parser.add_argument('--only-backtrans-data', action='store_true',
                            help='whether to train with back-translation data only')


    @classmethod
    def setup_task(cls, args, **kwargs):
        """Setup the task."""
        path = args.data

        if args.with_backtrans_data or args.only_backtrans_data:
            path = path.split(':')[0]
        assert len(path) > 0
        
        if args.langs is None:
            if args.sampled_data:
                languages = list(cls.get_languages(cls, path))
            else:
                languages = sorted([
                    name for name in os.listdir(path)
                    if os.path.isdir(os.path.join(path, name))
                ])
        else:
            languages = args.langs.split(",")

        dict_path = path
        if os.path.exists(os.path.join(dict_path, "dict.txt")):
            dictionary = Dictionary.load(os.path.join(dict_path, "dict.txt"))
        else:
            dictionary = Dictionary.load(os.path.join(dict_path, f"dict.{languages[0]}.txt"))

        domains = args.domains.split(',') if args.domains is not None else None
        assert (args.finetune_domain in domains or args.finetune_domain is None), args.finetune_domain
        dictionary.add_symbol('<mask>')
        if args.add_lang_token:
            if args.common_eos is not None:
                dictionary.add_symbol('[{}]'.format(args.common_eos))
            if domains is not None:
                for d in domains:
                    dictionary.add_symbol(f"[{d}]")
            for lang in languages:
                dictionary.add_symbol('[2{}]'.format(lang))
            if args.add_tgt_len_tags > 0:
                for i in range(args.add_tgt_len_tags):
                    dictionary.add_symbol('[LEN{}]'.format(i+1))
            if args.placeholder > 0:
                for i in range(args.placeholder):
                    dictionary.add_symbol('[placeholder{}]'.format(i))
            

        logger.info("dictionary: {} types".format(len(dictionary)))
        if not hasattr(args, "shuffle_instance"):
            args.shuffle_instance = False

        args.left_pad_source = utils.eval_bool(args.left_pad_source)
        args.left_pad_target = utils.eval_bool(args.left_pad_target)

        return cls(args, dictionary)

    def __init__(self, args, dictionary):
        super().__init__(args, dictionary)
        self.dictionary = dictionary
        # self.src_dict = dictionary
        # self.tgt_dict = dictionary
        self.seed = args.seed
        self.src_logged = False
        # add mask token
        self.mask_idx = dictionary.index('<mask>')
        self.langs = args.langs
        self.args = args
        self.path_cache = {}
        self.ft_langs = None if args.finetune_langs is None else args.finetune_langs.split(",") 

        self.tensorboard_writer = None
        self.tensorboard_dir = ""
        self.tensorboard_logdir = args.tensorboard_logdir
        if getattr(self.args, "eval_align_dist", False):
            from fairseq import scoring
            self.align_dist_scorer = scoring.build_scorer('align_dist', None)

        if args.tensorboard_logdir and SummaryWriter is not None:
            self.tensorboard_dir = os.path.join(args.tensorboard_logdir, "valid_extra")

        if args.validation_inference_interval > 0:
            import sys
            sys.path.append('./DiffSinger')
            from utils.hparams import set_hparams, hparams
            from inference.svs.ds_e2e import DiffSingerE2EInfer
            set_hparams(exp_name='0831_opencpop_ds1000', config='usr/configs/midi/e2e/opencpop/ds1000.yaml', print_hparams=False)
            self.wav_infer_ins = DiffSingerE2EInfer(hparams)


    def _get_sample_prob(self, dataset_lens):
        """
        Get smoothed sampling porbability by languages. This helps low resource
        languages by upsampling them.
        """
        
        prob = dataset_lens / dataset_lens.sum()
        smoothed_prob = prob ** self.args.multilang_sampling_alpha
        smoothed_prob = smoothed_prob / smoothed_prob.sum()
        return smoothed_prob

    def get_languages(self, data_folder):
        files = [path for path in os.listdir(data_folder)]
        lgs = set([x.split('.')[-2] for x in files])
        return lgs

    def get_dataset_path(self, split, data_folder, epoch, lgs=None, is_pair=False):
        if data_folder in self.path_cache:
            files = self.path_cache[data_folder]
        else:
            files = [path for path in os.listdir(data_folder)]
            # remove this to speed up
            # if os.path.isfile(os.path.join(data_folder, path))
            self.path_cache[data_folder] = files

        files = [path for path in files if(split in path) and (".bin" in path)]  

        if lgs is None:
            lgs = set([x.split('.')[-2] for x in files])

        paths = {} 
        for lg_index, lg in enumerate(lgs):
            if is_pair:
                pair = lg.split('-')
                split_count = len([path for path in files if ".{0}.{1}.bin".format(lg, pair[0]) in path])
            else:
                split_count = len([path for path in files if ".{0}.bin".format(lg) in path])
            big_step = epoch // split_count
            small_step = epoch % split_count
            with data_utils.numpy_seed((self.args.seed + big_step) * 100 + lg_index):
                shuffle = np.random.permutation(split_count)
                index = shuffle[small_step]
                path = os.path.join(data_folder, "{0}.{1}.{2}".format(split, index, lg))
                paths[lg] = path
        return paths

    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        path = self.args.data
        if self.args.with_backtrans_data or self.args.only_backtrans_data:
            path, bt_path = path.split(':')
        # pretrained dataset path
        lang_splits = [split]

        ft_path = self.args.finetune_data if path is not None else None
        ft_datasets = []
        bt_datasets = []
        bt_dataset = None


        for pair in self.ft_langs:
            src, tgt = pair.split("-")
            if ft_path is not None:
                lang_dataset = load_multi_langpair_with_alignments_dataset(
                    ft_path,
                    split,
                    src,
                    self.source_dictionary,
                    tgt,
                    self.target_dictionary,
                    combine=combine,
                    dataset_impl=self.args.dataset_impl,
                    upsample_primary=self.args.upsample_primary,
                    load_alignments=True,
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                    max_source_positions=getattr(self.args, 'max_source_positions', 512),
                    max_target_positions=getattr(self.args, 'max_target_positions', 512),
                    prepend_bos=getattr(self.args, 'preprend_bos', False),
                    add_language_token=self.args.add_lang_token,
                    add_length_token=self.args.add_length_token,
                    domain=self.args.finetune_domain,
                    common_eos=self.args.common_eos,
                    max_delta_note=self.args.max_delta_note,
                    )
                ft_datasets.append(lang_dataset)
            if (self.args.with_backtrans_data or self.args.only_backtrans_data) and split == getattr(self.args, "train_subset", "train"):
                bt_dataset = load_multi_langpair_with_alignments_dataset(
                    bt_path,
                    split,
                    src,
                    self.source_dictionary,
                    tgt,
                    self.target_dictionary,
                    combine=combine,
                    dataset_impl=self.args.dataset_impl,
                    upsample_primary=self.args.upsample_primary,
                    load_alignments=True,
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                    max_source_positions=getattr(self.args, 'max_source_positions', 512),
                    max_target_positions=getattr(self.args, 'max_target_positions', 512),
                    prepend_bos=getattr(self.args, 'preprend_bos', False),
                    add_language_token=self.args.add_lang_token,
                    add_length_token=self.args.add_length_token,
                    domain=self.args.finetune_domain,
                    common_eos=self.args.common_eos,
                    max_delta_note=self.args.max_delta_note,
                )
                bt_datasets.append(bt_dataset)


        if split == getattr(self.args, "train_subset", "train"):
            if len(ft_datasets) > 1:
                dataset_lengths = np.array([len(d) for d in ft_datasets], dtype=float)

                sample_probs = self._get_sample_prob(dataset_lengths)
                logger.info("Sample probability by language pair: {}".format({
                        pair: "{0:.4f}".format(sample_probs[id])
                        for id, pair in enumerate(self.ft_langs)
                    })
                )
                size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths
                logger.info("Up/Down Sampling ratio by language for finetuning: {}".format({
                        pair: "{0:.2f}".format(size_ratio[id])
                        for id, pair in enumerate(self.ft_langs)
                    })
                )
            
                resampled_lang_datasets = [
                    ResamplingDataset(
                        ft_datasets[i],
                        size_ratio=size_ratio[i],
                        seed=self.args.seed,
                        epoch=epoch,
                        replace=size_ratio[i] >= 1.0,
                    )
                    for i, d in enumerate(ft_datasets)
                ]
                ft_dataset = ConcatDataset(
                    resampled_lang_datasets,
                    )
            else:
                ft_dataset = ft_datasets[0] if len(ft_datasets) > 0 else None
                if self.args.with_backtrans_data or self.args.only_backtrans_data:
                    bt_dataset = bt_datasets[0] if len(bt_datasets) > 0 else None

                    bt_size_ratio = max((self.args.max_epoch // 2 - epoch), 0.05) / self.args.max_epoch
                    ft_size_ratio = max((self.args.max_epoch - epoch), 5) / 5

                    logger.info("Up/Down Sampling ratio for back translation and finetune: {}".format({
                        'finetune': ft_size_ratio,
                        'back translation': bt_size_ratio
                    }))
                    if ft_dataset is not None and self.args.with_backtrans_data:
                        ft_dataset = ResamplingDataset(
                            ft_dataset,
                            size_ratio=ft_size_ratio,
                            seed=self.args.seed,
                            epoch=epoch,
                            replace=ft_size_ratio >= 1.0,
                        )
                        if bt_size_ratio > 0:
                            bt_dataset = ResamplingDataset(
                                bt_dataset,
                                size_ratio=bt_size_ratio,
                                seed=self.args.seed,
                                epoch=epoch,
                                replace=bt_size_ratio < 1.0,
                            )
                        else:
                            bt_dataset = None
        else:
            ft_dataset = ConcatDataset(ft_datasets)

            domain_name = "_{}".format(self.args.finetune_domain) if self.args.finetune_domain is not None else ""
            for lang_id, lang_dataset in enumerate(ft_datasets):
                split_name = split + "_" + self.ft_langs[lang_id] + domain_name
                lang_splits.append(split_name)
                self.datasets[split_name] = lang_dataset

            if hasattr(self.args, "valid_subset"):
                if split in self.args.valid_subset:
                    self.args.valid_subset = self.args.valid_subset.replace(
                        split, ','.join(lang_splits)
                   )


        if self.args.with_backtrans_data and bt_dataset is not None and ft_dataset is not None:
            ft_dataset = ConcatDataset([ft_dataset, bt_dataset])
        elif self.args.only_backtrans_data and bt_dataset is not None:
            ft_dataset = bt_dataset

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(ft_dataset))


        if split == getattr(self.args, "train_subset", "train"):
            self.datasets[split] = SortDataset(
                ft_dataset,
                sort_order=[
                    shuffle,
                    ft_dataset.sizes,
                ],
            )
        else:
            self.datasets[split] = ft_dataset


    def valid_step(self, sample, model, criterion):
        model.eval()
        with torch.no_grad():
            loss, sample_size, logging_output = criterion(model, sample)

        # for i in range(len(sample['net_input']['src_tokens'])):
        #     print('src token id:', sample['net_input']['src_tokens'][i])

        if getattr(self.args, "eval_inference", False) and model._num_updates > getattr(self.args, "eval_inference_start_step", 0):
            if self.args.validation_inference_interval > 0 and model._num_updates % self.args.validation_inference_interval == 0:
                hypos, inference_losses = self.valid_step_with_inference(sample, model)
                picked_id = 0
                align_dist = None
                if getattr(self.args, "eval_align_dist", False):
                    self.align_dist_scorer.empty()
                    for temp_pred, ref_align in zip(hypos, sample['net_input']['tgt_alignments']):
                        self.align_dist_scorer.add(ref_align.cpu().numpy(), temp_pred[0]['pred_alignments'].cpu().numpy())
                    align_dist = (self.align_dist_scorer.score(), self.align_dist_scorer.result_overlap_histogram())

                if self.tensorboard_dir and (sample["id"] == picked_id).any():
                    self.log_tensorboard(sample, hypos, model._num_updates, alignment_dist=align_dist)



        return loss, sample_size, logging_output


    def valid_step_with_inference(self, sample, model):
        with torch.no_grad():
            bos_token = self.dictionary.bos()
            hypos = self.generator.generate([model], sample, bos_token=bos_token)
        losses = {}
        return hypos, losses


    def log_tensorboard(self, sample, hypos, num_updates, alignment_dist=None):
        if self.tensorboard_writer is None:
            self.tensorboard_writer = SummaryWriter(self.tensorboard_dir)
        tb_writer = self.tensorboard_writer

        if alignment_dist is not None:
            tb_writer.add_image("align hist", alignment_dist[1], global_step=num_updates, dataformats='HWC')
            tb_writer.add_scalar("align distance", alignment_dist[0], global_step=num_updates)

        src_musicxml_images, pred_musicxml_images = get_musicxml_image(sample, hypos, self.target_dictionary, log_src=not self.src_logged, log_dir=self.tensorboard_logdir)
        src_aligned_notes, pred_aligned_notes = get_aligned_notes(sample, hypos, self.target_dictionary, log_src=not self.src_logged)
        if not self.src_logged:
            src_lyrics = [self.dictionary.string(temp_sent.cpu()[2:], extra_symbols_to_ignore=[self.dictionary.pad()]) for temp_sent in sample['source']]
            gt_lyrics = [self.dictionary.string(temp_sent.cpu(), extra_symbols_to_ignore=[self.dictionary.pad()]) for temp_sent in sample['target']]
            _, gt_musicxml_images = get_musicxml_image(sample, [[{'tokens': sample['target'][i],
                                                                  'pred_alignments': sample['net_input']['tgt_alignments'][i],}] for i in range(len(sample['target']))],
                                                                self.target_dictionary,
                                                                log_src=self.src_logged,
                                                                log_dir=self.tensorboard_logdir)

            _, gt_aligned_notes = get_aligned_notes(sample, [[{'tokens': sample['target'][i],
                                                                  'pred_alignments': sample['net_input']['tgt_alignments'][
                                                                      i], }] for i in range(len(sample['target']))], self.target_dictionary,
                                                       log_src=self.src_logged)
        pred_lyrics = [self.dictionary.string(temp_sent[0]['tokens'].cpu(), extra_symbols_to_ignore=[self.dictionary.pad()]) for temp_sent in hypos]
        pred_alignments = [temp_sent[0]['pred_alignments'].cpu().numpy() for temp_sent in hypos]

        pred_pool = Pool(8)
        pred_futures = []
        src_pool = Pool(4)
        gt_pool = Pool(4)
        src_futures = []
        gt_futures = []

        for i in range(len(pred_musicxml_images)):
            tb_writer.add_image(
                f"infer_sample_{i}", pred_musicxml_images[i], num_updates, dataformats="HWC"
            )
            tb_writer.add_text(f"infer_sample_{i}", pred_lyrics[i], num_updates)
            tb_writer.add_text(f"infer_align_{i}", ' '.join([str(x) for x in hypos[i][0]['pred_alignments'].cpu().numpy()]), num_updates)

            pred_futures.append(pred_pool.apply_async(wav_process,
                                                      args=(pred_aligned_notes[i],
                                                            pred_lyrics[i],
                                                            pred_alignments[i],
                                                            self.wav_infer_ins,
                                                            self.ft_langs[0].split("-")[-1])))
            # try:
            #     temp_wav = out_to_wav(pred_lyrics[i], pred_aligned_notes[i], infer_ins=self.wav_infer_ins, lang=self.ft_langs[0].split("-")[-1])
            #     tb_writer.add_audio(f"infer_{i}", temp_wav, num_updates, sample_rate=24000)
            # except (TypeError, RuntimeError, IndexError) as e:
            #     print(e)
            #     print('out to pred wav failed')
            #     print(pred_lyrics[i])
            #     print(pred_alignments[i])
            #     print(pred_aligned_notes[i])
            #     print('===================')

            if not self.src_logged:

                tb_writer.add_image(
                    f"src_sample_{i}", src_musicxml_images[i], num_updates, dataformats="HWC"
                )
                tb_writer.add_text(f"src_sample_{i}", src_lyrics[i], num_updates)
                tb_writer.add_text(f"src_align_{i}", ' '.join([str(x) for x in sample['net_input']['src_alignments'][i].cpu().numpy()]), num_updates)
                tb_writer.add_image(
                    f"gt_sample_{i}", gt_musicxml_images[i], num_updates, dataformats="HWC"
                )
                tb_writer.add_text(f"gt_sample_{i}", gt_lyrics[i], num_updates)
                tb_writer.add_text(f"gt_align_{i}", ' '.join([str(x) for x in sample['net_input']['tgt_alignments'][i].cpu().numpy()]), num_updates)
                src_futures.append(src_pool.apply_async(wav_process,
                                                        args=(src_aligned_notes[i],
                                                              src_lyrics[i],
                                                              sample["net_input"]['src_alignments'][i],
                                                              self.wav_infer_ins,
                                                              self.ft_langs[0].split("-")[0])))
                # try:
                #     temp_wav = out_to_wav(src_lyrics[i], src_aligned_notes[i], infer_ins=self.wav_infer_ins, lang=self.ft_langs[0].split("-")[0])
                #     tb_writer.add_audio(f"src_{i}", temp_wav, num_updates, sample_rate=24000)
                # except (TypeError, RuntimeError, IndexError) as e:
                #     print(e)
                #     print('out to src wav failed')
                #     print(src_lyrics[i])
                #     print(sample["net_input"]['src_alignments'][i])
                #     print(src_aligned_notes[i])
                #     print('===================')
                gt_futures.append(gt_pool.apply_async(wav_process,
                                                      args=(gt_aligned_notes[i],
                                                            gt_lyrics[i],
                                                            sample['net_input']['tgt_alignments'][i],
                                                            self.wav_infer_ins,
                                                            self.ft_langs[0].split("-")[-1])))
                # try:
                #     temp_wav = out_to_wav(gt_lyrics[i], gt_aligned_notes[i], infer_ins=self.wav_infer_ins, lang=self.ft_langs[0].split("-")[-1])
                #     tb_writer.add_audio(f"gt_{i}", temp_wav, num_updates, sample_rate=24000)
                # except (TypeError, RuntimeError, IndexError) as e:
                #     print(e)
                #     print('out to gt wav failed')
                #     print(gt_lyrics[i])
                #     print(sample['net_input']['tgt_alignments'][i])
                #     print(gt_aligned_notes[i])
                #     print('===================')
        pred_pool.close()
        for i, future in enumerate(pred_futures):
            torch.cuda.empty_cache()
            temp_wav = future.get()
            if temp_wav is not None:
                tb_writer.add_audio(f"infer_{i}", temp_wav, num_updates, sample_rate=24000)
            torch.cuda.empty_cache()
        pred_pool.join()

        src_pool.close()
        gt_pool.close()
        if not self.src_logged:
            for i, future in enumerate(src_futures):
                torch.cuda.empty_cache()
                temp_wav = future.get()
                if temp_wav is not None:
                    tb_writer.add_audio(f"src_{i}", temp_wav, num_updates, sample_rate=24000)
                torch.cuda.empty_cache()
            src_pool.join()
            for i, future in enumerate(gt_futures):
                torch.cuda.empty_cache()
                temp_wav = future.get()
                if temp_wav is not None:
                    tb_writer.add_audio(f"gt_{i}", temp_wav, num_updates, sample_rate=24000)
                torch.cuda.empty_cache()
            gt_pool.join()
            self.src_logged = True


    def inference_step(self, generator, models, sample, prefix_tokens=None, constraints=None, path='', save_results=False):
        with torch.no_grad():
            bos_token = self.dictionary.bos()
            hypos = generator.generate(models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token)
        src_musicxml_images, pred_musicxml_images = get_musicxml_image(sample, hypos, self.target_dictionary, out_fd=os.path.join(self.tensorboard_logdir, "inference_musicxml"), log_src=True, log_dir=self.tensorboard_dir)
        _, pred_aligned_notes = get_aligned_notes(sample, hypos, self.target_dictionary,
                                                                  log_src=not self.src_logged)
        pred_lyrics = [self.dictionary.string(temp_sent[0]['tokens'].cpu()) for temp_sent in hypos]
        from fairseq.models.bart.musicxml_utils import save_wav
        os.makedirs(f'{self.tensorboard_logdir}/inference_results/', exist_ok=True)

        if save_results:
            pred_pool = Pool(16)
            pred_futures = []
            pred_alignments = [temp_sent[0]['pred_alignments'].cpu() for temp_sent in hypos]

            for i in range(len(pred_lyrics)):
                pred_futures.append(pred_pool.apply_async(wav_process, args=(pred_aligned_notes[i], pred_lyrics[i], pred_alignments[i], self.wav_infer_ins, self.ft_langs[0].split("-")[-1])))
                # try:
                #     temp_wav = out_to_wav(pred_lyrics[i], pred_aligned_notes[i], infer_ins=self.wav_infer_ins, lang=self.ft_langs[0].split("-")[-1])
                #     save_wav(temp_wav, path=f'{self.tensorboard_logdir}/inference_results/pred_{i}.wav', sr=24000)
                # except (TypeError, RuntimeError, IndexError):
                #     pass

            pred_pool.close()
            for i, future in enumerate(pred_futures):
                temp_wav = future.get()
                id = int(sample['id'][i])
                if temp_wav is not None:
                    save_wav(temp_wav, path=f'{self.tensorboard_logdir}/inference_results/pred_{id}.wav', sr=24000)
            pred_pool.join()

        return (hypos, src_musicxml_images, pred_musicxml_images)


    def build_model(self, args):
        model = super().build_model(args)
        self.generator = None
        if getattr(args, "eval_inference", False):
            self.generator = self.build_generator([model], args)
        return model


    def build_generator(
        self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None
    ):
        eos = self.source_dictionary.eos()
        extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}

        if getattr(args, "score_reference", False):
            from fairseq.sequence_scorer import SequenceScorer

            return SequenceScorer(
                self.target_dictionary,
                compute_alignment=getattr(args, "print_alignment", False),
            )

        if getattr(args, 'generate_with_melody', False) or getattr(self.args, "eval_inference", False) or getattr(args, "print_alignment", False) or extra_gen_cls_kwargs.get('generate_with_melody', False):
            from fairseq.sequence_generator_with_constraints import (
                SequenceGenerator,
                SequenceGeneratorWithMelodyAlignments,
                SequenceGeneratorWithAlignment,
            )
        else:
            from fairseq.sequence_generator_with_prefix import (
                SequenceGenerator,
            )

        # Choose search strategy. Defaults to Beam Search.
        sampling = getattr(args, "sampling", False)
        sampling_topk = getattr(args, "sampling_topk", -1)
        sampling_topp = getattr(args, "sampling_topp", -1.0)
        diverse_beam_groups = getattr(args, "diverse_beam_groups", -1)
        diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5)
        match_source_len = getattr(args, "match_source_len", False)
        diversity_rate = getattr(args, "diversity_rate", -1)
        constrained = getattr(args, "constraints", False)
        prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None)
        if (
            sum(
                int(cond)
                for cond in [
                    sampling,
                    diverse_beam_groups > 0,
                    match_source_len,
                    diversity_rate > 0,
                ]
            )
            > 1
        ):
            raise ValueError("Provided Search parameters are mutually exclusive.")
        assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling"
        assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling"

        if sampling:
            search_strategy = search.Sampling(
                self.target_dictionary, sampling_topk, sampling_topp
            )
        elif diverse_beam_groups > 0:
            search_strategy = search.DiverseBeamSearch(
                self.target_dictionary, diverse_beam_groups, diverse_beam_strength
            )
        elif match_source_len:
            # this is useful for tagging applications where the output
            # length should match the input length, so we hardcode the
            # length constraints for simplicity
            search_strategy = search.LengthConstrainedBeamSearch(
                self.target_dictionary,
                min_len_a=1,
                min_len_b=0,
                max_len_a=1,
                max_len_b=0,
            )
        elif diversity_rate > -1:
            search_strategy = search.DiverseSiblingsSearch(
                self.target_dictionary, diversity_rate
            )
        elif constrained:
            search_strategy = search.LexicallyConstrainedBeamSearch(
                self.target_dictionary, args.constraints
            )
        elif prefix_allowed_tokens_fn:
            search_strategy = search.PrefixConstrainedBeamSearch(
                self.target_dictionary, prefix_allowed_tokens_fn
            )
        else:
            search_strategy = search.BeamSearch(self.target_dictionary)


        if seq_gen_cls is None:
            if getattr(args, "generate_with_melody", False) or getattr(self.args, "eval_inference", False) or extra_gen_cls_kwargs.get('generate_with_melody', False):
                seq_gen_cls = SequenceGeneratorWithMelodyAlignments
                extra_gen_cls_kwargs.pop("generate_with_melody", None)
            elif getattr(args, "print_alignment", False):
                seq_gen_cls = SequenceGeneratorWithAlignment
                extra_gen_cls_kwargs["print_alignment"] = args.print_alignment
            else:
                seq_gen_cls = SequenceGenerator

        return seq_gen_cls(
            models,
            self.target_dictionary,
            beam_size=getattr(args, "beam", 5),
            max_len_a=getattr(args, "max_len_a", 0),
            max_len_b=getattr(args, "max_len_b", 200),
            min_len=getattr(args, "min_len", 1),
            normalize_scores=(not getattr(args, "unnormalized", False)),
            len_penalty=getattr(args, "lenpen", 1),
            unk_penalty=getattr(args, "unkpen", 0),
            temperature=getattr(args, "temperature", 1.0),
            match_source_len=getattr(args, "match_source_len", False),
            no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
            search_strategy=search_strategy,
            eos=eos,
            **extra_gen_cls_kwargs,
        )

if __name__ == '__main__':
    test_dict = Dictionary.load('/Users/xiji/Desktop/dataset/BPE/vocab.all.truncated.ini')
    test_sample = {'net_input':{'notes': [torch.LongTensor([69,69,72,72,77,77,76,76,74,69,74,74])],
                   'durs': [torch.LongTensor([4,4,4,4,4,4,4,4,4,4,4,2])]}}
    test_hypos = [[{'tokens': test_dict.encode_line('The rain drops are kissing grass so fresh and green <eos>s'),
                    'pred_alignments':torch.LongTensor([1,1,1,1,2,1,1,1,1,2]),}]]
    # get_musicxml_image(test_sample, test_hypos, test_dict)
    get_aligned_notes(test_sample, test_hypos, test_dict)