import itertools
import os
from tqdm import tqdm

from build_corpus import split_initials_finals, wenzi2pinyin
import random

def random_change_tones(tones):
    options=[0,1,2,3,4]
    random.seed(42)
    for i,x in enumerate(tones):
        if random.randint(0,99) < 30:
            tones[i]=random.choice(options)
    return tones

hanzi_dir="./data/test_data/split_random_wo_tones/hanzi"
pinyin_dir="./data/test_data/split_random_wo_tones/pinyin"

with open("./data/voc/yunmu.txt","r",encoding="utf-8") as f:
        yunmus=f.readlines()
        yunmus=[a.strip() for a in yunmus]

for file in os.listdir(hanzi_dir):
    print(file)
    with open(os.path.join(hanzi_dir,file),'r',encoding="utf-8") as f:
        contents=f.readlines()
    result=[]
    for line in tqdm(contents):
        sent = line.strip()
        sent = sent.replace(" ","")
        pinyins,tones=wenzi2pinyin(sent)
        # tones=random_change_tones(tones)
        pnyns=[]
        i=0
        for pinyin,tone in zip(pinyins,tones):
            if '\u4e00' <= sent[i] <= '\u9fa5':
                pnyns.append(split_initials_finals(pinyin,tone,sent[i]))
                i+=1
            else:
                pnyns.append(split_initials_finals(pinyin,tone,sent[i:i+len(pinyin)]))
                i+=len(pinyin)
        pnyns = " ".join(list(itertools.chain.from_iterable(pnyns)))
        result.append(pnyns)
    with open(os.path.join(pinyin_dir,file),"w",encoding="utf-8") as f:
        f.write("\n".join(result))