Disclosure of Invention
Aiming at the prior art, the technical problem to be solved by the invention is to provide a cross-domain task deep learning identification method capable of improving the accuracy rate in the subject identification tasks based on different backgrounds.
In order to solve the technical problem, the cross-domain task deep learning identification method provided by the invention comprises the following steps of:
step one, generating a data set with main body characteristics irrelevant to background characteristics:
three sets of MNIST data sets are produced, each number in each set of data set only has one background color corresponding to the number, the same number in the three sets corresponds to different background colors and number main body colors, the same number main body color in each set of data set is the same, and different number main body colors are different;
step two, building a network model:
building a gating parameter enhanced network model, which comprises convolution layers, a maximum pooling layer and a full-connection layer, adding base layer regularization behind a first convolution layer of the network, adding a gating factor alpha at a hidden layer behind every two convolution layers, adding a BN layer behind each layer of the network, activating the last convolution layer by using a linear activation function, activating the rest by using a Relu function, classifying the last convolution layer by using softmax behind the full-connection layer, wherein the dimensionality of data is (B, C, W, H), B is a batch of size, C is a channel, H is height, and W is width;
inputting the three data sets into the network model built in the step two, respectively obtaining three outputs, and outputting and solving the FRP punishment item by the three models, specifically:
in the formula R
e(W · Φ) is the cross entropy loss, let W equal to 1.0 be a false classifier,
is R
eAnd (W · phi) deriving W, and then respectively calculating GRR terms of three environments, wherein the GRR terms are specifically as follows:
wherein D represents the data distribution, E [ (f)θ(X)-y)2]Represents the variance, EX~D(X|y=k)[(fθ(X)-μk)2]Is the mean square error, β is a GRR parameter and the larger the value of β, the better the inhibition ability on unstable characteristics;
the loss function uses cross entropy loss to calculate cross entropy loss of three data sets, and then adds the loss function, the GRR regression term and the FRP penalty term to generate a new learning normal form, which specifically comprises the following steps:
wherein p represents an FRP penalty coefficient;
step four, training and storing parameters:
training adopts a multi-scale training method, setting initial weight, learning rate, GRR parameter beta and FRP punishment item coefficient p of training, and storing training weight parameters every epoch;
and step five, inputting the sample to be recognized into the classifier trained in the step four and outputting a recognition result.
The invention has the beneficial effects that: compared with other existing methods (such as CLP, ALP, PGD and VIB), the CDI method provided by the invention can well inhibit the influence of the background on the subject identification, and the accuracy and stability are much higher than those of other existing methods.
Detailed Description
The invention is further described with reference to the drawings and the detailed description.
A cross-domain identification method (CDI) is provided for solving the problem of identification errors caused by replacing backgrounds, and the method estimates non-linearity and invariant causal predictor from a plurality of training environments, so that a model is predicted only according to the characteristics of a main body. Through experimental comparison, the performance of CDI on the task of identifying subjects based on different backgrounds is better than that of other methods.
The method comprises the following implementation steps:
step one, manufacturing a data set
Although the MNIST dataset image is grayscale, we color each handwritten digit in a way strongly (but falsely) associated with class labels, making a total of three datasets, each ensuring that each digit has only one background color corresponding to it, but the same digit in the three sets corresponds to a different background color. The data sets were made up to 60000 sheets with the training data set to the validation data set at a ratio of 9: 1. The data set is made in such a way that color can be removed as a prediction feature, and the fact that the correlation between the label and the color is stronger than that between the label and the number is avoided, so that a better generalization effect is obtained.
Step two, building a network model
A gated parameter enhanced network model (CPEN) was constructed, which included 24 convolutional layers, 4 max pooling layers, and 2 full-link layers. Base layer regularization (GRR) is added after the first convolutional layer of the network, a gating factor α is added after every two convolutional layers at the hidden layer, the initial value of α is set to 1.5, and a BN layer is added after each layer of the network. The network has two fully-connected layers, one of which is 4096 neurons, and the last fully-connected layer has its parameters set to 10. The activation function of the network only uses the current activation function at the last layer of the convolutional layer, and the rest uses the Relu function for activation. Sorting was done using softmax after the full connectivity layer. The dimensions of the data are (B, C, W, H), B-batch size (initial value set to 128), C-channel (initial value 3, i.e., RGB channel), H-height, W-width.
Step three, loss function
Inputting data of three environments into a CEPN network to respectively obtain three outputs, and solving an FRP penalty term by the three model outputs, wherein the penalty term is calculated by a formula (1):
in the formula R
e(W · Φ) is the cross entropy loss, and let W equal to 1.0 be a false classifier.
Is R
e(W.phi.) is derived from W. Then, GRR terms of three environments are calculated respectively, as formula (2):
wherein D represents the data distribution, E [ (f)θ(X)-y)2]Represents the variance, EX~D(X|y=k)[(fθ(X)-μk)2]Is the mean square error, β is the GRR parameter and the larger the value of β, the better the inhibition of the instability feature, as shown in fig. 4, showing the sensitivity of the calculation. Baseline regularization acts as a filter, suppressing the sensitivity of the two models to weakly correlated features (pi close to 0.5).
The loss function uses cross-entropy loss, which is also a solution to the three environment cross-entropy loss. And then adding the loss function, the GRR regression term and the FRP penalty term to generate a new learning normal form, such as the formula (3):
where p represents the FRP penalty coefficient. As shown in fig. 5, the test results comparing GRR only and frp (GF) added on GRR basis indicate that GF is more effective than GRR only.
Step four, training and saving parameters
The training adopts a multi-scale training method, the preprocessed image scales are randomly replaced every 10 batchs, and six training scales of 576 multiplied by 576,512 multiplied by 512, 448 multiplied by 448, 416 multiplied by 416 and 384 multiplied by 384,320 multiplied by 320 are provided, so that the generalization performance is increased. The initial weight of training is the pre-training weight of VGG16, the learning rate of the first 10 batchs during training is set to be 0.01, the convergence speed is accelerated, and the learning rate is fixed to be 0.0001 later so as to converge to the optimal result. During training, the GRR parameter (beta) and the FRP penalty term coefficient (p) are set to be 10 and 0.01 respectively. A total of 50 epochs are trained, and each Epoch holds a training weight parameter.
Step five, loading a plurality of cross-domain data sets for test evaluation
A color handwritten digit dataset (C-MNIST), a MNIST-M dataset made up of MNIST digits mixed with random color patches in the BSDS500 dataset, and a google street view house number dataset (SVHN), are commonly used cross-domain test datasets. Py was loaded with training weights for 50 epochs and three digital data sets were tested separately, with the test accuracy being represented by the predicted positive sample (i.e., the number of whatever color the background or subject is, was correctly recognized) divided by the total sample. To demonstrate that the method works not only on digital data sets, we also performed additional training and testing on CIFA-10.
The flow diagram of cross-domain recognition is shown in fig. 1, and the CDI algorithm is applied to the model learning module. A DPEN neural network model is designed aiming at cross-domain images, as shown in figure 2, the network is characterized in that a multi-scale training method is adopted to increase generalization capability, GRR is added to the first layer of the network, namely, the influence of the background on main body prediction is reduced, overfitting can be effectively prevented, the network can be controlled to input and output only parameters which are useful for identifying a main body by adding gating factors, and the output of each scroll base layer is multiplied by a parameter alpha, so that the main body identification effect of the cross-domain task is enhanced.
The generated color data set is a data set designed for three environments each containing 20000 pictures as shown in fig. 3(a) -3 (b). In each environment, each number corresponds to a foreground and background color, the color of the handwritten number picture has strong correlation with the label, but the correlation of the picture color and the label in different environments changes. And simultaneously extracting data from the three environments to generate an FRP penalty item, cross entropy loss and GRR loss of each environment. A new objective function is generated and finally the model is updated by minimizing the objective.
As shown in fig. 6, when the algorithm is improved, a processing method of adaptive batch normalization (AdaBN) is considered, and it is found that the stability of prediction is improved and the accuracy rate does not change abruptly. After AdaBN was turned off, the accuracy curve became unstable. In the comparative experiment, the FRP parameter p is fixed to 0.01.
Other methods exist where training accuracy can reach 95.6%, but performance is not ideal when testing on test data sets with shifted color distributions. As shown in fig. 7, the CDI method is compared with the log pairing method (CLP), the opposition pairing method (ALP), the projection gradient descent method (PGD), and the variation bottleneck method (VIB) to show that the CDI method has better effect in cross-domain task identification.
The test accuracy of the CDI method on the data sets of C-MNIST, SVHN, MNIST-M and CIFA-10 is shown in Table 1, which shows that the CDI method has good effect in cross-domain task identification and meets the expected requirements.
TABLE 1 test accuracy of CDI on different datasets
| Data set | C-MNIST | SVHN | MNIST-M | CIFA-10 |
| Rate of accuracy | 93.88 | 79.75 | 90.40 | 87.94 |