import torch
from torchtext import data
import numpy as np
from torch.autograd import Variable
import copy
import random
from tqdm import tqdm
import time

def nopeak_mask(size, opt):
    np_mask = np.triu(np.ones((1, size, size)),
                      k=1).astype('uint8')
    np_mask = Variable(torch.from_numpy(np_mask) == 0)
    if opt.device == 0:
        np_mask = np_mask.cuda()
    return np_mask


def create_masks(src, trg, opt):

    src_mask = (src != opt.src_pad).unsqueeze(-2)

    if trg is not None:
        trg_mask = (trg != opt.trg_pad).unsqueeze(-2)
        size = trg.size(1)  # get seq_len for matrix
        np_mask = nopeak_mask(size, opt)
        if trg.is_cuda:
            np_mask.cuda()
        trg_mask = trg_mask & np_mask

    else:
        trg_mask = None
    return src_mask, trg_mask


def create_masks2(src, opt):

    src_mask = (src != opt.src_pad).unsqueeze(-2)
    return src_mask

# patch on Torchtext's batching process that makes it more efficient
# from http://nlp.seas.harvard.edu/2018/04/03/attention.html#position-wise-feed-forward-networks


class MyIterator(data.Iterator):
    def __init__(self, dataset, batch_size, sort_key=None, device=None,
                 batch_size_fn=None, train=True,
                 repeat=False, shuffle=None, sort=None,
                 sort_within_batch=None, augment=False, change_possibility=[0.5, 0.6, 0.7, 0.8, 0.9, 1]):
        super().__init__(dataset, batch_size, sort_key, device,
                         batch_size_fn, train,
                         repeat, shuffle, sort,
                         sort_within_batch)

        self.augment = augment
        self.change_possibility = change_possibility
        print("start copy...")
        start = time.time()
        self.ori_examples = copy.deepcopy(self.dataset.examples)
        print("end copy..., cost total:{}s".format(time.time()-start))

        with open("./data/voc/yunmus.txt", "r", encoding="utf-8") as f:
            yunmu = f.readlines()
        self.yunmus = [a.strip() for a in yunmu]

    def data(self):
        """重载data.Iterator的data方法，增加扩容代码

        Return the examples in the dataset in order, sorted, or shuffled."""

        if self.augment:
            print("augmenting data...")
            self.dataset.examples = []
            for ex in tqdm(self.ori_examples):
                for p in self.change_possibility:
                    new_ex = copy.deepcopy(ex)
                    for i, char in enumerate(ex.src):
                        r=random.random()
                        if r < p and char in self.yunmus:
                            new_ex.src[i] = char[:-1]+"0"
                    self.dataset.examples.append(new_ex)
            # print("data len:{}".format(len(self.dataset.examples)))
            # print("src:{}\ntrg:{}".format(type(ex.src),type(ex.trg)))

        if self.sort:
            xs = sorted(self.dataset, key=self.sort_key)
        elif self.shuffle:
            xs = [self.dataset[i]
                  for i in self.random_shuffler(range(len(self.dataset)))]
        else:
            xs = self.dataset
        return xs

    def create_batches(self):
        if self.train:
            def pool(d, random_shuffler):
                for p in data.batch(d, self.batch_size * 100):
                    p_batch = data.batch(
                        sorted(p, key=self.sort_key),
                        self.batch_size, self.batch_size_fn)
                    for b in random_shuffler(list(p_batch)):
                        yield b
            self.batches = pool(self.data(), self.random_shuffler)

        else:
            self.batches = []
            for b in data.batch(self.data(), self.batch_size,
                                self.batch_size_fn):
                self.batches.append(sorted(b, key=self.sort_key))


global max_src_in_batch, max_tgt_in_batch


def batch_size_fn(new, count, sofar):
    "Keep augmenting batch and calculate total number of tokens + padding."
    global max_src_in_batch, max_tgt_in_batch
    if count == 1:
        max_src_in_batch = 0
        max_tgt_in_batch = 0
    max_src_in_batch = max(max_src_in_batch,  len(new.src))
    max_tgt_in_batch = max(max_tgt_in_batch,  len(new.trg) + 2)
    src_elements = count * max_src_in_batch
    tgt_elements = count * max_tgt_in_batch
    return max(src_elements, tgt_elements)
