CenterLoss在Mnist数据集上的实现
centerloss,顾名思义,中心损失函数,它的原理主要是在softmax loss的基础上,通过对训练集的每个类别在特征空间分别维护一个类中心,在训练过程,增加样本经过网络映射后在特征空间与类中心的距离约束,从而兼顾了类内聚合与类间分离,centerloss只是一个辅助损失函数,softmaxloss才是主打,但softmaxloss只能简单的将类分开,还得加上centerloss这一个强力辅助才能保证特征之间不仅具有可分性,同时也具有可判别性。
我们都知道,对于分类来说,希望类内距小,类间距大,那centerloss+softmaxloss就有这种功能。
简单复习一下softmax函数:
关于softmax的这个函数,有一些基本特性:是归一化指数函数,本质是离散概率分布,常用于多分类,值域为[0,1],输出结果之和为1。
那接着就来看看softmaxloss这个损失函数:
其中Sj为sigmoid输出的值,yj为标签对应独热编码的值(0或者1)
因此softmaxloss可以化简为:
log函数大家都知道,是一个定义域为[0,+∞],值域在[-∞,∞]的增函数,那么softmaxloss定义域在[0,1],取log就是在[-∞,1],那么取-log整个函数最终就变成了定义域在[0,1],值域在[0,+∞]的减函数,并且过(1,0)这个点,这一点正好符合我们梯度下降(当概率为1,损失下降到0),因此我们就可以使用softmaxloss来一步步降低分类的损失。
关于cneterloss,可以先看看公式:
N表示mini-batch的大小,xi表示输出特征,C表示对应的i个类中心,因此centerloss就是希望一个batch中的每个样本的feature离feature 的中心的距离的平方和要越小越好,也就是类内距离要越小越好。
反向传播:
α是学习率,也就是步长,设置一般取值0.5。
这里有一个问题就是centerloss学习率取值为0.5,那如果用同一个优化器进行优化,必然会造成softmaxloss的梯度爆炸,导致整个模型崩溃。
因此这里我们想到用两个优化器进行优化,分别优化centerloss和softmaxloss,这一点可以在代码里看到。
两个损失函数共同作用,softmaxloss负责大致分开各数据,centerloss使类内距越来越小,各司其职,达到把特征区分到最佳效果。其中λ是一个超参数,表示训练时更加倾向于哪个的损失,我在训练时候,λ选择2。
下面看看训练的效果吧:
我只训练了39轮,其实还是能看出来效果还是挺好的。
大概提一下训练过程中的坑吧,因为这些中心点是随机的,有可能随机到的中心点不好,数据点久久不能分开,建议中止训练重新开始或者直接删除参数重新训练。还有就是λ的值对结果影响挺大的,小心调参。
代码:
import torch import torch.nn as nn class CLNet(nn.Module): def __init__(self): super().__init__() self.center = nn.Parameter(torch.randn(100, 2), requires_grad=True) # (10, 2) def forward(self, feature, label, lambdas=2): center_exp = self.center.index_select(dim=0, index=label.long()) # (100, 2) count = torch.histc(label, bins=int(max(label).item() + 1), min=int(min(label).item()), max=int(max(label).item())) # (10,) count_exp = count.index_select(dim=0, index=label.long()) # (100,) loss = lambdas / 2 * torch.mean(torch.div(torch.sum(torch.pow(feature - center_exp, 2), dim=1), count_exp)) return loss
import torch from Net_Model import Net from centerloss import CLNet import torch.nn as nn from torchvision import transforms, datasets import os class Trainer: def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.s_net = Net().to(self.device) self.c_net = CLNet().to(self.device) self.s_save_path = "models/softmax_net.pth" self.c_save_path = "models/center_net.pth" self.nll_loss = nn.NLLLoss() self.s_optimizer = torch.optim.SGD(self.s_net.parameters(), lr=0.0005, momentum=0.9, weight_decay=0.0005) self.c_optimizer = torch.optim.SGD(self.c_net.parameters(), lr=0.5) self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.s_optimizer, gamma=0.95, last_epoch=-1) self.mean, self.std = self.mean_std() self.dataLoader = self.data_loader() def mean_std(self): sets = datasets.MNIST("./MNIST", train=True, download=False, transform=transforms.ToTensor()) loader = torch.utils.data.DataLoader(sets, batch_size=len(sets), shuffle=True) data = next(iter(loader))[0] mean = round(torch.mean(data, dim=(0, 2, 3)).item(), 3) std = round(torch.std(data, dim=(0, 2, 3)).item(), 3) return mean, std def data_loader(self): transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((self.mean,), (self.std,)) ]) dataSet = datasets.MNIST("./MNIST", train=True, download=False, transform=transform) dataLoader = torch.utils.data.DataLoader(dataSet, batch_size=100, shuffle=True, num_workers=4) return dataLoader def train_test(self): if os.path.exists(self.s_save_path) and os.path.exists(self.c_save_path): self.s_net.load_state_dict(torch.load(self.s_save_path)) self.c_net.load_state_dict(torch.load(self.c_save_path)) else: print("NO Param") epoch = 0 while True: feature_loader = [] label_loader = [] for i, (x, y) in enumerate(self.dataLoader): x = x.to(self.device) y = y.to(self.device) feature, output = self.s_net(x) nll_loss = self.nll_loss(output, y) y = y.float() center_loss = self.c_net(feature, y, 2) loss = nll_loss + center_loss self.s_optimizer.zero_grad() self.c_optimizer.zero_grad() loss.backward() self.s_optimizer.step() self.c_optimizer.step() feature_loader.append(feature) label_loader.append(y) if i % 100 == 0: print("epoch:", epoch, "i:", i, "loss:", loss.item(), "softmax_loss:", nll_loss.item(), "center_loss:", center_loss.item()) features = torch.cat(feature_loader, dim=0) labels = torch.cat(label_loader, dim=0) self.s_net.visualize(features.data.cpu().numpy(), labels.data.cpu().numpy(), epoch) torch.save(self.s_net.state_dict(), self.s_save_path) torch.save(self.c_net.state_dict(), self.c_save_path) self.scheduler.step(None) epoch += 1 if epoch == 40: break if __name__ == '__main__': Trainer=Trainer() Trainer.train_test()