from fairseq.tasks.multilingual_translation_with_melody import(
    load_melody,
    load_align,
    wav_process,
    image_process)
from fairseq.models.bart.musicxml_utils import save_wav
import sys
import os
import multiprocessing
from multiprocessing import Pool
import torch
import numpy as np
from PIL import Image
torch.multiprocessing.set_sharing_strategy(os.getenv('TORCH_SHARE_STRATEGY', 'file_system'))
try:
    multiprocessing.set_start_method('spawn', force=True)
except RuntimeError:
   pass

sys.path.append('./DiffSinger')
from utils.hparams import set_hparams, hparams
from inference.svs.ds_e2e import DiffSingerE2EInfer

set_hparams(exp_name='0831_opencpop_ds1000', config='usr/configs/midi/e2e/opencpop/ds1000.yaml', print_hparams=False)
wav_infer_ins = DiffSingerE2EInfer(hparams)


def get_musicxml_image(sample, out_fd=None, log_src=False, log_dir=None):
    pool = Pool(8)
    os.makedirs(f'{log_dir}/tmp_musicxml_log', exist_ok=True)
    if out_fd is not None:
        os.makedirs(f'{out_fd}', exist_ok=True)
        os.makedirs(os.path.join(out_fd, 'pred'), exist_ok=True)
    pred_musicxml_images = []
    src_musicxml_images = []

    pred_futures = []
    pred_out_fd = os.path.join(out_fd, 'pred') if out_fd is not None else None
    for cnt, _sample in enumerate(sample):
        notes = _sample['notes']
        durs = _sample['durs']
        lyrics = _sample['pred_lyrics']
        alignments = _sample['pred_alignments']
        id = _sample['id']
        _eos_pos = len(lyrics)
        pred_futures.append(pool.apply_async(image_process, args=(id, cnt, lyrics, notes, durs, alignments + [1], _eos_pos, log_dir, pred_out_fd)))

    pool.close()
    for future in pred_futures:
        pred_musicxml_images.append(future.get())
        # infer wav后empty_cache
        torch.cuda.empty_cache()
    pool.join()
    for i, pred_musicxml_image in enumerate(pred_musicxml_images):
        pred_musicxml_images[i] = np.array(Image.open(pred_musicxml_image))

    return src_musicxml_images, pred_musicxml_images


def get_aligned_notes(sample, log_src=False):

    to_actual_dur = {0:0, 1:4.0, 2:2.0, 3:1.0, 4:0.5, 5:0.25, 6: 0.125,
                    7: 3.0, 8: 1.5, 9: 0.75, 10: 0.375, 11: 0.1875,
                    12: 1.3333333333, 13:0.66666666666666, 14: 0.3333333333, 15: 0.166666666,16: 0.0833333333333,
                    17: 2.6666666666, 18: 0.09375, 19: 0.03125, 20: 0.2145833333, 21: 0.07083333333,
                    22: 0.3541666666, 23: 6.9, 24: 0.0625, 25: 0.041666666, 26: 0.2, 27:0.1,
                    28: 1.75, 29:6.0, 30: 0.25}
    # {'pitch': , 'dur':}
    src_aligned_notes = []
    tgt_aligned_notes = []

    cnt = 0

    def tolerant_correct(alignments, num_notes):
        _sum = sum(alignments)
        if _sum > num_notes:
            while _sum > num_notes:
                flag = False
                for i in range(len(alignments)):
                    if alignments[i] > 1:
                        flag = True

                    if alignments[i] > 1 and _sum > num_notes:
                        alignments[i] -= 1
                        _sum -= 1
                    elif _sum <= num_notes:
                        break
                if not flag:
                    break

        elif _sum < num_notes:
            for i in range(len(alignments) - 1, -1, -1):
                if alignments[i] != 0:
                    alignments[i] += num_notes - _sum
                    break

    def convert_and_add_new_note(aligned_dict, notes, durs, note_index):
        if note_index >= len(notes) or notes[note_index] == 0:
            return
        aligned_dict[-1]['pitch'].append(notes[note_index])
        aligned_dict[-1]['dur'].append(durs[note_index])


    for _sample in sample:

        notes = _sample['notes']
        durs = _sample['durs']
        lyrics = _sample['pred_lyrics']
        alignments = _sample['pred_alignments']

        temp_pos = 0
        index = 0
        syllabic_cnt = 0

        durs = [to_actual_dur[dur] for dur in durs]
        temp_aligned_notes = []
        #把给eos的去掉
        # alignments = alignments[:-1]
        # while index < eos_pos[cnt][-1]:

        tolerant_correct(alignments, len(notes))
        while index < min(len(lyrics), len(alignments)):
            if alignments[index] == 0:
                syllabic_cnt += 1
            else:
                if syllabic_cnt > 0:
                    if alignments[index] > syllabic_cnt:
                        temp_aligned_notes.append({'pitch': [], 'dur': []})
                        convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos)
                        for i in range(1, syllabic_cnt):
                            convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos + i)
                        for i in range(alignments[index] - syllabic_cnt):
                            convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos + syllabic_cnt + i)
                    else:
                        if alignments[index] > 1:
                            temp_aligned_notes.append({'pitch': [], 'dur': []})
                            convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos)

                            for i in range(alignments[index] - 1):
                                convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos + i + 1)
                        else:
                            temp_aligned_notes.append({'pitch': [], 'dur': []})
                            convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos)

                    syllabic_cnt = 0
                else:
                    temp_aligned_notes.append({'pitch': [], 'dur': []})
                    for i in range(alignments[index]):
                        convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos + i)

            temp_pos += alignments[index]
            index += 1
        cnt += 1
        temp_aligned_notes = {'pitch': [x['pitch'] for x in temp_aligned_notes], 'dur': [x['dur'] for x in temp_aligned_notes]}
        tgt_aligned_notes.append(temp_aligned_notes)
    if log_src:
        cnt = 0
        # get rid of <bos> and <eos>
        for _sample in sample:
            notes = _sample['notes']
            durs = _sample['durs']
            lyrics = _sample['src_lyrics']
            alignments = _sample['src_alignments']

            temp_pos = 0
            index = 0
            syllabic_cnt = 0
            notes = notes.cpu().numpy()
            durs = durs.cpu().numpy()
            durs = [to_actual_dur[dur] for dur in durs]
            temp_aligned_notes = []
            # lyrics = lyrics[2:]
            # while index < eos_pos[cnt][-1]:
            tolerant_correct(alignments, sum(notes != 0))
            while index < len(lyrics):
                if alignments[index] == 0:
                    syllabic_cnt += 1
                else:
                    if syllabic_cnt > 0:
                        if alignments[index] > syllabic_cnt:
                            temp_aligned_notes.append({'pitch': [], 'dur': []})
                            convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos)
                            for i in range(1, syllabic_cnt):
                                convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos + i)
                            for i in range(alignments[index] - syllabic_cnt):
                                convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos + syllabic_cnt + i)
                        else:
                            if alignments[index] > 1:
                                temp_aligned_notes.append({'pitch': [], 'dur': []})
                                convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos)
                                for i in range(alignments[index] - 1):
                                    convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos + i + 1)
                            else:
                                temp_aligned_notes.append({'pitch': [], 'dur': []})
                                convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos)

                        syllabic_cnt = 0
                    else:
                        temp_aligned_notes.append({'pitch': [], 'dur': []})
                        for i in range(alignments[index]):
                            convert_and_add_new_note(temp_aligned_notes, notes, durs, temp_pos + i)

                temp_pos += alignments[index]
                index += 1
            cnt += 1
            temp_aligned_notes = {'pitch': [x['pitch'] for x in temp_aligned_notes],
                                  'dur': [x['dur'] for x in temp_aligned_notes]}
            src_aligned_notes.append(temp_aligned_notes)
    return src_aligned_notes, tgt_aligned_notes


if __name__ == '__main__':
    ids = [18, 19, 20, 21, 27, 28, 43, 44, 45, 46,
           92, 93, 94, 95, 102, 109, 125, 126, 146, 147, 151, 170, 209, 210,
           269, 270, 271, 275, 356, 357, 361, 362, 363, 364, 365,
           382, 406, 408, 409, 410, 411, 445, 446, 453, 455, 459,
           465, 466, 467, 486, 487, 592, 593, 594, 595, 596, 597, 599]

    ckpts = {'en-zh': 'Pretrain_xdae_multilingual_translation_decoder_length_en_zh_lr5e-4_m30_r0.5_4096_upf5_M0,1,2,3',
             'zh-en': 'Pretrain_xdae_multilingual_translation_decoder_length_zh_en_lr5e-4_m30_r0.5_4096_upf5_M0,1,2,3',}

    # target_dictionary = Dictionary.load(os.path.join(f'data/bin/ft_{src_lang}_{}', "dict.txt"))
    src_lang = 'zh'
    tgt_lang = 'en'
    # src_lang = 'en'
    # tgt_lang = 'zh'

    temp_ckpt = ckpts[f'{src_lang}-{tgt_lang}']
    with open(f'checkpoints/{temp_ckpt}/inference_results/{src_lang}.test.src', 'r', encoding='UTF-8') as f:
        source_lyrics = [x.strip().split(' ')[1:] for x in f.readlines()]
    with open(f'checkpoints/{temp_ckpt}/inference_results/{tgt_lang}.test.hyp.org', 'r', encoding='UTF-8') as f:
        pred_lyrics = [x.strip().split(' ')[1:] for x in f.readlines()]

    sample = []

    melody = load_melody(f'data/bin/ft_{src_lang}_{tgt_lang}/test.melody')
    source_alignments = load_align(f'data/bin/ft_{src_lang}_{tgt_lang}/test.{src_lang}-{tgt_lang}.alignment.{src_lang}')

    for index, _id in enumerate(ids):
        temp_sample = {'id': _id, 'notes': melody[_id]['notes'], 'durs': melody[_id]['durs'],
                       'pred_lyrics': pred_lyrics[_id], 'pred_alignments': [1 for _ in range(len(pred_lyrics[_id]))],
                       'src_lyrics': source_lyrics[_id], 'src_alignments': source_alignments[_id]}
        sample.append(temp_sample)

    _, pred_aligned_notes = get_aligned_notes(sample, log_src=False)
    _, pred_musicxml_images = get_musicxml_image(sample, log_src=False, out_fd=f'checkpoints/{temp_ckpt}/inference_musicxml')

    pred_lyrics = [pred_lyrics[i] for i in range(len(pred_lyrics)) if i in ids]

    pred_pool = Pool(8)
    pred_futures = []
    # src_pool = Pool(4)
    # src_futures = []

    for i in range(len(pred_musicxml_images)):
        pred_futures.append(pred_pool.apply_async(wav_process,
                                                  args=(pred_aligned_notes[i],
                                                        ''.join(pred_lyrics[i]) if tgt_lang == 'zh' else ' '.join(pred_lyrics[i]),
                                                        [1 for _ in range(len(pred_lyrics[i]))],
                                                        wav_infer_ins,
                                                        tgt_lang)))

    pred_pool.close()
    for i, future in enumerate(pred_futures):
        temp_wav = future.get()
        id = int(ids[i])
        if temp_wav is not None:
            save_wav(temp_wav, path=f'checkpoints/{temp_ckpt}/inference_results/pred_{id}.wav', sr=24000)
    pred_pool.join()

    os.system(f'ossutil -c /home/xiji.lcx/workspace/.ossutilconfig cp -u -r checkpoints/{temp_ckpt}/inference_musicxml oss://alitranx-public/xiji.lcx/SongTranslation/results/checkpoints/{temp_ckpt}/inference_musicxml')
    os.system(f'ossutil -c /home/xiji.lcx/workspace/.ossutilconfig cp -u -r checkpoints/{temp_ckpt}/inference_results oss://alitranx-public/xiji.lcx/SongTranslation/results/checkpoints/{temp_ckpt}/inference_results')