# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

TAG=$1        # num of GPUs to use
#CODE_ROOT=$3   # path/to/code_root
OUTPUT_DIR=checkpoints  # output dir to save checkpoints, decodings, etc
FT_BIN=data/bin/ft_en_zh/
BT_BIN=data/bin/bt_en_zh/

langs=en,zh
ft_langs=en-zh
ft_domain=LYRICS

lr=$2

TBS=1024
max_tokens=$3
max_pos=2048
update_freq=1

warmup=10


mepoch=${4}
do_backtrans=${5}
#prefix=${4}

#word_shuffle=3
#word_dropout=0.1
#word_blank=0.1
word_shuffle=0
word_dropout=0.0
word_blank=0.0
without_portion=0
mask_rate=0.3
poisson_lbd=3.5

stop_epsilon=0.05

task=xdae_multilingual_translation_with_melody
if [ "$do_backtrans" == "true" ];
then
   EXP="FT_GP_eps${stop_epsilon}_${ft_langs}_backtrans_lr${lr}_m${mepoch}_mtoken${max_pos}_upf${update_freq}_M${TAG}"
elif [ "$do_backtrans" == "only" ];
then
   EXP="FT_GP_eps${stop_epsilon}_${ft_langs}_only_bt_lr${lr}_m${mepoch}_mtoken${max_pos}_upf${update_freq}_M${TAG}"
else
   EXP="FT_GP_eps${stop_epsilon}_${ft_langs}_lr${lr}_m${mepoch}_mtoken${max_pos}_upf${update_freq}_M${TAG}"
fi



SUFFIX=""
#if [ ! -f $SAVE/checkpoint_last.pt ]; then
   #echo "copy pretrained model to last"
   #cp $PRETRAIN $SAVE/checkpoint_last.pt
#fi

if [ ! -f $SAVE/checkpoint_last.pt ]; then
   SUFFIX="$SUFFIX --reset-dataloader --reset-lr-scheduler --reset-meters --reset-optimizer"
fi

if [ "$do_backtrans" == "true" ];
then
   ALL_BIN="$FT_BIN:$BT_BIN"
   do_backtrans="--with-backtrans-data"
   estep=800
elif [ "$do_backtrans" == "only" ];
then
   ALL_BIN="$FT_BIN:$BT_BIN"
   do_backtrans="--only-backtrans-data"
   estep=800
else
   ALL_BIN=$FT_BIN
   do_backtrans=''
   estep=400
fi

if [ $without_portion -eq 1 ]; then
  EXP="NP_$EXP"
  SUFFIX="$SUFFIX --without-portion"
  echo $EXP
fi

SAVE=${OUTPUT_DIR}/$EXP
LOG=$SAVE/log

mkdir -p $SAVE
mkdir -p $LOG

NOW=`date '+%F_%H_%M_%S'`
CUDA_VISIBLE_DEVICES=$TAG fairseq-train $ALL_BIN \
           --act-stop-epsilon $stop_epsilon \
           --adam-eps 1e-06 \
           --adam-betas '(0.9, 0.98)' \
           --add-lang-token \
           --alignment-lambda 0.5 \
           --alignment-decoder-type 'grouping' \
           --attention-dropout 0.1  \
           --arch mbart_base_with_melody \
           --criterion label_smoothed_cross_entropy_with_alignment \
           --ddp-backend no_c10d \
           --decoder-layers 12 \
           --domains LYRICS,WMT \
           --dropout 0.1 \
           --dur-type-num 30 \
           --encoder-layers 12 \
           --eval-align-dist \
           --finetune-data $FT_BIN \
           --finetune-domain LYRICS \
           --finetune-langs $ft_langs \
           --grouping-arch 'act' \
           --keep-interval-updates 1 \
           --kernel-size 3 \
           --langs en,zh \
           --layernorm-embedding \
           --length-control-type 'simple' \
           --log-format simple --log-interval 5 \
           --lr-scheduler inverse_sqrt \
           --lr $lr \
           --max-tokens $max_tokens \
           --max-epoch $mepoch \
           --max-delta-note 85 \
           --max-source-positions $max_pos \
           --max-target-positions $max_pos \
           --multi-alignment-weight 4.0 \
           --note-num 128 \
           --optimizer adam \
           --pretrained-mt-ckpt-dir checkpoints/Pretrain_all_musescore_filtered_single_tag_lr5e-4_m50_r0.5_mtoken2048_upf5_M0,1,2,3,4,5/checkpoint_best.pt \
           --predictor-dropout 0.5 \
           --predictor-layers 5 \
           --save-dir $SAVE  --save-interval 1 \
           --save-interval-updates 100000 \
           --skip-invalid-size-inputs-valid-test \
           --sample-break-mode eos \
           --share-all-embeddings \
           --stop-min-lr 1e-09 \
           --task $task \
           --tensorboard-logdir $SAVE \
           --validation-inference-interval 1 \
           --warmup-init-lr 1e-07 \
           --warmup-updates $warmup \
           --weight-decay 0.01 \
           --update-freq $update_freq \
           $do_backtrans $SUFFIX 2>&1 | tee $LOG/log_$NOW.txt
#           --with-backtrans-data $do_backtrans \
#           --eval-inference \
#           --eval-inference-start-step $estep \