Movatterモバイル変換


[0]ホーム

URL:


CN114626481A - A multi-scale metric few-shot learning method based on class features - Google Patents

A multi-scale metric few-shot learning method based on class features
Download PDF

Info

Publication number
CN114626481A
CN114626481ACN202210314022.XACN202210314022ACN114626481ACN 114626481 ACN114626481 ACN 114626481ACN 202210314022 ACN202210314022 ACN 202210314022ACN 114626481 ACN114626481 ACN 114626481A
Authority
CN
China
Prior art keywords
class
sample
measurement
feature
few
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202210314022.XA
Other languages
Chinese (zh)
Other versions
CN114626481B (en
Inventor
吴磊
管林林
王晓敏
吴少智
龚海刚
刘明
陈坚武
单文煜
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Quzhou Haiyi Technology Co ltd
Yangtze River Delta Research Institute of UESTC Huzhou
Original Assignee
Quzhou Haiyi Technology Co ltd
Yangtze River Delta Research Institute of UESTC Huzhou
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Quzhou Haiyi Technology Co ltd, Yangtze River Delta Research Institute of UESTC HuzhoufiledCriticalQuzhou Haiyi Technology Co ltd
Priority to CN202210314022.XApriorityCriticalpatent/CN114626481B/en
Publication of CN114626481ApublicationCriticalpatent/CN114626481A/en
Application grantedgrantedCritical
Publication of CN114626481BpublicationCriticalpatent/CN114626481B/en
Activelegal-statusCriticalCurrent
Anticipated expirationlegal-statusCritical

Links

Images

Classifications

Landscapes

Abstract

The invention relates to a class feature-based multi-scale measurement and few-sample learning method, which comprises the following steps: s1, preprocessing data; s2, a characteristic embedding step; s3, class feature extraction: integrating a plurality of similar sample characteristics of the support set through a dynamic routing mechanism, and updating the weight vector of the input vector in an iterative mode to obtain similar integral characteristics; s4, multi-scale measurement: and (3) carrying out similarity measurement on the support set class characteristics and the query set samples by fusing three measurement criteria. The invention adopts a dynamic routing mechanism to generate class integral characteristics, and compared with a direct weighted average algorithm, the class integral characteristics obtained by the algorithm are more representative. In the measurement module, an attention mechanism is introduced into a measurement method of a parameter network, and the similarity between sample characteristics is jointly determined by combining the advantages and disadvantages of various measurement modes, so that a CFMMN network model with better expressive force is obtained.

Description

Translated fromChinese
一种基于类特征的多尺度度量少样本学习方法A multi-scale metric few-shot learning method based on class features

技术领域technical field

本发明涉及人工智能技术领域,尤其涉及一种基于类特征的多尺度度量少样本学习方法。The invention relates to the technical field of artificial intelligence, in particular to a multi-scale measurement few-sample learning method based on class features.

背景技术Background technique

针对不同领域的学习任务,需要根据任务的具体要求得到大量的标注数据样本,但这是十分不易的,需要耗费巨大的人力和财力代价。所以样本集数据资源少是目前深度学习中亟待解决的难点问题。人类具有从极少量样本中进行有效学习的能力,在此启发下,少样本学习的概念应运而生。少样本学习试图通过对先验知识的转化,快速地归纳有限的监督经验,通过类比方式模仿人类从少数例子中获取知识的能力。For learning tasks in different fields, it is necessary to obtain a large number of labeled data samples according to the specific requirements of the task, but this is very difficult and requires huge human and financial costs. Therefore, the lack of sample set data resources is a difficult problem to be solved urgently in deep learning. Humans have the ability to learn effectively from a very small number of samples, inspired by this, the concept of few-shot learning came into being. Few-shot learning attempts to quickly generalize limited supervised experience through the transformation of prior knowledge, imitating the ability of humans to acquire knowledge from a few examples by analogy.

在现有的少样本学习算法中,少样本学习模型分为三个大的类别:基于数据增强的方法、基于模型优化的方法和基于算法优化的方法。基于数据增强的少样本学习算法是从通过对实际问题分析后直接出发得到的一类解决方式,试图扩充当前缺失数据资源的数据集,从而提供丰富的有监督信息。基于模型优化的少样本学习算法模型的主要思想是通过对有限的有监督样本资源的高层语义信息进行充分挖掘之后,进一步缩小参数的优化空间,进而降低模型参数的优化难度。基于度量的少样本方法不仅仅只专注于如何得到更多的原始数据,而是从如何更好地利用好有限的样本数据的角度出发,巧妙地将少样本问题转化为寻找一个更准确的嵌入表达和一个更优的距离度量策略。基于算法优化的少样本学习方法试图在假设空间中探索更合适的搜索策略。基于算法优化的方法中许多模型选择探索一个更优秀的初始化参数模型,进而用更少的迭代次数得到更优的模型结果。Among existing few-shot learning algorithms, few-shot learning models fall into three broad categories: data augmentation-based methods, model-based optimization methods, and algorithm-based optimization methods. The few-shot learning algorithm based on data augmentation is a kind of solution directly obtained from the analysis of the actual problem, trying to expand the data set of the current missing data resources, thereby providing rich supervised information. The main idea of the few-shot learning algorithm model based on model optimization is to further reduce the optimization space of parameters by fully mining the high-level semantic information of limited supervised sample resources, thereby reducing the optimization difficulty of model parameters. The metric-based few-shot method does not only focus on how to get more original data, but from the perspective of how to make better use of limited sample data, it cleverly transforms the few-shot problem into finding a more accurate embedding expression and a more optimal distance metric strategy. Few-shot learning methods based on algorithmic optimization attempt to explore more suitable search strategies in the hypothesis space. Many models in algorithmic optimization-based methods choose to explore a better initialized parameter model, thereby obtaining better model results with fewer iterations.

在现有模型中,针对N-way K-shot(K>1)少样本任务,仅仅通过将同类样本的特征图进行简单加权或是求平均的方式得到该类别在高维空间的类整体特征。样本特征均经过前期的卷积神经网络训练得到,但由于卷积神经网络的空间局限性,特征图上某个位置的权重信息仅仅与映射到原图像位置的相关临近区域得到,并不能很好地融合整张图像的全部信息。所以当存在同一个类不同样本位置信息相差较大的情况时,会引入较大的噪声信息将原有的有效信息作用力消除,导致最后的图像分类结果不准确。In the existing model, for the N-way K-shot (K>1) few-sample task, the overall class features of the class in the high-dimensional space are obtained by simply weighting or averaging the feature maps of similar samples. . The sample features are obtained through the previous convolutional neural network training, but due to the spatial limitations of the convolutional neural network, the weight information of a certain position on the feature map is only obtained from the relevant adjacent area mapped to the original image position, which is not very good. to fuse all the information of the whole image. Therefore, when there is a large difference in the position information of different samples of the same class, large noise information will be introduced to eliminate the original effective information force, resulting in inaccurate final image classification results.

发明内容SUMMARY OF THE INVENTION

本发明的目的在于克服现有技术的缺点,提供了一种基于类特征的多尺度度量少样本学习方法,解决了现有少样本学习中存在类特征提取不准确的问题。The purpose of the present invention is to overcome the shortcomings of the prior art, to provide a multi-scale measurement few-sample learning method based on class features, and to solve the problem of inaccurate class feature extraction in the existing few-sample learning.

本发明的目的通过以下技术方案来实现:一种基于类特征的多尺度度量少样本学习方法,所述少样本学习方法包括:The object of the present invention is achieved by the following technical solutions: a multi-scale measurement few-sample learning method based on class features, the few-sample learning method includes:

S1、数据预处理步骤:通过随机固定角度的方式增强数据,以扩充数据量和为同一类增加不同角度的图像样本,并通过N-way K-shot方法得到支撑集和查询集;S1. Data preprocessing step: Enhance data by random fixed angle to expand data volume and add image samples from different angles for the same class, and obtain support set and query set by N-way K-shot method;

S2、特征嵌入步骤:通过特征嵌入网络

Figure BDA0003568283660000021
对支撑集和查询集中的样本xi进行嵌入后得到特征
Figure BDA0003568283660000022
S2, Feature Embedding Step: Embedding the network through features
Figure BDA0003568283660000021
Features are obtained after embedding the samplesxi in the support set and query set
Figure BDA0003568283660000022

S3、类特征提取步骤:通过动态路由机制融合支撑集同类的多个样本特征,并通过迭代的方式数输入向量的权重向量进行更新得到类整体特征;S3. Class feature extraction step: fuse multiple sample features of the same type in the support set through a dynamic routing mechanism, and update the weight vector of the input vector in an iterative manner to obtain the class overall feature;

S4、多尺度度量步骤:通过融合有参网络度量、余弦距离度量和欧式距离度量三种度量准则对支撑集类特征与查询集样本之间进行相似度度量。S4, multi-scale measurement step: measure the similarity between the support set class features and the query set samples by integrating three measurement criteria: the parametric network measurement, the cosine distance measurement and the Euclidean distance measurement.

所述通过N-way K-shot方法得到支撑集和查询集包括:The support set and query set obtained by the N-way K-shot method include:

从数据集中随机抽取N个类,每个类抽取k个样本作为支撑集,支撑集中的样本用于生成N个类的原型;N classes are randomly selected from the data set, k samples are selected from each class as the support set, and the samples in the support set are used to generate prototypes of N classes;

再从N个类剩余的样本中每类抽取k个样本作为查询集,查询集用于计算网络的准确率,一样模型性能。Then, k samples are extracted from each of the remaining samples of the N classes as a query set, and the query set is used to calculate the accuracy of the network, which is the same as the model performance.

所述类特征提取步骤的具体内容包括:The specific content of the class feature extraction step includes:

对特征嵌入步骤中得到的支撑集样本特征向量eij进行变换得到

Figure BDA0003568283660000023
其中,Ws、bs为转换矩阵和偏置项,Squash函数是一个非线性函数将向量压缩,使其长度在0到1之间对向量的长度进行归一化;Transform the support set sample feature vector eij obtained in the feature embedding step to obtain
Figure BDA0003568283660000023
Among them, Ws and bs are the transformation matrix and the bias term, and the Squash function is a nonlinear function that compresses the vector so that its length is between 0 and 1 to normalize the length of the vector;

通过迭代的方式对输入向量

Figure BDA0003568283660000024
的权重向量更新后得到类整体特征。Iterate over the input vector
Figure BDA0003568283660000024
The weight vector of is updated to obtain the overall feature of the class.

具体的迭代过程包括:The specific iterative process includes:

dij=softmax(bi)dij =softmax(bi )

Figure BDA0003568283660000025
Figure BDA0003568283660000025

Figure BDA0003568283660000026
Figure BDA0003568283660000026

Figure BDA0003568283660000027
Figure BDA0003568283660000027

其中,dij表示的是输入向量与输出类特征ci间的关联关系,bij的初始值为0,经过Softmax函数后就变为均匀分布,ci为输出的第i类支撑集样本的类特征。Among them, dij represents the relationship between the input vector and the output class feature ci , the initial value of bij is 0, and it becomes a uniform distribution after passing through the Softmax function, and ci is the output of the i-th support set sample. class features.

所述多尺度度量步骤具体包括:The multi-scale measurement step specifically includes:

根据所述特征嵌入步骤得到查询集样本特征eq和类特征提取步骤得到第i类支撑集的类特征ci,通过欧氏距离得到第i类支撑集样本与第q个查询样本间的匹配分数

Figure BDA0003568283660000031
Figure BDA0003568283660000032
According to the feature embedding step, the query set sample feature eq is obtained, and the class feature extraction step obtains the class feature ci of the ith support set, and the matching between the ith support set sample and the q th query sample is obtained through the Euclidean distance. Fraction
Figure BDA0003568283660000031
for
Figure BDA0003568283660000032

以及得到余弦相似度方法作为度量准则是的匹配分数

Figure BDA0003568283660000033
and get the matching score of the cosine similarity method as a metric
Figure BDA0003568283660000033

当度量方式为带有注意力机制的有参网络是,通过优化学习得到网络中具体的参数,进而得到匹配分数

Figure BDA0003568283660000034
其中C(.,.)为拼接级联函数,MAttention(.)表示带有注意力机制的度量准则,fφ表示带有激活函数的全连接网络;When the measurement method is a participant network with an attention mechanism, the specific parameters in the network are obtained through optimization learning, and then the matching score is obtained.
Figure BDA0003568283660000034
where C(.,.) is the splicing cascade function, MAttention (.) represents the metric with the attention mechanism, and fφ represents the fully connected network with the activation function;

选择三种度量方式相加的匹配得分最大的类别i作为该查询样本xq的类别标签

Figure BDA0003568283660000035
Select the category i with the largest matching score added by the three measures as the category label of the query sample xq
Figure BDA0003568283660000035

所述少样本学习方法还包括设置损失函数步骤;损失函数为带有间距的损失函数,其计算公式为

Figure BDA0003568283660000036
其中,m+表示间隔,α表示权重系数,1iq表示指示函数,riq表示查询样本与第i类支撑集样本的匹配得分;The few-sample learning method also includes the step of setting a loss function; the loss function is a loss function with a gap, and its calculation formula is
Figure BDA0003568283660000036
Among them, m+ represents the interval, α represents the weight coefficient, 1iq represents the indicator function, and riq represents the matching score between the query sample and the i-th support set sample;

损失函数计算公式表示了查询样本与其他所有类级别特征间的互相制约的结果,同类样本间产生向内的拉力和非同类样本间产生向外的推力,

Figure BDA0003568283660000037
表示了查询样本q与第i类支撑类特征之间的拉力,优化的目标就是减少同类样本间的距离;
Figure BDA0003568283660000038
Figure BDA0003568283660000039
约束了非同类样本间的最小距离不能小于阈值m+。The loss function calculation formula expresses the result of the mutual restriction between the query sample and all other class-level features, and there is an inward pulling force between similar samples and an outward pushing force between non-similar samples.
Figure BDA0003568283660000037
Represents the tension between the query sample q and the i-th type of support features, and the goal of optimization is to reduce the distance between similar samples;
Figure BDA0003568283660000038
Figure BDA0003568283660000039
It is restricted that the minimum distance between non-homogeneous samples cannot be less than the threshold m+ .

本发明具有以下优点:一种基于类特征的多尺度度量少样本学习方法,在先前基于度量的少样本学习的启发下,重点关注N-way K-shot(K>1)少样本分类任务,采用动态路由机制生成类整体特征,相比于直接加权平均的算法,通过该算法得到的类整体特征更具有代表性。在度量模块中,在有参网络的度量方法中引入了注意力机制,另外结合了多种度量方式的优劣,共同决定样本特征间相似度,从而得到了表现力更好的CFMMN网络模型。The present invention has the following advantages: a multi-scale metric few-shot learning method based on class features, inspired by the previous metric-based few-shot learning, focusing on N-way K-shot (K>1) few-shot classification tasks , the dynamic routing mechanism is used to generate the overall class features. Compared with the direct weighted average algorithm, the overall class features obtained by this algorithm are more representative. In the measurement module, the attention mechanism is introduced into the measurement method of the network with parameters, and the advantages and disadvantages of various measurement methods are combined to jointly determine the similarity between the sample features, so as to obtain a better expressive CFMMN network model.

附图说明Description of drawings

图1为基于类特征的多尺度度量少样本学习算法流程图;Figure 1 is a flowchart of a multi-scale metric few-shot learning algorithm based on class features;

图2为基于类特征的多尺度度量少样本学习CFMMN网络模型;Figure 2 is a CFMMN network model based on multi-scale measurement and few-shot learning based on class features;

图3为特征嵌入模块网络结构;Fig. 3 is the network structure of the feature embedding module;

图4为有参度量网络结构;Fig. 4 is a network structure with parameter measurement;

图5为少样本5-way 1-shot任务示例;Figure 5 is an example of a few-sample 5-way 1-shot task;

图6为Omniglot数据集实验结果对比图;Figure 6 is a comparison chart of the experimental results of the Omniglot dataset;

图7为mini-ImageNet数据集实验结果对比图。Figure 7 is a comparison chart of the experimental results of the mini-ImageNet dataset.

具体实施方式Detailed ways

为使本申请实施例的目的、技术方案和优点更加清楚,下面将结合本申请实施例中附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。通常在此处附图中描述和示出的本申请实施例的组件可以以各种不同的配置来布置和设计。因此,以下结合附图中提供的本申请的实施例的详细描述并非旨在限制要求保护的本申请的保护范围,而是仅仅表示本申请的选定实施例。基于本申请的实施例,本领域技术人员在没有做出创造性劳动的前提下所获得的所有其他实施例,都属于本申请保护的范围。下面结合附图对本发明做进一步的描述。In order to make the purposes, technical solutions and advantages of the embodiments of the present application more clear, the technical solutions in the embodiments of the present application will be clearly and completely described below with reference to the accompanying drawings in the embodiments of the present application. Obviously, the described embodiments are only It is a part of the embodiments of the present application, but not all of the embodiments. The components of the embodiments of the present application generally described and illustrated in the drawings herein may be arranged and designed in a variety of different configurations. Accordingly, the following detailed description of the embodiments of the present application provided in conjunction with the accompanying drawings is not intended to limit the scope of protection of the present application as claimed, but merely represents selected embodiments of the present application. Based on the embodiments of the present application, all other embodiments obtained by those skilled in the art without creative work fall within the protection scope of the present application. The present invention will be further described below with reference to the accompanying drawings.

本发明工作主要围绕少样本学习算法中的N-way K-shot(K>1)少样本分类任务,以及类特征选择与提取这一核心问题展开,另外综合了多种度量方式的优劣生成一个多尺度度量方法,最后选择少样本学习领域广泛使用的两个数据集:Omniglot数据集和mini-ImageNet数据集,运用本发明中的算法进行少样本图像分类测试。如图1所示为基于类特征的多尺度度量少样本学习算法流程图,如图2所示为该算法的网络模型。该方法主要步骤如下:The work of the present invention mainly focuses on the N-way K-shot (K>1) few-sample classification task in the few-sample learning algorithm, and the core problem of class feature selection and extraction. A multi-scale measurement method, and finally select two data sets widely used in the field of few-sample learning: Omniglot data set and mini-ImageNet data set, and use the algorithm in the present invention to perform a few-sample image classification test. Figure 1 shows the flow chart of the multi-scale metric few-shot learning algorithm based on class features, and Figure 2 shows the network model of the algorithm. The main steps of this method are as follows:

步骤一:数据预处理。由于少样本学习数据资源缺失的特殊性,首先采取最简单的随机旋转90°、180°和270°的方式增强数据,一方面能够扩充数据量,另一方面为同一个类增加不同角度的图像样本有利于测试模型类特征提取的有效性。另外,少样本图像分类模型的训练迭代其实是通过若干“任务”完成的,就需要先将图像数据生成任务再输入模型,这种训练方法又称为N-way,K-shot方法。先从数据集中随机抽取N个类,每个类抽取k个样本作为支撑集,支撑集中的样本用于生成N个类的原型,再从N个类剩余的样本中每类抽取k个样本作为查询集,查询集用于计算网络的准确率,以验证模型性能。每个任务都包含了少量的类,并且每类都包含少量的样本,这样的任务设置模拟了少样本图像分类的情景。Step 1: Data preprocessing. Due to the particularity of the lack of data resources for few-sample learning, the simplest way to randomly rotate 90°, 180° and 270° is to enhance the data first. On the one hand, the amount of data can be expanded, and on the other hand, images of different angles can be added to the same class. Samples are useful for testing the effectiveness of model-like feature extraction. In addition, the training iteration of the few-sample image classification model is actually completed through several "tasks", and the image data needs to be generated first and then input to the model. This training method is also called N-way, K-shot method. Firstly, N classes are randomly selected from the data set, and k samples are selected from each class as the support set. The samples in the support set are used to generate the prototypes of the N classes, and then k samples from each class are selected from the remaining samples of the N classes as the support set. Query set, the query set is used to calculate the accuracy of the network to verify the model performance. Each task contains a small number of classes, and each class contains a small number of samples, such a task setting simulates the scenario of few-shot image classification.

步骤二:特征嵌入。对于给定的有N个类每个类有k个样本的支撑集

Figure BDA0003568283660000041
Figure BDA0003568283660000042
和查询集
Figure BDA0003568283660000043
样本xi经过特征嵌入网络
Figure BDA0003568283660000044
后得到特征hi:Step 2: Feature embedding. For a given support set of N classes with k samples per class
Figure BDA0003568283660000041
Figure BDA0003568283660000042
and queryset
Figure BDA0003568283660000043
The samplexi passes through the feature embedding network
Figure BDA0003568283660000044
Then get the feature hi :

Figure BDA0003568283660000045
Figure BDA0003568283660000045

其中,嵌入模块具体网络

Figure BDA0003568283660000046
结构如图3所示。关系网络的嵌入模块由四个卷积块结构组成,每个卷积块结构包含卷积核为64*3*3的卷积层、一个Batch归一化层和一个ReLU层。其中第一个和第二个卷积块后紧跟着一个2*2的最大池化层以调整网络规格,后面两个卷积块后则未引入池化层。Among them, the specific network of the embedded module
Figure BDA0003568283660000046
The structure is shown in Figure 3. The embedding module of the relational network consists of four convolutional block structures, each of which contains a convolutional layer with a convolution kernel of 64*3*3, a Batch normalization layer and a ReLU layer. The first and second convolution blocks are followed by a 2*2 max pooling layer to adjust the network specifications, and no pooling layer is introduced after the last two convolution blocks.

步骤三:类特征提取。在完成对支撑集S和查询集Q进行特征提取之后,采取动态路由机制融合支撑集同类的多个样本特征。为了使该模型能够适应更多形式支撑集样本的任务输入,对经过步骤二后得到的支撑集样本特征向量eij先作如下变换:Step 3: Class feature extraction. After the feature extraction of the support set S and the query set Q is completed, a dynamic routing mechanism is adopted to fuse the features of multiple samples of the same type of the support set. In order to make the model adapt to the task input of more forms of support set samples, the following transformations are made to the support set sample feature vector eij obtained after step 2:

Figure BDA0003568283660000051
Figure BDA0003568283660000051

其中,Ws、bs为转换矩阵和偏置项,Squash函数的行为类似于sigmoid,它是一个非线性函数将向量压缩,使其长度在0到1之间对向量的长度进行归一化。对于任一向量si,Squash函数的计算如下:Among them, Ws , bs are the transformation matrix and the bias term. The behavior of the Squash function is similar to sigmoid. It is a nonlinear function that compresses the vector so that its length is between 0 and 1. Normalizes the length of the vector . For any vector si , the Squash function is computed as follows:

Figure BDA0003568283660000052
Figure BDA0003568283660000052

通过迭代的方式对输入向量

Figure BDA0003568283660000053
的权重向量更新后得到类整体特征,具体的迭代过程为:Iterate over the input vector
Figure BDA0003568283660000053
After updating the weight vector of , the overall characteristics of the class are obtained. The specific iterative process is:

dij=softmax(bi)dij =softmax(bi )

Figure BDA0003568283660000054
Figure BDA0003568283660000054

Figure BDA00035682836600000511
Figure BDA00035682836600000511

Figure BDA0003568283660000055
Figure BDA0003568283660000055

其中,dij表示的是输入向量与输出类特征ci间的关联关系,bij的初始值为0,经过Softmax函数后就变为均匀分布了,ci为输出的第i类支撑集样本的类特征。如果当前样本特征是属于某个类别的话,相似会更高,权重也将在下次迭代过程中变大,如果不是的话,权重向量就应该会更小。总的来说,通过多次迭代后,同一个类下的各个样本贡献程度会通过学习后变得不一样。迭代轮次结束后,也就会得到类级别上的特征,一般情况下3轮就可以完成。Among them, dij represents the relationship between the input vector and the output class feature ci , the initial value of bij is 0, and it becomes a uniform distribution after passing through the Softmax function, and ci is the output type i support set sample class features. If the current sample feature belongs to a certain category, the similarity will be higher, and the weight will become larger in the next iteration process, if not, the weight vector should be smaller. In general, after multiple iterations, the contribution levels of each sample under the same class will become different after learning. After the iteration round is over, the features at the class level will also be obtained, which can be completed in 3 rounds under normal circumstances.

步骤四:多尺度度量。得到支撑集中N个类别的类特征之后,就需要采取合适的方式进行支撑集类特征与查询样本之间的相似度度量。一般来说,特征相似度度量方法有余弦距离和欧氏距离两种。如图2的网络模型所示,本算法将融合有参网络度量、余弦距离度量和欧氏距离度量三种度量准则,具体为:。Step 4: Multi-scale measurement. After obtaining the class features of N categories in the support set, it is necessary to adopt an appropriate method to measure the similarity between the class features of the support set and the query samples. Generally speaking, there are two types of feature similarity measurement methods: cosine distance and Euclidean distance. As shown in the network model in Figure 2, this algorithm will integrate three metric criteria: parameter network metric, cosine distance metric and Euclidean distance metric, specifically:

经过前置特征嵌入网络得到查询样本特征为eq和经过动态路由模块得到第i类支撑集的类特征ci,那么通过欧氏距离得到第i类支撑集样本与第q个查询样本间的匹配分数

Figure BDA0003568283660000056
为:Through the pre-feature embedding network, the query sample feature is obtained as eq and the class feature ci of the ith support set is obtained through the dynamic routing module, then the Euclidean distance is used to obtain the relationship between the ith support set sample and the q th query sample. match score
Figure BDA0003568283660000056
for:

Figure BDA0003568283660000057
Figure BDA0003568283660000057

如果将余弦相似度方法作为度量准则,匹配分数

Figure BDA0003568283660000058
为:If the cosine similarity method is used as the metric, the matching score
Figure BDA0003568283660000058
for:

Figure BDA0003568283660000059
Figure BDA0003568283660000059

如图4所示,当度量方式为带有注意力机制的有参网络时需要通过优化学习得到网络中具体的参数,最终得到的匹配分数为:As shown in Figure 4, when the measurement method is a parameterized network with an attention mechanism, the specific parameters in the network need to be obtained through optimization learning, and the final matching score is:

Figure BDA00035682836600000510
Figure BDA00035682836600000510

其中C(.,.)为拼接级联函数,MAttention(.)表示带有注意力机制的度量准则,fφ表示带有激活函数的全连接网络。具体来说,注意力机制是将拼接后的特征矩阵P经过3个1x1卷积核后生成三个新的特征图A、B、C后,经过注意力层的计算方法如下式:where C(., .) is the concatenation cascade function, MAttention (.) represents the metric with attention mechanism, and fφ represents the fully connected network with activation function. Specifically, the attention mechanism is to generate three new feature maps A, B, and C after passing the spliced feature matrix P through three 1x1 convolution kernels. The calculation method of the attention layer is as follows:

Figure BDA0003568283660000061
Figure BDA0003568283660000061

H(A,B)=soft max(ATB)H(A, B)=soft max(ATB)

上述式子通过残差中和思想得到了带有注意力权值的特征图PAttentionOut,在网络中引入注意力不仅能够综合考查支撑集中各个类特征,还能够找到类特征与查询特征间更具有针对性的部分进行度量学习。The above formula obtains the feature map PAttentionOut with attention weights through the idea of residual neutralization. Introducing attention into the network can not only comprehensively examine the features of each class in the support set, but also find out the difference between class features and query features. Targeted part for metric learning.

综合以上三种度量方式的得到匹配分数

Figure BDA0003568283660000062
共同决定该少样本任务分类的最终结果,选择这三种度量方式相加的匹配得分最大的类别i作为该查询样本xq的类别标签:Combining the above three metrics to get the matching score
Figure BDA0003568283660000062
The final result of the few-shot task classification is jointly determined, and the category i with the largest matching score added by these three measurement methods is selected as the category label of the query sample xq :

Figure BDA0003568283660000063
Figure BDA0003568283660000063

步骤五:损失函数设计。损失函数在本发明中作为学习准则与优化问题相联系,即通过最小化损失函数求解来评估模型。本发明针对于CFMMN的少样本学习场景专门设计了一种带有间隔的损失函数计算方法:Step 5: Loss function design. The loss function is associated with the optimization problem in the present invention as a learning criterion, ie the model is evaluated by solving the minimization of the loss function. The present invention specially designs a loss function calculation method with interval for the few-sample learning scenario of CFMMN:

Figure BDA0003568283660000064
Figure BDA0003568283660000064

其中,m+表示间隔,α表示权重系数,1iq表示指示函数,riq表示查询样本与第i类支撑集样本的匹配得分。上式表示了查询样本与其他所有类级别特征间的互相制约的结果,同类样本间产生向内的拉力和非同类样本间产生向外的推力。第一项表示了查询样本q与第i类支撑类特征之间的拉力,优化的目标就是试图减少同类样本间的距离;式中第二项约束了非同类样本间的最小距离不能小于阈值m+Among them, m+ represents the interval, α represents the weight coefficient, 1iq represents the indicator function, and riq represents the matching score between the query sample and the i-th support set sample. The above formula expresses the result of the mutual restriction between the query sample and all other class-level features, and there is an inward pulling force between similar samples and an outward pushing force between non-homogeneous samples. The first term represents the tension between the query sample q and the i-th type of support feature. The goal of optimization is to try to reduce the distance between similar samples; the second term in the formula restricts the minimum distance between non-similar samples to not be less than the threshold m+ .

本发明中基于类特征的多尺度度量少样本学习网络模型是由任务模式展开训练和测试的,所以需要对原始数据集D采样构建任务。首先,将原始数据集D分割为训练数据集和测试数据集,分别对应少样本学习的训练和测试阶段。对训练数据集和测试数据集随机采样生成多个任务,其中单个任务又包含支撑样本集和查询样本集两个部分,该任务的查询样本标签一定是包含于支撑样本标签中的,也就是说该模型的目的是通过对大量训练任务的学习后,在测试任务中查询集样本的属于支撑集中哪一个类别标签。如果一个任务的支撑样本集有N个类,每个类有K个样本,就称此类任务为N-way K-shot任务,如下图5所示是一个典型的5-way 1-shot任务。The multi-scale measurement few-sample learning network model based on the class feature in the present invention is trained and tested by the task mode, so it is necessary to sample the original data set D to construct the task. First, the original dataset D is divided into training datasets and test datasets, which correspond to the training and testing phases of few-shot learning, respectively. Randomly sample the training data set and the test data set to generate multiple tasks, in which a single task contains two parts: the support sample set and the query sample set. The query sample label of the task must be included in the support sample label, that is to say The purpose of the model is to query which category label of the support set the sample belongs to in the test task after learning a large number of training tasks. If the supporting sample set of a task has N classes, and each class has K samples, such a task is called an N-way K-shot task. As shown in Figure 5 below, it is a typical 5-way 1-shot task. .

为了评估本发明中基于类特征的多尺度度量少样本学习CFMMN网络模型的性能,分别在Omniglot和mini-ImageNet数据集上进行了5-way 1-shot、5-way 5-shot、20-way1-shot和20-way 5-shot的实验,并与其他算法进行比较分析。本发明中采用训练集:测试集=8:2的模式,模型评估的标准为在测试集上查询样本标签的准确率Acc,下面列举出的MN、PN和RN网络模型的基线都是与本文相同的。In order to evaluate the performance of the class feature-based multi-scale metric few-shot learning CFMMN network model in the present invention, 5-way 1-shot, 5-way 5-shot, 20-way Experiments on way1-shot and 20-way 5-shot, and comparative analysis with other algorithms. In the present invention, the training set: test set=8:2 mode is adopted, and the standard of model evaluation is the accuracy rate Acc of querying the sample label on the test set. The baselines of the MN, PN and RN network models listed below are the same as those of this paper. identical.

CFMMN网络模型在Omniglot数据集上的实验结果如图6所示:5-way 1-shot任务的准确率达到99.34%±0.27%,5-way 5-shot任务的准确率达到99.55%±0.19%,与MN相比分别提升1.74%、1.25%,与PN相比分别提升2.04%、0.65%,与RN相比分别提升0.44%、0.51%;20-way 1-shot任务在RN基础上提升1.82%,20-way 5-shot任务在RN基础上提升0.51%。The experimental results of the CFMMN network model on the Omniglot dataset are shown in Figure 6: the accuracy of the 5-way 1-shot task reaches 99.34% ± 0.27%, and the accuracy of the 5-way 5-shot task reaches 99.55% ± 0.19% , compared with MN, increased by 1.74%, 1.25%, compared with PN, increased by 2.04%, 0.65%, respectively, compared with RN, increased by 0.44%, 0.51%; 20-way 1-shot task on the basis of RN increased by 1.82 %, the 20-way 5-shot task improves by 0.51% on the RN basis.

CFMMN网络模型在mini-ImageNet数据集上的实验结果如图7所示:5-way 1-shot任务和5-way 5-shot任务与RN相比分类准确率提升5.35%、6.74%。The experimental results of the CFMMN network model on the mini-ImageNet dataset are shown in Figure 7: Compared with the RN, the classification accuracy of the 5-way 1-shot task and the 5-way 5-shot task is improved by 5.35% and 6.74%.

以上所述仅是本发明的优选实施方式,应当理解本发明并非局限于本文所披露的形式,不应看作是对其他实施例的排除,而可用于各种其他组合、修改和环境,并能够在本文所述构想范围内,通过上述教导或相关领域的技术或知识进行改动。而本领域人员所进行的改动和变化不脱离本发明的精神和范围,则都应在本发明所附权利要求的保护范围内。The foregoing are only preferred embodiments of the present invention, and it should be understood that the present invention is not limited to the forms disclosed herein, and should not be construed as an exclusion of other embodiments, but may be used in various other combinations, modifications, and environments, and Modifications can be made within the scope of the concepts described herein, from the above teachings or from skill or knowledge in the relevant field. However, modifications and changes made by those skilled in the art do not depart from the spirit and scope of the present invention, and should all fall within the protection scope of the appended claims of the present invention.

Claims (6)

Translated fromChinese
1.一种基于类特征的多尺度度量少样本学习方法,其特征在于:所述少样本学习方法包括:1. A multi-scale measurement few-sample learning method based on class features, characterized in that: the few-sample learning method comprises:S1、数据预处理步骤:通过随机固定角度的方式增强数据,以扩充数据量和为同一类增加不同角度的图像样本,并通过N-way K-shot方法得到支撑集和查询集;S1. Data preprocessing step: Enhance data by random fixed angle to expand data volume and add image samples from different angles for the same class, and obtain support set and query set by N-way K-shot method;S2、特征嵌入步骤:通过特征嵌入网络
Figure FDA0003568283650000011
对支撑集和查询集中的样本xi进行嵌入后得到特征
Figure FDA0003568283650000012
S2, Feature Embedding Step: Embedding the network through features
Figure FDA0003568283650000011
Features are obtained after embedding the samplesxi in the support set and query set
Figure FDA0003568283650000012
S3、类特征提取步骤:通过动态路由机制融合支撑集同类的多个样本特征,并通过迭代的方式数输入向量的权重向量进行更新得到类整体特征;S3. Class feature extraction step: fuse multiple sample features of the same type in the support set through a dynamic routing mechanism, and update the weight vector of the input vector in an iterative manner to obtain the class overall feature;S4、多尺度度量步骤:通过融合有参网络度量、余弦距离度量和欧式距离度量三种度量准则对支撑集类特征与查询集样本之间进行相似度度量。S4, multi-scale measurement step: measure the similarity between the support set class features and the query set samples by integrating three measurement criteria: the parametric network measurement, the cosine distance measurement and the Euclidean distance measurement.2.根据权利要求1所述的一种基于类特征的多尺度度量少样本学习方法,其特征在于:所述通过N-way K-shot方法得到支撑集和查询集包括:2. a kind of multi-scale measurement few-sample learning method based on class feature according to claim 1, is characterized in that: described obtaining support set and query set by N-way K-shot method comprises:从数据集中随机抽取N个类,每个类抽取k个样本作为支撑集,支撑集中的样本用于生成N个类的原型;N classes are randomly selected from the data set, k samples are selected from each class as the support set, and the samples in the support set are used to generate prototypes of N classes;再从N个类剩余的样本中每类抽取k个样本作为查询集,查询集用于计算网络的准确率,一样模型性能。Then, k samples are extracted from each of the remaining samples of the N classes as a query set, and the query set is used to calculate the accuracy of the network, which is the same as the model performance.3.根据权利要求1所述的一种基于类特征的多尺度度量少样本学习方法,其特征在于:所述类特征提取步骤的具体内容包括:3. A kind of multi-scale measurement few-sample learning method based on class feature according to claim 1, is characterized in that: the specific content of described class feature extraction step comprises:对特征嵌入步骤中得到的支撑集样本特征向量eij进行变换得到
Figure FDA0003568283650000013
其中,Ws、bs为转换矩阵和偏置项,Squash函数是一个非线性函数将向量压缩,使其长度在0到1之间对向量的长度进行归一化;
Transform the support set sample feature vector eij obtained in the feature embedding step to obtain
Figure FDA0003568283650000013
Among them, Ws and bs are the transformation matrix and the bias term, and the Squash function is a nonlinear function that compresses the vector so that its length is between 0 and 1 to normalize the length of the vector;
通过迭代的方式对输入向量
Figure FDA0003568283650000014
的权重向量更新后得到类整体特征。
Iterate over the input vector
Figure FDA0003568283650000014
The weight vector of is updated to obtain the overall feature of the class.
4.根据权利要求3所述的一种基于类特征的多尺度度量少样本学习方法,其特征在于:具体的迭代过程包括:4. a kind of multi-scale measurement few-sample learning method based on class feature according to claim 3, is characterized in that: the specific iterative process comprises:dij=softmax(bi)dij =softmax(bi )
Figure FDA0003568283650000015
Figure FDA0003568283650000015
Figure FDA0003568283650000016
Figure FDA0003568283650000016
Figure FDA0003568283650000017
Figure FDA0003568283650000017
其中,dij表示的是输入向量与输出类特征ci间的关联关系,bij的初始值为0,经过Softmax 函数后就变为均匀分布,ci为输出的第i类支撑集样本的类特征。Among them, dij represents the relationship between the input vector and the output class feature ci , the initial value of bij is 0, and it becomes a uniform distribution after the Softmax function, and ci is the output of the i-th support set sample. class features.
5.根据权利要求1所述的一种基于类特征的多尺度度量少样本学习方法,其特征在于:所述多尺度度量步骤具体包括:5. A class feature-based multi-scale measurement few-sample learning method according to claim 1, wherein the multi-scale measurement step specifically comprises:根据所述特征嵌入步骤得到查询集样本特征eq和类特征提取步骤得到第i类支撑集的类特征ci,通过欧氏距离得到第i类支撑集样本与第q个查询样本间的匹配分数
Figure FDA0003568283650000021
According to the feature embedding step, the query set sample feature eq is obtained, and the class feature extraction step obtains the class feature ci of the ith support set, and the matching between the ith support set sample and the q th query sample is obtained through the Euclidean distance. Fraction
Figure FDA0003568283650000021
for
Figure FDA0003568283650000022
Figure FDA0003568283650000022
以及得到余弦相似度方法作为度量准则是的匹配分数
Figure FDA0003568283650000023
and get the matching score of the cosine similarity method as a metric
Figure FDA0003568283650000023
当度量方式为带有注意力机制的有参网络是,通过优化学习得到网络中具体的参数,进而得到匹配分数
Figure FDA0003568283650000024
其中C(.,.)为拼接级联函数,MAttention(.)表示带有注意力机制的度量准则,fφ表示带有激活函数的全连接网络;
When the measurement method is a participant network with an attention mechanism, the specific parameters in the network are obtained through optimization learning, and then the matching score is obtained.
Figure FDA0003568283650000024
where C(.,.) is the splicing cascade function, MAttention (.) represents the metric with the attention mechanism, and fφ represents the fully connected network with the activation function;
选择三种度量方式相加的匹配得分最大的类别i作为该查询样本xq的类别标签Select the category i with the largest matching score added by the three measures as the category label of the query sample xq
Figure 1
Figure 1
.
6.根据权利要求1-5中任意一项所述的一种基于类特征的多尺度度量少样本学习方法,其特征在于:所述少样本学习方法还包括设置损失函数步骤;损失函数为带有间距的损失函数,其计算公式为
Figure FDA0003568283650000026
其中,m+表示间隔,α表示权重系数,1iq表示指示函数,riq表示查询样本与第i类支撑集样本的匹配得分;
6. A multi-scale measurement few-sample learning method based on class features according to any one of claims 1-5, characterized in that: the few-sample learning method further comprises the step of setting a loss function; the loss function is The loss function with spacing, which is calculated as
Figure FDA0003568283650000026
Among them, m+ represents the interval, α represents the weight coefficient, 1iq represents the indicator function, and riq represents the matching score between the query sample and the i-th support set sample;
损失函数计算公式表示了查询样本与其他所有类级别特征间的互相制约的结果,同类样本间产生向内的拉力和非同类样本间产生向外的推力,
Figure FDA0003568283650000027
表示了查询样本q与第i类支撑类特征之间的拉力,优化的目标就是减少同类样本间的距离;
Figure FDA0003568283650000028
Figure FDA0003568283650000029
约束了非同类样本间的最小距离不能小于阈值m+
The loss function calculation formula expresses the result of the mutual restriction between the query sample and all other class-level features, and there is an inward pulling force between similar samples and an outward pushing force between non-similar samples.
Figure FDA0003568283650000027
Represents the tension between the query sample q and the i-th type of support features, and the goal of optimization is to reduce the distance between similar samples;
Figure FDA0003568283650000028
Figure FDA0003568283650000029
It is restricted that the minimum distance between non-homogeneous samples cannot be less than the threshold m+ .
CN202210314022.XA2022-03-282022-03-28 A multi-scale metric few-shot learning method based on class featuresActiveCN114626481B (en)

Priority Applications (1)

Application NumberPriority DateFiling DateTitle
CN202210314022.XACN114626481B (en)2022-03-282022-03-28 A multi-scale metric few-shot learning method based on class features

Applications Claiming Priority (1)

Application NumberPriority DateFiling DateTitle
CN202210314022.XACN114626481B (en)2022-03-282022-03-28 A multi-scale metric few-shot learning method based on class features

Publications (2)

Publication NumberPublication Date
CN114626481Atrue CN114626481A (en)2022-06-14
CN114626481B CN114626481B (en)2025-04-18

Family

ID=81903889

Family Applications (1)

Application NumberTitlePriority DateFiling Date
CN202210314022.XAActiveCN114626481B (en)2022-03-282022-03-28 A multi-scale metric few-shot learning method based on class features

Country Status (1)

CountryLink
CN (1)CN114626481B (en)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication numberPriority datePublication dateAssigneeTitle
CN118898187A (en)*2024-10-082024-11-05江苏金卓新材料科技有限公司 A finite element simulation method and system for predicting the compacting performance of metal welding powder

Citations (4)

* Cited by examiner, † Cited by third party
Publication numberPriority datePublication dateAssigneeTitle
CN109685135A (en)*2018-12-212019-04-26电子科技大学A kind of few sample image classification method based on modified metric learning
US20200285896A1 (en)*2019-03-092020-09-10Tongji UniversityMethod for person re-identification based on deep model with multi-loss fusion training strategy
CN111985581A (en)*2020-09-092020-11-24福州大学Sample-level attention network-based few-sample learning method
US11205098B1 (en)*2021-02-232021-12-21Institute Of Automation, Chinese Academy Of SciencesSingle-stage small-sample-object detection method based on decoupled metric

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication numberPriority datePublication dateAssigneeTitle
CN109685135A (en)*2018-12-212019-04-26电子科技大学A kind of few sample image classification method based on modified metric learning
US20200285896A1 (en)*2019-03-092020-09-10Tongji UniversityMethod for person re-identification based on deep model with multi-loss fusion training strategy
CN111985581A (en)*2020-09-092020-11-24福州大学Sample-level attention network-based few-sample learning method
US11205098B1 (en)*2021-02-232021-12-21Institute Of Automation, Chinese Academy Of SciencesSingle-stage small-sample-object detection method based on decoupled metric

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
年福东;束建华;吕刚;: "基于自适应特征比较的少样本学习算法", 西安文理学院学报(自然科学版), no. 04, 15 October 2020 (2020-10-15)*
张子振;刘明;朱德江;: "融合注意力机制和高效网络的糖尿病视网膜病变识别与分类", 中国图象图形学报, no. 08, 12 August 2020 (2020-08-12)*

Cited By (2)

* Cited by examiner, † Cited by third party
Publication numberPriority datePublication dateAssigneeTitle
CN118898187A (en)*2024-10-082024-11-05江苏金卓新材料科技有限公司 A finite element simulation method and system for predicting the compacting performance of metal welding powder
CN118898187B (en)*2024-10-082024-12-06江苏金卓新材料科技有限公司 A finite element simulation method and system for predicting the compacting performance of metal welding powder

Also Published As

Publication numberPublication date
CN114626481B (en)2025-04-18

Similar Documents

PublicationPublication DateTitle
CN110263227B (en)Group partner discovery method and system based on graph neural network
CN104732240B (en)A kind of Hyperspectral imaging band selection method using neural network sensitivity analysis
CN111931505A (en)Cross-language entity alignment method based on subgraph embedding
CN114239826B (en)Neural network pruning method, medium and electronic equipment
CN112232395A (en)Semi-supervised image classification method for generating confrontation network based on joint training
Huang et al.Solution Path for Pin-SVM Classifiers With Positive and Negative $\tau $ Values
CN101901251B (en) Cluster Structure Analysis and Identification Method of Complex Network Based on Markov Process Metastability
CN113712511B (en) A Stable Pattern Discrimination Method for Brain Imaging Fusion Features
CN112528873A (en)Signal semantic recognition method based on multi-stage semantic representation and semantic calculation
CN107169566A (en)Dynamic neural network model training method and device
CN112836629A (en)Image classification method
Chen et al.A novel localized and second order feature coding network for image recognition
CN117349743A (en) A data classification method and system based on multi-modal data hypergraph neural network
CN113052130A (en)Hyperspectral image classification method based on depth residual error network and edge protection filtering
CN118968164A (en) Graph anomaly detection method using graph data enhancement effect under random configuration network
CN117292249A (en)Underwater sonar image open set classification method, system, equipment and medium
CN110991603B (en)Local robustness verification method of neural network
CN112836763A (en) A kind of graph structure data classification method and apparatus
CN114626481A (en) A multi-scale metric few-shot learning method based on class features
CN112541530B (en)Data preprocessing method and device for clustering model
de Sá et al.A novel approach to estimated Boulingand-Minkowski fractal dimension from complex networks
CN113836260A (en)Total nitrogen content prediction method based on deep learning of knowledge enhancement
CN119167146A (en) A method and device for automatic modulation recognition of small sample signals
CN115631388B (en)Image classification method and device, electronic equipment and storage medium
CN112836007A (en) A Relational Meta-Learning Method Based on Contextual Attention Network

Legal Events

DateCodeTitleDescription
PB01Publication
PB01Publication
SE01Entry into force of request for substantive examination
SE01Entry into force of request for substantive examination
GR01Patent grant
GR01Patent grant

[8]ページ先頭

©2009-2025 Movatter.jp