@@ -174,17 +174,17 @@ def fit_mm(self, epochs: int = 50, interval: int = 1):
174174start = time .time ()
175175for i in range (self .batch_count ):
176176sentences ,labels ,lengths = self .get_batch ()
177- # transition = self.transition.eval()
178- # transition_init = self.transition_init.eval()
177+ transition = self .transition .eval ()
178+ transition_init = self .transition_init .eval ()
179179feed_dict = {self .input :sentences ,self .seq_length :lengths }
180180output = sess .run (self .output ,feed_dict = feed_dict )
181181pred_seq = []
182- seq = sess .run (self .seq ,feed_dict = feed_dict )
183- # for i in range(self.batch_size):
184- # seq = sess.run(self.seq,feed_dict=feed_dict)
185- # pred_seq.append(self.viterbi(output[i, :lengths[i], :].T, transition, transition_init, self.batch_length))
182+ # seq = sess.run(self.seq, feed_dict=feed_dict)
183+ for i in range (self .batch_size ):
184+ # seq = sess.run(self.seq,feed_dict=feed_dict)
185+ pred_seq .append (self .viterbi (output [i , :lengths [i ], :].T ,transition ,transition_init ,self .batch_length ))
186186# pred_seq.append(seq)
187- feed_dict = {self .true_seq :labels ,self .pred_seq :seq ,self .output_placeholder :output }
187+ feed_dict = {self .true_seq :labels ,self .pred_seq :pred_seq ,self .output_placeholder :output }
188188if epoch > 2 :
189189self .eval_params (sess ,feed_dict )
190190sess .run (self .train_model ,feed_dict = feed_dict )