Attention-based diagnosis and prediction method for bidirectional recurrent neural networkTechnical Field
The invention relates to the technical field of prediction diagnosis, in particular to a diagnosis prediction method of a bidirectional recurrent neural network based on attention.
Background
Electronic Health Records (EHRs) consist of longitudinal patient Health data including a sequence of visits over time, with each visit containing multiple medical codes, including diagnosis, medication and program codes, that have been successfully applied to several predictive modeling tasks in healthcare. EHR data consists of a set of high-dimensional clinical variables (i.e., medical specifications). One of the key tasks is to predict future diagnosis, i.e., diagnostic prediction, from past EHR data of a patient.
The time of visit for each patient and the medical code for each visit may be of different importance in predicting a diagnosis. Therefore, the most important and challenging problem in diagnostic prediction is 1. how to correctly model these temporal and high-order di-EHR data to significantly improve prediction performance; 2. how to reasonably explain the importance of the visit and medical norms in predicting the outcome.
Diagnosis and prediction are challenging and significant works, and accurate prediction of prediction results is a difficult and key problem of medical prediction models. Many existing diagnostic predictive efforts employ deep learning techniques, such as Recurrent Neural Networks (RNNs), to model EHR data in time and high dimensions. However, RNN-based approaches may not be able to fully remember all previous access information, resulting in erroneous predictions.
Disclosure of Invention
The present invention is directed to a diagnostic prediction method for bidirectional recurrent neural networks based on attention, so as to solve the aforementioned problems in the prior art.
In order to achieve the purpose, the technical scheme adopted by the invention is as follows:
a diagnostic prediction method of a bidirectional recurrent neural network based on attention comprises the following steps:
s1, constructing medical code x in electronic health medical record of patienttAs models of diagnosis, intervention codes, admission type and time lapse sequences; according to the visit information from time 1 to time t, the medical code x of the ith visit is codedi∈{0,1}|C|Embedding into vector representation viThe method comprises the following steps:
vi=ReLU(Wvxi+bc)
where | C | is the number of unique medical codes, m is the embedding dimension size, WvE m x C is the weight of the medical code, bcE m is the bias vector. ReLU is a rectifying linear unit defined as ReLU (v) ═ max (v,0), max () is applied to a vector in element order;
s2, constructing different bidirectional recurrent neural networks according to different attention mechanism states:
will vector viInput Bidirectional Recurrent Neural Network (BRNN), and output hidden state hiAs a representation of the ith access, the set of hidden states is noted
Wherein,in order to be in a forward-hidden state,the state is a backward hidden state;
s3, according to relative importance αtAndcomputing a context state vector ctThe following are:
α was obtained by the following softmax functiont:
αt=Softmax([αt1,αt2,…,αt(t-1)]).
Wherein the relative importance is calculated by an attention mechanism αtα inti;
S4, according to the context state vector ctAnd a current hidden state htA simple connection layer is used to combine the information from the two vectors to generate an attention hiding state vector, as follows:
wherein, WcE r x 4p is a weight matrix;
s5, hiding attentionHidden state vectorGenerating t +1 th access information by softmax layer feeding, defined as:
wherein,is the parameter to be learned and is,is the only number of categories.
Preferably, in S2, the calculating of relative importance by attention mechanism αtα intiThe method specifically comprises the following steps:
location based attention function, according to the current hidden state hiWeights are calculated separately, as follows:
wherein, Wα∈2p,Is the parameter to be learned;
or, attention is generally paid to the function, using a matrix WαE 2p x 2p to capture htAnd hiThe weights are calculated as follows:
or, based on the connected function, using a multi-layer perceptron MLPFirst connect the current hidden state hsAnd a previous state hiThen by multiplying by a weight matrix WαE q x 4q obtains the potential vector, q is the potential dimension, tanh is chosen as the activation function, noting that the weight vector is generated as follows:
wherein upsilon isαE q is the parameter to learn.
Preferably, after S5, the method further includes the step of interpreting the prediction, specifically:
using non-negative matricesRepresenting medical codes, then, carrying out reverse value arrangement on each dimension of the attention hiding state vector, and finally, selecting the first k codes with the maximum value to obtain clinical explanation of each dimension as follows:
wherein,to representColumn i or dimension.
The invention has the beneficial effects that: the invention provides a diagnostic prediction method of a bidirectional recurrent neural network based on attention, which is characterized in that firstly, high-dimensional medical codes (namely clinical variables) are embedded into a low code layer space, and then, the coded representation is input into the bidirectional recurrent neural network based on attention to generate a hidden state representation. The hidden representation is entered via the softmax layer to predict future access to the medical code. Experimental data show that by adopting the method provided by the embodiment, when future access information is predicted, different weights can be distributed to previous accesses by an attention mechanism, so that not only can a diagnosis and prediction task be effectively completed, but also a prediction result can be reasonably explained.
Drawings
FIG. 1 is a schematic flow chart of a diagnostic prediction method for an attention-based bidirectional recurrent neural network provided by the present invention;
FIG. 2 is a graph showing the results of attention mechanism analysis of five patients in the example embodiment.
Detailed Description
In order to make the objects, technical solutions and advantages of the present invention more apparent, the present invention is further described in detail below with reference to the accompanying drawings. It should be understood that the detailed description and specific examples, while indicating the invention, are intended for purposes of illustration only and are not intended to limit the scope of the invention.
As shown in fig. 1, the present invention provides a diagnostic prediction method for attention-based bidirectional recurrent neural network, comprising the following steps:
step 1) constructing a medical code x in an electronic health medical record of a patienttAs a model of the sequence of diagnosis, intervention code, admission type and passage of time;
step 2) constructing different bidirectional recurrent neural networks according to different attention mechanism states;
step 3) carrying out diagnosis prediction;
step 4) interpreting the prediction.
The detailed explanation about the above steps is as follows:
electronic Health Records (EHRs) for longitudinal patient HealthData is formed, and patient EHR data includes a sequence of visits over time, with each visit containing a plurality of medical codes, including diagnosis, medication and procedure codes. All unique medical codes from the EHR data are noted as c1,c2,…,c|C|E C, where | C | is the number of unique medical codes. Suppose there are N patients, the Nth patient has T in the EHR data(n)The secondary access record. The patient can be accessed by a series of visitsTo indicate. Each access ViComprising a subset of medical codesUsing binary vectors xt∈{0,1}|C|Is shown, wherein if VtContaining a code ciThen the ith element is 1. Each access VtWith corresponding coarse-grained class representationsWhereinIs the only number of categories. Each diagnostic code may map to a node of the international disease classification (ICD-9) and each process code may map to a node in the current process terminology. Due to the input of xtToo sparse and of high dimension, it is natural to learn its low dimension and meaningful embedding.
According to the visit information from time 1 to t, the medical code x of the ith visit can be codedi∈{0,1}|C|Embedding into vector representation viThe method comprises the following steps:
vi=ReLU(Wvxi+bc)
where m is the embedding dimension size, WvE m x C is the weight of the medical code, bcE m is the bias vector. ReLU is a rectified linearity defined as ReLU (v) ═ max (v,0)Element, where max () is applied to the vector in element order.
Will vector viInput Bidirectional Recurrent Neural Network (BRNN), and output hidden state hiAs a representation of the ith access, the set of hidden states is notedAccording to relative importance αtAndcompute context state ctThe following are:
an attention weight vector α is obtained by the following softmax functiont:
A BRNN contains a forward and backward Recurrent Neural Network (RNN). Forward directionFrom x1To xTReading an input access sequence and calculating a sequence of forward hidden states(And p is the dimension of the hidden state). Reverse directionReading the access sequence in reverse order, i.e. from xTTo x1Producing a series of hidden states backwardsForward hidden state over connectionAnd a backward hidden stateA final layer vector representation can be obtained
There are three attention mechanisms that can be used to calculate the relative importance αtα intiAnd capturing related information:
the location-based attention function is based on the current hidden state hiWeights are calculated separately, as follows:
wherein WαE 2p andare the parameters to be learned.
It is generally noted that the function is a function using a matrix WαE 2p x 2p to capture htAnd hiThe relationship between them, the weight is calculated:
another calculation αtiThe method of (2) is based on a connected function, using a multi-layer perceptron (MLP). First connect the current hidden state hsAnd a previous state hiThen by multiplying by a weight matrix WαE q 4q the potential vector can be obtained, q being the potential dimension. Tanh was chosen as the activation function. Note that the weight vector is generated as follows:
wherein upsilon isαE q is the parameter that needs to be learned.
Given the following vector ctAnd a current hidden state htA simple connected layer is used to combine the information from the two vectors to generate an attention hiding state, as follows:
wherein WcE r 4p is the weight matrix. Attention vectorGenerating t +1 th access information by softmax layer feeding, defined as:
whereinIs that the parameter to be learned uses real access information ytAnd predictive accessThe cross entropy of (c) to calculate the loss for all patients as follows
In healthcare, the interpretability of the learned medical code and access representation is very important, requiring understanding the clinical meaning of each dimension of the medically coded representation and analyzing which accesses are critical to prediction. Since the proposed model is based on attentionMechanistically, and therefore through analysis of attention scores, it is easy to find the importance of each visit to the prediction the tth prediction, if attention scores αtiVery large, then the probability prediction for the (i + 1) th access to relevant information is currently high. First using a non-negative matrixTo represent a medical code. Each dimension of the attention hiding state vector is then sorted in reverse order of value. Finally, the first k codes with the largest value are selected as follows:
whereinTo representColumn i or dimension. By analyzing the selected medical code, a clinical interpretation for each dimension can be obtained.
DETAILED DESCRIPTION OF EMBODIMENT (S) OF INVENTION
In order to illustrate the technical effect of the invention, the invention is verified by adopting a specific application example.
The experiment used two data sets: medical subsidy claims and diabetes claims. The medical assistance data set includes 2011 medical assistance applications. It contained data relating to 147,810 patients and 1,055,011 visits. The patients' visits were grouped by week, excluding patients with visits less than 2. The diabetes data set included medical assistance applications from 2012 and 2013, corresponding to patients diagnosed with diabetes (i.e., medical assistance members who had ICD-9 diagnostic code 250.xx in the application). It contains data relating to 22,820 patients who have been in 466,732 visits. The patients' visits were summarized in weeks, excluding patients with visits less than 5.
For each data set, the data set was randomly divided into training, validation and test sets at a ratio of 0.75:0.1: 0.15. The validation data set was used to determine the optimal values for the parameters, perform 100 iterations, and report the best performance of each method.
Experiment one:
the statistical data set is shown in table 1:
TABLE 1
Experiment two:
the following baseline model was performed: (1) med2Vec (model 1); (2) RETAIN (model 2); (3) RNN (model 3).
The following RNN-based prediction model is performed: (1) calculating RNN of relative importance with location-based attention functionl(model 4); (2) calculating RNN of relative importance using general attention functionsg(model 5); (3) calculating RNN of relative importance with connection-based attention functionc(model 6).
The following Dipole model was performed: (1) dipole without any attention mechanism-(model 7); (2) computing Dipole of relative importance with location-based attention functionl(model 8); (2) computing Dipole of relative importance with general attention functiong(model 9); (3) computing Dipole of relative importance with a connection-based attention functionc(model 10)
Results and analysis of the experiments
Table 2 shows the accuracy of all methods on the diabetes data set
TABLE 2
It can be observed in table 2 that since most medical codes are for diabetes, Med2Vec (model 2) can correctly learn the vector ratio representation on the diabetes data set. Thus, Med2Vec achieves the best results in the three baselines, for the medical data set, RETAIN (model 3) accuracy is better than Med2 Vec. The reason is that there are many diseases in the medical assistance data set and there are many more types of medical codes than in the diabetes data set. In this case, the attention mechanism may help RETAIN learn reasonable parameters to make the correct prediction.
The accuracy of RNN (model 1) is lowest in all methods of both datasets. This is because the prediction of RNNs depends mainly on recent access information. It cannot remember all past information. However, Retain and proposed RNN variant RNNl(model 4), RNNg(model 5) and RNNc(model 6) it is possible to take into full account all previous visit information, assign different attention scores to past visits, and achieve better performance when compared to RNNs.
Since most of the visits in the diabetes data set are related to diabetes, it is easy to predict the medical code of the next visit based on the past visit information. Retain uses a reverse temporal attention mechanism for prediction, which degrades prediction performance compared to methods using a temporal sequential attention mechanism. The performance of all three RNN variants was superior to that of RETAIN. However, RETAIN is more accurate than RNN variants because the data in the medical assistance data set is about different diseases. The use of a reverse time attention mechanism can help learn the correct access relationships.
Neither RNN nor proposed Dipole- (model 7) used any attention mechanism, but the accuracy of Dipole-was higher than RNN on both diabetic and medically-assisted datasets. The results show that modeling access information from both directions can improve prediction performance. Therefore, it is reasonable to use a bi-directional recurrent neural network for diagnostic prediction.
Proposed Dipolec(model 10) and Dipolel(model 8) best performed on the diabetes and the medical assistance data sets, respectively, which shows that modeling the visits from two directions and assigning different weights to each visit can improve the accuracy of the medical diagnosis prediction task. On the diabetes data set, DipolelAnd DipolecSuperior to all baseline and proposed RNN variants. On the medical assistance data set, Dipolel、DipolegAnd Dipolec were superior to baseline and RNN variants
FIG. 2 shows a case study predictive healthcare code at sixth visit (y)5) Based on previous diabetes dataset access. According to the hidden state h1,h2,h3The connection-based attention weights for the second through fifth visits are calculated. In fig. 2, the x-axis is the patient and the y-axis is the attention weight for each visit. In this case study, 5 patients were selected for the trial. It can be observed that the attention score learned by the attention mechanism is different for different patients. For the second patient in fig. 2, all diagnostic codes are listed in table 3:
TABLE 3
The weight α for the four visits by patient 2 was first obtained from fig. 2 as [0.2386,0.0824,0.2386,0.0824 ]. by analysis of this attention vector, it can be concluded that the medical codes at the second, fourth and fifth visits had a significant impact on the final prediction, as can be seen from table 3, the patient developed essential hypertension at the second and fourth visits and was diagnosed with diabetes at the fifth visit, therefore, the probability of the medical codes for diabetes and essential hypertension related diseases at the 6 th visit was higher.
In summary, Dipole can remedy the challenges of modeling EHR data and interpreting the predicted results. With a bi-directional recurrent neural network, Dipole can remember the learned hidden knowledge in past and future accesses. Three attention mechanisms can reasonably explain the prediction result. Experimental results on two large real EHR data sets demonstrate the effectiveness of this Dipole in diagnostic predictive tasks. Analysis shows that in predicting future access information, the attention mechanism may assign different weights to previous accesses.
By adopting the technical scheme disclosed by the invention, the following beneficial effects are obtained: the invention provides a diagnostic prediction method of a bidirectional recurrent neural network based on attention, which is characterized in that firstly, high-dimensional medical codes (namely clinical variables) are embedded into a low code layer space, and then, the coded representation is input into the bidirectional recurrent neural network based on attention to generate a hidden state representation. The hidden representation is entered via the softmax layer to predict future access to the medical code. Experimental data show that by adopting the method provided by the embodiment, when future access information is predicted, different weights can be distributed to previous accesses by an attention mechanism, so that not only can a diagnosis and prediction task be effectively completed, but also a prediction result can be reasonably explained.
The foregoing is only a preferred embodiment of the present invention, and it should be noted that, for those skilled in the art, various modifications and improvements can be made without departing from the principle of the present invention, and such modifications and improvements should also be considered within the scope of the present invention.