Disclosure of Invention
In order to overcome the defects of related products in the prior art, the invention provides a cross-domain small sample CT image semantic segmentation system and method based on meta-learning.
The invention provides a cross-domain small sample CT image semantic segmentation system based on meta-learning, which comprises the following steps:
and a data processing module: sampling from labeled source domain picture data, combining the labeled source domain picture data into a small sample segmentation task set, and sampling unlabeled pictures from a target domain as training data;
and the feature extraction module is used for: the middle layer characteristics and prototype characteristics of the source domain picture data are obtained by using a convolutional neural network, and the middle layer characteristics of the target domain picture are obtained by using the convolutional neural network;
a segmentation prediction module: calculating a segmentation result of the small sample segmentation task by using the cosine similarity;
and a loss calculation module: calculating segmentation loss by using the segmentation result and the real labeling data, calculating difference loss between the source domain features and the target domain features by using a maximum mean difference algorithm, performing domain alignment, calculating weighting loss and optimizing a model.
In some embodiments of the present invention, the data processing module is specifically configured to:
sampling from labeled source domain picture data, combining the labeled source domain picture data into a small sample segmentation task set which is used as model input of meta training, and completing the small sample segmentation task in a source domain; the small sample segmentation task for meta-training is:wherein C represents a small sample segmentation task set, N represents the sampled small sample segmentation task number, S represents a support set, Q represents a query set, and K samples with labels M, namely +.>Query set Q includes query picture Xq And a label M for calculating loss during trainingq ;
And adding unlabeled picture data from a target domain sampling part into each corresponding small sample task, wherein the unlabeled picture data are used for aligning data domains, namely at the moment:Ti and (5) representing unlabeled picture data added to the ith small sample segmentation task.
In some embodiments of the present invention, the feature extraction module is specifically configured to:
the middle layer characteristics of the source domain picture data are obtained, and the formulas are respectively as follows:
Fs =E(Xs ),Fq =E(Xq ),Ft =E(Xt ),
the target domain picture data is Xt Extracting a target class prototype and a background class prototype respectively through global label averaging pooling, wherein i is 1 when the target class prototype is extracted, i is 0 when the background class prototype is extracted, and the formula is as follows:
wherein P represents prototype features, x and y represent space coordinates respectively, the function delta is an indication function, 1 is taken when the parameters are true, and otherwise 0 is taken; if there are multiple support set samples, multiple prototypes are computed and then the prototypes are averaged.
In some embodiments of the present invention, the partition prediction module is specifically configured to:
the measurement segmentation of the prototype feature P and the picture feature F_q is completed through a measurement tool with or without parameters, the similarity of the prototype feature and the picture feature is measured by adopting cosine similarity based on the measurement tool without parameters, then the segmentation is completed, and the cosine similarity between the foreground prototype feature and the background prototype feature and the picture feature is calculated as follows:
wherein Pred represents a similarity value, and alpha is an adjustment multiplier;
obtaining a segmentation result through an argmax function, wherein the segmentation result is represented by the formula:
in some embodiments of the present invention, the loss calculation module is specifically configured to:
the small sample segmentation task of the supervision source domain, the loss calculation of which is based on a cross entropy loss function, is specifically as follows:
in the formula, middle layer characteristic Fq ∈Rc×h×w Middle layer feature Ft ∈Rc×h×w C is the number of channels, h and w are the height and width of the feature map respectively;
will middle layer feature Fq ,Ft The dimensions are respectively transformed as follows: f (F)q ∈Rhw×c ,Ft ∈Rhw×c The method comprises the steps of carrying out a first treatment on the surface of the And then from the middle layer characteristic Fq ,Ft Respectively extracting m and n channel feature vectors as data field samples to respectively obtain formulas:
and calculating the alignment loss by using the extracted channel characteristic vector and a maximum mean difference algorithm, wherein the alignment loss is as follows:
wherein k represents a kernel matrix obtained by calculation of a Gaussian kernel function;
the final weighted total loss is calculated as:
L=Lseg +βLmmd where β is the multiplier.
The invention also provides a cross-domain small sample CT image semantic segmentation method based on meta-learning, which is applied to any cross-domain small sample CT image semantic segmentation system based on meta-learning and comprises the following steps:
step S1: sampling from labeled source domain picture data, combining the labeled source domain picture data into a small sample segmentation task set, and sampling unlabeled pictures from a target domain as training data;
step S2: the middle layer characteristics and prototype characteristics of the source domain picture data are obtained by using a convolutional neural network, and the middle layer characteristics of the target domain picture are obtained by using the convolutional neural network;
step S3: calculating a segmentation result of the small sample segmentation task by using the cosine similarity;
step S4: calculating segmentation loss by using the segmentation result and the real labeling data, calculating difference loss between the source domain features and the target domain features by using a maximum mean difference algorithm, performing domain alignment, calculating weighting loss and optimizing a model.
In certain embodiments of the present invention, step S1 specifically includes:
sampling from labeled source domain picture data, combining the labeled source domain picture data into a small sample segmentation task set which is used as model input of meta training, and completing the small sample segmentation task in a source domain; small sample fraction for meta-trainingThe cutting task is as follows:wherein C represents a small sample segmentation task set, N represents the sampled small sample segmentation task number, S represents a support set, Q represents a query set, and K samples with labels M, namely +.>Query set Q includes query picture Xq And a label M for calculating loss during trainingq ;
And adding unlabeled picture data from a target domain sampling part into each corresponding small sample task, wherein the unlabeled picture data are used for aligning data domains, namely at the moment:Ti and (5) representing unlabeled picture data added to the ith small sample segmentation task.
In certain embodiments of the present invention, step S2 specifically includes:
the middle layer characteristics of the source domain picture data are obtained, and the formulas are respectively as follows:
Fs =E(Xs ),Fq =E(Xq ),Ft =E(Xt ),
the target domain picture data is Xt Extracting a target class prototype and a background class prototype respectively through global label averaging pooling, wherein i is 1 when the target class prototype is extracted, i is 0 when the background class prototype is extracted, and the formula is as follows:
wherein P represents prototype features, x and y represent space coordinates respectively, the function delta is an indication function, 1 is taken when the parameters are true, and otherwise 0 is taken; if there are multiple support set samples, multiple prototypes are computed and then the prototypes are averaged.
In certain embodiments of the present invention, step S3 specifically includes:
the measurement segmentation of the prototype feature P and the picture feature F_q is completed through a measurement tool with or without parameters, the similarity of the prototype feature and the picture feature is measured by adopting cosine similarity based on the measurement tool without parameters, then the segmentation is completed, and the cosine similarity between the foreground prototype feature and the background prototype feature and the picture feature is calculated as follows:
wherein Pred represents a similarity value, and alpha is an adjustment multiplier;
obtaining a segmentation result through an argmax function, wherein the segmentation result is represented by the formula:
in certain embodiments of the present invention, step S4 specifically includes:
the small sample segmentation task of the supervision source domain, the loss calculation of which is based on a cross entropy loss function, is specifically as follows:
in the formula, middle layer characteristic Fq ∈Rc×h×w Middle layer feature Ft ∈Rc×h×w C is the number of channels, h and w are the height and width of the feature map respectively;
will middle layer feature Fq ,Ft The dimensions are respectively transformed as follows: f (F)q ∈Rhw×c ,Ft ∈Rhw×c The method comprises the steps of carrying out a first treatment on the surface of the And then from the middle layer characteristic Fq ,Ft Respectively extracting m and n channel feature vectors as data field samples to respectively obtain formulas:
and calculating the alignment loss by using the extracted channel characteristic vector and a maximum mean difference algorithm, wherein the alignment loss is as follows:
wherein k represents a kernel matrix obtained by calculation of a Gaussian kernel function;
the final weighted total loss is calculated as:
L=Lseg +βLmmd where β is the multiplier.
Compared with the prior art, the invention has the following advantages:
compared with the prior art, the cross-domain small sample CT image semantic segmentation system based on meta-learning is used for realizing a small sample segmentation task on a source domain through a meta-learning algorithm, and a maximum mean difference algorithm is used for aligning the source domain and a CT image domain in a cross-domain manner; when cross-domain alignment is carried out, channel vectors of layer characteristics in a source domain and a target domain are randomly extracted based on the middle layer characteristics output by a backbone network, and the cross-domain alignment is carried out by using a maximum mean difference algorithm; when the small sample segmentation task in the source domain is completed, a backbone network is designed, middle layer characteristics are used as the output of the backbone network, and a prototype is used for replacing three-dimensional characteristics to perform measurement operation; the feature segmentation is performed by using a non-parametric cosine similarity function, and when a small sample segmentation model trained in a source domain is migrated to a target domain test, the prediction performance of the model is not greatly reduced, so that the problem that the number of marked CT medical images is insufficient to support meta-training in the source domain is solved, CT image semantic segmentation data can be provided in the medical domain, a target region is rapidly positioned, more accurate disease diagnosis and treatment decision are supported, and the method has a good application prospect.
Detailed Description
In order to enable those skilled in the art to better understand the present invention, the following description will make clear and complete descriptions of the technical solutions according to the embodiments of the present invention with reference to the accompanying drawings. It is apparent that the described embodiments are only some embodiments of the invention, but not all embodiments, and that the preferred embodiments of the invention are shown in the drawings. This invention may be embodied in many different forms and is not limited to the embodiments described herein, but rather is provided to provide a more thorough understanding of the present disclosure.
Referring to fig. 1, the cross-domain small sample CT image semantic segmentation system based on meta-learning includes:
the data processing module 100: the method comprises the steps of sampling from source domain picture data with labels, combining the source domain picture data with labels into a small sample segmentation task set, and sampling unlabeled pictures from a target domain as training data; the function of the data processing module comprises two parts of source domain data processing and target domain data processing.
The data processing module 100 samples from the labeled source domain picture data, combines the sample data into a small sample segmentation task set, and is used as the input of a model during meta training to obtain the general capability of completing the small sample segmentation task in the source domain; specifically, according to the N-way-K-shot setting, the small sample segmentation task for meta training is:wherein C represents a small sample segmentation task set, N represents the sampled small sample segmentation task number, S represents a support set, Q represents a query set, and K samples with labels M, namely +.>The query set Q comprises 1 query picture Xq and a label M for calculating loss during trainingq ;
The target domain data processing part, the data processing module 100 adds unlabeled picture data from the target domain sampling part into each corresponding small sample task, where the unlabeled picture data is used for aligning data domains, that is, at this time:Ti and (5) representing unlabeled picture data added to the ith small sample segmentation task.
Feature extraction module 200: the method comprises the steps of obtaining middle layer characteristics and prototype characteristics of source domain picture data by using a convolutional neural network, and obtaining middle layer characteristics of a target domain picture by using the convolutional neural network; the feature extraction module 200 is a deep neural network based on a convolution structure, and is used for extracting middle layer features and prototype features of picture data (support picture, query picture and target domain picture);
and extracting the middle layer characteristic F to obtain the middle layer characteristic of the source domain picture data, wherein the formulas are respectively as follows:
Fs =E(Xs ),Fq =X(Xq ),Ft =(Xt )#(6-1),
the target domain picture data is Xt Extracting a target category prototype (i takes 1) and a background category prototype (i takes 0) through global label average pooling, and then the formula is as follows:
wherein P represents prototype features, x and y in (x, y) represent spatial coordinates respectively, the function delta [ ] is an indication function, 1 is taken when the parameter is true, otherwise 0 is taken; if there are multiple support set samples, multiple prototypes are computed and then the prototypes are averaged.
Because the semantic information of the CT medical picture is not complex, the source domain data and the target domain data are in different domains and are not intersected in category, deep features easily cause the problem that the model is biased to the category of the source domain data, namely over-fitting. The middle layer features are often related to the public parts of the object, do not refer to specific categories, do not easily cause over-fitting problems, and are more easily generalized to the non-categories. And because noise is often introduced in the imaging process of the medical picture, the prototype is used for reducing the interference of the noise on the segmentation prediction, and the robustness is improved. In the embodiment of the invention, after the multi-task loss function is optimized, the feature extraction module can realize the following steps: extracting middle layer characteristics and prototype characteristics of the picture data; mapping the picture features of the support set and the query set to the same embedded space to facilitate feature measurement and segmentation; and aligning the distribution of the target domain to the distribution of the source domain, and completing the alignment of the domains.
Partition prediction module 300: the method comprises the steps of calculating a segmentation result of a small sample segmentation task by using cosine similarity; using a feature metric segmentation module to complete small sample segmentation prediction in a source domain; the measurement operation is performed by a feature measurement module, and measurement segmentation of the prototype feature P and the picture feature F_q is completed by a parameter or non-parameter measurement tool, such as convolution structure (parameter structure) or cosine similarity (non-parameter structure); in order to avoid the problems that the parametric measurement tool has the fitting problem and the result is interfered by noise, the embodiment of the invention measures the similarity of the prototype and the feature based on the cosine similarity of the parametric measurement tool, and then completes the segmentation, and the cosine similarity of the foreground prototype feature and the background prototype feature with the picture feature is calculated as follows:
wherein Pred represents a similarity value, and alpha is an adjustment multiplier;
obtaining a segmentation result through an argmax function, wherein the segmentation result is represented by the formula:
loss calculation module 400: the method is used for calculating the segmentation loss by using the segmentation result and the real labeling data, calculating the difference loss between the source domain feature and the target domain feature by using a maximum mean difference algorithm to perform domain alignment, calculating the weighting loss and optimizing the model.
The loss calculation module 400 performs a segmentation loss calculation and an alignment loss calculation, respectively.
In particular, segmentation loss is used to supervise the small sample segmentation task of the source domain, whose loss calculation is based on a cross entropy loss function, in particular:
the segmentation task is a dense classification task, and three-dimensional characteristics are required to be maintained in the whole process of the segmentation task, wherein the middle layer characteristics Fq ∈Rc×h×w Middle layer feature Ft ∈Rc×h×w C is the number of channels, h and w are the height and width of the feature map respectively; middle layer feature Fq ,Ft After multiple convolution operations, the channel characteristics (vectors) at each pixel point are abstract local characteristics, and can reflect the data distribution of different data domains. And the alignment process does not focus on the imageThe method has the advantages that the method can provide enough channel vector samples for fewer unlabeled target domain pictures, so that when the alignment loss is calculated, channel vectors at each pixel point of the middle layer feature are extracted to serve as one data domain sample, the cross-domain difference can be reflected, and the number requirement on the target domain samples can be effectively controlled.
The alignment loss calculation is based on a maximum mean difference algorithm, and is used for supervising the alignment tasks of a source domain and a target domain, and the middle layer characteristic F is firstly adoptedq ,Ft The dimensions are respectively transformed as follows: f (F)q ∈Rhw×c ,Ft ∈Rhw×c The method comprises the steps of carrying out a first treatment on the surface of the And then from the middle layer characteristic Fq ,Ft Respectively extracting m and n channel feature vectors as data field samples to respectively obtain formulas:
and calculating the alignment loss by using the extracted channel characteristic vector and a maximum mean difference algorithm, wherein the alignment loss is as follows:
wherein k represents a kernel matrix obtained by calculation of a Gaussian kernel function;
the final weighted total loss is calculated as:
L=Lseg +βLmmd # (6-8), where β is the multiplier.
In some embodiments of the present invention, referring to fig. 2, the system further includes a test flow module 500 for applying the semantic segmentation system, and during the test, no calculation loss is required, and therefore no alignment is required for the pictures of the input target domain, where the test flow module 500 is specifically configured to:
after training is completed and a well deployed system M, a small sample tests tasks { S }i ,Qi After } extracting the picture features F of the support set and the query set by using the formula 6-1s And Fq ;
Calculating a foreground prototype and a background prototype P by using a formula 6-2;
calculation of prototype P and feature F using equation 6-3q Cosine similarity Pred between them;
obtaining a segmentation result M 'of the query picture by using a formula 6-4'q 。
On the basis of the above embodiments, the present invention also provides the following specific embodiments, which are based on an actual dataset comprising two public datasets, the source (training) dataset being a PASCAL VOC 2012 and the target (test) dataset being a 2D CT lung lobe dataset from Lung Nodule Analysis (LUNA) accounting. The pasal VOC 2012 dataset is a natural dataset of color channels with visual aberrations; whereas the CT dataset is a black and white gray-scale dataset, although free of visual distortion, is often noisy due to the imaging process of the CT image, which is a combination of two cross-domain datasets with large inter-domain differences. The experiment was performed under the 1-way-1-shot setting, and details of implementation are described next.
First, the structure of the feature extraction module is determined. The structure is improved by a deep convolutional neural network ResNet50, and the composition is shown in the following table. On the basis of the ResNet50 network, the feature extraction module of the system does not use deep features (a sixth layer of the ResNet50 network), and the middle layer features (the third layer and the fourth layer features) are cascaded and then reduced in dimension to 512 dimension to be used as final feature output, so that the module can combine multi-level information. During training, pre-training parameters of the ResNet50 network are loaded to speed up training progress.
As shown in the following table, the feature extraction module structure in the implementation example is:
in the training and testing stage, the resolution of the input picture is 321×321 pixels, and the size of the feature map output by the feature extraction module is 1/8 of the size of the input image. During training, random clipping and overturning are carried out on training images. The entire system was trained using an SGD optimizer for 30000 iterations with a learning rate of 0.001, a batch size of 1, a momentum of 0.99, and a weight decay rate of 0.0005. The multiplier α in the formula 6-3 is fixed to 20, the kernel matrix k in the formula 6-7 is [2,5, 10, 20, 40, 80], and the weight multiplier β in the formula 6-8 is fixed to 1. When the channel vector of the middle layer is extracted, m and n are set to be 500 so as to save the video memory.
After training, the system tests on a small sample segmentation task selected randomly, the test results are shown in fig. 3, and lung lobe parts are marked with different colors in the pictures for visual comparison. Experimental results show that the system obtained by training the embodiment can realize accurate CT lung lobe image segmentation under the conditions of cross-domain and small samples.
Compared with the prior art, the cross-domain small sample CT image semantic segmentation system based on meta-learning is used for realizing a small sample segmentation task on a source domain through a meta-learning algorithm, and a maximum mean difference algorithm is used for aligning the source domain and a CT image (target) domain in a cross-domain manner; when cross-domain alignment is carried out, channel vectors of layer characteristics in a source domain and a target domain are randomly extracted based on the middle layer characteristics output by a backbone network, and the cross-domain alignment is carried out by using a maximum mean difference algorithm; when the small sample segmentation task in the source domain is completed, a backbone network is designed, middle layer characteristics are used as the output of the backbone network, and a prototype is used for replacing three-dimensional characteristics to perform measurement operation; the feature segmentation is performed by using a non-parametric cosine similarity function, and when a small sample segmentation model trained in a source domain is migrated to a target domain test, the prediction performance of the model is not greatly reduced, so that the problem that the number of marked CT medical images is insufficient to support meta-training in the source domain is solved, CT image semantic segmentation data can be provided in the medical domain, a target region is rapidly positioned, more accurate disease diagnosis and treatment decision are supported, and the method has a good application prospect.
Referring to fig. 4, the invention further provides a cross-domain small sample CT image semantic segmentation method based on meta-learning, which is applied to the cross-domain small sample CT image semantic segmentation system based on meta-learning in any of the above embodiments, and includes the following steps:
step S1: sampling from labeled source domain picture data, combining the labeled source domain picture data into a small sample segmentation task set, and sampling unlabeled pictures from a target domain as training data;
step S2: the middle layer characteristics and prototype characteristics of the source domain picture data are obtained by using a convolutional neural network, and the middle layer characteristics of the target domain picture are obtained by using the convolutional neural network;
step S3: calculating a segmentation result of the small sample segmentation task by using the cosine similarity;
step S4: calculating segmentation loss by using the segmentation result and the real labeling data, calculating difference loss between the source domain features and the target domain features by using a maximum mean difference algorithm, performing domain alignment, calculating weighting loss and optimizing a model.
In certain embodiments of the present invention, step S1 specifically includes:
sampling from labeled source domain picture data, combining the labeled source domain picture data into a small sample segmentation task set which is used as model input of meta training, and completing the small sample segmentation task in a source domain; the small sample segmentation task for meta-training is:wherein C represents a small sample segmentation task set, N represents the sampled small sample segmentation task number, S represents a support set, Q represents a query set, and K samples with labels M, namely +.>Query set Q includes query picture Xq And a label M for calculating loss during trainingq ;
And adding unlabeled picture data from a target domain sampling part into each corresponding small sample task, wherein the unlabeled picture data are used for aligning data domains, namely at the moment:Ti and (5) representing unlabeled picture data added to the ith small sample segmentation task.
In certain embodiments of the present invention, step S2 specifically includes:
the middle layer characteristics of the source domain picture data are obtained, and the formulas are respectively as follows:
Fs =E(Xs ),Fq =E(Xq ),Ft =E(Xt ),
the target domain picture data is Xt Extracting a target class prototype and a background class prototype respectively through global label averaging pooling, wherein i is 1 when the target class prototype is extracted, i is 0 when the background class prototype is extracted, and the formula is as follows:
wherein P represents prototype features, x and y represent space coordinates respectively, the function delta is an indication function, 1 is taken when the parameters are true, and otherwise 0 is taken; if there are multiple support set samples, multiple prototypes are computed and then the prototypes are averaged.
In certain embodiments of the present invention, step S3 specifically includes:
the measurement segmentation of the prototype feature P and the picture feature F_q is completed through a measurement tool with or without parameters, the similarity of the prototype feature and the picture feature is measured by adopting cosine similarity based on the measurement tool without parameters, then the segmentation is completed, and the cosine similarity between the foreground prototype feature and the background prototype feature and the picture feature is calculated as follows:
wherein Pred represents a similarity value, and alpha is an adjustment multiplier;
obtaining a segmentation result through an argmax function, wherein the segmentation result is represented by the formula:
in certain embodiments of the present invention, step S4 specifically includes:
the small sample segmentation task of the supervision source domain, the loss calculation of which is based on a cross entropy loss function, is specifically as follows:
in the formula, middle layer characteristic Fq ∈Rc×h×w Middle layer feature Ft ∈Rc×h×w C is the number of channels, h and w are the height and width of the feature map respectively;
will middle layer feature Fq ,Ft The dimensions are respectively transformed as follows: f (F)q ∈Rhw×c ,Ft ∈Rhw×c The method comprises the steps of carrying out a first treatment on the surface of the And then from the middle layer characteristic Fq ,Ft Respectively extracting m and n channel feature vectors as data field samples to respectively obtain formulas:
and calculating the alignment loss by using the extracted channel characteristic vector and a maximum mean difference algorithm, wherein the alignment loss is as follows:
wherein k represents a kernel matrix obtained by calculation of a Gaussian kernel function;
the final weighted total loss is calculated as:
L=Lseg +βLmmd where β is the multiplier.
What is not described in detail in this specification is prior art known to those skilled in the art. Although the present invention has been described in detail with reference to the foregoing embodiments, it will be apparent to those skilled in the art that the present invention may be modified or equivalents substituted for some of the features thereof. All equivalent structures made by the content of the specification and the drawings of the invention are directly or indirectly applied to other related technical fields, and are also within the scope of the invention.