import itertools
import os
from tqdm import tqdm

from build_corpus import split_initials_finals, wenzi2pinyin
import random

def random_change_tones(tones):
    change_possibility=[0.5, 0.6, 0.7, 0.8, 0.9, 1]
    change_possibility=random.choice(change_possibility)
    random.seed(42)
    for i,x in enumerate(tones):
        if random.random() < change_possibility:
            tones[i]=0
    return tones

def convert_pinyin(file,hanzi_dir,pinyin_dir,new_file):
    print(file)
    with open(os.path.join(hanzi_dir,file),'r',encoding="utf-8") as f:
        contents=f.readlines()
    result=[]
    for line in tqdm(contents):
        sent = line.strip()
        sent = sent.replace(" ","")
        pinyins,tones=wenzi2pinyin(sent)
        tones=random_change_tones(tones)
        pnyns=[]
        i=0
        for pinyin,tone in zip(pinyins,tones):
            if '\u4e00' <= sent[i] <= '\u9fa5':
                pnyns.append(split_initials_finals(pinyin,tone,sent[i]))
                i+=1
            else:
                pnyns.append(split_initials_finals(pinyin,tone,sent[i:i+len(pinyin)]))
                i+=len(pinyin)
        pnyns = " ".join(list(itertools.chain.from_iterable(pnyns)))
        result.append(pnyns)
    with open(os.path.join(pinyin_dir,new_file),"w",encoding="utf-8") as f:
        f.write("\n".join(result))

if __name__=="__main__":
    hanzi_dir="./data/test_data/split_random_wo_tones/hanzi"
    pinyin_dir="./data/test_data/split_random_wo_tones/pinyin2"

# with open("./data/voc/yunmu.txt","r",encoding="utf-8") as f:
#         yunmus=f.readlines()
#         yunmus=[a.strip() for a in yunmus]
    convert_pinyin("dev_hanzi.txt","./data/dev","./data/dev","dev_pinyin_split.txt")
    # for file in os.listdir(hanzi_dir):
    #     convert_pinyin(file,hanzi_dir,pinyin_dir)