# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
from dataclasses import dataclass

from fairseq.dataclass import FairseqDataclass
from fairseq.scoring import BaseScorer, register_scorer
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from PIL import Image
import gc


@dataclass
class AlignmentDistributionScorerConfig(FairseqDataclass):
    pass


@register_scorer("align_dist", dataclass=AlignmentDistributionScorerConfig)
class AlignmentDistributionScorer(BaseScorer):
    def __init__(self, args):
        super(AlignmentDistributionScorer, self).__init__(args)
        self.scores = []
        self.pred_hist = None
        self.ref_hist = None
        self.bins = 0

    def empty(self):
        self.ref = []
        self.pred = []
        self.scores = []
        self.pred_hist = None
        self.ref_hist = None
        self.bins = 0

    def add(self, ref, pred):
        temp_end = ref.nonzero()[0]
        if len(temp_end) > 0:
            ref = ref[: temp_end[-1] + 1]
            self.ref.append(ref)
        temp_end = pred.nonzero()[0]
        if len(temp_end) > 0:
            pred = pred[: temp_end[-1] + 1]

        def tolerant_correct(alignments, num_notes):
            _sum = sum(alignments)
            if _sum > num_notes:
                while _sum > num_notes:
                    flag = False
                    for i in range(len(alignments)):
                        if alignments[i] > 1:
                            flag = True

                        if alignments[i] > 1 and _sum > num_notes:
                            alignments[i] -= 1
                            _sum -= 1
                        elif _sum <= num_notes:
                            break
                    if not flag:
                        break

        tolerant_correct(pred, sum(ref))
        self.pred.append(pred)
        # if sum(ref) <= sum(pred):
        #     self.pred.append(pred)

    def score(self):
        self._prepare()

        self.scores = [
            min(r, p)
            for r, p in zip(self.ref_hist, self.pred_hist) if r != 0
        ]

        temp_weights = [index for index, r in enumerate(self.ref_hist) if r != 0]
        temp_ref_hist = [r for r in self.ref_hist if r != 0]
        return np.average(np.array(self.scores), weights=np.array(temp_weights)) / np.average(np.array(temp_ref_hist), weights=np.array(temp_weights))

    def result_string(self, order=4):
        return f"Alignment Distribution Distance: {self.score():.4f}"

    def result_pred_histogram(self):

        if self.pred_hist is None:
            self._prepare()
        figure = plt.figure(num=1, dpi=720)
        plt.xlabel('$K(j)(\Delta_j)$', fontsize=18)
        plt.ylabel('frequency', fontsize=18)
        plt.xlim((0, 10))
        plt.ylim((0.0, 1.0))
        plt.xticks(np.arange(0, 10, 1))
        # plt.hist(np.array(self.pred), bins=self.bins, density=False)
        plt.bar([i for i in range(self.bins)], self.pred_hist)
        plt.legend(['pred'], fontsize=18)
        temp_img = self._get_image(figure.canvas)
        plt.clf()
        plt.close()
        gc.collect()
        return temp_img

    def result_ref_histogram(self):
        if self.ref_hist is None:
            self._prepare()
        self._prepare()
        figure = plt.figure(num=1, dpi=720)
        plt.xlabel('$K(j)(\Delta_j)$', fontsize=18)
        plt.ylabel('frequency', fontsize=18)
        plt.xlim((0, 10))
        plt.ylim((0.0, 1.0))
        plt.xticks(np.arange(0, 10, 1))
        # plt.hist(np.array(self.ref), bins=self.bins, density=True)
        plt.bar([i for i in range(self.bins)], self.ref_hist)
        plt.legend(['GT'], fontsize=18)
        temp_img = self._get_image(figure.canvas)
        plt.clf()
        plt.close()
        gc.collect()
        return temp_img

    def result_overlap_histogram(self, path=None):
        if self.ref_hist is None or self.pred_hist is None:
            self._prepare()
        figure = plt.figure(num=1, dpi=720)
        plt.xlabel('$K(j)(\Delta_j)$', fontsize=18)
        plt.ylabel('frequency', fontsize=18)
        plt.xlim((0, 10))
        plt.ylim((0.0, 1.0))
        plt.xticks(np.arange(0, 10, 1))
        # plt.hist(self.ref, bins=self.bins, density=True, alpha=0.5, label='ref')
        # plt.hist(self.pred, bins=self.bins, density=True, alpha=0.5, label='pred')
        plt.bar([i for i in range(self.bins)], self.ref_hist, alpha=0.5, label='ref')
        plt.bar([i for i in range(self.bins)], self.pred_hist, alpha=0.5, label='pred')
        plt.legend(['GT', 'pred'], fontsize=18)
        if path is not None:
            plt.savefig(path)
        temp_img = self._get_image(figure.canvas)
        plt.clf()
        plt.close()
        gc.collect()
        return temp_img


    def _prepare(self):
        if len(self.pred) > 0:
            temp_pred = np.concatenate(self.pred)
        else:
            temp_pred = np.array([0])
        if len(self.ref) > 0:
            temp_ref = np.concatenate(self.ref)
        else:
            temp_ref = np.array([1])
        # self.bins = max(temp_pred.max(), temp_ref.max())
        self.bins = max(10, temp_ref.max())
        self.pred_hist = np.histogram(temp_pred, bins=[i for i in range(self.bins + 1)])[0]
        self.ref_hist = np.histogram(temp_ref, bins=[i for i in range(self.bins + 1)])[0]

        self.pred_hist = self.pred_hist / self.pred_hist.sum()
        self.ref_hist = self.ref_hist / self.ref_hist.sum()
        # self.sum_max = sum([max(self.pred_hist[i], self.ref_hist[i]) for i in range(self.bins)])

    def _get_image(self, plt_canvas):
        plt_canvas.draw()
        w, h = plt_canvas.get_width_height()
        buf = np.frombuffer(plt_canvas.tostring_argb(), dtype=np.uint8)
        buf.shape = (w, h, 4)
        buf = np.roll(buf, 3, axis=2)
        temp_image = Image.frombytes("RGBA", (w, h), buf.tobytes())
        temp_image = np.asarray(temp_image)
        return temp_image