Disclosure of Invention
The present invention is directed to a diagnostic prediction method for bidirectional recurrent neural networks based on attention, so as to solve the aforementioned problems in the prior art.
In order to achieve the purpose, the technical scheme adopted by the invention is as follows:
a diagnostic prediction method of a bidirectional recurrent neural network based on attention comprises the following steps:
s1, constructing a medical code x in the electronic health medical record of the patientt As models of diagnosis, intervention codes, admission type and time lapse sequences; according to the visit information fromtime 1 to t, the medical code x of the ith visit is codedi ∈{0,1}|C| Embedding into a vector representation vi The method comprises the following steps:
vi =ReLU(Wv xi +bc )
where | C | is the number of unique medical codes, m is the embedding dimension size, Wv ∈Rm×|C| Is the weight of the medical code, bc ∈Rm Is a bias vector. ReLU is a rectifying linear unit defined as ReLU (v) = max (v, 0), max () is applied to a vector in element order;
s2, constructing different bidirectional recurrent neural networks according to different attention mechanism states:
will vector v
i Input Bidirectional Recurrent Neural Network (BRNN), and output hidden state h
i As a representation of the ith access, the set of hidden states is denoted as
Wherein,
in order to be in a forward-hidden state,
the state is a backward hidden state;
s3, according to relative importance alpha
t And
computing a context state vector c
t The following:
alpha is obtained by the following softmax functiont :
αt =Softmax([αt1 ,αt2 ,...,αt(t-1) ]).
Wherein the relative importance α is calculated by an attention mechanismt Alpha in (A)ti ;
S4, according to the context state vector ct And hidden state h of the tth accesst A simple connection layer is used to combine the information from the two vectors to generate an attention hiding state vector, as follows:
wherein, Wc ∈Rr×4p Is a weight matrix;
s5, hiding attention into the state vector
Generating t +1 th access information by softmax layer feeding, defined as:
wherein,
is the parameter to be learned and is,
is a unique number of categories.
Preferably, in S2, the calculation of the relative importance α by the attention mechanism is describedt Alpha in (A)ti The method specifically comprises the following steps:
location based attention function, according to the current hidden state hi Weights are calculated separately, as follows:
wherein, Wα ∈R2p ,bα The epsilon R is a parameter to be learned;
or, attention is generally paid to functions, using a matrix Wα ∈R2p×2p To capture ht And hi The weights are calculated as follows:
or, using a multi-layered perceptron MLP, the hidden state h of the tth access, based on the connected functiont And hidden state h of ith accessi Then by multiplying by a weight matrix Wα ∈Rq×4q Obtaining potential vectors, q being the potential dimensions, selecting tanh as the activation function, noting that the weight vectors are generated as follows:
wherein upsilon isα ∈Rq Are the parameters to be learned.
Preferably, after S5, a step of interpreting the prediction is further included, specifically:
using non-negative matrices
Representing medical codes, then arranging each dimension of the attention hiding state vector in a reverse order according to values, and finally selecting the top k codes with the maximum values to obtain clinical explanation of each dimension as follows:
wherein,
to represent
Column i or dimension.
The beneficial effects of the invention are: the invention provides a diagnostic prediction method of an attention-based bidirectional recurrent neural network, which is characterized in that a high-dimensional medical code (namely a clinical variable) is embedded into a low-code layer space, and then a code representation is input into the attention-based bidirectional recurrent neural network to generate a hidden state representation. The hidden representation is entered via the softmax layer to predict future access to the medical code. Experimental data show that by adopting the method provided by the embodiment, when the future access information is predicted, different weights can be distributed to the previous access by the attention mechanism, so that the diagnosis and prediction task can be effectively completed, and the prediction result can be reasonably explained.
Detailed Description
In order to make the objects, technical solutions and advantages of the present invention more apparent, the present invention is further described in detail below with reference to the accompanying drawings. It should be understood that the detailed description and specific examples, while indicating the invention, are intended for purposes of illustration only and are not intended to limit the scope of the invention.
As shown in fig. 1, the present invention provides a diagnostic prediction method for attention-based bidirectional recurrent neural network, comprising the following steps:
step 1) constructing a medical code x in an electronic health medical record of a patientt As a diagnostic,A model of the sequence of intervention codes, admission types and time lapses;
step 2) constructing different bidirectional recurrent neural networks according to different attention mechanism states;
step 3) carrying out diagnosis prediction;
step 4) interpreting the prediction.
The detailed explanation about the above steps is as follows:
electronic Health Records (EHRs) consist of longitudinal patient Health data including a sequence of visits over time, with each visit containing a number of medical codes, including diagnosis, medication and program codes. All unique medical codes from the EHR data are noted as c
1 ,c
2 ,…,c
|C| E C, where | C | is the number of unique medical codes. Assuming there are N patients, the Nth patient has T in the EHR data
(n) The secondary access record. The patient can be accessed by a series of visits
To indicate. Each access V
i Comprising a subset of medical codes
Using binary vectors x
t ∈{0,1}
|C| Is shown, wherein if V
t Containing a code c
i Then the ith element is 1. Each access V
t With corresponding coarse-grained class representations
Wherein
Is the only number of categories. Each diagnostic code may map to a node of the international disease classification (ICD-9) and each process code may map to a node in the current process terminology. Due to the input of x
t Too sparse and of high dimension, it is natural to learn its low dimension and meaningful embedding.
According to the timeAccess information from 1 to t, the medical code x of the ith access can be encodedi ∈{0,1}|C| Embedding into vector representation vi The method comprises the following steps:
vi =ReLU(Wv xi +bc )
where m is the embedding dimension size, Wv ∈Rm×|C| Is the weight of the medical code, bc ∈Rm Is a bias vector. ReLU is a rectifying linear unit defined as ReLU (v) = max (v, 0), where max () is applied to a vector in element order.
Will vector v
i Input Bidirectional Recurrent Neural Network (BRNN), and output hidden state h
i As a representation of the ith access, the set of hidden states is denoted as
According to relative importance α
t And
compute context state c
t The following are:
an attention weight vector alpha is obtained by the following softmax functiont :
A BRNN contains a forward and backward Recurrent Neural Network (RNN). Forward direction of rotation
From x
1 To x
T Reading an input access sequence and calculating a sequence of forward hidden states
(
And p is the dimension of the hidden state). Reverse direction
Reading the access sequence in reverse order, i.e. from x
T To x
1 Generating a series of hidden states backwards
Forward hidden state over connection
And backward hidden state
A final layer vector representation can be obtained
There are three attention mechanisms that can be used to calculate the relative importance αt Alpha in (1)ti And capturing related information:
the location-based attention function is based on the current hidden state hi Weights are calculated separately, as follows:
wherein Wα ∈R2p And bα e.R is the parameter to learn.
It is generally noted that the function is to use a matrix Wα ∈R2p×2p To capture ht And hi The relation between the weight:
another calculation of alphati The method of (4) is based on a connected function, using a multi-layer perceptron (MLP). First connect the hidden state h of the tth accesst And hidden state h of ith accessi Then by multiplying by a weight matrix Wα ∈Rq×4q Potential vectors can be obtained, q being a potential dimension. Tanh is chosen as the activation function. Note that the weight vector is generated as follows:
wherein upsilon isα ∈Rq Are parameters that need to be learned.
According to the given context vector ct And a current hidden state ht A simple connected layer is used to combine the information from the two vectors to generate an attention hiding state, as follows:
wherein W
c ∈R
r×4p Is a weight matrix. Attention vector
Generating t +1 th access information by softmax layer feeding, defined as:
wherein
Is a parameter to be learned
Using real access information y
t And predictive access
The cross entropy of (c) to calculate the loss for all patients as follows
In medical careThe interpretability of the learned medical code and access representation is very important, it is necessary to understand the clinical meaning of each dimension of the medical code representation and to analyze which accesses are crucial for prediction. Since the proposed model is based on attention mechanisms, the importance of each visit to the prediction is easily discovered through analysis of the attention scores. Prediction of the t-th time, if the attention score is alpha
ti Very large, then the probability prediction for the (i + 1) th access to relevant information is currently high. First using a non-negative matrix
To represent a medical code. Each dimension of the attention hiding state vector is then sorted in reverse order of value. Finally, the first k codes with the largest value are selected as follows:
wherein
To represent
Column i or dimension. By analyzing the selected medical code, a clinical interpretation for each dimension can be obtained.
DETAILED DESCRIPTION OF EMBODIMENT (S) OF INVENTION
In order to illustrate the technical effect of the invention, the invention is verified by adopting a specific application example.
The experiment used two data sets: medical subsidy claims and diabetes claims. The medical assistance data set includes 2011 medical assistance applications. It contains data relating to 147,810 patients and 1,055,011 visits. The patients' visits were grouped by week, excluding patients with visits less than 2. The diabetes data set included medical assistance applications from 2012 and 2013, corresponding to patients diagnosed with diabetes (i.e., medical assistance members who had ICD-9 diagnostic code 250.Xx in the application). It contains the relevant data of 22,820 patients, and the number of times of visit of these patients is 466,732. The patients' visits were summarized in weeks, excluding patients with visits less than 5 times.
For each data set, the data set was randomly divided into training, validation and test sets in a ratio of 0.75. The validation data set was used to determine the optimal values for the parameters, perform 100 iterations, and report the best performance for each method.
Experiment one:
the statistical data set is shown in table 1:
TABLE 1
Experiment two:
the following baseline model was performed: (1) Med2Vec (model 1); (2) RETAIN (model 2); (3) RNN (model 3).
The following RNN-based prediction model is performed: (1) Calculating RNN of relative importance using location-based attention functionl (model 4); (2) Calculating RNN of relative importance using general attention functiong (model 5); (3) Calculating RNN of relative importance with connection-based attention functionc (model 6).
The following Dipole model was performed: (1) Dipole without any attention mechanism- (model 7); (2) Computing Dipole of relative importance with location-based attention functionl (model 8); (2) Computing Dipole of relative importance with general attention functiong (model 9); (3) Computing Dipole of relative importance with a connection-based attention functionc (model 10)
Results and analysis of the experiments
Table 2 shows the accuracy of all methods on the diabetes data set
TABLE 2
It can be observed in table 2 that since most medical codes are for diabetes, med2Vec (model 2) can correctly learn the vector ratio representation on the diabetes data set. Thus, med2Vec achieves the best results in the three baselines, for the medical data set, RETAIN (model 3) accuracy is better than Med2Vec. The reason is that there are many diseases in the medical assistance data set and there are many more types of medical codes than in the diabetes data set. In this case, the attention mechanism may help RETAIN learn reasonable parameters to make the correct prediction.
The accuracy of RNN (model 1) is lowest in all methods of both datasets. This is because the prediction of RNNs depends mainly on recent access information. It cannot remember all past information. However, retain and proposed RNN variant RNNl (model 4), RNNg (model 5) and RNNc (model 6) it is possible to take into full account all previous visit information, assign different attention scores to past visits, and achieve better performance when compared to RNNs.
Since most of the visits in the diabetes data set are related to diabetes, it is easy to predict the medical code of the next visit based on the past visit information. Retain uses a reverse temporal attention mechanism for prediction, which degrades prediction performance compared to methods using a temporal sequential attention mechanism. The performance of all three RNN variants was superior to that of RETAIN. However, RETAIN is more accurate than RNN variants because the data in the medical assistance data set is about different diseases. The use of a reverse time attention mechanism can help learn the correct access relationships.
RNN and proposed Dipole- (model 7) none used any attention mechanism, but on the diabetes and medical assistance data sets, dipole- The accuracy of (2) is higher than that of RNN. The results show that modeling access information from both directions can improve prediction performance. Therefore, it is reasonable to use a bi-directional recurrent neural network for diagnostic prediction.
Proposed isDipolec (model 10) and Dipolel (model 8) best performs on the diabetes and the medical assistance data sets, respectively, which shows that modeling the visits from two directions and assigning different weights to each visit can improve the accuracy of the medical diagnosis prediction task. On the diabetes data set, dipolel And Dipolec Superior to all baseline and proposed RNN variants. On the medical assistance data set, dipolel 、Dipoleg And Dipolec were superior to baseline and RNN variants
FIG. 2 shows a case study predictive healthcare code at sixth visit (y)5 ) Based on previous diabetes dataset access. According to the hidden state h1 ,h2 ,h3 The connection-based attention weights for the second through fifth visits are calculated. In fig. 2, the x-axis is the patient and the y-axis is the attention weight for each visit. In this case study, 5 patients were selected for the trial. It can be observed that the attention scores learned by the attention mechanism are different for different patients. For the second patient in fig. 2, all diagnostic codes are listed in table 3:
TABLE 3
The weight α = [0.2386,0.0824,0.2386,0.0824] for four visits bypatient 2 was first obtained from fig. 2. From an analysis of this attention vector, it can be concluded that the medical coding at the second, fourth and fifth visit has a significant impact on the final prediction. As can be seen from table 3, the patient developed essential hypertension at the second and fourth visits, with diabetes being diagnosed at the fifth visit. Therefore, the probability of medical coding for diabetes and essential hypertension related diseases at visit 6 is high.
In summary, dipole can remedy the challenges of modeling EHR data and interpreting the predicted results. With a bi-directional recurrent neural network, dipole can remember the learned hidden knowledge in past and future accesses. Three attention mechanisms can reasonably explain the prediction result. Experimental results on two large real EHR data sets demonstrate the effectiveness of this Dipole in diagnostic predictive tasks. Analysis shows that in predicting future access information, the attention mechanism may assign different weights to previous accesses.
By adopting the technical scheme disclosed by the invention, the following beneficial effects are obtained: the invention provides a diagnostic prediction method of an attention-based bidirectional recurrent neural network, which is characterized in that a high-dimensional medical code (namely a clinical variable) is embedded into a low-code layer space, and then a code representation is input into the attention-based bidirectional recurrent neural network to generate a hidden state representation. The hidden representation is entered via the softmax layer to predict future access to the medical code. Experimental data show that by adopting the method provided by the embodiment, when future access information is predicted, different weights can be distributed to previous accesses by an attention mechanism, so that not only can a diagnosis and prediction task be effectively completed, but also a prediction result can be reasonably explained.
The foregoing is only a preferred embodiment of the present invention, and it should be noted that, for those skilled in the art, various modifications and improvements can be made without departing from the principle of the present invention, and such modifications and improvements should also be considered within the scope of the present invention.