import argparse
import shutil
import time
import paddle
from paddlespeech.cli import ASRExecutor

from PaddlePaddle_DeepSpeech2.data_utils.audio_process import AudioInferProcess
from PaddlePaddle_DeepSpeech2.utils.predict import Predictor
from PaddlePaddle_DeepSpeech2.utils.audio_vad import crop_audio_vad
import os

normal_speed = 4


# from data_utils.audio_process import AudioInferProcess
# from utils.predict import Predictor
# from utils.audio_vad import crop_audio_vad
# from utils.utility import add_arguments, print_arguments

# parser = argparse.ArgumentParser(description=__doc__)
# add_arg = functools.partial(add_arguments, argparser=parser)
# add_arg('wav_path', str, './dataset/test.wav', "预测音频的路径")
# add_arg('is_long_audio', bool, False, "是否为长语音")
# add_arg('use_gpu', bool, False, "是否使用GPU预测")
# add_arg('enable_mkldnn', bool, False, "是否使用mkldnn加速")
# add_arg('to_an', bool, True, "是否转为阿拉伯数字")
# add_arg('beam_size', int, 300, "集束搜索解码相关参数，搜索的大小，范围:[5, 500]")
# add_arg('alpha', float, 1.2, "集束搜索解码相关参数，LM系数")
# add_arg('beta', float, 0.35, "集束搜索解码相关参数，WC系数")
# add_arg('cutoff_prob', float, 0.99, "集束搜索解码相关参数，剪枝的概率")
# add_arg('cutoff_top_n', int, 40, "集束搜索解码相关参数，剪枝的最大值")
# add_arg('mean_std_path', str, './PaddlePaddle_DeepSpeech2/dataset/mean_std.npz', "数据集的均值和标准值的npy文件路径")
# add_arg('vocab_path', str, './PaddlePaddle_DeepSpeech2/dataset/zh_vocab.txt', "数据集的词汇表文件路径")
# add_arg('model_dir', str, './PaddlePaddle_DeepSpeech2/models/infer/', "导出的预测模型文件夹路径")
# add_arg('lang_model_path', str, './PaddlePaddle_DeepSpeech2/lm/zh_giga.no_cna_cmn.prune01244.klm',
#         "集束搜索解码相关参数，语言模型文件路径")
# add_arg('decoding_method', str, 'ctc_greedy', "结果解码方法，有集束搜索(ctc_beam_search)、贪婪策略(ctc_greedy)",
#         choices=['ctc_beam_search', 'ctc_greedy'])
# args = parser.parse_args()
# print_arguments(args)


# 使用paddle deepspeech进行语音识别
def predict_long_audio_with_paddle(wav_path, pre_time, state):
    # 获取数据生成器，处理数据和获取字典需要
    vocab_path = './PaddlePaddle_DeepSpeech2/dataset/zh_vocab.txt'
    mean_std_path = './PaddlePaddle_DeepSpeech2/dataset/mean_std.npz'
    decoding_method = 'ctc_greedy'
    alpha = 1.2
    beta = 0.35
    model_dir = './PaddlePaddle_DeepSpeech2/models/infer/'
    lang_model_path = './PaddlePaddle_DeepSpeech2/lm/zh_giga.no_cna_cmn.prune01244.klm'
    beam_size = 300
    cutoff_prob = 0.99
    cutoff_top_n = 40
    use_gpu = False
    enable_mkldnn = False

    audio_process = AudioInferProcess(vocab_filepath=vocab_path, mean_std_filepath=mean_std_path)

    predictor = Predictor(model_dir=model_dir, audio_process=audio_process, decoding_method=decoding_method,
                          alpha=alpha, beta=beta, lang_model_path=lang_model_path,
                          beam_size=beam_size,
                          cutoff_prob=cutoff_prob, cutoff_top_n=cutoff_top_n, use_gpu=use_gpu,
                          enable_mkldnn=enable_mkldnn)

    asr_executor = ASRExecutor()
    start = time.time()
    # 分割长音频
    audios_path, time_stamps = crop_audio_vad(wav_path)
    texts = ''
    narratages = []
    last_time = 0
    # 执行识别
    for i, audio_path in enumerate(audios_path):
        print("{}开始处理{}".format(paddle.get_device(), audio_path))
        # 标识当前语音识别的进度
        state[0] = float((i + 1) / len(audio_path)) if state[0] is None or state[0] < 0.99 else 0.99
        text = asr_executor(
            model='conformer_wenetspeech',
            lang='zh',
            sample_rate=16000,
            config=None,  # Set `config` and `ckpt_path` to None to use pretrained model.
            ckpt_path=None,
            audio_file=audio_path,
            force_yes=True,
            device=paddle.get_device()
        )
        if text:
            if i == 0 or (i > 0 and time_stamps[i][0] - last_time >= 1):
                recommend_lens = int(time_stamps[i][0] * normal_speed) if i == 0 else int(
                    (time_stamps[i][0] - last_time) * normal_speed)
                narratages.append(["", "", "", "插入旁白，推荐字数为%d" % recommend_lens])
            narratages.append(
                [round(time_stamps[i][0] + pre_time, 2), round(time_stamps[i][1] + pre_time, 2), text, ''])
            last_time = time_stamps[i][1]
        print(
            "第%d个分割音频 对应时间为%.2f-%.2f 识别结果: %s" % (i, time_stamps[i][0] + pre_time, time_stamps[i][1] + pre_time, text))
    print("最终结果，消耗时间：%d, 识别结果: %s" % (round((time.time() - start) * 1000), texts))

    # 完成后删除分割出来的音频
    save_path = os.path.join(os.path.dirname(wav_path), 'crop_audio')
    if os.path.exists(save_path):
        shutil.rmtree(save_path)

    return narratages


# # 使用网上已有的模型进行识别（效果差）
# def predict_audio_with_paddle():
#     start = time.time()
#     text = asr_executor(
#         model='conformer_wenetspeech',
#         lang='zh',
#         sample_rate=16000,
#         config=None,  # Set `config` and `ckpt_path` to None to use pretrained model.
#         ckpt_path=None,
#         audio_file=args.wav_path,
#         force_yes=False,
#         device=paddle.get_device()
#     )
#     print("消耗时间：%dms, 识别结果: %s" % (round((time.time() - start) * 1000), text))
#
#
# def predict_long_audio():
#     start = time.time()
#     # 分割长音频
#     audios_path = crop_audio_vad(args.wav_path)
#     texts = ''
#     scores = []
#     # 执行识别
#     for i, audio_path in enumerate(audios_path):
#         score, text = predictor.predict(audio_path=audio_path, to_an=args.to_an)
#         texts = texts + '，' + text
#         scores.append(score)
#         print("第%d个分割音频, 得分: %d, 识别结果: %s" % (i, score, text))
#     print("最终结果，消耗时间：%d, 得分: %d, 识别结果: %s" % (round((time.time() - start) * 1000), sum(scores) / len(scores), texts))
#
#
# def predict_audio():
#     start = time.time()
#     score, text = predictor.predict(audio_path=args.wav_path, to_an=args.to_an)
#     print("消耗时间：%dms, 识别结果: %s, 得分: %d" % (round((time.time() - start) * 1000), text, score))


if __name__ == "__main__":
    # if args.is_long_audio:
    #     # predict_long_audio()
    #     predict_long_audio_with_paddle()
    # else:
    #     # predict_audio()
    #     predict_audio_with_paddle()
    pass
