1+ # -*- coding: UTF-8 -*-
2+ import tensorflow as tf
3+ import numpy as np
4+ import math
5+ from dnlp .core .dnn_crf_base import DnnCrfBase
6+ from dnlp .config import DnnCrfConfig
7+
8+
9+ class DnnCrfEmr (DnnCrfBase ):
10+ def __init__ (self ,* ,config :DnnCrfConfig = None ,data_path :str = '' ,dtype :type = tf .float32 ,task :str = 'ner' ,mode :str = 'train' ,
11+ nn :str ,model_path :str = '' ):
12+ if mode not in ['train' ,'predict' ]:
13+ raise Exception ('mode error' )
14+ if nn not in ['mlp' ,'rnn' ,'lstm' ,'bilstm' ,'gru' ]:
15+ raise Exception ('name of neural network entered is not supported' )
16+
17+ DnnCrfBase .__init__ (self ,config ,data_path ,mode ,model_path )
18+ self .dtype = dtype
19+ self .mode = mode
20+ self .nn = nn
21+ self .task = task
22+
23+ # 构建
24+ tf .reset_default_graph ()
25+ self .transition = self .__get_variable ([self .tags_count ,self .tags_count ],'transition' )
26+ self .transition_init = self .__get_variable ([self .tags_count ],'transition_init' )
27+ self .params = [self .transition ,self .transition_init ]
28+ # 输入层
29+ if mode == 'train' :
30+ self .input = tf .placeholder (tf .int32 , [self .batch_size ,self .batch_length ,self .windows_size ])
31+ self .real_indices = tf .placeholder (tf .int32 , [self .batch_size ,self .batch_length ])
32+ self .seq_length = tf .placeholder (tf .int32 , [None ])
33+ else :
34+ self .input = tf .placeholder (tf .int32 , [None ,self .windows_size ])
35+
36+ # 查找表层
37+ self .embedding_layer = self .get_embedding_layer ()
38+ # 隐藏层
39+ if nn == 'mlp' :
40+ self .hidden_layer = self .get_mlp_layer (tf .transpose (self .embedding_layer ))
41+ elif nn == 'lstm' :
42+ self .hidden_layer = self .get_lstm_layer (tf .transpose (self .embedding_layer ))
43+ elif nn == 'gru' :
44+ self .hidden_layer = self .get_gru_layer (tf .transpose (self .embedding_layer ))
45+ else :
46+ self .hidden_layer = self .get_rnn_layer (tf .transpose (self .embedding_layer ))
47+ # 输出层
48+ self .output = self .get_output_layer (self .hidden_layer )
49+
50+ if mode == 'predict' :
51+ self .output = tf .squeeze (self .output ,axis = 1 )
52+ self .sess = tf .Session ()
53+ self .sess .run (tf .global_variables_initializer ())
54+ tf .train .Saver ().restore (save_path = self .model_path ,sess = self .sess )
55+ else :
56+ # 构建训练函数
57+ # 训练用placeholder
58+ self .ll_corr = tf .placeholder (tf .int32 ,shape = [None ,3 ])
59+ self .ll_curr = tf .placeholder (tf .int32 ,shape = [None ,3 ])
60+ self .trans_corr = tf .placeholder (tf .int32 , [None ,2 ])
61+ self .trans_curr = tf .placeholder (tf .int32 , [None ,2 ])
62+ self .trans_init_corr = tf .placeholder (tf .int32 , [None ,1 ])
63+ self .trans_init_curr = tf .placeholder (tf .int32 , [None ,1 ])
64+ # 损失函数
65+ self .loss ,self .loss_with_init = self .get_loss ()
66+ self .optimizer = tf .train .AdagradOptimizer (self .learning_rate )
67+ self .train = self .optimizer .minimize (self .loss )
68+ self .train_with_init = self .optimizer .minimize (self .loss_with_init )
69+
70+ def fit (self ,epochs :int = 100 ,interval :int = 20 ):
71+ with tf .Session ()as sess :
72+ tf .global_variables_initializer ().run ()
73+ saver = tf .train .Saver (max_to_keep = 100 )
74+ for epoch in range (1 ,epochs + 1 ):
75+ print ('epoch:' ,epoch )
76+ for _ in range (self .batch_count ):
77+ characters ,labels ,lengths = self .get_batch ()
78+ self .fit_batch (characters ,labels ,lengths ,sess )
79+ if epoch % interval == 0 :
80+ model_path = '../dnlp/models/emr_old/{0}-{1}.ckpt' .format (self .nn ,epoch )
81+ saver .save (sess ,model_path )
82+ self .save_config (model_path )
83+
84+ def fit_batch (self ,characters ,labels ,lengths ,sess ):
85+ scores = sess .run (self .output ,feed_dict = {self .input :characters })
86+ transition = self .transition .eval (session = sess )
87+ transition_init = self .transition_init .eval (session = sess )
88+ update_labels_pos = None
89+ update_labels_neg = None
90+ current_labels = []
91+ trans_pos_indices = []
92+ trans_neg_indices = []
93+ trans_init_pos_indices = []
94+ trans_init_neg_indices = []
95+ for i in range (self .batch_size ):
96+ current_label = self .viterbi (scores [:, :lengths [i ],i ],transition ,transition_init )
97+ current_labels .append (current_label )
98+ diff_tag = np .subtract (labels [i , :lengths [i ]],current_label )
99+ update_index = np .where (diff_tag != 0 )[0 ]
100+ update_length = len (update_index )
101+ if update_length == 0 :
102+ continue
103+ update_label_pos = np .stack ([labels [i ,update_index ],update_index ,i * np .ones ([update_length ])],axis = - 1 )
104+ update_label_neg = np .stack ([current_label [update_index ],update_index ,i * np .ones ([update_length ])],axis = - 1 )
105+ if update_labels_pos is not None :
106+ np .concatenate ((update_labels_pos ,update_label_pos ))
107+ np .concatenate ((update_labels_neg ,update_label_neg ))
108+ else :
109+ update_labels_pos = update_label_pos
110+ update_labels_neg = update_label_neg
111+
112+ trans_pos_index ,trans_neg_index ,trans_init_pos ,trans_init_neg ,update_init = self .generate_transition_update_index (
113+ labels [i , :lengths [i ]],current_labels [i ])
114+
115+ trans_pos_indices .extend (trans_pos_index )
116+ trans_neg_indices .extend (trans_neg_index )
117+
118+ if update_init :
119+ trans_init_pos_indices .append (trans_init_pos )
120+ trans_init_neg_indices .append (trans_init_neg )
121+
122+ if update_labels_pos is not None and update_labels_neg is not None :
123+ feed_dict = {self .input :characters ,self .ll_curr :update_labels_neg ,self .ll_corr :update_labels_pos ,
124+ self .trans_curr :trans_neg_indices ,self .trans_corr :trans_pos_indices }
125+
126+ if not trans_init_pos_indices :
127+ sess .run (self .train ,feed_dict )
128+ else :
129+ feed_dict [self .trans_init_corr ]= trans_init_pos_indices
130+ feed_dict [self .trans_init_curr ]= trans_init_neg_indices
131+ sess .run (self .train_with_init ,feed_dict )
132+
133+ def generate_transition_update_index (self ,correct_labels ,current_labels ):
134+ if correct_labels .shape != current_labels .shape :
135+ print ('sequence length is not equal' )
136+ return None
137+
138+ before_corr = correct_labels [0 ]
139+ before_curr = current_labels [0 ]
140+ update_init = False
141+
142+ trans_init_pos = None
143+ trans_init_neg = None
144+ trans_pos = []
145+ trans_neg = []
146+
147+ if before_corr != before_curr :
148+ trans_init_pos = [before_corr ]
149+ trans_init_neg = [before_curr ]
150+ update_init = True
151+
152+ for _ , (corr_label ,curr_label )in enumerate (zip (correct_labels [1 :],current_labels [1 :])):
153+ if corr_label != curr_label or before_corr != before_curr :
154+ trans_pos .append ([before_corr ,corr_label ])
155+ trans_neg .append ([before_curr ,curr_label ])
156+ before_corr = corr_label
157+ before_curr = curr_label
158+
159+ return trans_pos ,trans_neg ,trans_init_pos ,trans_init_neg ,update_init
160+
161+ def predict_ll (self ,sentence :str ,return_labels = False ):
162+ if self .mode != 'predict' :
163+ raise Exception ('mode is not allowed to predict' )
164+
165+ input = self .indices2input (self .sentence2indices (sentence ))
166+ runner = [self .output ,self .transition ,self .transition_init ]
167+ output ,trans ,trans_init = self .sess .run (runner ,feed_dict = {self .input :input })
168+ labels = self .viterbi (output ,trans ,trans_init )
169+ if self .task == 'cws' :
170+ result = self .tags2words (sentence ,labels )
171+ else :
172+ result = self .tags2entities (sentence ,labels )
173+ if not return_labels :
174+ return result
175+ else :
176+ return result ,self .tag2sequences (labels )
177+
178+ def get_embedding_layer (self )-> tf .Tensor :
179+ embeddings = self .__get_variable ([self .dict_size ,self .embed_size ],'embeddings' )
180+ self .params .append (embeddings )
181+ if self .mode == 'train' :
182+ input_size = [self .batch_size ,self .batch_length ,self .concat_embed_size ]
183+ layer = tf .reshape (tf .nn .embedding_lookup (embeddings ,self .input ),input_size )
184+ else :
185+ layer = tf .reshape (tf .nn .embedding_lookup (embeddings ,self .input ), [1 ,- 1 ,self .concat_embed_size ])
186+ return layer
187+
188+ def get_mlp_layer (self ,layer :tf .Tensor )-> tf .Tensor :
189+ hidden_weight = self .__get_variable ([self .hidden_units ,self .concat_embed_size ],'hidden_weight' )
190+ hidden_bias = self .__get_variable ([self .hidden_units ,1 ,1 ],'hidden_bias' )
191+ self .params += [hidden_weight ,hidden_bias ]
192+ layer = tf .sigmoid (tf .tensordot (hidden_weight ,layer , [[1 ], [0 ]])+ hidden_bias )
193+ return layer
194+
195+ def get_rnn_layer (self ,layer :tf .Tensor )-> tf .Tensor :
196+ rnn = tf .nn .rnn_cell .BasicRNNCell (self .hidden_units )
197+ rnn_output ,rnn_out_state = tf .nn .dynamic_rnn (rnn ,layer ,dtype = self .dtype )
198+ self .params += [v for v in tf .global_variables ()if v .name .startswith ('rnn' )]
199+ return tf .transpose (rnn_output )
200+
201+ def get_lstm_layer (self ,layer :tf .Tensor )-> tf .Tensor :
202+ lstm = tf .nn .rnn_cell .BasicLSTMCell (self .hidden_units )
203+ lstm_output ,lstm_out_state = tf .nn .dynamic_rnn (lstm ,layer ,dtype = self .dtype )
204+ self .params += [v for v in tf .global_variables ()if v .name .startswith ('rnn' )]
205+ return tf .transpose (lstm_output )
206+
207+ def get_gru_layer (self ,layer :tf .Tensor )-> tf .Tensor :
208+ gru = tf .nn .rnn_cell .GRUCell (self .hidden_units )
209+ gru_output ,gru_out_state = tf .nn .dynamic_rnn (gru ,layer ,dtype = self .dtype )
210+ self .params += [v for v in tf .global_variables ()if v .name .startswith ('rnn' )]
211+ return tf .transpose (gru_output )
212+
213+ def get_dropout_layer (self ,layer :tf .Tensor )-> tf .Tensor :
214+ return tf .layers .dropout (layer ,self .dropout_rate )
215+
216+ def get_output_layer (self ,layer :tf .Tensor )-> tf .Tensor :
217+ output_weight = self .__get_variable ([self .tags_count ,self .hidden_units ],'output_weight' )
218+ output_bias = self .__get_variable ([self .tags_count ,1 ,1 ],'output_bias' )
219+ self .params += [output_weight ,output_bias ]
220+ return tf .tensordot (output_weight ,layer , [[1 ], [0 ]])+ output_bias
221+
222+ def get_loss (self )-> (tf .Tensor ,tf .Tensor ):
223+ output_loss = tf .reduce_sum (tf .gather_nd (self .output ,self .ll_curr )- tf .gather_nd (self .output ,self .ll_corr ))
224+ trans_loss = tf .gather_nd (self .transition ,self .trans_curr )- tf .gather_nd (self .transition ,self .trans_corr )
225+ trans_i_curr = tf .gather_nd (self .transition_init ,self .trans_init_curr )
226+ trans_i_corr = tf .gather_nd (self .transition_init ,self .trans_init_corr )
227+ trans_init_loss = tf .reduce_sum (trans_i_curr - trans_i_corr )
228+ loss = output_loss + trans_loss
229+ regu = tf .contrib .layers .apply_regularization (tf .contrib .layers .l2_regularizer (self .lam ),self .params )
230+ l1 = loss + regu
231+ l2 = l1 + trans_init_loss
232+ return l1 ,l2
233+
234+ def __get_variable (self ,size ,name )-> tf .Variable :
235+ return tf .Variable (tf .truncated_normal (size ,stddev = 1.0 / math .sqrt (size [- 1 ]),dtype = self .dtype ),name = name )