import shutil
import os
import platform
import subprocess
import logging
from tempfile import TemporaryDirectory
from pathlib import Path
from g2p_en import G2p
from librosa import midi_to_note
import numpy as np
from scipy.io import wavfile
from itertools import chain

LOGGER = logging.getLogger(__name__)


def find_musescore3():
    # # possible way to detect MuseScore... executable
    # for p in os.environ['PATH'].split(':'):
    #     c = glob.glob(os.path.join(p, 'MuseScore*'))
    #     if c:
    #         print(c)
    #         break

    result = shutil.which("musescore")

    if result is None:
        result = shutil.which("musescore3")

    if result is None:
        result = shutil.which("mscore")

    if result is None:
        result = shutil.which("mscore3")

    if platform.system() == "Linux":
        pass

    elif platform.system() == "Darwin":

        result = shutil.which("/Applications/MuseScore 3.app/Contents/MacOS/mscore")

    elif platform.system() == "Windows":
        pass

    return result


def render(xml_fh, dpi=360, out_fn=None):

    mscore_exec = find_musescore3()

    if not mscore_exec:
        return None

    # with NamedTemporaryFile(suffix='.musicxml') as xml_fh, \
    #      NamedTemporaryFile(suffix='.{}'.format(fmt)) as img_fh:
    with TemporaryDirectory() as tmpdir:
        cnt_fh = xml_fh.split('/')[-1].split('.')[0]
        img_fh = Path(tmpdir) / f"score{cnt_fh}.png"
        cmd = [
            mscore_exec,
            "-T",
            "10",
            "-r",
            "{}".format(dpi),
            "-o",
            os.fspath(img_fh),
            os.fspath(xml_fh),
        ]
        try:

            ps = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

            if ps.returncode != 0:
                LOGGER.error(
                    "Command {} failed with code {}; stdout: {}; stderr: {}".format(
                        cmd,
                        ps.returncode,
                        ps.stdout.decode("UTF-8"),
                        ps.stderr.decode("UTF-8"),
                    )
                )
                return None

        except FileNotFoundError as f:

            LOGGER.error('Executing "{}" returned  {}.'.format(" ".join(cmd), f))
            return None

        # LOGGER.error('Command "{}" returned with code {}; stdout: {}; stderr: {}'
        #              .format(' '.join(cmd), ps.returncode, ps.stdout.decode('UTF-8'),
        #                      ps.stderr.decode('UTF-8')))
        img_fh = (img_fh.parent / (img_fh.stem + "-1")).with_suffix(img_fh.suffix)
        if img_fh.is_file():
            if out_fn is not None:
                shutil.copy(img_fh, out_fn)
                return out_fn

        return None



def phonemes_to_units(phonemes):
    phone_dict = {'vowel': ['AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2',
                            'AO0', 'AO1', 'AO2', 'AW0', 'AW1', 'AW2', 'AY0', 'AY1', 'AY2',
                            'EH0', 'EH1', 'EH2', 'ER0', 'ER1', 'ER2', 'EY0', 'EY1', 'EY2',
                            'IH0', 'IH1', 'IH2', 'IY0', 'IY1', 'IY2', 'OW0', 'OW1', 'OW2',
                            'OY0', 'OY1', 'OY2', 'UH0', 'UH1', 'UH2', 'UW0', 'UW1', 'UW2', ],
                  'consonant': ['B', 'CH', 'D', 'DH', 'F', 'G', 'HH', 'JH', 'K', 'L', 'M',
                                'N', 'NG', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UW', 'V', 'W',
                                'Y', 'Z', 'ZH']}
    units = []
    vowel_pos = []
    phone_types = []

    for i in range(len(phonemes)):
        if phonemes[i] in phone_dict['vowel']:
            phone_types.append(1)
            vowel_pos.append(i)
        else:
            phone_types.append(0)
    for i in range(len(vowel_pos)):
        temp_start = 0
        if not i == 0:
            temp_start = vowel_pos[i] - 1 if not vowel_pos[i] - 1 == vowel_pos[i - 1] else vowel_pos[i]
        if i == len(vowel_pos) - 1:
            temp_end = len(phonemes)
        else:
            temp_end = vowel_pos[i + 1] - 1 if not vowel_pos[i + 1] - 1 == vowel_pos[i] else vowel_pos[i + 1]
        units.append(phonemes[temp_start:temp_end])
    return units

g2p_model = G2p()
def out_to_wav(lyrics, notes, infer_ins, lang='zh', velocity=70):
    en_zh_ph_dict = {"AA0": 'a', "AA1": 'a', "AA2": 'a', "EY0": 'ei', "EY1": 'ei', "EY2": 'ei',
                     "AE0": 'ai', "AE1": 'ai', "AE2": 'ai', "IH0": 'i', "IH1": 'i', "IH2": 'i',
                     "AH0": 'a', "AH1": 'a', "AH2": 'a', "AO0": 'ao', "AO1": 'ao', "AO2": 'ao',
                     "AW0": 'ao', "AW1": 'ao', "AW2": 'ao', "IY0": 'i', "IY1": 'i', "IY2": 'i',
                     "AY0": 'ai', "AY1": 'ai', "AY2": 'ai', "OW0": 'ou', "OW1": 'ou', "OW2": 'ou',
                     "OY0": 'o', "OY1": 'o', "OY2": 'o', "UH0": 'u', "UH1": 'u', "UH2": 'u',
                     "UW0": 'u', "UW1": 'u', "UW2": 'u',
                     "EH0": 'ai', "EH1": 'ai', "EH2": 'ai', "ER0": 'er', "ER1": 'er', "ER2": 'er',
                     "B": 'b', "CH": 'ch', "D": 'd', "DH": 'z', "F": 'f', "G": 'g', "HH": 'h',
                     "JH": 'j', "K": 'k', "L": 'l', "M": 'm', "N": 'n', "NG": 'ng', "P": 'p', "R": 'r',
                     "S": 's', "SH": 'sh', "T": 't', "TH": 'sh', "V": 'u', "W": 'u', "Y": 'i', "Z": 'z', "ZH": 'r'}
    vowels = ['a', 'o', 'e', 'i', 'u', 'v', 'ai', 'ao', 'ou', 'er', 'uei']
    base_dur = 60 / velocity

    if lang == 'en':
        base_dur = 1.0
        combined_lyrics = []
        temp_word = ''
        lyrics = lyrics.split(' ')
        for word in lyrics:
            temp_word += word.replace('@@', '')
            if '@@' in word:
                continue
            combined_lyrics.append(temp_word)
            temp_word = ''

        if temp_word != '':
            combined_lyrics.append(temp_word)



        lyrics_phoneme = []
        for word in combined_lyrics:

            temp_phoneme = [ph for ph in g2p_model(word)]

            temp_phoneme = phonemes_to_units(temp_phoneme)
            temp_phoneme = [[en_zh_ph_dict.get(ph, '') for ph in phs] for phs in temp_phoneme] # [[L AW0 V], [B OY0]]
            temp_phoneme = [[ph for ph in phs if ph != ''] for phs in temp_phoneme]

            for i in range(len(temp_phoneme)):
                new_temp_phoneme = []
                for j in range(len(temp_phoneme[i])):
                    if temp_phoneme[i][j] == 'ng':
                        if j == 0:
                            continue
                        if temp_phoneme[i][j - 1] in ['a', 'e', 'i', 'o']:
                            new_temp_phoneme[-1] = new_temp_phoneme[-1] + temp_phoneme[i][j]
                            continue
                        else:
                            continue
                    new_temp_phoneme.append(temp_phoneme[i][j])
                temp_phoneme[i] = new_temp_phoneme
            lyrics_phoneme.append(temp_phoneme) # [[[L AW0 V], [B OY0]], [[H AW0],[L OW0]]
        # split by vowels
        # give out alignment ny orders
        # deal with slur
        slur = []
        new_pitches = []
        new_lyrics_phoneme = []
        new_durs = []
        for index, temp_phoneme in enumerate(lyrics_phoneme):
            note_pos = 0
            ph_pos = 0
            while ph_pos < len(temp_phoneme) and note_pos < len(notes['pitch'][index]):
                if note_pos == len(notes['pitch'][index]) - 1:
                    for i in range(ph_pos, len(temp_phoneme)):
                        new_pitches.extend([notes['pitch'][index][note_pos] for _ in range(len(temp_phoneme[i]))])
                        new_durs.extend([notes['dur'][index][note_pos] for _ in range(len(temp_phoneme[i]))])
                        new_lyrics_phoneme.extend(temp_phoneme[i])
                        slur.extend(['0' for _ in range(len(temp_phoneme[i]))])
                    break
                if ph_pos == len(temp_phoneme) - 1:
                    vowel_pos = -1
                    for i in range(len(temp_phoneme[ph_pos])):
                        if temp_phoneme[ph_pos][i] in vowels:
                            vowel_pos = i
                            break

                    new_pitches.extend([notes['pitch'][index][note_pos] for _ in range(vowel_pos)] + notes['pitch'][index][note_pos:] + [notes['pitch'][index][-1] for _ in range(len(temp_phoneme[ph_pos]) - vowel_pos - 1)])
                    new_durs.extend([notes['dur'][index][note_pos] for _ in range(vowel_pos)] + notes['dur'][index][note_pos:] + [notes['dur'][index][-1] for _ in range(len(temp_phoneme[ph_pos]) - vowel_pos - 1)])
                    new_lyrics_phoneme.extend(temp_phoneme[ph_pos][:vowel_pos] + [temp_phoneme[ph_pos][vowel_pos]] * (len(notes['pitch'][index]) - note_pos) + temp_phoneme[ph_pos][vowel_pos + 1:])
                    slur.extend(['0' for _ in range(vowel_pos + 1)] + ['1' for _ in range(len(notes['pitch'][index]) - note_pos - 1)] + ['0' for _ in range(len(temp_phoneme[ph_pos]) - vowel_pos - 1)])
                    break

                new_pitches.extend([notes['pitch'][index][note_pos] for _ in range(len(temp_phoneme[ph_pos]))])
                new_durs.extend([notes['dur'][index][note_pos] for _ in range(len(temp_phoneme[ph_pos]))])
                new_lyrics_phoneme.extend(temp_phoneme[ph_pos])
                slur.extend(['0' for _ in range(len(temp_phoneme[ph_pos]))])
                ph_pos += 1
                note_pos += 1

        slur = ' '.join(slur)
        new_durs = ' '.join([str(x * base_dur) for x in new_durs])

        if any([x >= 72 for x in new_pitches]):
            temp_gap = max([x - 72 for x in new_pitches])
            new_pitches = [x - temp_gap for x in new_pitches]

        new_pitches = ' '.join([midi_to_note(x) for x in new_pitches])
        new_lyrics_phoneme = ' '.join(new_lyrics_phoneme)

    else:
        if any([x >= 72 for x in list(chain(*(notes['pitch'])))]):
            temp_gap = max([x - 72 for x in list(chain(*(notes['pitch'])))])
            notes['pitch'] = [[_x - temp_gap for _x in x] for x in notes['pitch']]
        notes['pitch'] = [' '.join([midi_to_note(_x) for _x in x]) for x in notes['pitch']]
        notes['pitch'] = ' | '.join(notes['pitch'])

        notes['dur'] = [' '.join([str(_x * base_dur) for _x in x]) for x in notes['dur']]
        notes['dur'] = ' | '.join(notes['dur'])

    if lang == 'en':
        inp = {'text': ' '.join(lyrics),
               'ph_seq': new_lyrics_phoneme,
               'note_seq': new_pitches,
               'note_dur_seq': new_durs,
               'is_slur_seq': slur,
               'input_type': 'phoneme'}
    else:
        inp = {
            'text': lyrics,
            'notes': notes['pitch'],
            'notes_duration': notes['dur'],
            'input_type': 'word'
        }
    out_wav = infer_ins.infer_once(inp)

    return out_wav

def save_wav(wav, path, sr, norm=False):
    if norm:
        wav = wav / np.abs(wav).max()
    wav *= 32767
    # proposed by @dsmiller
    wavfile.write(path, sr, wav.astype(np.int16))


if __name__ == '__main__':
    out_to_wav(['I', 'love', 'you', 'per@@', 'sua@@', 'sion'], {'pitch': [[65], [68], [70], [72, 73]], 'dur': [[0.4], [0.5], [0.6], [0.25, 0.26]]}, None, lang='en')

