import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import models
from torch.autograd import Variable
import math
import pdb
import torch.nn.utils.weight_norm as weightNorm
from collections import OrderedDict
import torch.nn.functional as F
from models.layersFw import Conv2d_fw,BatchNorm2d_fw,Linear_fw,BatchNorm1d_fw

def init_weights(m):
    classname = m.__class__.__name__
    if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
        nn.init.kaiming_uniform_(m.weight)
        nn.init.zeros_(m.bias)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight, 1.0, 0.02)
        nn.init.zeros_(m.bias)
    elif classname.find('Linear') != -1:
        nn.init.xavier_normal_(m.weight)
        nn.init.zeros_(m.bias)

class feat_bootleneck(nn.Module):
    def __init__(self, feature_dim, bottleneck_dim=256, type="bn"):
        super(feat_bootleneck, self).__init__()
        # self.bn = BatchNorm1d_fw(bottleneck_dim, affin=True)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(p=0.5)
        self.bottleneck = Linear_fw(feature_dim, bottleneck_dim)
        self.bottleneck.apply(init_weights)
        self.type = type

    def forward(self, x):
        # print(x.shape)

        x = self.bottleneck(x)
        if self.type == "bn":
            x = self.bn(x)
            x = self.dropout(x)
        return x

class feat_classifier(nn.Module):
    def __init__(self, class_num=10,input_dim=256*4*4, bottleneck_dim=256, type="linear"):
        super(feat_classifier, self).__init__()
        # if type == "linear":
        # self.fc = Linear_fw(input_dim, class_num)
        self.hidden = 256
        self.class_num = class_num
        self.in_features = input_dim
        self.lin1 = Linear_fw(bottleneck_dim, self.hidden//2)
        # self.lin2 = Linear_fw(self.hidden, self.hidden // 2)
        self.lin3 = Linear_fw(self.hidden // 2, self.class_num)
        # else:
        #     self.fc = weightNorm(nn.Linear(bottleneck_dim, class_num), name="weight")
        self.lin1.apply(init_weights)
        # self.lin2.apply(init_weights)
        self.lin3.apply(init_weights)
        self.relu = F.leaky_relu
        # self.bn1 = torch.nn.BatchNorm1d(self.nhid,affine=False)
        # self.bn2 = torch.nn.BatchNorm1d(self.nhid // 2,affine=False)
        # self.bn3 = torch.nn.BatchNorm1d(self.num_classes,affine=False)
    def forward(self, x):

        x = self.relu(self.lin1(x), negative_slope=0.1)
        x=F.dropout(x,0.5)
        # x = self.relu(self.lin2(x), negative_slope=0.1)
        # x = F.dropout(x, 0.3)
        x = self.relu(self.lin3(x), negative_slope=0.1)

        # x = F.log_softmax(x, dim=-1)
        return x

class DTNBase(nn.Module):
    def __init__(self):
        super(DTNBase, self).__init__()
        self.conv_params = nn.Sequential(
                Conv2d_fw(3, 64, kernel_size=5, stride=2, padding=2),
                # BatchNorm2d_fw(64),
                nn.Dropout2d(0.5),
                nn.ReLU(),
                Conv2d_fw(64, 128, kernel_size=5, stride=2, padding=2),
                # BatchNorm2d_fw(128),
                nn.Dropout2d(0.5),
                nn.ReLU(),
                Conv2d_fw(128, 256, kernel_size=5, stride=2, padding=2),
                # BatchNorm2d_fw(256),
                nn.Dropout2d(0.5),
                nn.ReLU()
                )   
        self.in_features = 256*4*4

    def forward(self, x):
        x = self.conv_params(x)
        x = x.view(x.size(0), -1)
        return x

class LeNetBase(nn.Module):
    def __init__(self):
        super(LeNetBase, self).__init__()
        self.conv_params = nn.Sequential(
                Conv2d_fw(3, 20, kernel_size=5),
                nn.MaxPool2d(2),
                nn.ReLU(),
                Conv2d_fw(20, 50, kernel_size=5),
                nn.Dropout2d(p=0.5),
                nn.MaxPool2d(2),
                nn.ReLU(),
                )
        # self.in_features = 50*4*4
        self.in_features = 1250

    def forward(self, x):
        x = self.conv_params(x)
        x = x.view(x.size(0), -1)
        return x
        
import torchvision
from torchvision import models
import math
import pdb

def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0):
    return np.float(2.0 * (high - low) / (1.0 + np.exp(-alpha * iter_num / max_iter)) - (high - low) + low)


def init_weights(m):
    classname = m.__class__.__name__
    if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
        nn.init.kaiming_uniform_(m.weight)
        nn.init.zeros_(m.bias)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight, 1.0, 0.02)
        nn.init.zeros_(m.bias)
    elif classname.find('Linear') != -1:
        nn.init.xavier_normal_(m.weight)
        nn.init.zeros_(m.bias)


resnet_dict = {"ResNet18": models.resnet18, "ResNet34": models.resnet34, "ResNet50": models.resnet50,
               "ResNet101": models.resnet101, "ResNet152": models.resnet152}


def grl_hook(coeff):
    def fun1(grad):
        return -coeff * grad.clone()

    return fun1


class ResNetFc(nn.Module):
    def __init__(self, resnet_name, use_bottleneck=True, bottleneck_dim=256, new_cls=False, class_num=1000):
        super(ResNetFc, self).__init__()
        model_resnet = resnet_dict[resnet_name](pretrained=True)
        self.conv1 = model_resnet.conv1
        self.bn1 = model_resnet.bn1
        self.relu = model_resnet.relu
        self.maxpool = model_resnet.maxpool
        self.layer1 = model_resnet.layer1
        self.layer2 = model_resnet.layer2
        self.layer3 = model_resnet.layer3
        self.layer4 = model_resnet.layer4

        self.avgpool = model_resnet.avgpool
        self.feature_layers = nn.Sequential(self.conv1, self.bn1, self.relu, self.maxpool, \
                                            self.layer1, self.layer2, self.layer3, self.layer4, self.avgpool)

        self.use_bottleneck = use_bottleneck
        self.new_cls = new_cls
        if new_cls:
            if self.use_bottleneck:
                self.bottleneck = nn.Linear(model_resnet.fc.in_features, bottleneck_dim)
                self.fc = nn.Linear(bottleneck_dim, class_num)
                self.bottleneck.apply(init_weights)
                self.fc.apply(init_weights)
                self.__in_features = bottleneck_dim
            else:
                self.fc = nn.Linear(model_resnet.fc.in_features, class_num)
                self.fc.apply(init_weights)
                self.__in_features = model_resnet.fc.in_features
        else:
            self.fc = model_resnet.fc
            self.__in_features = model_resnet.fc.in_features

    def forward(self, x):
        x = self.feature_layers(x)
        x = x.view(x.size(0), -1)
        if self.use_bottleneck and self.new_cls:
            x = self.bottleneck(x)
        # y = self.fc(x)
        return x

    def output_num(self):
        return self.__in_features

    def get_parameters(self):
        if self.new_cls:
            if self.use_bottleneck:
                parameter_list = [{"params": self.feature_layers.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
                                  {"params": self.bottleneck.parameters(), "lr_mult": 10, 'decay_mult': 2}, \
                                  {"params": self.fc.parameters(), "lr_mult": 10, 'decay_mult': 2}]
            else:
                parameter_list = [{"params": self.feature_layers.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
                                  {"params": self.fc.parameters(), "lr_mult": 10, 'decay_mult': 2}]
        else:
            parameter_list = [{"params": self.parameters(), "lr_mult": 1, 'decay_mult': 2}]
        return parameter_list

class RandomLayer(nn.Module):
    def __init__(self, input_dim_list=[], output_dim=1024):
        super(RandomLayer, self).__init__()
        self.input_num = len(input_dim_list)
        self.output_dim = output_dim
        self.random_matrix = [torch.randn(input_dim_list[i], output_dim) for i in range(self.input_num)]

    def forward(self, input_list):
        return_list = [torch.mm(input_list[i], self.random_matrix[i]) for i in range(self.input_num)]
        return_tensor = return_list[0] / math.pow(float(self.output_dim), 1.0 / len(return_list))
        for single in return_list[1:]:
            return_tensor = torch.mul(return_tensor, single)
        return return_tensor

    def cuda(self):
        super(RandomLayer, self).cuda()
        self.random_matrix = [val.cuda() for val in self.random_matrix]


# class LRN(nn.Module):
#     def __init__(self, local_size=1, alpha=1.0, beta=0.75, ACROSS_CHANNELS=True):
#         super(LRN, self).__init__()
#         self.ACROSS_CHANNELS = ACROSS_CHANNELS
#         if ACROSS_CHANNELS:
#             self.average = nn.AvgPool3d(kernel_size=(local_size, 1, 1),
#                                         stride=1,
#                                         padding=(int((local_size - 1.0) / 2), 0, 0))
#         else:
#             self.average = nn.AvgPool2d(kernel_size=local_size,
#                                         stride=1,
#                                         padding=int((local_size - 1.0) / 2))
#         self.alpha = alpha
#         self.beta = beta
#
#     def forward(self, x):
#         if self.ACROSS_CHANNELS:
#             div = x.pow(2).unsqueeze(1)
#             div = self.average(div).squeeze(1)
#             div = div.mul(self.alpha).add(1.0).pow(self.beta)
#         else:
#             div = x.pow(2)
#             div = self.average(div)
#             div = div.mul(self.alpha).add(1.0).pow(self.beta)
#         x = x.div(div)
#         return x
#
#
# class AlexNet(nn.Module):
#
#     def __init__(self, num_classes=1000):
#         super(AlexNet, self).__init__()
#         self.features = nn.Sequential(
#             nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0),
#             nn.ReLU(inplace=True),
#             LRN(local_size=5, alpha=0.0001, beta=0.75),
#             nn.MaxPool2d(kernel_size=3, stride=2),
#             nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2),
#             nn.ReLU(inplace=True),
#             LRN(local_size=5, alpha=0.0001, beta=0.75),
#             nn.MaxPool2d(kernel_size=3, stride=2),
#             nn.Conv2d(256, 384, kernel_size=3, padding=1),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2),
#             nn.ReLU(inplace=True),
#             nn.MaxPool2d(kernel_size=3, stride=2),
#         )
#         self.classifier = nn.Sequential(
#             nn.Linear(256 * 6 * 6, 4096),
#             nn.ReLU(inplace=True),
#             nn.Dropout(),
#             nn.Linear(4096, 4096),
#             nn.ReLU(inplace=True),
#             nn.Dropout(),
#             nn.Linear(4096, num_classes),
#         )
#
#     def forward(self, x):
#         x = self.features(x)
#         print(x.size())
#         x = x.view(x.size(0), 256 * 6 * 6)
#         x = self.classifier(x)
#         return x
#
#
# def alexnet(pretrained=False, **kwargs):
#     r"""AlexNet model architecture from the
#     `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
#     Args:
#         pretrained (bool): If True, returns a model pre-trained on ImageNet
#     """
#     model = AlexNet(**kwargs)
#     if pretrained:
#         model_path = './alexnet.pth.tar'
#         pretrained_model = torch.load(model_path)
#         model.load_state_dict(pretrained_model['state_dict'])
#     return model
#
#
# # convnet without the last layer
# class AlexNetFc(nn.Module):
#     def __init__(self, use_bottleneck=True, bottleneck_dim=256, new_cls=False, class_num=1000):
#         super(AlexNetFc, self).__init__()
#         model_alexnet = alexnet(pretrained=True)
#         self.features = model_alexnet.features
#         self.classifier = nn.Sequential()
#         for i in range(6):
#             self.classifier.add_module("classifier" + str(i), model_alexnet.classifier[i])
#         self.feature_layers = nn.Sequential(self.features, self.classifier)
#
#         self.use_bottleneck = use_bottleneck
#         self.new_cls = new_cls
#         if new_cls:
#             if self.use_bottleneck:
#                 self.bottleneck = nn.Linear(4096, bottleneck_dim)
#                 self.fc = nn.Linear(bottleneck_dim, class_num)
#                 self.bottleneck.apply(init_weights)
#                 self.fc.apply(init_weights)
#                 self.__in_features = bottleneck_dim
#             else:
#                 self.fc = nn.Linear(4096, class_num)
#                 self.fc.apply(init_weights)
#                 self.__in_features = 4096
#         else:
#             self.fc = model_alexnet.classifier[6]
#             self.__in_features = 4096
#
#     def forward(self, x):
#         x = self.features(x)
#         x = x.view(x.size(0), -1)
#         x = self.classifier(x)
#         if self.use_bottleneck and self.new_cls:
#             x = self.bottleneck(x)
#         y = self.fc(x)
#         return x, y
#
#     def output_num(self):
#         return self.__in_features
#
#     def get_parameters(self):
#         if self.new_cls:
#             if self.use_bottleneck:
#                 parameter_list = [{"params": self.features.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
#                                   {"params": self.classifier.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
#                                   {"params": self.bottleneck.parameters(), "lr_mult": 10, 'decay_mult': 2}, \
#                                   {"params": self.fc.parameters(), "lr_mult": 10, 'decay_mult': 2}]
#             else:
#                 parameter_list = [{"params": self.feature_layers.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
#                                   {"params": self.classifier.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
#                                   {"params": self.fc.parameters(), "lr_mult": 10, 'decay_mult': 2}]
#         else:
#             parameter_list = [{"params": self.parameters(), "lr_mult": 1, 'decay_mult': 2}]
#         return parameter_list
#
#
#
# vgg_dict = {"VGG11": models.vgg11, "VGG13": models.vgg13, "VGG16": models.vgg16, "VGG19": models.vgg19,
#             "VGG11BN": models.vgg11_bn, "VGG13BN": models.vgg13_bn, "VGG16BN": models.vgg16_bn,
#             "VGG19BN": models.vgg19_bn}
#
#
# class VGGFc(nn.Module):
#     def __init__(self, vgg_name, use_bottleneck=True, bottleneck_dim=256, new_cls=False, class_num=1000):
#         super(VGGFc, self).__init__()
#         model_vgg = vgg_dict[vgg_name](pretrained=True)
#         self.features = model_vgg.features
#         self.classifier = nn.Sequential()
#         for i in range(6):
#             self.classifier.add_module("classifier" + str(i), model_vgg.classifier[i])
#         self.feature_layers = nn.Sequential(self.features, self.classifier)
#
#         self.use_bottleneck = use_bottleneck
#         self.new_cls = new_cls
#         if new_cls:
#             if self.use_bottleneck:
#                 self.bottleneck = nn.Linear(4096, bottleneck_dim)
#                 self.fc = nn.Linear(bottleneck_dim, class_num)
#                 self.bottleneck.apply(init_weights)
#                 self.fc.apply(init_weights)
#                 self.__in_features = bottleneck_dim
#             else:
#                 self.fc = nn.Linear(4096, class_num)
#                 self.fc.apply(init_weights)
#                 self.__in_features = 4096
#         else:
#             self.fc = model_vgg.classifier[6]
#             self.__in_features = 4096
#
#     def forward(self, x):
#         x = self.features(x)
#         x = x.view(x.size(0), -1)
#         x = self.classifier(x)
#         if self.use_bottleneck and self.new_cls:
#             x = self.bottleneck(x)
#         y = self.fc(x)
#         return x, y
#
#     def output_num(self):
#         return self.__in_features
#
#     def get_parameters(self):
#         if self.new_cls:
#             if self.use_bottleneck:
#                 parameter_list = [{"params": self.features.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
#                                   {"params": self.classifier.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
#                                   {"params": self.bottleneck.parameters(), "lr_mult": 10, 'decay_mult': 2}, \
#                                   {"params": self.fc.parameters(), "lr_mult": 10, 'decay_mult': 2}]
#             else:
#                 parameter_list = [{"params": self.feature_layers.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
#                                   {"params": self.classifier.parameters(), "lr_mult": 1, 'decay_mult': 2}, \
#                                   {"params": self.fc.parameters(), "lr_mult": 10, 'decay_mult': 2}]
#         else:
#             parameter_list = [{"params": self.parameters(), "lr_mult": 1, 'decay_mult': 2}]
#         return parameter_list
#
#
# # For SVHN dataset
# class DTN(nn.Module):
#     def __init__(self):
#         super(DTN, self).__init__()
#         self.conv_params = nn.Sequential(
#             nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),
#             nn.BatchNorm2d(64),
#             nn.Dropout2d(0.1),
#             nn.ReLU(),
#             nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
#             nn.BatchNorm2d(128),
#             nn.Dropout2d(0.3),
#             nn.ReLU(),
#             nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
#             nn.BatchNorm2d(256),
#             nn.Dropout2d(0.5),
#             nn.ReLU()
#         )
#
#         self.fc_params = nn.Sequential(
#             nn.Linear(256 * 4 * 4, 512),
#             nn.BatchNorm1d(512),
#             nn.ReLU(),
#             nn.Dropout()
#         )
#
#         self.classifier = nn.Linear(512, 10)
#         self.__in_features = 512
#
#     def forward(self, x):
#         x = self.conv_params(x)
#         x = x.view(x.size(0), -1)
#         x = self.fc_params(x)
#         y = self.classifier(x)
#         return x, y
#
#     def output_num(self):
#         return self.__in_features
#
#
# class LeNet(nn.Module):
#     def __init__(self):
#         super(LeNet, self).__init__()
#         self.conv_params = nn.Sequential(
#             nn.Conv2d(1, 20, kernel_size=5),
#             nn.MaxPool2d(2),
#             nn.ReLU(),
#             nn.Conv2d(20, 50, kernel_size=5),
#             nn.Dropout2d(p=0.5),
#             nn.MaxPool2d(2),
#             nn.ReLU(),
#         )
#
#         self.fc_params = nn.Sequential(nn.Linear(50 * 4 * 4, 500), nn.ReLU(), nn.Dropout(p=0.5))
#         self.classifier = nn.Linear(500, 10)
#         self.__in_features = 500
#
#     def forward(self, x):
#         x = self.conv_params(x)
#         x = x.view(x.size(0), -1)
#         x = self.fc_params(x)
#         y = self.classifier(x)
#         return x, y
#
#     def output_num(self):
#         return self.__in_features


# class AdversarialNetwork(nn.Module):
#     def __init__(self, in_feature, hidden_size):
#         super(AdversarialNetwork, self).__init__()
#         self.ad_layer1 = nn.Linear(in_feature, hidden_size)
#         self.ad_layer2 = nn.Linear(hidden_size, hidden_size)
#         self.ad_layer3 = nn.Linear(hidden_size, 1)
#         self.relu1 = nn.ReLU()
#         self.relu2 = nn.ReLU()
#         self.dropout1 = nn.Dropout(0.5)
#         self.dropout2 = nn.Dropout(0.5)
#         self.sigmoid = nn.Sigmoid()
#         self.apply(init_weights)
#         self.iter_num = 0
#         self.alpha = 10
#         self.low = 0.0
#         self.high = 1.0
#         self.max_iter = 10000.0
#
#     def forward(self, x):
#         if self.training:
#             self.iter_num += 1
#         coeff = calc_coeff(self.iter_num, self.high, self.low, self.alpha, self.max_iter)
#         x = x * 1.0
#         x.register_hook(grl_hook(coeff))
#         x = self.ad_layer1(x)
#         x = self.relu1(x)
#         x = self.dropout1(x)
#         x = self.ad_layer2(x)
#         x = self.relu2(x)
#         x = self.dropout2(x)
#         y = self.ad_layer3(x)
#         y = self.sigmoid(y)
#         return y
#
#     def output_num(self):
#         return 1
#
#     def get_parameters(self):
#         return [{"params": self.parameters(), "lr_mult": 10, 'decay_mult': 2}]
