from fairseq.tasks.multilingual_translation_with_melody import load_align
from fairseq.scoring.alignment_distribution import AlignmentDistributionScorer, AlignmentDistributionScorerConfig
import sys
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt


if __name__ == '__main__':
    src_lang = 'en'
    tgt_lang = 'zh'
    ckpts = ['FT_GP_eps0.05_en-zh_lr5e-6_m800_mtoken2048_upf1_M0,1',
             'FT_GP_eps0.05_en-zh_only_bt_lr1e-5_m100_mtoken2048_upf1_M0',
             'FT_GP_eps0.05_en-zh_backtrans_lr1e-5_m100_mtoken2048_upf1_M0',
             'FT_GP_eps0.05_en-zh_backtrans_lr5e-5_m150_mtoken2048_upf1_M0_old',
             'FT_bsl_eps_en-zh_only_bt_lr1e-5_m100_mtoken2048_upf1_M0',
             'FT_bsl_eps_en-zh_backtrans_lr5e-5_m150_mtoken2048_upf1_M0_old',
             'FT_bsl_eps_en-zh_backtrans_lr1e-5_m100_mtoken2048_upf1_M0',
             'FT_bsl_eps_en-zh_lr5e-6_m800_mtoken2048_upf1_M0,1',
             'Pretrain_xdae_multilingual_translation_decoder_length_en_zh_lr5e-4_m30_r0.5_4096_upf5_M0,1,2,3',
             'NP_FT_GP_eps0.05_en-zh_backtrans_lr5e-5_m150_mtoken2048_upf1_M0']


    scorer = AlignmentDistributionScorer(AlignmentDistributionScorerConfig())

    for temp_ckpt in ckpts:
        scorer.empty()
        target_alignments = load_align(f'data/bin/ft_{src_lang}_{tgt_lang}/test.{src_lang}-{tgt_lang}.alignment.{tgt_lang}')
        if os.path.exists(f'checkpoints/{temp_ckpt}/inference_results/{tgt_lang}.test.alignment'):
            pred_alignments = load_align(f'checkpoints/{temp_ckpt}/inference_results/{tgt_lang}.test.alignment')
        else:
            with open(f'checkpoints/{temp_ckpt}/inference_results/{tgt_lang}.test.hyp.org') as f:
                pred_lyrics = f.readlines()
            pred_alignments = [[1 for _ in range(len(temp_lyrics.strip().split(' ')) - 1)] for temp_lyrics in pred_lyrics]
        assert len(target_alignments) == len(pred_alignments)

        for pred_alignment, target_alignment in zip(pred_alignments, target_alignments):
            scorer.add(np.array(target_alignment), np.array(pred_alignment))

        print(temp_ckpt + ':')
        print(scorer.result_string())

        temp_hist = scorer.result_overlap_histogram(f'checkpoints/{temp_ckpt}/inference_results/align_hist.pdf')
        temp_hist = Image.fromarray(temp_hist)
        temp_hist.save(f'checkpoints/{temp_ckpt}/inference_results/align_hist.png')
        os.system(f'ossutil -c /home/xiji.lcx/workspace/.ossutilconfig cp -u checkpoints/{temp_ckpt}/inference_results/align_hist.pdf oss://alitranx-public/xiji.lcx/SongTranslation/results/hists/{temp_ckpt}.pdf')

    src_lang = 'zh'
    tgt_lang = 'en'
    ckpts = ['FT_GP_eps0.05_zh-en_lr5e-6_m800_mtoken4096_upf1_M0,1',
             'FT_GP_eps0.05_zh-en_only_bt_lr1e-5_m50_mtoken4096_upf1_M0',
             'FT_GP_eps0.05_zh-en_backtrans_lr1e-5_m50_mtoken4096_upf1_M0',
             'FT_GP_eps0.05_zh-en_backtrans_lr5e-5_m150_mtoken4096_upf1_M0_old',
             'FT_bsl_eps_zh-en_lr5e-6_m800_mtoken4096_upf5_M0,1',
             'FT_bsl_eps_zh-en_only_bt_lr5e-5_m100_mtoken4096_upf1_M0',
             'FT_bsl_eps_zh-en_backtrans_lr1e-5_m100_mtoken4096_upf1_M0_old',
             'Pretrain_xdae_multilingual_translation_decoder_length_zh_en_lr5e-4_m30_r0.5_4096_upf5_M0,1,2,3',
             'NP_FT_GP_eps0.05_zh-en_backtrans_lr5e-5_m150_mtoken4096_upf1_M0']

    for temp_ckpt in ckpts:
        scorer.empty()
        target_alignments = load_align(f'data/bin/ft_{src_lang}_{tgt_lang}/test.{src_lang}-{tgt_lang}.alignment.{tgt_lang}')
        if os.path.exists(f'checkpoints/{temp_ckpt}/inference_results/{tgt_lang}.test.alignment'):
            pred_alignments = load_align(f'checkpoints/{temp_ckpt}/inference_results/{tgt_lang}.test.alignment')
        else:
            with open(f'checkpoints/{temp_ckpt}/inference_results/{tgt_lang}.test.hyp.org') as f:
                pred_lyrics = f.readlines()
            pred_alignments = [[1 for _ in range(len(temp_lyrics.strip().split(' ')) - 1)] for temp_lyrics in pred_lyrics]
        assert len(target_alignments) == len(pred_alignments)

        for pred_alignment, target_alignment in zip(pred_alignments, target_alignments):
            scorer.add(np.array(target_alignment), np.array(pred_alignment))

        print(temp_ckpt + ':')
        print(scorer.result_string())

        temp_hist = scorer.result_overlap_histogram(f'checkpoints/{temp_ckpt}/inference_results/align_hist.pdf')
        temp_hist = Image.fromarray(temp_hist)
        temp_hist.save(f'checkpoints/{temp_ckpt}/inference_results/align_hist.png')