Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit6d389a8

Browse files
2 parentsd2a926c +0095bca commit6d389a8

File tree

4 files changed

+41
-68
lines changed

4 files changed

+41
-68
lines changed

‎python/dnlp/core/re_cnn.py‎

Lines changed: 23 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
classRECNN(RECNNBase):
1010
def__init__(self,config:RECNNConfig,dtype:type=tf.float32,dict_path:str='',mode:str='train',
1111
data_path:str='',relation_count:int=2,model_path:str='',embedding_path:str='',
12-
remark:str=''):
12+
remark:str='',data_mode='prefetch'):
1313
tf.reset_default_graph()
1414
RECNNBase.__init__(self,config,dict_path)
1515
self.dtype=dtype
@@ -21,12 +21,6 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
2121
self.remark=remark
2222

2323
self.concat_embed_size=self.word_embed_size+2*self.position_embed_size
24-
self.words,self.primary,self.secondary,self.labels=self.load_data()
25-
self.input_words=tf.constant(self.words)
26-
self.input_primary=tf.constant(self.primary)
27-
self.input_secondary=tf.constant(self.secondary)
28-
self.input_labels=tf.constant(self.labels)
29-
self.input_indices=tf.placeholder(tf.int32, [self.batch_size])
3024
self.input_characters=tf.placeholder(tf.int32, [None,self.batch_length])
3125
self.input_position=tf.placeholder(tf.int32, [None,self.batch_length])
3226
self.input=tf.placeholder(self.dtype, [None,self.batch_length,self.concat_embed_size,1])
@@ -43,10 +37,6 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
4337
self.full_connected_weight=self.__weight_variable([self.filter_size*len(self.window_size),self.relation_count],
4438
name='full_connected_weight')
4539
self.full_connected_bias=self.__weight_variable([self.relation_count],name='full_connected_bias')
46-
self.input_words_lookup=tf.nn.embedding_lookup(self.input_words,self.input_indices)
47-
self.input_primary_lookup=tf.nn.embedding_lookup(self.input_primary,self.input_indices)
48-
self.input_secondary_lookup=tf.nn.embedding_lookup(self.input_secondary,self.input_indices)
49-
self.input_labels_lookup=tf.nn.embedding_lookup(self.input_labels,self.input_indices)
5040
self.position_lookup=tf.nn.embedding_lookup(self.position_embedding,self.input_position)
5141
self.character_lookup=tf.nn.embedding_lookup(self.word_embedding,self.input_characters)
5242
self.character_embed_holder=tf.placeholder(self.dtype,
@@ -55,17 +45,24 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
5545
[None,self.batch_length,self.position_embed_size])
5646
self.secondary_embed_holder=tf.placeholder(self.dtype,
5747
[None,self.batch_length,self.position_embed_size])
58-
self.emebd_concat=tf.expand_dims(
48+
self.embedded_concat=tf.expand_dims(
5949
tf.concat([self.character_embed_holder,self.primary_embed_holder,self.secondary_embed_holder],2),3)
60-
50+
ifdata_mode=='prefetch':
51+
self.words,self.primary,self.secondary,self.labels=self.load_data()
52+
self.data_count=len(self.words)
53+
self.words=tf.data.Dataset.from_tensor_slices(self.words)
54+
self.primary=tf.data.Dataset.from_tensor_slices(self.primary)
55+
self.secondary=tf.data.Dataset.from_tensor_slices(self.secondary)
56+
self.labels=tf.data.Dataset.from_tensor_slices(self.labels)
57+
self.input_data=tf.data.Dataset.zip((self.words,self.primary,self.secondary,self.labels))
58+
self.input_data=self.input_data.repeat(-1).batch(self.batch_size)
59+
self.input_data_iterator=self.input_data.make_initializable_iterator()
60+
self.iterator=self.input_data_iterator.get_next()
6161
ifself.mode=='train':
62-
self.start=0
6362
self.hidden_layer=tf.layers.dropout(self.get_hidden(),self.dropout_rate)
64-
self.data_count=len(self.words)
6563
self.saver=tf.train.Saver(max_to_keep=100)
6664
else:
6765
self.hidden_layer=self.get_hidden()
68-
# self.hidden_layer = tf.expand_dims(tf.layers.dropout(self.get_hidden(), self.dropout_rate), 0)
6966
self.sess=tf.Session()
7067
self.saver=tf.train.Saver().restore(self.sess,self.model_path)
7168
self.output_no_softmax=tf.matmul(self.hidden_layer,self.full_connected_weight)+self.full_connected_bias
@@ -118,27 +115,19 @@ def fit(self, epochs=50, interval=5):
118115
withtf.Session()assess:
119116
tf.global_variables_initializer().run()
120117
sess.graph.finalize()
121-
start=0
118+
sess.run(self.input_data_iterator.initializer)
119+
122120
foriinrange(1,epochs+1):
123121
print('epoch:'+str(i))
124122
forjinrange(self.data_count//self.batch_size):
125-
ifstart+self.batch_size<self.data_count:
126-
indices=list(range(start,start+self.batch_size))
127-
start+=self.batch_size
128-
else:
129-
new_start=self.batch_size-self.data_count+start
130-
indices=list(range(start,self.data_count))+list(range(0,new_start))
131-
start=new_start
132-
words,primary,secondary,labels=sess.run([self.input_words,self.input_primary,self.input_secondary,
133-
self.input_labels],feed_dict={self.input_indices:indices})
134-
# words, primary, secondary, labels = self.load_batch()
123+
words,primary,secondary,labels=sess.run(self.iterator)
135124
character_embeds,primary_embeds=sess.run([self.character_lookup,self.position_lookup],
136125
feed_dict={self.input_characters:words,
137126
self.input_position:primary})
138127
secondary_embeds=sess.run(self.position_lookup,feed_dict={self.input_position:secondary})
139-
input=sess.run(self.emebd_concat,feed_dict={self.character_embed_holder:character_embeds,
140-
self.primary_embed_holder:primary_embeds,
141-
self.secondary_embed_holder:secondary_embeds})
128+
input=sess.run(self.embedded_concat,feed_dict={self.character_embed_holder:character_embeds,
129+
self.primary_embed_holder:primary_embeds,
130+
self.secondary_embed_holder:secondary_embeds})
142131
# sess.run(self.train_model, feed_dict={self.input: input, self.input_relation: batch['label']})
143132
sess.run(self.train_cross_entropy_model,feed_dict={self.input:input,self.input_relation:labels})
144133
ifi%interval==0:
@@ -156,10 +145,10 @@ def predict(self, words, primary, secondary):
156145
feed_dict={self.input_characters:words,
157146
self.input_position:primary})
158147
secondary_embeds=self.sess.run(self.position_lookup,feed_dict={self.input_position:secondary})
159-
input=self.sess.run(self.emebd_concat,feed_dict={self.character_embed_holder:character_embeds,
160-
self.primary_embed_holder:primary_embeds,
161-
self.secondary_embed_holder:secondary_embeds})
162-
output=self.sess.run(self.output,feed_dict={self.input:input})
148+
input_matrix=self.sess.run(self.embedded_concat,feed_dict={self.character_embed_holder:character_embeds,
149+
self.primary_embed_holder:primary_embeds,
150+
self.secondary_embed_holder:secondary_embeds})
151+
output=self.sess.run(self.output,feed_dict={self.input:input_matrix})
163152
returnnp.argmax(output,1)
164153

165154
defevaluate(self):
@@ -190,23 +179,6 @@ def get_score(self, predict, true):
190179
print(prec,recall,f1)
191180
returnprec,recall,f1
192181

193-
defload_batch(self):
194-
ifself.start+self.batch_size>self.data_count:
195-
new_start=self.start+self.batch_size-self.data_count
196-
words=np.concatenate([self.words[self.start:],self.words[:new_start]])
197-
primary=np.concatenate([self.primary[self.start:],self.primary[:new_start]])
198-
secondary=np.concatenate([self.secondary[self.start:],self.secondary[:new_start]])
199-
labels=np.concatenate([self.labels[self.start:],self.labels[:new_start]])
200-
self.start=new_start
201-
else:
202-
new_start=self.start+self.batch_size
203-
words=self.words[self.start:new_start]
204-
primary=self.primary[self.start:new_start]
205-
secondary=self.secondary[self.start:new_start]
206-
labels=self.labels[self.start:new_start]
207-
self.start=new_start
208-
returnwords,primary,secondary,labels
209-
210182
def__weight_variable(self,shape,name):
211183
initial=tf.truncated_normal(shape,stddev=0.1,dtype=self.dtype)
212184
returntf.Variable(initial,name=name)

‎python/dnlp/core/re_cnn_base.py‎

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
fromdnlp.configimportRECNNConfig
55
fromdnlp.utils.constantimportBATCH_PAD,BATCH_PAD_VAL
66

7+
78
classRECNNBase(object):
8-
def__init__(self,config:RECNNConfig,dict_path:str,data_path:str=''):
9+
def__init__(self,config:RECNNConfig,dict_path:str,data_path:str=''):
910
self.window_size=config.window_size
1011
self.filter_size=config.filter_size
1112
self.learning_rate=config.learning_rate
@@ -18,8 +19,8 @@ def __init__(self, config:RECNNConfig,dict_path:str,data_path:str=''):
1819
self.dictionary=self.read_dictionary(dict_path)
1920
self.words_size=len(self.dictionary)
2021

21-
defread_dictionary(self,dict_path):
22-
withopen(dict_path,encoding='utf-8')asf:
22+
defread_dictionary(self,dict_path):
23+
withopen(dict_path,encoding='utf-8')asf:
2324
content=f.read().splitlines()
2425
dictionary= {}
2526
dict_arr=map(lambdaitem:item.split(' '),content)
@@ -47,5 +48,4 @@ def load_data(self):
4748
sentence_labels=np.zeros([self.relation_count])
4849
sentence_labels[sentence['type']]=1
4950
labels.append(sentence_labels)
50-
returnnp.array(words,np.int32),np.array(primary,np.int32),np.array(secondary,np.int32),np.array(labels,
51-
np.float32)
51+
returnnp.array(words,np.int32),np.array(primary,np.int32),np.array(secondary,np.int32),np.array(labels,np.float32)

‎python/dnlp/data_process/process_emr.py‎

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,7 @@ def generate_re_mutli_training_data(self):
204204
word_indices=self.map_to_indices(annotation['words'])
205205
fortrue_rel_nameinannotation['true_relations']:
206206
true_rel=annotation['true_relations'][true_rel_name]
207-
train_data.append({'words':word_indices,'primary':true_rel['primary'],'secondary':true_rel['secondary'],
208-
'type':self.relation_category_labels[true_rel['type']]})
207+
train_data.append({'words':word_indices,'primary':true_rel['primary'],'secondary':true_rel['secondary'],'type':self.relation_category_labels[true_rel['type']]})
209208
returntrain_data
210209

211210
defmap_to_indices(self,words):

‎python/scripts/rel.py‎

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,16 @@ def train_re_cnn():
2323
cbow_path=BASE_FOLDER+'emr_word_light_cbow.npy'
2424
forwinWINDOW_LIST:
2525
start=time.time()
26-
train_re_cnn_by_window(w,data_path_two_directed,embedding_path=cbow_path,remark='_cbow_directed')
27-
train_re_cnn_by_window(w,data_path_two_directed,embedding_path=embedding_path,remark='_skip_gram_directed')
28-
train_re_cnn_by_window(w,data_path_multi_directed,relation_count=28,embedding_path=cbow_path,
29-
remark='_cbow_directed')
30-
train_re_cnn_by_window(w,data_path_multi_directed,relation_count=28,embedding_path=embedding_path,
31-
remark='_skip_gram_directed')
26+
#train_re_cnn_by_window(w, data_path_two_directed, embedding_path=cbow_path, remark='_cbow_directed')
27+
#train_re_cnn_by_window(w, data_path_two_directed, embedding_path=embedding_path, remark='_skip_gram_directed')
28+
#train_re_cnn_by_window(w, data_path_multi_directed, relation_count=28, embedding_path=cbow_path,
29+
# remark='_cbow_directed')
30+
#train_re_cnn_by_window(w, data_path_multi_directed, relation_count=28, embedding_path=embedding_path,
31+
# remark='_skip_gram_directed')
3232
# train_re_cnn_by_window(w,data_path_two)
33-
train_re_cnn_by_window(w,data_path_multi_directed,28,remark='_directed')
3433
train_re_cnn_by_window(w,data_path_two_directed,remark='_directed')
34+
train_re_cnn_by_window(w,data_path_multi_directed,28,remark='_directed')
35+
3536
print(time.time()-start)
3637

3738

@@ -161,9 +162,10 @@ def test_re_cnn_with_embedding():
161162
# test_re_cnn()
162163
# test_re_cnn_by_window((2,),epoch=1,embedding_path=SKIP_GRAM_PATH,remark='_skip_gram')
163164
# test_re_cnn_by_window((2,), epoch=5, embedding_path=CBOW_PATH, remark='_cbow_directed')
164-
get_re_cnn_result()
165-
get_re_cnn_result('multi')
166-
# test_re_cnn(remark='_directed')
165+
# get_re_cnn_result()
166+
# get_re_cnn_result('multi')
167+
test_re_cnn_by_window((2,3,4),50,mode='two',relation_count=2,remark='_directed')
168+
test_re_cnn(remark='_directed')
167169
# test_re_cnn('multi')
168170
# test_re_cnn_with_embedding()
169171
# test_single_model((2, 3, 4), 1)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp