import argparse
import pickle
import os
import time
from Models import get_model, get_model_token_classification
import random
import itertools
from tqdm import tqdm
from translate_file2 import translate_sentence

from Process import create_fields

def translate(pylist,opt, model, SRC, TRG):

    if len("".join(pylist))>opt.max_len:
        a = translate(pylist[:int(len(pylist)/2)],opt, model, SRC, TRG)
        b = translate(pylist[int(len(pylist)/2):],opt, model, SRC, TRG)
        return a+b
    # sentences = opt.text.lower().split('.')
    # sentences=[a for a in sentences if len(a)>0]

    return translate_sentence("".join(pylist), model, opt, SRC, TRG)

def main():
    
    parser = argparse.ArgumentParser()
    parser.add_argument('-load_weights', required=True)
    parser.add_argument('-pkl_dir', required=True)
    parser.add_argument('-k', type=int, default=3)
    parser.add_argument('-max_len', type=int, default=80)
    parser.add_argument('-d_model', type=int, default=512)
    parser.add_argument('-n_layers', type=int, default=6)
    # parser.add_argument('-src_lang', required=True)
    # parser.add_argument('-trg_lang', required=True)
    parser.add_argument('-heads', type=int, default=8)
    parser.add_argument('-dropout', type=int, default=0.1)
    parser.add_argument('-no_cuda', action='store_true')
    parser.add_argument('-floyd', action='store_true')
    parser.add_argument('-test_dir', type=str, required=True)
    parser.add_argument('-result_dir', type=str, required=True)
    parser.add_argument('-src_voc')
    parser.add_argument('-trg_voc')

    opt = parser.parse_args()

    opt.device = 0 if opt.no_cuda is False else -1
 
    assert opt.k > 0
    assert opt.max_len > 10

    SRC, TRG = create_fields(opt)
    model = get_model_token_classification(opt, len(SRC.vocab), len(TRG.vocab))
    
    for file in tqdm(os.listdir(opt.test_dir)):
        # print("filename:{}".format(file))
        contents = pickle.load(open(os.path.join(opt.test_dir,file),"rb"))
        # contents=random.sample(contents,10)
        
        # contents=[list(itertools.chain.from_iterable(lines)) for lines in contents]
        # # print(contents)
        # contents=list(itertools.chain.from_iterable(contents))
        # print(contents)
        # print(contents)

        start=time.time()
        # translates = [translate("".join(i), opt,model, SRC, TRG) for i in contents if len(i)>0]
        translates = [translate(contents, opt,model, SRC, TRG)]
        # print("Average time: {}".format((time.time()-start)/len(contents)))

        with open(os.path.join(opt.result_dir,file[:-4]+".txt"),'w',encoding='utf-8') as f:
            f.write("\n".join(translates))

    # while True:
    #     opt.text =input("Enter a sentence to translate (type 'f' to load from file, or 'q' to quit):\n")
    #     if opt.text=="q":
    #         break
    #     if opt.text=='f':
    #         fpath =input("Enter a sentence to translate (type 'f' to load from file, or 'q' to quit):\n")
    #         try:
    #             opt.text = ' '.join(open(opt.text, encoding='utf-8').read().split('\n'))
    #         except:
    #             print("error opening or reading text file")
    #             continue
    #     phrase = translate(opt, model, SRC, TRG)
    #     print('> '+ phrase + '\n')

if __name__ == '__main__':
    main()

# data_path="./data/pkl/label_50_pkl"

# for file in os.listdir(data_path):
#     i = pickle.load(open(os.path.join(data_path,file),"rb"))
#     print(i)
#     print()