22import tensorflow as tf
33import numpy as np
44import pickle
5+ from collections import Counter
56from dnlp .core .re_cnn_base import RECNNBase
67from dnlp .config import RECNNConfig
78from dnlp .utils .constant import BATCH_PAD ,BATCH_PAD_VAL
89
910
1011class RECNN (RECNNBase ):
1112def __init__ (self ,config :RECNNConfig ,dtype :type = tf .float32 ,dict_path :str = '' ,mode :str = 'train' ,
12- data_path :str = '' ,relation_count :int = 2 ,model_path :str = '' ):
13+ data_path :str = '' ,relation_count :int = 2 ,model_path :str = '' ,embedding_path :str = '' ,
14+ remark :str = '' ):
1315tf .reset_default_graph ()
1416RECNNBase .__init__ (self ,config ,dict_path )
1517self .dtype = dtype
1618self .mode = mode
1719self .data_path = data_path
1820self .model_path = model_path
1921self .relation_count = relation_count
22+ self .embedding_path = embedding_path
23+ self .remark = remark
2024
2125self .concat_embed_size = self .word_embed_size + 2 * self .position_embed_size
2226self .input_characters = tf .placeholder (tf .int32 , [None ,self .batch_length ])
2327self .input_position = tf .placeholder (tf .int32 , [None ,self .batch_length ])
2428self .input = tf .placeholder (self .dtype , [None ,self .batch_length ,self .concat_embed_size ,1 ])
2529self .input_relation = tf .placeholder (self .dtype , [None ,self .relation_count ])
26- self .position_embedding = self .__weight_variable ([2 * self .batch_length - 1 ,self .position_embed_size ],
30+ self .position_embedding = self .__weight_variable ([2 * self .batch_length - 1 ,self .position_embed_size ],
2731name = 'position_embedding' )
28- self .word_embedding = self .__weight_variable ([self .words_size ,self .word_embed_size ],name = 'word_embedding' )
32+ if self .embedding_path :
33+ self .word_embedding = tf .Variable (np .load (self .embedding_path ),dtype = self .dtype ,name = 'word_embedding' ,
34+ trainable = True )
35+ else :
36+ self .word_embedding = self .__weight_variable ([self .words_size ,self .word_embed_size ],name = 'word_embedding' )
2937self .conv_kernel = self .get_conv_kernel ()
3038self .bias = [self .__weight_variable ([self .filter_size ],name = 'conv_bias' )]* len (self .window_size )
3139self .full_connected_weight = self .__weight_variable ([self .filter_size * len (self .window_size ),self .relation_count ],
@@ -42,13 +50,15 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
4250self .emebd_concat = tf .expand_dims (
4351tf .concat ([self .character_embed_holder ,self .primary_embed_holder ,self .secondary_embed_holder ],2 ),3 )
4452self .words ,self .primary ,self .secondary ,self .labels = self .load_data ()
53+
4554if self .mode == 'train' :
46- self .hidden_layer = tf .layers .dropout (self .get_hidden (),self .dropout_rate )
4755self .start = 0
56+ self .hidden_layer = tf .layers .dropout (self .get_hidden (),self .dropout_rate )
4857self .data_count = len (self .words )
4958self .saver = tf .train .Saver (max_to_keep = 100 )
5059else :
51- self .hidden_layer = tf .expand_dims (tf .layers .dropout (self .get_hidden (),self .dropout_rate ),0 )
60+ self .hidden_layer = self .get_hidden ()
61+ # self.hidden_layer = tf.expand_dims(tf.layers.dropout(self.get_hidden(), self.dropout_rate), 0)
5262self .sess = tf .Session ()
5363self .saver = tf .train .Saver ().restore (self .sess ,self .model_path )
5464self .output_no_softmax = tf .matmul (self .hidden_layer ,self .full_connected_weight )+ self .full_connected_bias
@@ -60,12 +70,11 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
6070self .loss = tf .reduce_sum (tf .square (self .output - self .input_relation ))/ self .batch_size + self .regularization
6171self .cross_entropy = tf .nn .softmax_cross_entropy_with_logits (labels = self .input_relation ,
6272logits = self .output_no_softmax )+ self .regularization
63- self .optimizer = tf .train .GradientDescentOptimizer (self .learning_rate )
64- # self.optimizer = tf.train.AdagradOptimizer(self.learning_rate)
73+ # self.optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
74+ self .optimizer = tf .train .AdagradOptimizer (self .learning_rate )
6575self .train_model = self .optimizer .minimize (self .loss )
6676self .train_cross_entropy_model = self .optimizer .minimize (self .cross_entropy )
6777
68-
6978def get_conv_kernel (self ):
7079conv_kernel = []
7180for w in self .window_size :
@@ -85,10 +94,10 @@ def get_hidden(self):
8594h = tf .squeeze (self .max_pooling (tf .nn .relu (self .conv (conv )+ bias ),w ))
8695else :
8796hh = tf .squeeze (self .max_pooling (tf .nn .relu (self .conv (conv )+ bias ),w ))
88- if self .mode == 'train' :
89- h = tf .concat ([h ,hh ],1 )
90- else :
91- h = tf .concat ([h ,hh ],0 )
97+ # if self.mode == 'train':
98+ h = tf .concat ([h ,hh ],1 )
99+ # else:
100+ # h = tf.concat([h, hh], 0)
92101return h
93102
94103def conv (self ,conv_kernel ):
@@ -98,7 +107,7 @@ def max_pooling(self, x, window_size):
98107return tf .nn .max_pool (x ,ksize = [1 ,self .batch_length - window_size + 1 ,1 ,1 ],
99108strides = [1 ,1 ,1 ,1 ],padding = 'VALID' )
100109
101- def fit (self ,epochs = 100 ,interval = 20 ):
110+ def fit (self ,epochs = 40 ,interval = 5 ):
102111with tf .Session ()as sess :
103112tf .global_variables_initializer ().run ()
104113sess .graph .finalize ()
@@ -116,21 +125,53 @@ def fit(self, epochs=100, interval=20):
116125# sess.run(self.train_model, feed_dict={self.input: input, self.input_relation: batch['label']})
117126sess .run (self .train_cross_entropy_model ,feed_dict = {self .input :input ,self .input_relation :labels })
118127if i % interval == 0 :
119- model_name = '../dnlp/models/re/{0}-{1}.ckpt' .format (i ,'_' .join (map (str ,self .window_size )))
128+ if self .relation_count == 2 :
129+ model_name = '../dnlp/models/re_{2}/{0}-{1}{3}.ckpt' .format (i ,'_' .join (map (str ,self .window_size )),
130+ 'two' ,self .remark )
131+ else :
132+ model_name = '../dnlp/models/re_{2}/{0}-{1}{3}.ckpt' .format (i ,'_' .join (map (str ,self .window_size )),
133+ 'multi' ,self .remark )
134+
120135self .saver .save (sess ,model_name )
121- def predict (self ,words ,primary ,secondary ):
136+
137+ def predict (self ,words ,primary ,secondary ):
122138character_embeds ,primary_embeds = self .sess .run ([self .character_lookup ,self .position_lookup ],
123- feed_dict = {self .input_characters :words ,
124- self .input_position :primary })
139+ feed_dict = {self .input_characters :words ,
140+ self .input_position :primary })
125141secondary_embeds = self .sess .run (self .position_lookup ,feed_dict = {self .input_position :secondary })
126142input = self .sess .run (self .emebd_concat ,feed_dict = {self .character_embed_holder :character_embeds ,
127- self .primary_embed_holder :primary_embeds ,
128- self .secondary_embed_holder :secondary_embeds })
143+ self .primary_embed_holder :primary_embeds ,
144+ self .secondary_embed_holder :secondary_embeds })
129145output = self .sess .run (self .output ,feed_dict = {self .input :input })
130146return np .argmax (output ,1 )
131147
132148def evaluate (self ):
133- res = self .predict (self .words ,self .primary ,self .secondary )
149+ res = self .predict (self .words ,self .primary ,self .secondary )
150+ res_count = Counter (res )[1 ]
151+ target = np .argmax (self .labels ,1 )
152+ target_count = Counter (target )[1 ]
153+ correct_number = Counter (np .array (res )- target )
154+ print (correct_number )
155+ return self .get_score (np .array (res ),target )
156+
157+ def get_score (self ,predict ,true ):
158+ types = Counter (predict ).keys ()
159+ corr_count = []
160+ true_count = []
161+ pred_count = []
162+
163+ for t in types :
164+ corr_count .append (len ([v for v ,c in zip (predict - t ,predict - true )if v == 0 and c == 0 ]))
165+ true_count .append (len ([te for te in true if te == t ]))
166+ pred_count .append (len ([pd for pd in predict if pd == t ]))
167+
168+ precs = [c / p for c ,p in zip (corr_count ,pred_count )if p != 0 and c != 0 ]
169+ recalls = [c / r for c ,r in zip (corr_count ,true_count )if r != 0 and c != 0 ]
170+ prec = sum (precs )/ len (precs )
171+ recall = sum (recalls )/ len (recalls )
172+ f1 = 2 * prec * recall / (prec + recall )
173+ print (prec ,recall ,f1 )
174+ return prec ,recall ,f1
134175
135176def load_batch (self ):
136177if self .start + self .batch_size > self .data_count :
@@ -163,8 +204,8 @@ def load_data(self):
163204else :
164205sentence_words = sentence_words [:self .batch_length ]
165206words .append (sentence_words )
166- primary .append (np .arange (self .batch_length )- sentence ['primary' ]+ self .batch_length - 1 )
167- secondary .append (np .arange (self .batch_length )- sentence ['secondary' ]+ self .batch_length - 1 )
207+ primary .append (np .arange (self .batch_length )- sentence ['primary' ]+ self .batch_length - 1 )
208+ secondary .append (np .arange (self .batch_length )- sentence ['secondary' ]+ self .batch_length - 1 )
168209sentence_labels = np .zeros ([self .relation_count ])
169210sentence_labels [sentence ['type' ]]= 1
170211labels .append (sentence_labels )