Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit746d2e4

Browse files
add relation extraction by cnn
1 parentce5b31e commit746d2e4

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

‎python/dnlp/core/re_cnn.py‎

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
classRECNN(RECNNBase):
1111
def__init__(self,config:RECNNConfig,dtype:type=tf.float32,dict_path:str='',mode:str='train',
1212
data_path:str='',relation_count:int=2,model_path:str=''):
13+
tf.reset_default_graph()
1314
RECNNBase.__init__(self,config,dict_path)
1415
self.dtype=dtype
1516
self.mode=mode
@@ -40,13 +41,16 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
4041
[None,self.batch_length,self.position_embed_size])
4142
self.emebd_concat=tf.expand_dims(
4243
tf.concat([self.character_embed_holder,self.primary_embed_holder,self.secondary_embed_holder],2),3)
44+
self.words,self.primary,self.secondary,self.labels=self.load_data()
4345
ifself.mode=='train':
4446
self.hidden_layer=tf.layers.dropout(self.get_hidden(),self.dropout_rate)
45-
self.words,self.primary,self.secondary,self.labels=self.load_data()
4647
self.start=0
4748
self.data_count=len(self.words)
49+
self.saver=tf.train.Saver(max_to_keep=100)
4850
else:
4951
self.hidden_layer=tf.expand_dims(tf.layers.dropout(self.get_hidden(),self.dropout_rate),0)
52+
self.sess=tf.Session()
53+
self.saver=tf.train.Saver().restore(self.sess,self.model_path)
5054
self.output_no_softmax=tf.matmul(self.hidden_layer,self.full_connected_weight)+self.full_connected_bias
5155
self.output=tf.nn.softmax(tf.matmul(self.hidden_layer,self.full_connected_weight)+self.full_connected_bias)
5256
self.params= [self.position_embedding,self.word_embedding,self.full_connected_weight,
@@ -60,7 +64,7 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
6064
# self.optimizer = tf.train.AdagradOptimizer(self.learning_rate)
6165
self.train_model=self.optimizer.minimize(self.loss)
6266
self.train_cross_entropy_model=self.optimizer.minimize(self.cross_entropy)
63-
self.saver=tf.train.Saver(max_to_keep=100)
67+
6468

6569
defget_conv_kernel(self):
6670
conv_kernel= []
@@ -114,6 +118,19 @@ def fit(self, epochs=100, interval=20):
114118
ifi%interval==0:
115119
model_name='../dnlp/models/re/{0}-{1}.ckpt'.format(i,'_'.join(map(str,self.window_size)))
116120
self.saver.save(sess,model_name)
121+
defpredict(self,words,primary,secondary):
122+
character_embeds,primary_embeds=self.sess.run([self.character_lookup,self.position_lookup],
123+
feed_dict={self.input_characters:words,
124+
self.input_position:primary})
125+
secondary_embeds=self.sess.run(self.position_lookup,feed_dict={self.input_position:secondary})
126+
input=self.sess.run(self.emebd_concat,feed_dict={self.character_embed_holder:character_embeds,
127+
self.primary_embed_holder:primary_embeds,
128+
self.secondary_embed_holder:secondary_embeds})
129+
output=self.sess.run(self.output,feed_dict={self.input:input})
130+
returnnp.argmax(output,1)
131+
132+
defevaluate(self):
133+
res=self.predict(self.words,self.primary,self.secondary)
117134

118135
defload_batch(self):
119136
ifself.start+self.batch_size>self.data_count:

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp