Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit09a15ba

Browse files
remove init with constant and native python, add dataset api
1 parent1dabce4 commit09a15ba

File tree

2 files changed

+28
-56
lines changed

2 files changed

+28
-56
lines changed

‎python/dnlp/core/re_cnn.py‎

Lines changed: 23 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
classRECNN(RECNNBase):
1010
def__init__(self,config:RECNNConfig,dtype:type=tf.float32,dict_path:str='',mode:str='train',
1111
data_path:str='',relation_count:int=2,model_path:str='',embedding_path:str='',
12-
remark:str=''):
12+
remark:str='',data_mode='prefetch'):
1313
tf.reset_default_graph()
1414
RECNNBase.__init__(self,config,dict_path)
1515
self.dtype=dtype
@@ -21,12 +21,6 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
2121
self.remark=remark
2222

2323
self.concat_embed_size=self.word_embed_size+2*self.position_embed_size
24-
self.words,self.primary,self.secondary,self.labels=self.load_data()
25-
self.input_words=tf.constant(self.words)
26-
self.input_primary=tf.constant(self.primary)
27-
self.input_secondary=tf.constant(self.secondary)
28-
self.input_labels=tf.constant(self.labels)
29-
self.input_indices=tf.placeholder(tf.int32, [self.batch_size])
3024
self.input_characters=tf.placeholder(tf.int32, [None,self.batch_length])
3125
self.input_position=tf.placeholder(tf.int32, [None,self.batch_length])
3226
self.input=tf.placeholder(self.dtype, [None,self.batch_length,self.concat_embed_size,1])
@@ -43,10 +37,6 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
4337
self.full_connected_weight=self.__weight_variable([self.filter_size*len(self.window_size),self.relation_count],
4438
name='full_connected_weight')
4539
self.full_connected_bias=self.__weight_variable([self.relation_count],name='full_connected_bias')
46-
self.input_words_lookup=tf.nn.embedding_lookup(self.input_words,self.input_indices)
47-
self.input_primary_lookup=tf.nn.embedding_lookup(self.input_primary,self.input_indices)
48-
self.input_secondary_lookup=tf.nn.embedding_lookup(self.input_secondary,self.input_indices)
49-
self.input_labels_lookup=tf.nn.embedding_lookup(self.input_labels,self.input_indices)
5040
self.position_lookup=tf.nn.embedding_lookup(self.position_embedding,self.input_position)
5141
self.character_lookup=tf.nn.embedding_lookup(self.word_embedding,self.input_characters)
5242
self.character_embed_holder=tf.placeholder(self.dtype,
@@ -55,17 +45,24 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
5545
[None,self.batch_length,self.position_embed_size])
5646
self.secondary_embed_holder=tf.placeholder(self.dtype,
5747
[None,self.batch_length,self.position_embed_size])
58-
self.emebd_concat=tf.expand_dims(
48+
self.embedded_concat=tf.expand_dims(
5949
tf.concat([self.character_embed_holder,self.primary_embed_holder,self.secondary_embed_holder],2),3)
60-
50+
ifdata_mode=='prefetch':
51+
self.words,self.primary,self.secondary,self.labels=self.load_data()
52+
self.data_count=len(self.words)
53+
self.words=tf.data.Dataset.from_tensor_slices(self.words)
54+
self.primary=tf.data.Dataset.from_tensor_slices(self.primary)
55+
self.secondary=tf.data.Dataset.from_tensor_slices(self.secondary)
56+
self.labels=tf.data.Dataset.from_tensor_slices(self.labels)
57+
self.input_data=tf.data.Dataset.zip((self.words,self.primary,self.secondary,self.labels))
58+
self.input_data=self.input_data.repeat(-1).batch(self.batch_size)
59+
self.input_data_iterator=self.input_data.make_initializable_iterator()
60+
self.iterator=self.input_data_iterator.get_next()
6161
ifself.mode=='train':
62-
self.start=0
6362
self.hidden_layer=tf.layers.dropout(self.get_hidden(),self.dropout_rate)
64-
self.data_count=len(self.words)
6563
self.saver=tf.train.Saver(max_to_keep=100)
6664
else:
6765
self.hidden_layer=self.get_hidden()
68-
# self.hidden_layer = tf.expand_dims(tf.layers.dropout(self.get_hidden(), self.dropout_rate), 0)
6966
self.sess=tf.Session()
7067
self.saver=tf.train.Saver().restore(self.sess,self.model_path)
7168
self.output_no_softmax=tf.matmul(self.hidden_layer,self.full_connected_weight)+self.full_connected_bias
@@ -118,27 +115,19 @@ def fit(self, epochs=50, interval=5):
118115
withtf.Session()assess:
119116
tf.global_variables_initializer().run()
120117
sess.graph.finalize()
121-
start=0
118+
sess.run(self.input_data_iterator.initializer)
119+
122120
foriinrange(1,epochs+1):
123121
print('epoch:'+str(i))
124122
forjinrange(self.data_count//self.batch_size):
125-
ifstart+self.batch_size<self.data_count:
126-
indices=list(range(start,start+self.batch_size))
127-
start+=self.batch_size
128-
else:
129-
new_start=self.batch_size-self.data_count+start
130-
indices=list(range(start,self.data_count))+list(range(0,new_start))
131-
start=new_start
132-
words,primary,secondary,labels=sess.run([self.input_words,self.input_primary,self.input_secondary,
133-
self.input_labels],feed_dict={self.input_indices:indices})
134-
# words, primary, secondary, labels = self.load_batch()
123+
words,primary,secondary,labels=sess.run(self.iterator)
135124
character_embeds,primary_embeds=sess.run([self.character_lookup,self.position_lookup],
136125
feed_dict={self.input_characters:words,
137126
self.input_position:primary})
138127
secondary_embeds=sess.run(self.position_lookup,feed_dict={self.input_position:secondary})
139-
input=sess.run(self.emebd_concat,feed_dict={self.character_embed_holder:character_embeds,
140-
self.primary_embed_holder:primary_embeds,
141-
self.secondary_embed_holder:secondary_embeds})
128+
input=sess.run(self.embedded_concat,feed_dict={self.character_embed_holder:character_embeds,
129+
self.primary_embed_holder:primary_embeds,
130+
self.secondary_embed_holder:secondary_embeds})
142131
# sess.run(self.train_model, feed_dict={self.input: input, self.input_relation: batch['label']})
143132
sess.run(self.train_cross_entropy_model,feed_dict={self.input:input,self.input_relation:labels})
144133
ifi%interval==0:
@@ -156,10 +145,10 @@ def predict(self, words, primary, secondary):
156145
feed_dict={self.input_characters:words,
157146
self.input_position:primary})
158147
secondary_embeds=self.sess.run(self.position_lookup,feed_dict={self.input_position:secondary})
159-
input=self.sess.run(self.emebd_concat,feed_dict={self.character_embed_holder:character_embeds,
160-
self.primary_embed_holder:primary_embeds,
161-
self.secondary_embed_holder:secondary_embeds})
162-
output=self.sess.run(self.output,feed_dict={self.input:input})
148+
input_matrix=self.sess.run(self.embedded_concat,feed_dict={self.character_embed_holder:character_embeds,
149+
self.primary_embed_holder:primary_embeds,
150+
self.secondary_embed_holder:secondary_embeds})
151+
output=self.sess.run(self.output,feed_dict={self.input:input_matrix})
163152
returnnp.argmax(output,1)
164153

165154
defevaluate(self):
@@ -190,23 +179,6 @@ def get_score(self, predict, true):
190179
print(prec,recall,f1)
191180
returnprec,recall,f1
192181

193-
defload_batch(self):
194-
ifself.start+self.batch_size>self.data_count:
195-
new_start=self.start+self.batch_size-self.data_count
196-
words=np.concatenate([self.words[self.start:],self.words[:new_start]])
197-
primary=np.concatenate([self.primary[self.start:],self.primary[:new_start]])
198-
secondary=np.concatenate([self.secondary[self.start:],self.secondary[:new_start]])
199-
labels=np.concatenate([self.labels[self.start:],self.labels[:new_start]])
200-
self.start=new_start
201-
else:
202-
new_start=self.start+self.batch_size
203-
words=self.words[self.start:new_start]
204-
primary=self.primary[self.start:new_start]
205-
secondary=self.secondary[self.start:new_start]
206-
labels=self.labels[self.start:new_start]
207-
self.start=new_start
208-
returnwords,primary,secondary,labels
209-
210182
def__weight_variable(self,shape,name):
211183
initial=tf.truncated_normal(shape,stddev=0.1,dtype=self.dtype)
212184
returntf.Variable(initial,name=name)

‎python/dnlp/core/re_cnn_base.py‎

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
fromdnlp.configimportRECNNConfig
55
fromdnlp.utils.constantimportBATCH_PAD,BATCH_PAD_VAL
66

7+
78
classRECNNBase(object):
8-
def__init__(self,config:RECNNConfig,dict_path:str,data_path:str=''):
9+
def__init__(self,config:RECNNConfig,dict_path:str,data_path:str=''):
910
self.window_size=config.window_size
1011
self.filter_size=config.filter_size
1112
self.learning_rate=config.learning_rate
@@ -18,8 +19,8 @@ def __init__(self, config:RECNNConfig,dict_path:str,data_path:str=''):
1819
self.dictionary=self.read_dictionary(dict_path)
1920
self.words_size=len(self.dictionary)
2021

21-
defread_dictionary(self,dict_path):
22-
withopen(dict_path,encoding='utf-8')asf:
22+
defread_dictionary(self,dict_path):
23+
withopen(dict_path,encoding='utf-8')asf:
2324
content=f.read().splitlines()
2425
dictionary= {}
2526
dict_arr=map(lambdaitem:item.split(' '),content)
@@ -47,5 +48,4 @@ def load_data(self):
4748
sentence_labels=np.zeros([self.relation_count])
4849
sentence_labels[sentence['type']]=1
4950
labels.append(sentence_labels)
50-
returnnp.array(words,np.int32),np.array(primary,np.int32),np.array(secondary,np.int32),np.array(labels,
51-
np.float32)
51+
returnnp.array(words,np.int32),np.array(primary,np.int32),np.array(secondary,np.int32),np.array(labels,np.float32)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp