from http.client import NO_CONTENT
import pandas as pd
import torchtext
from torchtext import data
from Tokenize import tokenize
from Batch import MyIterator, batch_size_fn
import os
import dill as pickle
from pypinyin import Style, pinyin
from pypinyin.core import lazy_pinyin
from tqdm import tqdm 
from torchtext import vocab
from torchtext.vocab import build_vocab_from_iterator
import io



def wenzi2pinyin(text):
    pinyin_list = lazy_pinyin(text, style=Style.TONE3)
    # print(pinyin_list)
    tones_list = [int(py[-1]) if py[-1].isdigit()
                  else 0 for py in pinyin_list]
    pinyin_list = lazy_pinyin(text, style=Style.NORMAL)
    return "".join(pinyin_list)

def yield_tokens(file_path):
    with io.open(file_path, encoding = 'utf-8') as f:
        for line in f:             
            yield line.strip().split()

def read_data(opt):
    
    if opt.src_data is not None:
        # try:
            print("loading src_data")
            if os.path.isdir(opt.src_data):
                train_set=[]
                for file in os.listdir(opt.src_data):
                    train_set = train_set + open(os.path.join(opt.src_data,file)).read().strip().split('\n')
                opt.src_data=train_set
                opt.src_data=[x for x in tqdm(opt.src_data)]
            else:
                opt.src_data = open(opt.src_data).read().strip().split('\n')
                opt.src_data=[x for x in tqdm(opt.src_data)]
            # print(len(opt.src_data))
        # except:
        #     print("error: '" + opt.src_data + "' file not found")
        #     quit()
    
    if opt.trg_data is not None:
        try:
            print("loading trg_data")
            if os.path.isdir(opt.trg_data):
                train_set=[]
                for file in os.listdir(opt.trg_data):
                    train_set = train_set + open(os.path.join(opt.trg_data,file)).read().strip().split('\n')
                opt.trg_data=train_set
                opt.trg_data=[x for x in tqdm(opt.trg_data)]
            else:
                opt.trg_data = open(opt.trg_data).read().strip().split('\n')
                opt.trg_data=[x for x in tqdm(opt.trg_data)]
        except:
            print("error: '" + opt.trg_data + "' file not found")
            quit()
    print("len of src_data:{} ; len of trg_data:{}".format(len(opt.src_data),len(opt.trg_data)))

def my_tokenize(text):
    return list(text)

def my_tokenize2(text):
    return text.split(" ")

def create_fields(opt):
    
    spacy_langs = ['en', 'fr', 'de', 'es', 'pt', 'it', 'nl']
    # if opt.src_lang not in spacy_langs:
    #     print('invalid src language: ' + opt.src_lang + 'supported languages : ' + spacy_langs)  
    # if opt.trg_lang not in spacy_langs:
    #     print('invalid trg language: ' + opt.trg_lang + 'supported languages : ' + spacy_langs)
    
    print("loading spacy tokenizers...")
    
    # t_src = tokenize(opt.src_lang)
    # t_trg = tokenize(opt.trg_lang)

    # TRG = data.Field(lower=True, tokenize=t_trg.tokenizer, init_token='<sos>', eos_token='<eos>')
    # SRC = data.Field(lower=True, tokenize=t_src.tokenizer)

    # TRG = data.Field(tokenize=my_tokenize, init_token='<sos>', eos_token='<eos>')
    if opt.src_voc is None and opt.trg_voc is None:
        TRG = data.Field(tokenize=my_tokenize)
        SRC = data.Field(tokenize=my_tokenize)
    else:
        TRG = data.Field(tokenize=my_tokenize2)
        SRC = data.Field(tokenize=my_tokenize2)

    if opt.pkl_dir is not None:
        try:
            print("loading presaved fields...")
            SRC = pickle.load(open(f'{opt.pkl_dir}/SRC.pkl', 'rb'))
            TRG = pickle.load(open(f'{opt.pkl_dir}/TRG.pkl', 'rb'))
            # print(SRC.vocab.stoi)
        except:
            print("error opening SRC.pkl and TXT.pkl field files, please ensure they are in " + opt.load_weights + "/")
            quit()
        
    return(SRC, TRG)

def create_dataset(opt, SRC, TRG):

    print("creating dataset and iterator... ")

    raw_data = {'src' : [line for line in opt.src_data], 'trg': [line for line in opt.trg_data]}
    df = pd.DataFrame(raw_data, columns=["src", "trg"])

    print(df.sample(5))
    
    mask = (df['src'].str.count(' ') < opt.max_strlen) & (df['trg'].str.count(' ') < opt.max_strlen)
    # print(mask)
    # df = df.loc[mask]

    df.to_csv("translate_transformer_temp.csv", index=False)
    
    data_fields = [('src', SRC), ('trg', TRG)]
    train = data.TabularDataset('./translate_transformer_temp.csv', format='csv', fields=data_fields,skip_header=True)

    # train_iter = MyIterator(train, batch_size=opt.batchsize, device=opt.device,
    #                     repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
    #                     batch_size_fn=batch_size_fn, train=True, shuffle=True)
    # train_iter = MyIterator(train, batch_size=opt.batchsize, device=opt.device,
    #                     repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
    #                     batch_size_fn=None, train=True, shuffle=True)

    train_iter = MyIterator(train, batch_size=opt.batchsize, device=opt.device,
                        repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                        batch_size_fn=None, train=True, shuffle=True,augment=True,change_possibility=[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1])
    
    os.remove('translate_transformer_temp.csv')

    if opt.load_weights is None:
        if opt.src_voc is not None:
            vocab=build_vocab_from_iterator(yield_tokens(opt.src_voc))
            SRC.vocab=vocab
        else:
            SRC.build_vocab(train)
        # print(SRC.vocab.stoi)

        if opt.trg_voc is not None:
            vocab=build_vocab_from_iterator(yield_tokens(opt.trg_voc))
            TRG.vocab=vocab
        else:
            TRG.build_vocab(train)

        # print(TRG.vocab.stoi)
        if opt.checkpoint > 0:
            try:
                os.mkdir("weights")
            except:
                print("weights folder already exists, run program with -load_weights weights to load them")
                quit()
            pickle.dump(SRC, open('weights/SRC.pkl', 'wb'))
            pickle.dump(TRG, open('weights/TRG.pkl', 'wb'))

    opt.src_pad = SRC.vocab.stoi['<pad>']
    opt.trg_pad = TRG.vocab.stoi['<pad>']

    opt.train_len = get_len(train_iter)
    print("train len:{}".format(opt.train_len))

    return train_iter

def get_len(train):

    for i, b in enumerate(train):
        pass
    
    return i
