1010class RECNN (RECNNBase ):
1111def __init__ (self ,config :RECNNConfig ,dtype :type = tf .float32 ,dict_path :str = '' ,mode :str = 'train' ,
1212data_path :str = '' ,relation_count :int = 2 ,model_path :str = '' ):
13+ tf .reset_default_graph ()
1314RECNNBase .__init__ (self ,config ,dict_path )
1415self .dtype = dtype
1516self .mode = mode
@@ -40,13 +41,16 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
4041 [None ,self .batch_length ,self .position_embed_size ])
4142self .emebd_concat = tf .expand_dims (
4243tf .concat ([self .character_embed_holder ,self .primary_embed_holder ,self .secondary_embed_holder ],2 ),3 )
44+ self .words ,self .primary ,self .secondary ,self .labels = self .load_data ()
4345if self .mode == 'train' :
4446self .hidden_layer = tf .layers .dropout (self .get_hidden (),self .dropout_rate )
45- self .words ,self .primary ,self .secondary ,self .labels = self .load_data ()
4647self .start = 0
4748self .data_count = len (self .words )
49+ self .saver = tf .train .Saver (max_to_keep = 100 )
4850else :
4951self .hidden_layer = tf .expand_dims (tf .layers .dropout (self .get_hidden (),self .dropout_rate ),0 )
52+ self .sess = tf .Session ()
53+ self .saver = tf .train .Saver ().restore (self .sess ,self .model_path )
5054self .output_no_softmax = tf .matmul (self .hidden_layer ,self .full_connected_weight )+ self .full_connected_bias
5155self .output = tf .nn .softmax (tf .matmul (self .hidden_layer ,self .full_connected_weight )+ self .full_connected_bias )
5256self .params = [self .position_embedding ,self .word_embedding ,self .full_connected_weight ,
@@ -60,7 +64,7 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
6064# self.optimizer = tf.train.AdagradOptimizer(self.learning_rate)
6165self .train_model = self .optimizer .minimize (self .loss )
6266self .train_cross_entropy_model = self .optimizer .minimize (self .cross_entropy )
63- self . saver = tf . train . Saver ( max_to_keep = 100 )
67+
6468
6569def get_conv_kernel (self ):
6670conv_kernel = []
@@ -114,6 +118,19 @@ def fit(self, epochs=100, interval=20):
114118if i % interval == 0 :
115119model_name = '../dnlp/models/re/{0}-{1}.ckpt' .format (i ,'_' .join (map (str ,self .window_size )))
116120self .saver .save (sess ,model_name )
121+ def predict (self ,words ,primary ,secondary ):
122+ character_embeds ,primary_embeds = self .sess .run ([self .character_lookup ,self .position_lookup ],
123+ feed_dict = {self .input_characters :words ,
124+ self .input_position :primary })
125+ secondary_embeds = self .sess .run (self .position_lookup ,feed_dict = {self .input_position :secondary })
126+ input = self .sess .run (self .emebd_concat ,feed_dict = {self .character_embed_holder :character_embeds ,
127+ self .primary_embed_holder :primary_embeds ,
128+ self .secondary_embed_holder :secondary_embeds })
129+ output = self .sess .run (self .output ,feed_dict = {self .input :input })
130+ return np .argmax (output ,1 )
131+
132+ def evaluate (self ):
133+ res = self .predict (self .words ,self .primary ,self .secondary )
117134
118135def load_batch (self ):
119136if self .start + self .batch_size > self .data_count :