from fairseq.tasks.multilingual_translation_with_melody import get_aligned_notes
from fairseq.data import Dictionary
import torch
from fairseq.models.bart.musicxml_utils import out_to_wav, save_wav

if __name__ == "__main__":
    test_dict = Dictionary.load('/Users/xiji/Desktop/dataset/BPE/vocab.all.truncated.ini')
    test_dict = Dictionary.load('/home/xiji.lcx/workspace/GagaST-code/data/vocab/vocab.all.truncated')
    test_sample = {'net_input':{'notes': [torch.LongTensor([69,69,72,72,77,77,76,76,74,69,74,74])],
                   'durs': [torch.LongTensor([4,4,4,4,4,4,4,4,4,4,4,2])]}}
    test_hypos = [[{'tokens': test_dict.encode_line('The rain drops are kissing grass so fresh and green <eos>'),
                    'pred_alignments':torch.LongTensor([1,1,1,1,2,1,1,1,1,2]),}]]
    test_sample = {'net_input':{'notes': [torch.LongTensor([69,69,72,72,77,76])],
                   'durs': [torch.LongTensor([4,4,4,4,4,2])]}}
    test_hypos = [[{'tokens': test_dict.encode_line('我 爱 你 <eos>'),
                    'pred_alignments':torch.LongTensor([2, 3, 3]),}]]
    # get_musicxml_image(test_sample, test_hypos, test_dict)

    import sys

    sys.path.append('./DiffSinger')
    from utils.hparams import set_hparams, hparams
    from inference.svs.ds_cascade import DiffSingerCascadeInfer

    set_hparams(exp_name='0303_opencpop_ds58_midi', config='checkpoints/0303_opencpop_ds58_midi/config.yaml',
                print_hparams=False)
    wav_infer_ins = DiffSingerCascadeInfer(hparams)
    src, tgt = get_aligned_notes(test_sample, test_hypos, test_dict)
    print(tgt)
    wav = out_to_wav('我 爱 你 <eos>', tgt[0], wav_infer_ins, lang='zh')
    save_wav(wav, './test.wav', sr=24000)
    # from fairseq import scoring
    # import numpy as np
    # align_dist_scorer = scoring.build_scorer('align_dist', None)
    # align_dist_scorer.empty()
    # preds = [[1, 1, 1],
    #          [1, 2, 1, 1, 1, 1],
    #          [1, 1, 2, 3],
    #          [1, 1, 2, 1, 1, 1, 1, 1, 1, 1],
    #          [1, 1, 1, 2, 1, 1, 1,],
    #          [1, 1, 2, 3, 1, 1, 2, 1, 1, 1, 1, 1],
    #          [1, 1, 2, 2],
    #          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    #          [1, 1, 2, 1, 1, 1, 1, 2,],
    #          [1, 1, 1, 2, 2],
    #          [1, 1, 1, 1, 1, 1, 1, 2],
    #          [1, 1, 1, 2, 1, 1, 1, 2],
    #          [1, 1, 1, 1, 1, 2, 1, 1, 1, 1],
    #          [1, 1, 1, 2, 1, 1, 1, 1, 1, 1],
    #          [3, 1, 1, 1, 1],
    #          [3, 1, 1, 1, 1, 1],
    #          [1, 1, 1, 4, 7],
    #          [2, 1, 1, 1, 1, 1, 1],
    #          [1, 1, 2, 1, 2, 1, 1, 1, 1],
    #          [1, 1, 1, 4, 2, 2, 1, 1],
    #          [1, 1, 2, 1, 2, 1, 1, 4, 1, 4, 3, 1],
    #          [1, 1, 1, 2, 1, 1],
    #          [1, 1, 1, 1, 1, 1, 2, 1, 1],
    #          [1, 1, 2, 1, 1, 1, 2],
    #          [1, 1, 1, 1, 1, 2],
    #          [1, 1, 1, 1, 1, 1, 1, 1, 3],
    #          [1, 1, 1, 1, 1, 1, 1, 2],
    #          [2, 1, 1, 1, 1, 2, 1, 1, 1],
    #          [1, 1, 1, 1, 2, 1, 2, 2],
    #          [3, 1, 1, 1, 1]]
    # preds = [[1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],
    #          [1,1,1,1,1,1,1],]
    # refs = [[3, 1, 1, 1, 1, 3,],
    #         [1, 1, 1, 4, 7],
    #         [2, 1, 1, 1, 1, 1, 1],
    #         [1, 1, 2, 1, 2, 1, 1, 1, 1],
    #         [1, 1, 1, 4, 2, 2, 1, 1],
    #         [1, 1, 2, 1, 2, 2, 1, 4, 1, 2, 3, 1],
    #         [1, 1, 1, 2, 1, 1],
    #         [1, 1, 1, 1, 1, 1, 2, 1, 1],
    #         [1, 1, 1, 1, 1, 1, 1, 6],
    #         [1, 1, 1, 1, 1, 2, 1, 2, 3],
    #         [1, 1, 1, 1, 1, 1, 2, 1, 1],
    #         [2, 1, 1, 1, 2, 1, 1, 2],
    #         [1, 1, 1, 4, 7],
    #         [2, 1, 1, 1, 1, 1, 1],
    #         [1, 1, 2, 1, 1, 1, 1, 1, 1],
    #         [1, 1, 1, 3, 2, 2, 1, 1],
    #         [1, 1, 2, 1, 1, 2, 1, 2, 1, 2, 3, 1],
    #         [1, 1, 1, 2, 1, 1],
    #         [1, 1, 1, 1, 1, 1, 1, 1, 1],
    #         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    #         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    #         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    #         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    #         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    #         [1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1],
    #         [1, 1, 1, 1, 1, 2],
    #         [1, 1, 1, 1, 1, 1],
    #         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    #         [1, 1, 1, 1],
    #         [1, 1, 1, 1]]
    #
    # for pred_align, ref_align in zip(preds, refs):
    #     align_dist_scorer.add(np.array(ref_align), np.array(pred_align))
    # _score = align_dist_scorer.score()
    # print(_score)
    #
    #
    # print(align_dist_scorer.pred_hist)
    # print(align_dist_scorer.ref_hist)
    # print(align_dist_scorer.bins)
    # try:
    #     from torch.utils.tensorboard import SummaryWriter
    # except ImportError:
    #     try:
    #         from tensorboardX import SummaryWriter
    #     except ImportError:
    #         SummaryWriter = None
    # tb_writer = SummaryWriter('./')
    # for i in range(100):
    #     tb_writer.add_image('test', align_dist_scorer.result_overlap_histogram(), i, dataformats="HWC")
