33import numpy as np
44import math
55import os
6+ import time
67from dnlp .core .dnn_crf_base import DnnCrfBase
7- from dnlp .config . config import DnnCrfConfig
8+ from dnlp .config import DnnCrfConfig
89
910
1011class DnnCrf (DnnCrfBase ):
1112def __init__ (self ,* ,config :DnnCrfConfig = None ,task = 'cws' ,data_path :str = '' ,dtype :type = tf .float32 ,
12- mode :str = 'train' ,dropout_position :str = 'input' ,predict :str = 'll' ,nn :str ,model_path :str = '' ,
13+ mode :str = 'train' ,dropout_position :str = 'input' ,train :str = 'mm' ,predict :str = 'll' ,nn :str ,
14+ model_path :str = '' ,
1315embedding_path :str = '' ,remark :str = '' ):
1416if mode not in ['train' ,'predict' ]:
1517raise Exception ('mode error' )
1618if nn not in ['mlp' ,'rnn' ,'lstm' ,'bilstm' ,'gru' ]:
17- raise Exception ('name of neural network entered is not supported' )
19+ raise Exception ('neural network name entered is not supported' )
1820
1921DnnCrfBase .__init__ (self ,config ,data_path ,mode ,model_path )
2022self .dtype = dtype
@@ -24,6 +26,7 @@ def __init__(self, *, config: DnnCrfConfig = None, task='cws', data_path: str =
2426self .remark = remark
2527self .embedding_path = embedding_path
2628self .graph = tf .Graph ()
29+ self .train = train
2730with self .graph .as_default ():
2831# 构建
2932# tf.reset_default_graph()
@@ -35,6 +38,12 @@ def __init__(self, *, config: DnnCrfConfig = None, task='cws', data_path: str =
3538if mode == 'train' :
3639self .input = tf .placeholder (tf .int32 , [self .batch_size ,self .batch_length ,self .windows_size ])
3740self .real_indices = tf .placeholder (tf .int32 , [self .batch_size ,self .batch_length ])
41+ # self.sentence_inputs = tf.data.Dataset.from_tensor_slices(self.sentences).repeat(-1).batch(self.batch_size)
42+ # self.label_inputs = tf.data.Dataset.from_tensor_slices(self.labels).repeat(-1).batch(self.batch_size)
43+ # self.length_inputs = tf.data.Dataset.from_tensor_slices(self.sentence_lengths).repeat(-1).batch(self.batch_size)
44+ # self.sentence_iterator = self.sentence_inputs.make_initializable_iterator()
45+ # self.label_iterator = self.label_inputs.make_initializable_iterator()
46+ # self.length_iterator = self.length_inputs.make_initializable_iterator()
3847else :
3948self .input = tf .placeholder (tf .int32 , [None ,self .windows_size ])
4049
@@ -43,7 +52,7 @@ def __init__(self, *, config: DnnCrfConfig = None, task='cws', data_path: str =
4352# 查找表层
4453self .embedding_layer = self .get_embedding_layer ()
4554# 执行drpout
46- if dropout_position == 'input' :
55+ if mode == 'train' and dropout_position == 'input' :
4756self .embedding_layer = self .get_dropout_layer (self .embedding_layer )
4857# 隐藏层
4958if nn == 'mlp' :
@@ -56,10 +65,11 @@ def __init__(self, *, config: DnnCrfConfig = None, task='cws', data_path: str =
5665self .hidden_layer = self .get_gru_layer (self .embedding_layer )
5766else :
5867self .hidden_layer = self .get_rnn_layer (self .embedding_layer )
59- if dropout_position == 'hidden' :
68+ if mode == 'train' and dropout_position == 'hidden' :
6069self .hidden_layer = self .get_dropout_layer (self .hidden_layer )
6170# 输出层
6271self .output = self .get_output_layer (self .hidden_layer )
72+ # self.output = tf.nn.softmax(self.output,2)
6373
6474if mode == 'predict' :
6575if predict != 'll' :
@@ -69,18 +79,46 @@ def __init__(self, *, config: DnnCrfConfig = None, task='cws', data_path: str =
6979self .sess .run (tf .global_variables_initializer ())
7080tf .train .Saver ().restore (save_path = self .model_path ,sess = self .sess )
7181else :
72- self .crf_loss ,_ = tf .contrib .crf .crf_log_likelihood (self .output ,self .real_indices ,self .seq_length ,
73- self .transition )
74- #self.loss = -self.loss
7582self .regularization = tf .contrib .layers .apply_regularization (tf .contrib .layers .l2_regularizer (self .lam ),
76- self .params )
77- self .loss = - self .crf_loss / self .batch_size + self .regularization
83+ self .params )
84+ if train == 'll' :
85+ self .crf_loss ,_ = tf .contrib .crf .crf_log_likelihood (self .output ,self .real_indices ,self .seq_length ,
86+ self .transition )
87+ # self.loss = -self.loss
88+ self .loss = - self .crf_loss / self .batch_size + self .regularization
89+ # self.optimizer = tf.train.AdagradOptimizer(self.learning_rate)
90+ # self.optimizer.minimize(self.loss)
91+ # self.train = self.optimizer.minimize(self.loss)
92+ else :
93+ self .true_seq = tf .placeholder (tf .int32 , [self .batch_size ,self .batch_length ])
94+ self .pred_seq = tf .placeholder (tf .int32 , [self .batch_size ,self .batch_length ])
95+ self .output_placeholder = tf .placeholder (self .dtype , [self .batch_size ,self .batch_length ,self .tags_count ])
96+ batch_index = np .repeat (np .expand_dims (np .arange (0 ,self .batch_size ),1 ),self .batch_length ,1 )
97+ sent_index = np .repeat (np .expand_dims (np .arange (0 ,self .batch_length ),0 ),self .batch_size ,0 )
98+ true_index = tf .stack ([batch_index ,sent_index ,self .true_seq ],axis = 2 )
99+ pred_index = tf .stack ([batch_index ,sent_index ,self .pred_seq ],axis = 2 )
100+ state_difference = tf .reduce_sum (
101+ tf .gather_nd (self .output_placeholder ,pred_index )- tf .gather_nd (self .output_placeholder ,true_index ),
102+ axis = 1 )
103+ # r = tf.stack([self.true_seq[:, :-1], self.true_seq[:, 1:]], 2)
104+ transition_difference = tf .reduce_sum (
105+ tf .gather_nd (self .transition ,tf .stack ([self .pred_seq [:, :- 1 ],self .pred_seq [:,1 :]],2 ))- tf .gather_nd (
106+ self .transition ,tf .stack ([self .true_seq [:, :- 1 ],self .true_seq [:,1 :]],2 )),axis = 1 )
107+ init_transition_difference = tf .gather_nd (self .transition_init ,
108+ tf .expand_dims (self .pred_seq [:,0 ],1 ))- tf .gather_nd (
109+ self .transition_init ,tf .expand_dims (self .true_seq [:,0 ],1 ))
110+ hinge_loss = tf .count_nonzero (self .pred_seq - self .true_seq ,axis = 1 ,dtype = self .dtype )
111+ self .score_diff = state_difference + transition_difference + init_transition_difference + self .hinge_rate * hinge_loss
112+ self .loss = tf .reduce_sum (tf .maximum (0.0 ,self .score_diff ))/ self .batch_size + self .regularization
113+ self .learning_rate = 0.01
114+ self .optimizer = tf .train .GradientDescentOptimizer (self .learning_rate )
78115self .optimizer = tf .train .AdagradOptimizer (self .learning_rate )
79- self .new_optimizer = tf .train .AdamOptimizer ()
116+ # self.new_optimizer = tf.train.AdamOptimizer()
80117gvs = self .optimizer .compute_gradients (self .loss )
81- cliped_grad = [(tf .clip_by_norm (grad ,5 )if grad is not None else grad ,var )for grad ,var in gvs ]
82- self .train = self .optimizer .apply_gradients (cliped_grad )# self.optimizer.minimize(self.loss)
83- # self.train = self.optimizer.minimize(self.loss)
118+ cliped_grad = [(tf .clip_by_norm (grad ,10 )if grad is not None else grad ,var )for grad ,var in gvs ]
119+ # self.train_model = self.optimizer.apply_gradients(cliped_grad)
120+ self .train_model = self .optimizer .minimize (self .loss )
121+
84122current_dir = os .path .dirname (__file__ )
85123dest_dir = os .path .realpath (os .path .join (current_dir ,'..\\ data\\ logs' ))
86124self .train_writer = tf .summary .FileWriter (dest_dir ,flush_secs = 10 )
@@ -89,18 +127,63 @@ def __init__(self, *, config: DnnCrfConfig = None, task='cws', data_path: str =
89127self .merged = tf .summary .merge_all ()
90128
91129def fit (self ,epochs :int = 50 ,interval :int = 10 ):
130+ if self .train == 'll' :
131+ self .fit_ll (epochs ,interval )
132+ else :
133+ self .fit_mm (epochs ,interval )
134+
135+ def fit_ll (self ,epochs :int = 50 ,interval :int = 10 ):
92136with tf .Session (graph = self .graph )as sess :
93137tf .global_variables_initializer ().run ()
138+ # sess.run(self.sentence_iterator.initializer)
139+ # sess.run(self.label_iterator.initializer)
140+ # sess.run(self.length_iterator.initializer)
141+ # sentence = self.sentence_iterator.get_next()
142+ # label = self.label_iterator.get_next()
143+ # length = self.length_iterator.get_next()
94144saver = tf .train .Saver (max_to_keep = epochs )
95145for epoch in range (1 ,epochs + 1 ):
96146print ('epoch:' ,epoch )
97147j = 0
98148for i in range (self .batch_count ):
99- characters ,labels ,lengths = self .get_batch ()
100- feed_dict = {self .input :characters ,self .real_indices :labels ,self .seq_length :lengths }
101- _ ,summary ,loss = sess .run ([self .train ,self .merged ,self .mean_loss ],feed_dict = feed_dict )
149+ sentences ,labels ,lengths = self .get_batch ()
150+ # sentences = sess.run(sentence)
151+ # labels = sess.run(label)
152+ # lengths = sess.run(length)
153+ feed_dict = {self .input :sentences ,self .real_indices :labels ,self .seq_length :lengths }
154+ _ ,summary ,loss = sess .run ([self .train_model ,self .merged ,self .mean_loss ],feed_dict = feed_dict )
102155self .train_writer .add_summary (summary ,j )
103156j += 1
157+ if epoch % interval == 0 :
158+ if not self .embedding_path :
159+ if self .remark :
160+ model_path = '../dnlp/models/emr/{0}-{1}-{2}-{3}.ckpt' .format (self .task ,self .nn ,self .remark ,epoch )
161+ else :
162+ model_path = '../dnlp/models/emr/{0}-{1}-{2}.ckpt' .format (self .task ,self .nn ,epoch )
163+ else :
164+ model_path = '../dnlp/models/emr/{0}-{1}-embedding-{2}.ckpt' .format (self .task ,self .nn ,epoch )
165+ saver .save (sess ,model_path )
166+ self .save_config (model_path )
167+ self .train_writer .close ()
168+
169+ def fit_mm (self ,epochs :int = 50 ,interval :int = 1 ):
170+ with tf .Session (graph = self .graph )as sess :
171+ tf .global_variables_initializer ().run ()
172+ saver = tf .train .Saver (max_to_keep = epochs )
173+ for epoch in range (1 ,epochs + 1 ):
174+ print ('epoch:' ,epoch )
175+ start = time .time ()
176+ for i in range (self .batch_count ):
177+ sentences ,labels ,lengths = self .get_batch ()
178+ transition = self .transition .eval ()
179+ transition_init = self .transition_init .eval ()
180+ feed_dict = {self .input :sentences ,self .seq_length :lengths }
181+ output = sess .run (self .output ,feed_dict = feed_dict )
182+ pred_seq = []
183+ for i in range (self .batch_size ):
184+ pred_seq .append (self .viterbi (output [i , :lengths [i ], :].T ,transition ,transition_init ,self .batch_length ))
185+ feed_dict = {self .true_seq :labels ,self .pred_seq :pred_seq ,self .output_placeholder :output }
186+ sess .run (self .train_model ,feed_dict = feed_dict )
104187if epoch % interval == 0 :
105188if not self .embedding_path :
106189if self .remark :
@@ -111,16 +194,17 @@ def fit(self, epochs: int = 50, interval: int = 10):
111194model_path = '../dnlp/models/{0}-{1}-embedding-{2}.ckpt' .format (self .task ,self .nn ,epoch )
112195saver .save (sess ,model_path )
113196self .save_config (model_path )
114- self . train_writer . close ( )
197+ print ( 'epoch time' , ( time . time () - start ) / 60 )
115198
116199def predict (self ,sentence :str ,return_labels = False ):
117200if self .mode != 'predict' :
118201raise Exception ('mode is not allowed to predict' )
119202
120203input = self .indices2input (self .sentence2indices (sentence ))
121204runner = [self .output ,self .transition ,self .transition_init ]
122- output ,trans ,trans_init = self .sess .run (runner ,feed_dict = {self .input :input })
123- labels = self .viterbi (output ,trans ,trans_init )
205+ output ,trans ,trans_init = self .sess .run (runner ,feed_dict = {self .input :input ,self .seq_length :[len (sentence )]})
206+ output = np .squeeze (output ,0 )
207+ labels = self .viterbi (output .T ,trans ,trans_init )
124208if self .task == 'cws' :
125209result = self .tags2words (sentence ,labels )
126210else :
@@ -152,7 +236,7 @@ def predict_ll(self, sentence: str, return_labels=False):
152236
153237def get_embedding_layer (self )-> tf .Tensor :
154238if self .embedding_path :
155- embeddings = tf .Variable (np .load (self .embedding_path ),trainable = False ,name = 'embeddings' )
239+ embeddings = tf .Variable (np .load (self .embedding_path ),trainable = True ,name = 'embeddings' )
156240else :
157241embeddings = self .__get_variable ([self .dict_size ,self .embed_size ],'embeddings' )
158242self .params .append (embeddings )
@@ -178,7 +262,7 @@ def get_rnn_layer(self, layer: tf.Tensor) -> tf.Tensor:
178262
179263def get_lstm_layer (self ,layer :tf .Tensor )-> tf .Tensor :
180264lstm = tf .nn .rnn_cell .BasicLSTMCell (self .hidden_units )
181- lstm_output ,lstm_out_state = tf .nn .dynamic_rnn (lstm ,layer ,dtype = self .dtype )
265+ lstm_output ,lstm_out_state = tf .nn .dynamic_rnn (lstm ,layer ,sequence_length = self . seq_length , dtype = self .dtype )
182266self .params += [v for v in tf .global_variables ()if v .name .startswith ('rnn' )]
183267return lstm_output
184268