@@ -29,7 +29,7 @@ def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: t
2929if mode == 'train' :
3030self .input = tf .placeholder (tf .int32 , [self .batch_size ,self .batch_length ,self .windows_size ])
3131self .real_indices = tf .placeholder (tf .int32 , [self .batch_size ,self .batch_length ])
32- self .seq_length = tf .placeholder (tf .int32 , [self . batch_size ])
32+ self .seq_length = tf .placeholder (tf .int32 , [None ])
3333else :
3434self .input = tf .placeholder (tf .int32 , [None ,self .windows_size ])
3535
@@ -48,7 +48,7 @@ def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: t
4848self .output = self .get_output_layer (self .hidden_layer )
4949
5050if mode == 'predict' :
51- self .output = tf .squeeze (self .output ,axis = 2 )
51+ self .output = tf .squeeze (self .output ,axis = 1 )
5252self .sess = tf .Session ()
5353self .sess .run (tf .global_variables_initializer ())
5454tf .train .Saver ().restore (save_path = self .model_path ,sess = self .sess )
@@ -81,10 +81,10 @@ def fit(self, epochs: int = 100, interval: int = 20):
8181for _ in range (self .batch_count ):
8282characters ,labels ,lengths = self .get_batch ()
8383self .fit_batch (characters ,labels ,lengths ,sess )
84- # if epoch % interval == 0:
85- model_path = '../dnlp/models/emr_old/{0}-{1}.ckpt' .format (self .nn ,epoch )
86- saver .save (sess ,model_path )
87- self .save_config (model_path )
84+ if epoch % interval == 0 :
85+ model_path = '../dnlp/models/emr_old/{0}-{1}.ckpt' .format (self .nn ,epoch )
86+ saver .save (sess ,model_path )
87+ self .save_config (model_path )
8888
8989def fit_batch (self ,characters ,labels ,lengths ,sess ):
9090scores = sess .run (self .output ,feed_dict = {self .input :characters })
@@ -224,7 +224,7 @@ def get_rnn_layer(self, layer: tf.Tensor) -> tf.Tensor:
224224return tf .transpose (rnn_output )
225225
226226def get_lstm_layer (self ,layer :tf .Tensor )-> tf .Tensor :
227- lstm = tf .nn .rnn_cell .LSTMCell (self .hidden_units )
227+ lstm = tf .nn .rnn_cell .BasicLSTMCell (self .hidden_units )
228228lstm_output ,lstm_out_state = tf .nn .dynamic_rnn (lstm ,layer ,dtype = self .dtype )
229229self .params += [v for v in tf .global_variables ()if v .name .startswith ('rnn' )]
230230return tf .transpose (lstm_output )