88
99class DnnCrf (DnnCrfBase ):
1010def __init__ (self ,* ,config :DnnCrfConfig = None ,data_path :str = '' ,dtype :type = tf .float32 ,mode :str = 'train' ,
11- train :str = 'll' ,nn :str ,model_path :str = '' ):
11+ predict :str = 'll' ,nn :str ,model_path :str = '' ):
1212if mode not in ['train' ,'predict' ]:
1313raise Exception ('mode error' )
1414if nn not in ['mlp' ,'rnn' ,'lstm' ,'bilstm' ,'gru' ]:
@@ -27,113 +27,42 @@ def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: t
2727if mode == 'train' :
2828self .input = tf .placeholder (tf .int32 , [self .batch_size ,self .batch_length ,self .windows_size ])
2929self .real_indices = tf .placeholder (tf .int32 , [self .batch_size ,self .batch_length ])
30- self .seq_length = tf .placeholder (tf .int32 , [self .batch_size ])
3130else :
3231self .input = tf .placeholder (tf .int32 , [None ,self .windows_size ])
3332
33+ self .seq_length = tf .placeholder (tf .int32 , [None ])
34+
3435# 查找表层
3536self .embedding_layer = self .get_embedding_layer ()
3637# 隐藏层
3738if nn == 'mlp' :
3839self .hidden_layer = self .get_mlp_layer (tf .transpose (self .embedding_layer ))
3940elif nn == 'lstm' :
4041self .hidden_layer = self .get_lstm_layer (self .embedding_layer )
42+ elif nn == 'bilstm' :
43+ self .hidden_layer = self .get_bilstm_layer (self .embedding_layer )
4144elif nn == 'gru' :
42- self .hidden_layer = self .get_gru_layer (tf . transpose ( self .embedding_layer ) )
45+ self .hidden_layer = self .get_gru_layer (self .embedding_layer )
4346else :
44- self .hidden_layer = self .get_rnn_layer (tf . transpose ( self .embedding_layer ) )
47+ self .hidden_layer = self .get_rnn_layer (self .embedding_layer )
4548# 输出层
4649self .output = self .get_output_layer (self .hidden_layer )
4750
4851if mode == 'predict' :
49- self .output = tf .squeeze (tf .transpose (self .output ),axis = 2 )
52+ if predict != 'll' :
53+ self .output = tf .squeeze (tf .transpose (self .output ),axis = 2 )
54+ self .seq ,self .best_score = tf .contrib .crf .crf_decode (self .output ,self .transition ,self .seq_length )
5055self .sess = tf .Session ()
5156self .sess .run (tf .global_variables_initializer ())
5257tf .train .Saver ().restore (save_path = self .model_path ,sess = self .sess )
53- elif train == 'll' :
54- self .ll_loss ,_ = tf .contrib .crf .crf_log_likelihood (self .output ,self .real_indices ,self .seq_length ,
55- self .transition )
56- self .optimizer = tf .train .AdagradOptimizer (self .learning_rate )
57- self .train_ll = self .optimizer .minimize (- self .ll_loss )
5858else :
59- # 构建训练函数
60- # 训练用placeholder
61- self .ll_corr = tf .placeholder (tf .int32 ,shape = [None ,3 ])
62- self .ll_curr = tf .placeholder (tf .int32 ,shape = [None ,3 ])
63- self .trans_corr = tf .placeholder (tf .int32 , [None ,2 ])
64- self .trans_curr = tf .placeholder (tf .int32 , [None ,2 ])
65- self .trans_init_corr = tf .placeholder (tf .int32 , [None ,1 ])
66- self .trans_init_curr = tf .placeholder (tf .int32 , [None ,1 ])
67- # 损失函数
68- self .loss ,self .loss_with_init = self .get_loss ()
59+ self .loss ,_ = tf .contrib .crf .crf_log_likelihood (self .output ,self .real_indices ,self .seq_length ,
60+ self .transition )
6961self .optimizer = tf .train .AdagradOptimizer (self .learning_rate )
70- self .train = self . optimizer . minimize ( self . loss )
71- self .train_with_init = self .optimizer .minimize (self .loss_with_init )
62+ self .new_optimizer = tf . train . AdamOptimizer ( )
63+ self .train = self .optimizer .minimize (- self .loss )
7264
7365def fit (self ,epochs :int = 100 ,interval :int = 20 ):
74- with tf .Session ()as sess :
75- tf .global_variables_initializer ().run ()
76- saver = tf .train .Saver (max_to_keep = 100 )
77- for epoch in range (1 ,epochs + 1 ):
78- print ('epoch:' ,epoch )
79- for _ in range (self .batch_count ):
80- characters ,labels ,lengths = self .get_batch ()
81- self .fit_batch (characters ,labels ,lengths ,sess )
82- # if epoch % interval == 0:
83- model_path = '../dnlp/models/cws{0}.ckpt' .format (epoch )
84- saver .save (sess ,model_path )
85- self .save_config (model_path )
86-
87- def fit_batch (self ,characters ,labels ,lengths ,sess ):
88- scores = sess .run (self .output ,feed_dict = {self .input :characters })
89- transition = self .transition .eval (session = sess )
90- transition_init = self .transition_init .eval (session = sess )
91- update_labels_pos = None
92- update_labels_neg = None
93- current_labels = []
94- trans_pos_indices = []
95- trans_neg_indices = []
96- trans_init_pos_indices = []
97- trans_init_neg_indices = []
98- for i in range (self .batch_size ):
99- current_label = self .viterbi (scores [:, :lengths [i ],i ],transition ,transition_init )
100- current_labels .append (current_label )
101- diff_tag = np .subtract (labels [i , :lengths [i ]],current_label )
102- update_index = np .where (diff_tag != 0 )[0 ]
103- update_length = len (update_index )
104- if update_length == 0 :
105- continue
106- update_label_pos = np .stack ([labels [i ,update_index ],update_index ,i * np .ones ([update_length ])],axis = - 1 )
107- update_label_neg = np .stack ([current_label [update_index ],update_index ,i * np .ones ([update_length ])],axis = - 1 )
108- if update_labels_pos is not None :
109- np .concatenate ((update_labels_pos ,update_label_pos ))
110- np .concatenate ((update_labels_neg ,update_label_neg ))
111- else :
112- update_labels_pos = update_label_pos
113- update_labels_neg = update_label_neg
114-
115- trans_pos_index ,trans_neg_index ,trans_init_pos ,trans_init_neg ,update_init = self .generate_transition_update_index (
116- labels [i , :lengths [i ]],current_labels [i ])
117-
118- trans_pos_indices .extend (trans_pos_index )
119- trans_neg_indices .extend (trans_neg_index )
120-
121- if update_init :
122- trans_init_pos_indices .append (trans_init_pos )
123- trans_init_neg_indices .append (trans_init_neg )
124-
125- if update_labels_pos is not None and update_labels_neg is not None :
126- feed_dict = {self .input :characters ,self .ll_curr :update_labels_neg ,self .ll_corr :update_labels_pos ,
127- self .trans_curr :trans_neg_indices ,self .trans_corr :trans_pos_indices }
128-
129- if not trans_init_pos_indices :
130- sess .run (self .train ,feed_dict )
131- else :
132- feed_dict [self .trans_init_corr ]= trans_init_pos_indices
133- feed_dict [self .trans_init_curr ]= trans_init_neg_indices
134- sess .run (self .train_with_init ,feed_dict )
135-
136- def fit_ll (self ,epochs :int = 100 ,interval :int = 20 ):
13766with tf .Session ()as sess :
13867tf .global_variables_initializer ().run ()
13968saver = tf .train .Saver (max_to_keep = epochs )
@@ -143,44 +72,13 @@ def fit_ll(self, epochs: int = 100, interval: int = 20):
14372characters ,labels ,lengths = self .get_batch ()
14473# scores = sess.run(self.output, feed_dict={self.input: characters})
14574feed_dict = {self .input :characters ,self .real_indices :labels ,self .seq_length :lengths }
146- sess .run (self .train_ll ,feed_dict = feed_dict )
75+ sess .run (self .train ,feed_dict = feed_dict )
14776# self.fit_batch(characters, labels, lengths, sess)
14877# if epoch % interval == 0:
14978model_path = '../dnlp/models/cws{0}.ckpt' .format (epoch )
15079saver .save (sess ,model_path )
15180self .save_config (model_path )
15281
153- def fit_batch_ll (self ):
154- pass
155-
156- def generate_transition_update_index (self ,correct_labels ,current_labels ):
157- if correct_labels .shape != current_labels .shape :
158- print ('sequence length is not equal' )
159- return None
160-
161- before_corr = correct_labels [0 ]
162- before_curr = current_labels [0 ]
163- update_init = False
164-
165- trans_init_pos = None
166- trans_init_neg = None
167- trans_pos = []
168- trans_neg = []
169-
170- if before_corr != before_curr :
171- trans_init_pos = [before_corr ]
172- trans_init_neg = [before_curr ]
173- update_init = True
174-
175- for _ , (corr_label ,curr_label )in enumerate (zip (correct_labels [1 :],current_labels [1 :])):
176- if corr_label != curr_label or before_corr != before_curr :
177- trans_pos .append ([before_corr ,corr_label ])
178- trans_neg .append ([before_curr ,curr_label ])
179- before_corr = corr_label
180- before_curr = curr_label
181-
182- return trans_pos ,trans_neg ,trans_init_pos ,trans_init_neg ,update_init
183-
18482def predict (self ,sentence :str ,return_labels = False ):
18583if self .mode != 'predict' :
18684raise Exception ('mode is not allowed to predict' )
@@ -194,6 +92,22 @@ def predict(self, sentence: str, return_labels=False):
19492else :
19593return self .tags2words (sentence ,labels ),self .tag2sequences (labels )
19694
95+ def predict_ll (self ,sentence :str ,return_labels = False ):
96+ if self .mode != 'predict' :
97+ raise Exception ('mode is not allowed to predict' )
98+
99+ input = self .indices2input (self .sentence2indices (sentence ))
100+ runner = [self .seq ,self .best_score ,self .output ,self .transition ]
101+ labels ,best_score ,output ,trans = self .sess .run (runner ,
102+ feed_dict = {self .input :input ,self .seq_length : [len (sentence )]})
103+ # print(output)
104+ # print(trans)
105+ labels = np .squeeze (labels )
106+ if return_labels :
107+ return self .tags2words (sentence ,labels ),self .tag2sequences (labels )
108+ else :
109+ return self .tags2words (sentence ,labels )
110+
197111def get_embedding_layer (self )-> tf .Tensor :
198112embeddings = self .__get_variable ([self .dict_size ,self .embed_size ],'embeddings' )
199113self .params .append (embeddings )
@@ -215,19 +129,27 @@ def get_rnn_layer(self, layer: tf.Tensor) -> tf.Tensor:
215129rnn = tf .nn .rnn_cell .RNNCell (self .hidden_units )
216130rnn_output ,rnn_out_state = tf .nn .dynamic_rnn (rnn ,layer ,dtype = self .dtype )
217131self .params += [v for v in tf .global_variables ()if v .name .startswith ('rnn' )]
218- return tf . transpose ( rnn_output )
132+ return rnn_output
219133
220134def get_lstm_layer (self ,layer :tf .Tensor )-> tf .Tensor :
221135lstm = tf .nn .rnn_cell .LSTMCell (self .hidden_units )
222136lstm_output ,lstm_out_state = tf .nn .dynamic_rnn (lstm ,layer ,dtype = self .dtype )
223137self .params += [v for v in tf .global_variables ()if v .name .startswith ('rnn' )]
224138return lstm_output
225139
140+ def get_bilstm_layer (self ,layer :tf .Tensor )-> tf .Tensor :
141+ lstm_fw = tf .nn .rnn_cell .LSTMCell (self .hidden_units // 2 )
142+ lstm_bw = tf .nn .rnn_cell .LSTMCell (self .hidden_units // 2 )
143+ bilstm_output ,bilstm_output_state = tf .nn .bidirectional_dynamic_rnn (lstm_fw ,lstm_bw ,layer ,self .seq_length ,
144+ dtype = self .dtype )
145+ self .params += [v for v in tf .global_variables ()if v .name .startswith ('rnn' )]
146+ return tf .concat ([bilstm_output [0 ],bilstm_output [1 ]],- 1 )
147+
226148def get_gru_layer (self ,layer :tf .Tensor )-> tf .Tensor :
227149gru = tf .nn .rnn_cell .GRUCell (self .hidden_units )
228150gru_output ,gru_out_state = tf .nn .dynamic_rnn (gru ,layer ,dtype = self .dtype )
229151self .params += [v for v in tf .global_variables ()if v .name .startswith ('rnn' )]
230- return tf . transpose ( gru_output )
152+ return gru_output
231153
232154def get_dropout_layer (self ,layer :tf .Tensor )-> tf .Tensor :
233155return tf .layers .dropout (layer ,self .dropout_rate )
@@ -238,17 +160,5 @@ def get_output_layer(self, layer: tf.Tensor) -> tf.Tensor:
238160self .params += [output_weight ,output_bias ]
239161return tf .tensordot (layer ,output_weight , [[2 ], [0 ]])+ output_bias
240162
241- def get_loss (self )-> (tf .Tensor ,tf .Tensor ):
242- output_loss = tf .reduce_sum (tf .gather_nd (self .output ,self .ll_curr )- tf .gather_nd (self .output ,self .ll_corr ))
243- trans_loss = tf .gather_nd (self .transition ,self .trans_curr )- tf .gather_nd (self .transition ,self .trans_corr )
244- trans_i_curr = tf .gather_nd (self .transition_init ,self .trans_init_curr )
245- trans_i_corr = tf .gather_nd (self .transition_init ,self .trans_init_corr )
246- trans_init_loss = tf .reduce_sum (trans_i_curr - trans_i_corr )
247- loss = output_loss + trans_loss
248- regu = tf .contrib .layers .apply_regularization (tf .contrib .layers .l2_regularizer (self .lam ),self .params )
249- l1 = loss + regu
250- l2 = l1 + trans_init_loss
251- return l1 ,l2
252-
253163def __get_variable (self ,size ,name )-> tf .Variable :
254164return tf .Variable (tf .truncated_normal (size ,stddev = 1.0 / math .sqrt (size [- 1 ]),dtype = self .dtype ),name = name )