from tqdm import tqdm
import re
import os


def get_and_filter(data_dir, dst_dir, direction):
    source_lang, target_lang = direction.split('_')
    if not os.path.exists(data_dir + f'/{target_lang}.all.src'):
        return

    with open(data_dir + f'/{target_lang}.all.src', 'r') as f:
        source_lyrics = f.readlines()
        source_lyrics = [' '.join([x for x in sent.split(' ') if x != ''][1:]) for sent in source_lyrics]
    with open(data_dir + f'/{source_lang}.all.hyp.org', 'r') as f:
        target_lyrics = f.readlines()
        target_lyrics = [[re.sub(r'[\.,!?，。？！#$%^&*]', '', x) for x in sent.split(' ')[1:]] for sent in target_lyrics]
        target_lyrics = [[x for x in sent if x != ''] for sent in target_lyrics]
    with open(dst_dir + f'/all.alignment.bpe.{target_lang}', 'r') as f:
        alignments = f.readlines()
        alignments = [[int(x) for x in line.strip().replace('\n', '').split(' ')] for line in alignments]
    with open(dst_dir + f'/all.melody', 'r') as f:
        melody = f.readlines()
        _melody = [x.replace('\n', '').split('|')[0].split(' ') for x in melody]
        _melody = [[x for x in line if x != ''] for line in _melody]

    filtered_source_lyrics = []
    filtered_target_lyrics = []
    filtered_melody = []
    filtered_alignments = []

    for index, (temp_tgt_lyric, temp_notes, alignment, temp_melody) in tqdm(enumerate(zip(target_lyrics, _melody, alignments, _melody))):
        if not len(temp_tgt_lyric) == sum(alignment) or not sum(alignment) == len(temp_melody):
            continue
        else:
            filtered_source_lyrics.append(source_lyrics[index])
            filtered_target_lyrics.append(' '.join(target_lyrics[index]))
            filtered_melody.append(melody[index])
            filtered_alignments.append(alignment)

    print(f'total lines for {source_lang}_{target_lang} after filtering:', len(filtered_source_lyrics))

    with open(dst_dir + f'/train.filtered.bpe.{target_lang}', 'w') as f:
        for temp_line in filtered_source_lyrics:
            f.write(temp_line)

    with open(dst_dir + f'/train.filtered.bpe.{source_lang}', 'w') as f:
        for temp_line in filtered_target_lyrics:
            f.write(temp_line)

    with open(dst_dir + f'/train.filtered.melody', 'w') as f:
        for temp_line in filtered_melody:
            f.write(temp_line)

    with open(dst_dir + f'/train.filtered.alignment.{target_lang}', 'w') as f:
        for temp_line in filtered_alignments:
            f.write(' '.join([str(x) for x in temp_line]) + '\n')

    with open(dst_dir + f'/train.filtered.alignment.{source_lang}', 'w') as f:
        for temp_line in filtered_target_lyrics:
            temp_alignment = []
            temp_accum = 0
            for index, temp_char in enumerate(temp_line.split(' ')):
                if temp_char == '\n' or temp_char == '':
                    continue
                temp_accum += 1
                if '@@' in temp_char and index < len(temp_line.split(' ')) - 2:
                    temp_alignment.append('0')
                else:
                    temp_alignment.append(str(temp_accum))
                    temp_accum = 0
            f.write(' '.join(temp_alignment) + '\n')



if __name__ == '__main__':
    # get_and_filter('checkpoints/Pretrain_xdae_multilingual_translation_decoder_length_zh_en_lr5e-4_m30_r0.5_4096_upf5_M0,1,2,3,4,5,6,7/inference_results', 'data/bpe/bt/en_zh', 'en_zh')
    get_and_filter('checkpoints/Pretrain_xdae_multilingual_translation_decoder_length_en_zh_lr5e-4_m30_r0.5_4096_upf5_M0,1,2,3/inference_results', 'data/bpe/bt/zh_en', 'zh_en')