import argparse
import time
import torch
from Models import get_model, get_model_token_classification
from Process import *
import torch.nn.functional as F
from Optim import CosineWithRestarts
from Batch import create_masks
import pdb
import dill as pickle
import argparse
from Models import get_model
from Beam import beam_search
# from nltk.corpus import wordnet
from torch.autograd import Variable
import re
import time
import random
import distance

from cer_multi import ishan


def get_result(src, model, SRC, TRG, opt):
    src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2)
    output = model(src, src_mask)
    output = F.softmax(output, dim=-1)
    preds = torch.argmax(output, dim=-1)
    return ''.join([TRG.vocab.itos[tok] for tok in preds[0][:]]).replace("_", "")


def translate_sentence(sentence, model, opt, SRC, TRG):

    model.eval()
    indexed = []
    sentence = SRC.preprocess(sentence)
    for tok in sentence:
        if SRC.vocab.stoi[tok] != 0 or opt.floyd == True:
            indexed.append(SRC.vocab.stoi[tok])
        else:
            # indexed.append(get_synonym(tok, SRC))
            pass
    sentence = Variable(torch.LongTensor([indexed]))
    if opt.device == 0:
        sentence = sentence.cuda()

    sentence = get_result(sentence, model, SRC, TRG, opt)

    return sentence


def translate(i, opt, model, SRC, TRG):
    # sentences = opt.text.lower().split('.')
    sentence = i
    # sentences=[a for a in sentences if len(a)>0]
    translated = []

    translated.append(translate_sentence(sentence, model, opt, SRC, TRG))

    return (' '.join(translated)[:])


def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('-load_weights', required=True)
    parser.add_argument('-pkl_dir', required=True)
    parser.add_argument('-k', type=int, default=3)
    parser.add_argument('-max_len', type=int, default=80)
    parser.add_argument('-d_model', type=int, default=512)
    parser.add_argument('-n_layers', type=int, default=6)
    # parser.add_argument('-src_lang', required=True)
    # parser.add_argument('-trg_lang', required=True)
    parser.add_argument('-heads', type=int, default=8)
    parser.add_argument('-dropout', type=int, default=0.1)
    parser.add_argument('-no_cuda', action='store_true')
    parser.add_argument('-floyd', action='store_true')
    parser.add_argument("-dev_dir", type=str, required=True)
    parser.add_argument('-src_voc')
    parser.add_argument('-trg_voc')

    opt = parser.parse_args()

    opt.device = 0 if opt.no_cuda is False else -1

    assert opt.k > 0
    assert opt.max_len > 10

    SRC, TRG = create_fields(opt)

    i=1
    while i<=60:
        for model_name in os.listdir(opt.pkl_dir):
            if "token_classification_split_3_"+str(i)+"_" in model_name:
                print("model_name:{}".format(model_name))

                opt.load_weights=os.path.join(opt.pkl_dir,model_name)

                model = get_model_token_classification(
                    opt, len(SRC.vocab), len(TRG.vocab))

                contents = open(os.path.join(opt.dev_dir, "dev_pinyin_split.txt")
                                ).read().strip().split('\n')
                translates = [translate(i, opt, model, SRC, TRG)
                            for i in tqdm(contents)]

                # with open(os.path.join(opt.dev_dir,model_name),'w',encoding='utf-8') as f:
                #     f.write("\n".join(translates))

                gt = open(os.path.join(opt.dev_dir, "dev_hanzi.txt")
                        ).read().strip().split('\n')

                total_edit_distance, num_chars = 0, 0
                for pred, expected in tqdm(zip(translates, gt)):
                    pred = ishan(pred)
                    expected = ishan(expected)
                    edit_distance = distance.levenshtein(expected, pred)
                    total_edit_distance += edit_distance
                    num_chars += len(expected)

                print("Total CER: {}/{}={}\n".format(total_edit_distance,
                                                    num_chars,
                                                    round(float(total_edit_distance)/num_chars, 5)))
                break
        i=i+1


if __name__ == '__main__':
    main()
