1010
1111class DnnCrf (DnnCrfBase ):
1212def __init__ (self ,* ,config :DnnCrfConfig = None ,task = 'cws' ,data_path :str = '' ,dtype :type = tf .float32 ,
13- mode :str = 'train' ,dropout_position :str = 'input' ,train :str = 'mm' ,predict :str = 'll ' ,nn :str ,
13+ mode :str = 'train' ,dropout_position :str = 'input' ,train :str = 'mm' ,predict :str = 'mm ' ,nn :str ,
1414model_path :str = '' ,
1515embedding_path :str = '' ,remark :str = '' ):
1616if mode not in ['train' ,'predict' ]:
@@ -75,7 +75,8 @@ def __init__(self, *, config: DnnCrfConfig = None, task='cws', data_path: str =
7575if mode == 'predict' :
7676if predict != 'll' :
7777self .output = tf .squeeze (tf .transpose (self .output ),axis = 2 )
78- self .seq ,self .best_score = tf .contrib .crf .crf_decode (self .output ,self .transition ,self .seq_length )
78+ if predict == 'll' :
79+ self .seq ,self .best_score = tf .contrib .crf .crf_decode (self .output ,self .transition ,self .seq_length )
7980self .sess = tf .Session ()
8081self .sess .run (tf .global_variables_initializer ())
8182tf .train .Saver ().restore (save_path = self .model_path ,sess = self .sess )
@@ -105,16 +106,16 @@ def __init__(self, *, config: DnnCrfConfig = None, task='cws', data_path: str =
105106tf .expand_dims (self .pred_seq [:,0 ],1 ))- tf .gather_nd (
106107self .transition_init ,tf .expand_dims (self .true_seq [:,0 ],1 ))
107108self .hinge_loss = tf .count_nonzero (self .pred_seq - self .true_seq ,axis = 1 ,dtype = self .dtype )
108- self .seq ,self .best_score = tf .contrib .crf .crf_decode (self .output ,self .transition ,self .seq_length )
109+ # self.seq, self.best_score = tf.contrib.crf.crf_decode(self.output, self.transition, self.seq_length)
109110# self.score_diff = self.state_difference + self.transition_difference + self.init_transition_difference + self.hinge_rate*self.hinge_loss
110111self .score_diff = self .state_difference + self .transition_difference + self .hinge_rate * self .hinge_loss
111112self .loss = tf .reduce_sum (tf .maximum (0.0 ,self .score_diff ))/ self .batch_size + self .regularization
112- self .learning_rate = 0.005
113- self .optimizer = tf .train .GradientDescentOptimizer (self .learning_rate )
114- # self.optimizer = tf.train.AdagradOptimizer(self.learning_rate)
113+ # self.learning_rate = 0.005
114+ # self.optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
115+ self .optimizer = tf .train .AdagradOptimizer (self .learning_rate )
115116# self.new_optimizer = tf.train.AdamOptimizer()
116- gvs = self .optimizer .compute_gradients (self .loss )
117- cliped_grad = [(tf .clip_by_norm (grad ,10 )if grad is not None else grad ,var )for grad ,var in gvs ]
117+ # gvs = self.optimizer.compute_gradients(self.loss)
118+ # cliped_grad = [(tf.clip_by_norm(grad, 10) if grad is not None else grad, var) for grad, var in gvs]
118119# self.train_model = self.optimizer.apply_gradients(cliped_grad)
119120self .train_model = self .optimizer .minimize (self .loss )
120121
@@ -172,7 +173,7 @@ def fit_mm(self, epochs: int = 50, interval: int = 1):
172173for epoch in range (1 ,epochs + 1 ):
173174print ('epoch:' ,epoch )
174175start = time .time ()
175- for i in range (self .batch_count ):
176+ for j in range (self .batch_count ):
176177sentences ,labels ,lengths = self .get_batch ()
177178transition = self .transition .eval ()
178179transition_init = self .transition_init .eval ()
@@ -182,11 +183,13 @@ def fit_mm(self, epochs: int = 50, interval: int = 1):
182183# seq = sess.run(self.seq, feed_dict=feed_dict)
183184for i in range (self .batch_size ):
184185# seq = sess.run(self.seq,feed_dict=feed_dict)
185- pred_seq .append (self .viterbi (output [i , :lengths [i ], :].T ,transition ,transition_init ,self .batch_length ))
186+ seq = self .viterbi (output [i , :lengths [i ], :].T ,transition ,transition_init ,labels [i ],
187+ self .batch_length ,True )
188+ pred_seq .append (seq )
186189# pred_seq.append(seq)
187190feed_dict = {self .true_seq :labels ,self .pred_seq :pred_seq ,self .output_placeholder :output }
188- if epoch > 2 :
189- self .eval_params (sess ,feed_dict )
191+ # if epoch > 2:
192+ # self.eval_params(sess, feed_dict)
190193sess .run (self .train_model ,feed_dict = feed_dict )
191194if epoch % interval == 0 :
192195if not self .embedding_path :
@@ -214,8 +217,8 @@ def predict(self, sentence: str, return_labels=False):
214217input = self .indices2input (self .sentence2indices (sentence ))
215218runner = [self .output ,self .transition ,self .transition_init ]
216219output ,trans ,trans_init = self .sess .run (runner ,feed_dict = {self .input :input ,self .seq_length : [len (sentence )]})
217- output = np .squeeze (output ,0 )
218- labels = self .viterbi (output . T ,trans ,trans_init )
220+ # output = np.squeeze(output, 0)
221+ labels = self .viterbi (output ,trans ,trans_init )
219222if self .task == 'cws' :
220223result = self .tags2words (sentence ,labels )
221224else :