from sklearn.feature_extraction.text import CountVectorizer
from tqdm import tqdm

def counting_ngram(bpe_data_file):
    ngrams = [[], [], [], [], []]
    with open(bpe_data_file, 'r', encoding='utf-8') as f:
        data = f.readlines()
        for i, x in tqdm(enumerate(data)):
            data[i] = x.strip().replace('\n', ' ')
    vectorizer = CountVectorizer(token_pattern=r"(?u)\b[\w\'@\’]+" if bpe_data_file.endswith('zh') else r'(?u)\b\w[\w@\'\’]+', ngram_range=(2, 3), min_df=5)
    _ = vectorizer.fit_transform(data)
    X = vectorizer.get_feature_names()
    for vocab in tqdm(X):
        ngrams[len(vocab.split(' ')) - 1].append(vocab)
    return ngrams


def get_ngram_IOU(ngrams_1, ngrams_2, mono_ngram):
    ratios = []
    for temp_ngram1, temp_ngram2, temp_mono in zip(ngrams_1, ngrams_2, mono_ngram):
        if temp_ngram1 == [] or temp_ngram2 == [] or temp_mono == []:
            continue
        temp_ngram1 = set(temp_ngram1)
        temp_ngram2 = set(temp_ngram2)
        temp_mono = set(temp_mono)
        temp_ngram2 = temp_ngram2.union(temp_mono)
        temp_ratio = len(temp_ngram1.intersection(temp_ngram2)) / len(temp_ngram1)
        ratios.append(temp_ratio)

    return ratios

if __name__ == '__main__':

    mono_ngram = counting_ngram('data/bpe/mono/mono.bpe.en')
    # mono_ngram = [[], [], [], []]

    ngrams_1 = counting_ngram('data/bpe/pair/zh_en/pair.test.bpe.en')
    ngrams_2 = counting_ngram('data/bpe/pair/zh_en/pair.train.bpe.en')
    overlapped_ratio = get_ngram_IOU(ngrams_1, ngrams_2, mono_ngram)
    print('zh-en train/test ngram overlapped ratio: ', overlapped_ratio)

    ngrams_1 = counting_ngram('data/bpe/pair/en_zh/pair.test.bpe.en')
    ngrams_2 = counting_ngram('data/bpe/pair/en_zh/pair.train.bpe.en')
    overlapped_ratio = get_ngram_IOU(ngrams_1, ngrams_2, mono_ngram)
    print('en-zh train/test ngram overlapped ratio: ', overlapped_ratio)

    ngrams_1 = counting_ngram('data/bpe/ft/word.valid.bpe.en')
    ngrams_2 = counting_ngram('data/bpe/pair/zh_en/pair.train.bpe.en')
    overlapped_ratio = get_ngram_IOU(ngrams_1, ngrams_2, mono_ngram)
    print('newly annotated test ngram overlapped ratio: ', overlapped_ratio)
