# -*- coding: utf-8 -*-
#!/usr/bin/python2
"""
Before running this code, make sure that you've downloaded Leipzig Chinese Corpus 
(http://corpora2.informatik.uni-leipzig.de/downloads/zho_news_2007-2009_1M-text.tar.gz)
Extract and copy the `zho_news_2007-2009_1M-sentences.txt` to `data/` folder.

This code should generate a file which looks like this:
2[Tab]zhegeyemianxianzaiyijingzuofei...。[Tab]这__个_页_面___现___在__已_经___作__废__...。

In each line, the id, pinyin, and a chinese sentence are separated by a tab.
Note that _ means blanks.

Created in Aug. 2017, kyubyong. kbpark.linguist@gmail.com
"""
from __future__ import print_function
import codecs
import os
from threading import Semaphore
import regex  # pip install regex
from xpinyin import Pinyin  # pip install xpinyin
import traceback
from pypinyin.style._utils import get_initials, get_finals
from pypinyin import Style, pinyin
from pypinyin.core import lazy_pinyin
import itertools


def wenzi2pinyin(text):
    pinyin_list = lazy_pinyin(text, style=Style.TONE3)
    # print(pinyin_list)
    tones_list = [int(py[-1]) if py[-1].isdigit()
                  else 0 for py in pinyin_list]
    pinyin_list = lazy_pinyin(text, style=Style.NORMAL)

    return pinyin_list, tones_list


def split_initials_finals(pinyin, tone, char):
    strict = True

    if not ('\u4e00' <= char <= '\u9fff'):
        # return [a for a in get_initials(pinyin, strict)+get_finals(pinyin, strict)]
        return [a for a in pinyin]
    else:
        pinyin = pinyin.replace("v", "ü") 
        if get_initials(pinyin, strict) != "":
            return [get_initials(pinyin, strict), get_finals(pinyin, strict)+str(tone)]
        else:
            return [get_finals(pinyin, strict)+str(tone)]


def align(sent):
    '''
    Args:
      sent: A string. A sentence.

    Returns:
      A tuple of pinyin and chinese sentence.
    '''
    pnyns = pinyin.get_pinyin(sent, " ").split()

    hanzis = []
    x = 0
    for i in range(len(pnyns)):
        if i+x < len(sent.replace(" ", "")):
            char = sent[i+x]
            p = pnyns[i]
            if '\u4e00' <= char <= '\u9fa5':
                hanzis.extend([char] + ["_"] * (len(p) - 1))
            else:
                while not '\u4e00' <= char <= '\u9fa5':
                    hanzis.extend([char])
                    x = x+1
                    if x+i >= len(sent.replace(" ", "")):
                        break
                    char = sent[i+x]
                x = x-1

    # for char, p in zip(sent.replace(" ", ""), pnyns):
    #     if '\u4e00' <= char <= '\u9fa5':
    #         hanzis.extend([char] + ["_"] * (len(p) - 1))

    pnyns = "".join(pnyns)
    hanzis = "".join(hanzis)

    assert len(pnyns) == len(
        hanzis), "The hanzis and the pinyins must be the same in length."
    return pnyns, hanzis


def align2(sent):
    '''
    Args:
      sent: A string. A sentence.

    Returns:
      A tuple of pinyin and chinese sentence.
    '''
    # pnyns = pinyin.get_pinyin(sent, " ").split()
    pinyins, tones = wenzi2pinyin(sent)
    pnyns = []
    i = 0
    for pinyin, tone in zip(pinyins, tones):
        if '\u4e00' <= sent[i] <= '\u9fa5':
            pnyns.append(split_initials_finals(pinyin, tone, sent[i]))
            i += 1
        else:
            pnyns.append(split_initials_finals(
                pinyin, tone, sent[i:i+len(pinyin)]))
            i += len(pinyin)
    hanzis = []
    x = 0
    for i in range(len(pnyns)):
        if i+x < len(sent.replace(" ", "")):
            char = sent[i+x]
            p = pnyns[i]
            if '\u4e00' <= char <= '\u9fa5':
                hanzis.extend([char] + ["_"] * (len(p) - 1))
            else:
                for q in p:
                    hanzis.append(q)
                    x += len(q)
                x -= 1
                # word=""

                # while not '\u4e00' <= char <= '\u9fa5':
                #     word=word+char
                #     x=x+1
                #     if x+i>=len(sent.replace(" ", "")):
                #         break
                #     char=sent[i+x]

                # hanzis.extend([word])
                # x=x-1

    # for char, p in zip(sent.replace(" ", ""), pnyns):
    #     if '\u4e00' <= char <= '\u9fa5':
    #         hanzis.extend([char] + ["_"] * (len(p) - 1))

    pnyns = " ".join(list(itertools.chain.from_iterable(pnyns)))
    hanzis = " ".join(hanzis)

    if len(pnyns.split(" ")) != len(hanzis.split(" ")):
        print(sent)
        print(pnyns)
        print(hanzis)

    # assert len(pnyns.split(" ")) == len(hanzis.split(" ")), "The hanzis and the pinyins must be the same in length."
    return pnyns, hanzis


def clean(text):
    # if regex.search("[A-Za-z0-9]", text) is not None: # For simplicity, roman alphanumeric characters are removed.
    # # if regex.search("[A-Za-z0-9]", text) is not None: # For simplicity, roman alphanumeric characters are removed.
    #     return ""
    text = regex.sub(u"[^ \p{\u4e00-\u9fa5}。，！？]", "", text)
    text_new = ""
    flag = 0
    for char in text:
        if char != "。" and char != "，" and char != "！" and char != "？":
            flag = 0
            text_new = text_new+char
        elif not flag:
            flag = 1
            text_new = text_new+char
    # while "，，" in text:
    #     text=text.replace("，，","，")
    if len(text_new) < 10:
        return ""
    return text_new


def build_corpus(src_file, pinyin_file, hanzi_file):
    pinyin_list = []
    hanzi_list = []
    # with codecs.open("data/zho_news_2007-2009_1M-sentences.txt", 'r', 'utf-8') as fin:
    with codecs.open(src_file, 'r', 'utf-8') as fin:
        i = 1
        while 1:
            line = fin.readline()
            if not line:
                break

            try:
                # idx, sent = line.strip().split("\t")
                # if idx == "234":
                #     print(sent)
                # sent = clean(sent)
                # if len(sent) > 0:
                #     pnyns, hanzis = align(sent)
                #     fout.write(u"{}\t{}\t{}\n".format(idx, pnyns, hanzis))
                sent = line.strip()
                sent = sent.replace(" ", "")
                # sent = clean(sent)
                if len(sent) > 0:
                    pnyns, hanzis = align2(sent)
                    pinyin_list.append(pnyns)
                    hanzi_list.append(hanzis)
            except:
                traceback.print_exc()
                continue  # it's okay as we have a pretty big corpus!

            if i % 10000 == 0:
                print(i, )
            i += 1
    with open(pinyin_file, 'w', encoding='utf-8') as f:
        f.write("\n".join(pinyin_list))
    with open(hanzi_file, 'w', encoding='utf-8') as f:
        f.write("\n".join(hanzi_list))


if __name__ == "__main__":
    # with open("./data/voc/yunmu.txt","r",encoding="utf-8") as f:
    #     yunmus=f.readlines()
    #     yunmus=[a.strip() for a in yunmus]
    # ori_dir="./data/train_file/ori_file_split_random_wo_tones"
    # hanzi_dir="./data/train_file/hanzi_split_random_wo_tones"
    # pinyin_dir="./data/train_file/pinyin_split_random_wo_tones"
    # for file in os.listdir(ori_dir):
    #     build_corpus(os.path.join(ori_dir,file),
    #                 os.path.join(pinyin_dir,file), os.path.join(hanzi_dir,file))
        # print("Done")
    build_corpus("./data/dev/dev_hanzi.txt",
                    "./data/dev/dev_pinyin_split.txt", "./data/dev/dev_hanzi_split.txt")
