99class RECNN (RECNNBase ):
1010def __init__ (self ,config :RECNNConfig ,dtype :type = tf .float32 ,dict_path :str = '' ,mode :str = 'train' ,
1111data_path :str = '' ,relation_count :int = 2 ,model_path :str = '' ,embedding_path :str = '' ,
12- remark :str = '' ):
12+ remark :str = '' , data_mode = 'prefetch' ):
1313tf .reset_default_graph ()
1414RECNNBase .__init__ (self ,config ,dict_path )
1515self .dtype = dtype
@@ -21,12 +21,6 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
2121self .remark = remark
2222
2323self .concat_embed_size = self .word_embed_size + 2 * self .position_embed_size
24- self .words ,self .primary ,self .secondary ,self .labels = self .load_data ()
25- self .input_words = tf .constant (self .words )
26- self .input_primary = tf .constant (self .primary )
27- self .input_secondary = tf .constant (self .secondary )
28- self .input_labels = tf .constant (self .labels )
29- self .input_indices = tf .placeholder (tf .int32 , [self .batch_size ])
3024self .input_characters = tf .placeholder (tf .int32 , [None ,self .batch_length ])
3125self .input_position = tf .placeholder (tf .int32 , [None ,self .batch_length ])
3226self .input = tf .placeholder (self .dtype , [None ,self .batch_length ,self .concat_embed_size ,1 ])
@@ -43,10 +37,6 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
4337self .full_connected_weight = self .__weight_variable ([self .filter_size * len (self .window_size ),self .relation_count ],
4438name = 'full_connected_weight' )
4539self .full_connected_bias = self .__weight_variable ([self .relation_count ],name = 'full_connected_bias' )
46- self .input_words_lookup = tf .nn .embedding_lookup (self .input_words ,self .input_indices )
47- self .input_primary_lookup = tf .nn .embedding_lookup (self .input_primary ,self .input_indices )
48- self .input_secondary_lookup = tf .nn .embedding_lookup (self .input_secondary ,self .input_indices )
49- self .input_labels_lookup = tf .nn .embedding_lookup (self .input_labels ,self .input_indices )
5040self .position_lookup = tf .nn .embedding_lookup (self .position_embedding ,self .input_position )
5141self .character_lookup = tf .nn .embedding_lookup (self .word_embedding ,self .input_characters )
5242self .character_embed_holder = tf .placeholder (self .dtype ,
@@ -55,17 +45,24 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
5545 [None ,self .batch_length ,self .position_embed_size ])
5646self .secondary_embed_holder = tf .placeholder (self .dtype ,
5747 [None ,self .batch_length ,self .position_embed_size ])
58- self .emebd_concat = tf .expand_dims (
48+ self .embedded_concat = tf .expand_dims (
5949tf .concat ([self .character_embed_holder ,self .primary_embed_holder ,self .secondary_embed_holder ],2 ),3 )
60-
50+ if data_mode == 'prefetch' :
51+ self .words ,self .primary ,self .secondary ,self .labels = self .load_data ()
52+ self .data_count = len (self .words )
53+ self .words = tf .data .Dataset .from_tensor_slices (self .words )
54+ self .primary = tf .data .Dataset .from_tensor_slices (self .primary )
55+ self .secondary = tf .data .Dataset .from_tensor_slices (self .secondary )
56+ self .labels = tf .data .Dataset .from_tensor_slices (self .labels )
57+ self .input_data = tf .data .Dataset .zip ((self .words ,self .primary ,self .secondary ,self .labels ))
58+ self .input_data = self .input_data .repeat (- 1 ).batch (self .batch_size )
59+ self .input_data_iterator = self .input_data .make_initializable_iterator ()
60+ self .iterator = self .input_data_iterator .get_next ()
6161if self .mode == 'train' :
62- self .start = 0
6362self .hidden_layer = tf .layers .dropout (self .get_hidden (),self .dropout_rate )
64- self .data_count = len (self .words )
6563self .saver = tf .train .Saver (max_to_keep = 100 )
6664else :
6765self .hidden_layer = self .get_hidden ()
68- # self.hidden_layer = tf.expand_dims(tf.layers.dropout(self.get_hidden(), self.dropout_rate), 0)
6966self .sess = tf .Session ()
7067self .saver = tf .train .Saver ().restore (self .sess ,self .model_path )
7168self .output_no_softmax = tf .matmul (self .hidden_layer ,self .full_connected_weight )+ self .full_connected_bias
@@ -118,27 +115,19 @@ def fit(self, epochs=50, interval=5):
118115with tf .Session ()as sess :
119116tf .global_variables_initializer ().run ()
120117sess .graph .finalize ()
121- start = 0
118+ sess .run (self .input_data_iterator .initializer )
119+
122120for i in range (1 ,epochs + 1 ):
123121print ('epoch:' + str (i ))
124122for j in range (self .data_count // self .batch_size ):
125- if start + self .batch_size < self .data_count :
126- indices = list (range (start ,start + self .batch_size ))
127- start += self .batch_size
128- else :
129- new_start = self .batch_size - self .data_count + start
130- indices = list (range (start ,self .data_count ))+ list (range (0 ,new_start ))
131- start = new_start
132- words ,primary ,secondary ,labels = sess .run ([self .input_words ,self .input_primary ,self .input_secondary ,
133- self .input_labels ],feed_dict = {self .input_indices :indices })
134- # words, primary, secondary, labels = self.load_batch()
123+ words ,primary ,secondary ,labels = sess .run (self .iterator )
135124character_embeds ,primary_embeds = sess .run ([self .character_lookup ,self .position_lookup ],
136125feed_dict = {self .input_characters :words ,
137126self .input_position :primary })
138127secondary_embeds = sess .run (self .position_lookup ,feed_dict = {self .input_position :secondary })
139- input = sess .run (self .emebd_concat ,feed_dict = {self .character_embed_holder :character_embeds ,
140- self .primary_embed_holder :primary_embeds ,
141- self .secondary_embed_holder :secondary_embeds })
128+ input = sess .run (self .embedded_concat ,feed_dict = {self .character_embed_holder :character_embeds ,
129+ self .primary_embed_holder :primary_embeds ,
130+ self .secondary_embed_holder :secondary_embeds })
142131# sess.run(self.train_model, feed_dict={self.input: input, self.input_relation: batch['label']})
143132sess .run (self .train_cross_entropy_model ,feed_dict = {self .input :input ,self .input_relation :labels })
144133if i % interval == 0 :
@@ -156,10 +145,10 @@ def predict(self, words, primary, secondary):
156145feed_dict = {self .input_characters :words ,
157146self .input_position :primary })
158147secondary_embeds = self .sess .run (self .position_lookup ,feed_dict = {self .input_position :secondary })
159- input = self .sess .run (self .emebd_concat ,feed_dict = {self .character_embed_holder :character_embeds ,
160- self .primary_embed_holder :primary_embeds ,
161- self .secondary_embed_holder :secondary_embeds })
162- output = self .sess .run (self .output ,feed_dict = {self .input :input })
148+ input_matrix = self .sess .run (self .embedded_concat ,feed_dict = {self .character_embed_holder :character_embeds ,
149+ self .primary_embed_holder :primary_embeds ,
150+ self .secondary_embed_holder :secondary_embeds })
151+ output = self .sess .run (self .output ,feed_dict = {self .input :input_matrix })
163152return np .argmax (output ,1 )
164153
165154def evaluate (self ):
@@ -190,23 +179,6 @@ def get_score(self, predict, true):
190179print (prec ,recall ,f1 )
191180return prec ,recall ,f1
192181
193- def load_batch (self ):
194- if self .start + self .batch_size > self .data_count :
195- new_start = self .start + self .batch_size - self .data_count
196- words = np .concatenate ([self .words [self .start :],self .words [:new_start ]])
197- primary = np .concatenate ([self .primary [self .start :],self .primary [:new_start ]])
198- secondary = np .concatenate ([self .secondary [self .start :],self .secondary [:new_start ]])
199- labels = np .concatenate ([self .labels [self .start :],self .labels [:new_start ]])
200- self .start = new_start
201- else :
202- new_start = self .start + self .batch_size
203- words = self .words [self .start :new_start ]
204- primary = self .primary [self .start :new_start ]
205- secondary = self .secondary [self .start :new_start ]
206- labels = self .labels [self .start :new_start ]
207- self .start = new_start
208- return words ,primary ,secondary ,labels
209-
210182def __weight_variable (self ,shape ,name ):
211183initial = tf .truncated_normal (shape ,stddev = 0.1 ,dtype = self .dtype )
212184return tf .Variable (initial ,name = name )