1- #-*- coding: UTF-8 -*-
1+ # -*- coding: UTF-8 -*-
2+ import tensorflow as tf
3+ import numpy as np
4+ import math
5+ from dnlp .core .dnn_crf_base import DnnCrfBase
6+ from dnlp .config import DnnCrfConfig
7+
8+
9+ class DnnCrfEmr (DnnCrfBase ):
10+ def __init__ (self ,* ,config :DnnCrfConfig = None ,data_path :str = '' ,dtype :type = tf .float32 ,task :str = 'ner' ,mode :str = 'train' ,
11+ train :str = '' ,nn :str ,model_path :str = '' ):
12+ if mode not in ['train' ,'predict' ]:
13+ raise Exception ('mode error' )
14+ if nn not in ['mlp' ,'rnn' ,'lstm' ,'bilstm' ,'gru' ]:
15+ raise Exception ('name of neural network entered is not supported' )
16+
17+ DnnCrfBase .__init__ (self ,config ,data_path ,mode ,model_path )
18+ self .dtype = dtype
19+ self .mode = mode
20+ self .nn = nn
21+ self .task = task
22+
23+ # 构建
24+ tf .reset_default_graph ()
25+ self .transition = self .__get_variable ([self .tags_count ,self .tags_count ],'transition' )
26+ self .transition_init = self .__get_variable ([self .tags_count ],'transition_init' )
27+ self .params = [self .transition ,self .transition_init ]
28+ # 输入层
29+ if mode == 'train' :
30+ self .input = tf .placeholder (tf .int32 , [self .batch_size ,self .batch_length ,self .windows_size ])
31+ self .real_indices = tf .placeholder (tf .int32 , [self .batch_size ,self .batch_length ])
32+ self .seq_length = tf .placeholder (tf .int32 , [self .batch_size ])
33+ else :
34+ self .input = tf .placeholder (tf .int32 , [None ,self .windows_size ])
35+
36+ # 查找表层
37+ self .embedding_layer = self .get_embedding_layer ()
38+ # 隐藏层
39+ if nn == 'mlp' :
40+ self .hidden_layer = self .get_mlp_layer (tf .transpose (self .embedding_layer ))
41+ elif nn == 'lstm' :
42+ self .hidden_layer = self .get_lstm_layer (tf .transpose (self .embedding_layer ))
43+ elif nn == 'gru' :
44+ self .hidden_layer = self .get_gru_layer (tf .transpose (self .embedding_layer ))
45+ else :
46+ self .hidden_layer = self .get_rnn_layer (tf .transpose (self .embedding_layer ))
47+ # 输出层
48+ self .output = self .get_output_layer (self .hidden_layer )
49+
50+ if mode == 'predict' :
51+ self .output = tf .squeeze (self .output ,axis = 2 )
52+ self .sess = tf .Session ()
53+ self .sess .run (tf .global_variables_initializer ())
54+ tf .train .Saver ().restore (save_path = self .model_path ,sess = self .sess )
55+ elif train == 'll' :
56+ self .ll_loss ,_ = tf .contrib .crf .crf_log_likelihood (self .output ,self .real_indices ,self .seq_length ,
57+ self .transition )
58+ self .optimizer = tf .train .AdagradOptimizer (self .learning_rate )
59+ self .train_ll = self .optimizer .minimize (- self .ll_loss )
60+ else :
61+ # 构建训练函数
62+ # 训练用placeholder
63+ self .ll_corr = tf .placeholder (tf .int32 ,shape = [None ,3 ])
64+ self .ll_curr = tf .placeholder (tf .int32 ,shape = [None ,3 ])
65+ self .trans_corr = tf .placeholder (tf .int32 , [None ,2 ])
66+ self .trans_curr = tf .placeholder (tf .int32 , [None ,2 ])
67+ self .trans_init_corr = tf .placeholder (tf .int32 , [None ,1 ])
68+ self .trans_init_curr = tf .placeholder (tf .int32 , [None ,1 ])
69+ # 损失函数
70+ self .loss ,self .loss_with_init = self .get_loss ()
71+ self .optimizer = tf .train .AdagradOptimizer (self .learning_rate )
72+ self .train = self .optimizer .minimize (self .loss )
73+ self .train_with_init = self .optimizer .minimize (self .loss_with_init )
74+
75+ def fit (self ,epochs :int = 100 ,interval :int = 20 ):
76+ with tf .Session ()as sess :
77+ tf .global_variables_initializer ().run ()
78+ saver = tf .train .Saver (max_to_keep = 100 )
79+ for epoch in range (1 ,epochs + 1 ):
80+ print ('epoch:' ,epoch )
81+ for _ in range (self .batch_count ):
82+ characters ,labels ,lengths = self .get_batch ()
83+ self .fit_batch (characters ,labels ,lengths ,sess )
84+ # if epoch % interval == 0:
85+ model_path = '../dnlp/models/emr_old/{0}-{1}.ckpt' .format (self .nn ,epoch )
86+ saver .save (sess ,model_path )
87+ self .save_config (model_path )
88+
89+ def fit_batch (self ,characters ,labels ,lengths ,sess ):
90+ scores = sess .run (self .output ,feed_dict = {self .input :characters })
91+ transition = self .transition .eval (session = sess )
92+ transition_init = self .transition_init .eval (session = sess )
93+ update_labels_pos = None
94+ update_labels_neg = None
95+ current_labels = []
96+ trans_pos_indices = []
97+ trans_neg_indices = []
98+ trans_init_pos_indices = []
99+ trans_init_neg_indices = []
100+ for i in range (self .batch_size ):
101+ current_label = self .viterbi (scores [:, :lengths [i ],i ],transition ,transition_init )
102+ current_labels .append (current_label )
103+ diff_tag = np .subtract (labels [i , :lengths [i ]],current_label )
104+ update_index = np .where (diff_tag != 0 )[0 ]
105+ update_length = len (update_index )
106+ if update_length == 0 :
107+ continue
108+ update_label_pos = np .stack ([labels [i ,update_index ],update_index ,i * np .ones ([update_length ])],axis = - 1 )
109+ update_label_neg = np .stack ([current_label [update_index ],update_index ,i * np .ones ([update_length ])],axis = - 1 )
110+ if update_labels_pos is not None :
111+ np .concatenate ((update_labels_pos ,update_label_pos ))
112+ np .concatenate ((update_labels_neg ,update_label_neg ))
113+ else :
114+ update_labels_pos = update_label_pos
115+ update_labels_neg = update_label_neg
116+
117+ trans_pos_index ,trans_neg_index ,trans_init_pos ,trans_init_neg ,update_init = self .generate_transition_update_index (
118+ labels [i , :lengths [i ]],current_labels [i ])
119+
120+ trans_pos_indices .extend (trans_pos_index )
121+ trans_neg_indices .extend (trans_neg_index )
122+
123+ if update_init :
124+ trans_init_pos_indices .append (trans_init_pos )
125+ trans_init_neg_indices .append (trans_init_neg )
126+
127+ if update_labels_pos is not None and update_labels_neg is not None :
128+ feed_dict = {self .input :characters ,self .ll_curr :update_labels_neg ,self .ll_corr :update_labels_pos ,
129+ self .trans_curr :trans_neg_indices ,self .trans_corr :trans_pos_indices }
130+
131+ if not trans_init_pos_indices :
132+ sess .run (self .train ,feed_dict )
133+ else :
134+ feed_dict [self .trans_init_corr ]= trans_init_pos_indices
135+ feed_dict [self .trans_init_curr ]= trans_init_neg_indices
136+ sess .run (self .train_with_init ,feed_dict )
137+
138+ def fit_ll (self ,epochs :int = 100 ,interval :int = 20 ):
139+ with tf .Session ()as sess :
140+ tf .global_variables_initializer ().run ()
141+ saver = tf .train .Saver (max_to_keep = epochs )
142+ for epoch in range (1 ,epochs + 1 ):
143+ print ('epoch:' ,epoch )
144+ for _ in range (self .batch_count ):
145+ characters ,labels ,lengths = self .get_batch ()
146+ # scores = sess.run(self.output, feed_dict={self.input: characters})
147+ feed_dict = {self .input :characters ,self .real_indices :labels ,self .seq_length :lengths }
148+ sess .run (self .train_ll ,feed_dict = feed_dict )
149+ # self.fit_batch(characters, labels, lengths, sess)
150+ if epoch % interval == 0 :
151+ model_path = '../dnlp/models/emr_old/{0}-{1}.ckpt' .format (self .nn ,epoch )
152+ saver .save (sess ,model_path )
153+ self .save_config (model_path )
154+
155+ def fit_batch_ll (self ):
156+ pass
157+
158+ def generate_transition_update_index (self ,correct_labels ,current_labels ):
159+ if correct_labels .shape != current_labels .shape :
160+ print ('sequence length is not equal' )
161+ return None
162+
163+ before_corr = correct_labels [0 ]
164+ before_curr = current_labels [0 ]
165+ update_init = False
166+
167+ trans_init_pos = None
168+ trans_init_neg = None
169+ trans_pos = []
170+ trans_neg = []
171+
172+ if before_corr != before_curr :
173+ trans_init_pos = [before_corr ]
174+ trans_init_neg = [before_curr ]
175+ update_init = True
176+
177+ for _ , (corr_label ,curr_label )in enumerate (zip (correct_labels [1 :],current_labels [1 :])):
178+ if corr_label != curr_label or before_corr != before_curr :
179+ trans_pos .append ([before_corr ,corr_label ])
180+ trans_neg .append ([before_curr ,curr_label ])
181+ before_corr = corr_label
182+ before_curr = curr_label
183+
184+ return trans_pos ,trans_neg ,trans_init_pos ,trans_init_neg ,update_init
185+
186+ def predict_ll (self ,sentence :str ,return_labels = False ):
187+ if self .mode != 'predict' :
188+ raise Exception ('mode is not allowed to predict' )
189+
190+ input = self .indices2input (self .sentence2indices (sentence ))
191+ runner = [self .output ,self .transition ,self .transition_init ]
192+ output ,trans ,trans_init = self .sess .run (runner ,feed_dict = {self .input :input })
193+ labels = self .viterbi (output ,trans ,trans_init )
194+ if self .task == 'cws' :
195+ result = self .tags2words (sentence ,labels )
196+ else :
197+ result = self .tags2entities (sentence ,labels )
198+ if not return_labels :
199+ return result
200+ else :
201+ return result ,self .tag2sequences (labels )
202+
203+ def get_embedding_layer (self )-> tf .Tensor :
204+ embeddings = self .__get_variable ([self .dict_size ,self .embed_size ],'embeddings' )
205+ self .params .append (embeddings )
206+ if self .mode == 'train' :
207+ input_size = [self .batch_size ,self .batch_length ,self .concat_embed_size ]
208+ layer = tf .reshape (tf .nn .embedding_lookup (embeddings ,self .input ),input_size )
209+ else :
210+ layer = tf .reshape (tf .nn .embedding_lookup (embeddings ,self .input ), [1 ,- 1 ,self .concat_embed_size ])
211+ return layer
212+
213+ def get_mlp_layer (self ,layer :tf .Tensor )-> tf .Tensor :
214+ hidden_weight = self .__get_variable ([self .hidden_units ,self .concat_embed_size ],'hidden_weight' )
215+ hidden_bias = self .__get_variable ([self .hidden_units ,1 ,1 ],'hidden_bias' )
216+ self .params += [hidden_weight ,hidden_bias ]
217+ layer = tf .sigmoid (tf .tensordot (hidden_weight ,layer , [[1 ], [0 ]])+ hidden_bias )
218+ return layer
219+
220+ def get_rnn_layer (self ,layer :tf .Tensor )-> tf .Tensor :
221+ rnn = tf .nn .rnn_cell .RNNCell (self .hidden_units )
222+ rnn_output ,rnn_out_state = tf .nn .dynamic_rnn (rnn ,layer ,dtype = self .dtype )
223+ self .params += [v for v in tf .global_variables ()if v .name .startswith ('rnn' )]
224+ return tf .transpose (rnn_output )
225+
226+ def get_lstm_layer (self ,layer :tf .Tensor )-> tf .Tensor :
227+ lstm = tf .nn .rnn_cell .LSTMCell (self .hidden_units )
228+ lstm_output ,lstm_out_state = tf .nn .dynamic_rnn (lstm ,layer ,dtype = self .dtype )
229+ self .params += [v for v in tf .global_variables ()if v .name .startswith ('rnn' )]
230+ return tf .transpose (lstm_output )
231+
232+ def get_gru_layer (self ,layer :tf .Tensor )-> tf .Tensor :
233+ gru = tf .nn .rnn_cell .GRUCell (self .hidden_units )
234+ gru_output ,gru_out_state = tf .nn .dynamic_rnn (gru ,layer ,dtype = self .dtype )
235+ self .params += [v for v in tf .global_variables ()if v .name .startswith ('rnn' )]
236+ return tf .transpose (gru_output )
237+
238+ def get_dropout_layer (self ,layer :tf .Tensor )-> tf .Tensor :
239+ return tf .layers .dropout (layer ,self .dropout_rate )
240+
241+ def get_output_layer (self ,layer :tf .Tensor )-> tf .Tensor :
242+ output_weight = self .__get_variable ([self .tags_count ,self .hidden_units ],'output_weight' )
243+ output_bias = self .__get_variable ([self .tags_count ,1 ,1 ],'output_bias' )
244+ self .params += [output_weight ,output_bias ]
245+ return tf .tensordot (output_weight ,layer , [[1 ], [0 ]])+ output_bias
246+
247+ def get_loss (self )-> (tf .Tensor ,tf .Tensor ):
248+ output_loss = tf .reduce_sum (tf .gather_nd (self .output ,self .ll_curr )- tf .gather_nd (self .output ,self .ll_corr ))
249+ trans_loss = tf .gather_nd (self .transition ,self .trans_curr )- tf .gather_nd (self .transition ,self .trans_corr )
250+ trans_i_curr = tf .gather_nd (self .transition_init ,self .trans_init_curr )
251+ trans_i_corr = tf .gather_nd (self .transition_init ,self .trans_init_corr )
252+ trans_init_loss = tf .reduce_sum (trans_i_curr - trans_i_corr )
253+ loss = output_loss + trans_loss
254+ regu = tf .contrib .layers .apply_regularization (tf .contrib .layers .l2_regularizer (self .lam ),self .params )
255+ l1 = loss + regu
256+ l2 = l1 + trans_init_loss
257+ return l1 ,l2
258+
259+ def __get_variable (self ,size ,name )-> tf .Variable :
260+ return tf .Variable (tf .truncated_normal (size ,stddev = 1.0 / math .sqrt (size [- 1 ]),dtype = self .dtype ),name = name )