深度学习(五)之原型网络

在本文中,将介绍一些关于小样本学习的相关知识,以及介绍如何使用pytorch构建一个原型网络(Prototypical Networks[1]),并应用于miniImageNet 数据集。

实验环境:

pytorch:1.11.0
代码地址:https://github.com/xiaohuiduan/deeplearning-study/tree/main/小样本学习

小样本学习引入

在这一节将简要的对小样本学习(FSL)相关的知识进行介绍。由于我并不是专门研究小样本的(我学习FSL也只是为了完成我的课程作业),因此,如果本文存在任何问题,欢迎进行批评指正😎。

邮箱📫:xiaohuiduan@hunnu.edu.cn

首先将模型看成一个黑盒子,不去关注它的内部结构,而是关注其inputoutput

在分类模型[2]中,input是一张猫or狗的图片,output则为0/1(代表其为猫或者狗;实际上输出的是两者的预测概率)。

深度学习(五)之原型网络

但是关于上面的模型,存在一个问题,训练这样的模型需要大量的数据,根据@petewarden[3]的说法,训练一个分类图片的网络,每个类别需要大约1000张。但是,很多场景我们并没有足够的数据进行训练,也就是说数据集的样本比较少,这时候针对小样本数据集可以有两种处理方式[4]

  • 数据增强:比如说对图像进行旋转,裁剪等等。

  • 数据建模:如使用小样本学习的方法进行对数据进行建模。

小样本学习是什么

以VGG模型预测猫狗分类举例,模型对于一张新的图片的预测可以形象的解释为:

图片中的动物因为有着尖尖的耳朵,有着较长的胡须,鼻子那一个地方不是很突出,因此,我(VGG)判断它是一只猫o(=•ェ•=)m。

但是在小样本学习中,不是这样的,小样本学习对于一张新的图片的预测可以形象的解释为:

我手中有2张图片,图片A和图片B。对于新的图片,我(模型)也不知道他是啥,但是我发现它跟图片B长得很相似,因此我(模型)判断这张新的图片和图片B是同一个类别。

在上述解释中,图片A和B称之为support sets,而新的图片称之为query sets

小样本的分类模型与传统的深度学习分类模型(如VGG)有着不同,这里引用论文[1:1]的一句话:

Few-shot classification is a task in which a classifier must be adapted to accommodate
new classes not seen in training, given only a few examples of each of these classes. A naive approach,
such as re-training the model on the new data, would severely overfit.

也就是说,小样本分类并不是像VGG一样,test的数据是以前训练过的类别,对于小样本分类来说,进行test的数据是一些新的类,并且这些类别的样本很少,因此,没法对其进行re-training,否则会造成过拟合。

小样本学习方法

小样本学习可以认为是一个N-way K-shot的分类问题(不确定是不是所有的小样本分类任务都被认为是N-way K-shot分类问题)。

无论对于测试集还是训练集,都需要进行如下的划分,将数据集划分为两个部分:左边DataSet代表数据集,右边分别代表Support set,右边代表Query Set。在train或者test数据集中:

  1. 首先对于所有类别,随机选择其中N(图中N=3)个类别(图中,选择了类别2,类别3和类别5)。
  2. 在step 1中选择的类别样本中,随机选择K(图中K=3)个样本(绿色的部分),构成Support Set。也就是说,Support Set中拥有K*N个样本。
  3. 然后在所选择类别的剩余样本中,选择X(这里X=1)个样本(红色的部分),构成Query Set。也就是说,Query Set中拥有X*N个样本。

在训练集中,以上步骤构成的Support Set和Query Set会被input到Model中进行训练,称之为一个eposides(相当于mini-batch)。

深度学习(五)之原型网络

对于模型来说,其目的则为判断Query Set中的样本与哪一个支持集最相似

原型网络(Prototype Network)

原理简述

Prototype Network的原理很简单,可以简单的概括为:将support set中的图片(data_1,data_2,cdots,data_n)映射到某一个向量空间(c_1,c_2,cdots,c_n);对于Query set中的某一张图片(query_i)使用同一个映射函数,也映射到到向量空间(x_i),然后判断(x_i)(c_1,c_2,cdots,c_n)的距离(余弦距离or欧氏距离),选择距离最近向量所对应的类别作为(query_i)所属的类别。

示意图[5]如下所示:

深度学习(五)之原型网络

如果了解NLP中word2vec的话,会发现,其与word2vec的Embedding思想是很相似的。

算法流程

算法流程图[1:2]如下所示,红色框和绿色框中的过程已经在前文进行介绍,这里主要是来介绍一下loss的计算方式。

深度学习(五)之原型网络

实际上,loss的计算方式就是一个交叉熵损失函数,pytorch中CrossEntropyLoss的计算方法如下所示,class代表(x)实际所属类别(x[j])代表模型对于(x)所属类别(j)的概率预测。

[operatorname{loss}(x, text { class })=-log left(frac{exp (x[text { class }])}{sum_{j} exp (x[j])}right)=-x[text { class }]+log left(sum_{j} exp (x[j])right) ]

但是,在算法流程图中,大家会发现,其loss计算的正负号刚好与上面公式中的相反,解释如下:

以欧式距离为例,距离越远((d)则越大),则代表两者的相似度越低。如果不加负号的话,进行softmax计算,距离越远的则predict概率越大,这明显是错误的。因此,加了一个负号之后,距离越远,进行softmax之后,输出则越小,predict的概率也变小,这才是合理的。

以上,便是原型网络的算法流程。

算法实现

数据集处理

mini-Imagenet是一个专门用于训练小样本学习的训练集,数据集中一共有100个类别,每个类别600张图片,一共有60000张图片。数据集可以从mini-ImageNet | Kaggles上面下载。在下载文件中,一共有4个文件,一个是数据集图片的压缩包,另外3个csv文件分别代表了训练、验证和测试集相关的信息。其中训练集有64个类,验证集16个类,测试集20个类。

csv的部分数据如下所示,filename代表了图片的名字,label代表了图片对应的标签。

深度学习(五)之原型网络

因此,可以构建一个label所对应filename的字典

def read_csv(csv_path):     dict = collections.defaultdict(list)     df = pd.read_csv(csv_path)     for index,row in df.iterrows():         dict[row["label"]].append(row["filename"])     return dict train_dict = read_csv(train_csv_path) val_dict = read_csv(val_csv_path) test_dict = read_csv(test_csv_path) 

同时,构建data与labels的对应关系:

from PIL import Image import numpy as np from torchvision import transforms from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True  resize_transform = transforms.Resize(84) # 提前对图片进行缩放,以节省内存空间,将最短的边变成84 def build_data(data_dict):     datas = []     labels = []     label_index = 0     for label in data_dict.keys(): # 对图片的标签进行迭代         for path in data_dict[label]: # 对标签对应的文件名进行迭代             img_path = os.path.join(img_root_dir,path)              img = Image.open(img_path) # 读取文件             img = resize_transform(img) # 进行缩放             datas.append(img)              labels.append(label_index)          label_index += 1     return {"datas":datas,"labels":labels} 

随机产生Support set 和 Query set

在下面代码中,CategoriesSampler的作用是为了产生index,然后供给dataloader使用。

class CategoriesSampler():     """         目的是为了随机产生K_way*(N_support+N_query)个图片对应的index     """     def __init__(self, data, n_batch, K_way, N_per):         self.n_batch = n_batch         self.K_way = K_way         self.N_per = N_per         labels = np.array(data["labels"]) # [0,0,0,0,1,1,1,1,2,2,2,2……]         self.index = [] # 记录label对应的索引位置         for i in range(max(labels)+1):             ind = np.argwhere(labels == i).reshape(-1)             self.index.append(torch.from_numpy(ind))         def __len__(self):         return self.n_batch          def __iter__(self):         for i_batch in range(self.n_batch):               batch = []             classes = torch.randperm(len(self.index))[:self.K_way] # 随机选择K个类别构成support set和query set             for c in classes:                 l = self.index[c] # 类别c对应的图片数组的索引,如 l = [5,6,7,8,9]                 pos = torch.randperm(len(l))[:self.N_per] # 如 pos = [4,1,0]                 batch.append(l[pos]) # 如 l[pos] = [9,6,5]              batch = torch.stack(batch).reshape(-1)             yield batch 

同时,定义Dataset类,如下:

class MiniImageNet(Dataset):      def __init__(self, data):          self.datas = data["datas"]         self.labels = data["labels"]         self.transform = transforms.Compose([             transforms.CenterCrop(84),             transforms.ToTensor(),             transforms.Normalize(mean=[0.485, 0.456, 0.406],                                  std=[0.229, 0.224, 0.225])         ])      def __len__(self):         return len(self.datas)      def __getitem__(self, i):         img, label = self.datas[i], self.labels[i]         return self.transform(img), label 

使用示例如下,在MiniImageNet的__getitem__函数中,其参数iCategoriesSampler__iter__函数所产生。

batch_sampler = CategoriesSampler(datas,eval_step,K_way,N_shot+N_query) data_loader = DataLoader(dataset=data_set, batch_sampler=batch_sampler,                                             num_workers=16, pin_memory=True) 

实际上,上面的代码就是为了实现如下图所示的功能:

深度学习(五)之原型网络

Embedding网络构建

前文说到,需要将图片进行向量化表示,在下面的代码中,便可以将一张图片(shape=(3times84times84))变成一个1600维的向量。(网络结构来自于论文)

class CNN_Net(nn.Module):     """         用于特征提取     """      def __init__(self, input_dim):         super(CNN_Net, self).__init__()                  self.input_dim = input_dim         def conv_block(in_channel,out_channel):             return nn.Sequential(                 nn.Conv2d(in_channel, out_channel, 3,padding=1),                 nn.BatchNorm2d(out_channel),                 nn.ReLU(),                 nn.MaxPool2d(2)             )         self.encoder = nn.Sequential(             conv_block(input_dim,64),             conv_block(64,64),             conv_block(64,64),             conv_block(64,64),         )     def forward(self, x):         x = self.encoder(x)         x = x.view(x.size(0), -1)         return x 

损失函数

关于计算损失的关键函数如下所示(参考了论文作者的源代码[6]):

def cal_euc_distance(self, query_z, center,K_way, N_query):     """         计算query_z与center的距离         query_z : (K_way*N_query,z_dim)         center : (K_way,z_dim)     """     center = center.unsqueeze(0).expand(         K_way*N_query, K_way, self.z_dim)  # (K_way*N_query,K_way,z_dim)     query_z = query_z.unsqueeze(1).expand(         K_way*N_query, K_way, self.z_dim)  # (K_way*N_query,K_way,z_dim)      return torch.pow(query_z-center, 2).sum(2)  # (K_way*N_query,K_way)  def loss_acc(self, query_z, center, K_way, N_query):     """         计算loss和acc         query_z : (K_way*N_query,z_dim)         center : (K_way,z_dim)     """     target_inds = torch.arange(0, K_way).view(K_way, 1).expand(         K_way, N_query).long().to(self.device) # shape=(K_way, N_query)          distance = self.cal_euc_distance(query_z, center,K_way, N_query)    # (K_way*N_query,K_way)      predict_label = torch.argmin(distance, dim=1)  # (K_way*N_query) 预测出来的label      acc = torch.eq(target_inds.contiguous().view(-1),                     predict_label).float().mean() # 准确率      loss = F.log_softmax(-distance, dim=1).view(K_way,                                                 N_query, K_way)  # (K_way,N_query,K_way)     loss = -          loss.gather(dim=2, index=target_inds.unsqueeze(2)).view(-1).mean()     return loss, acc  def set_forward_loss(self, K_way, N_shot, N_query,sample_datas):     """         sample_datas: shape(K_way*(N_shot+N_query),3,84,84)     """      z = self.cnn_net(sample_datas) # shape=(K_way*(N_shot+N_query),z_dim) ,将support set和query set都进行向量化表示     z = z.view(K_way,N_shot+N_query,-1) # shape = (K_way,N_shot+N_query,1600)          support_z = z[:,:N_shot] # support set的向量化表示 shape=(K_way,N_shot,1600)     query_z = z[:,N_shot:].contiguous().view(K_way*N_query,-1) # Query set的向量化表示 shape=(K_way*N_query,1600)          center = torch.mean(support_z, dim=1) # 计算support set的向量均值,shape=(K_way,1600)     return self.loss_acc(query_z, center,K_way,N_query) 

关于实验中具体的参数设计,可以参考论文[1:3]或者Github上面的源代码。在原论文中,对于实验的设计讲得非常清楚。

实验结果

下面的表格为测试集的acc(当验证集acc为最大值时测试集所对应的acc):

N-shot=1 N-shot=5
K_way=5 0.4313 0.6684

深度学习(五)之原型网络

总结

总的来说,原型网络是一个容易理解的网络模型,思想简单,易于实现。

References


  1. [1703.05175] Prototypical Networks for Few-shot Learning (arxiv.org) ↩︎ ↩︎ ↩︎ ↩︎

  2. 深度学习(二)之猫狗分类 - 段小辉 - 博客园 (cnblogs.com) ↩︎

  3. How many images do you need to train a neural network? « Pete Warden's blog ↩︎

  4. 刘颖, 雷研博, 范九伦, 王富平, 公衍超, 田奇. 基于小样本学习的图像分类技术综述. 自动化学报, 2021, 47(2): 297−315 ↩︎

  5. 【Pytorch】prototypical network原型网络小样本图像分类简述及其实现_Jnchin的博客-CSDN博客_原型网络小样本 ↩︎

  6. jakesnell/prototypical-networks: Code for the NeurIPS 2017 Paper "Prototypical Networks for Few-shot Learning" (github.com) ↩︎

发表评论

相关文章