Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit68b7b13

Browse files
modify input to constant
1 parentd34d342 commit68b7b13

File tree

3 files changed

+53
-28
lines changed

3 files changed

+53
-28
lines changed

‎python/dnlp/core/re_cnn.py‎

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
# -*- coding: UTF-8 -*-
22
importtensorflowastf
33
importnumpyasnp
4-
importpickle
54
fromcollectionsimportCounter
65
fromdnlp.core.re_cnn_baseimportRECNNBase
76
fromdnlp.configimportRECNNConfig
8-
fromdnlp.utils.constantimportBATCH_PAD,BATCH_PAD_VAL
7+
98

109

1110
classRECNN(RECNNBase):
@@ -23,6 +22,12 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
2322
self.remark=remark
2423

2524
self.concat_embed_size=self.word_embed_size+2*self.position_embed_size
25+
self.words,self.primary,self.secondary,self.labels=self.load_data()
26+
self.input_words=tf.constant(self.words)
27+
self.input_primary=tf.constant(self.primary)
28+
self.input_secondary=tf.constant(self.secondary)
29+
self.input_labels=tf.constant(self.labels)
30+
self.input_indices=tf.placeholder(tf.int32, [self.batch_size])
2631
self.input_characters=tf.placeholder(tf.int32, [None,self.batch_length])
2732
self.input_position=tf.placeholder(tf.int32, [None,self.batch_length])
2833
self.input=tf.placeholder(self.dtype, [None,self.batch_length,self.concat_embed_size,1])
@@ -39,6 +44,10 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
3944
self.full_connected_weight=self.__weight_variable([self.filter_size*len(self.window_size),self.relation_count],
4045
name='full_connected_weight')
4146
self.full_connected_bias=self.__weight_variable([self.relation_count],name='full_connected_bias')
47+
self.input_words_lookup=tf.nn.embedding_lookup(self.input_words,self.input_indices)
48+
self.input_primary_lookup=tf.nn.embedding_lookup(self.input_primary,self.input_indices)
49+
self.input_secondary_lookup=tf.nn.embedding_lookup(self.input_secondary,self.input_indices)
50+
self.input_labels_lookup=tf.nn.embedding_lookup(self.input_labels,self.input_indices)
4251
self.position_lookup=tf.nn.embedding_lookup(self.position_embedding,self.input_position)
4352
self.character_lookup=tf.nn.embedding_lookup(self.word_embedding,self.input_characters)
4453
self.character_embed_holder=tf.placeholder(self.dtype,
@@ -49,7 +58,8 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
4958
[None,self.batch_length,self.position_embed_size])
5059
self.emebd_concat=tf.expand_dims(
5160
tf.concat([self.character_embed_holder,self.primary_embed_holder,self.secondary_embed_holder],2),3)
52-
self.words,self.primary,self.secondary,self.labels=self.load_data()
61+
62+
5363

5464
ifself.mode=='train':
5565
self.start=0
@@ -111,10 +121,20 @@ def fit(self, epochs=50, interval=5):
111121
withtf.Session()assess:
112122
tf.global_variables_initializer().run()
113123
sess.graph.finalize()
124+
start=0
114125
foriinrange(1,epochs+1):
115126
print('epoch:'+str(i))
116-
for_inrange(self.data_count//self.batch_size):
117-
words,primary,secondary,labels=self.load_batch()
127+
forjinrange(self.data_count//self.batch_size):
128+
ifstart+self.batch_size<self.data_count:
129+
indices=list(range(start,start+self.batch_size))
130+
start+=self.batch_size
131+
else:
132+
new_start=self.batch_size-self.data_count+start
133+
indices=list(range(start,self.data_count))+list(range(0,new_start))
134+
start=new_start
135+
words,primary,secondary,labels=sess.run([self.input_words,self.input_primary,self.input_secondary,
136+
self.input_labels],feed_dict={self.input_indices:indices})
137+
# words, primary, secondary, labels = self.load_batch()
118138
character_embeds,primary_embeds=sess.run([self.character_lookup,self.position_lookup],
119139
feed_dict={self.input_characters:words,
120140
self.input_position:primary})
@@ -190,27 +210,7 @@ def load_batch(self):
190210
self.start=new_start
191211
returnwords,primary,secondary,labels
192212

193-
defload_data(self):
194-
primary= []
195-
secondary= []
196-
words= []
197-
labels= []
198-
withopen(self.data_path,'rb')asf:
199-
data=pickle.load(f)
200-
forsentenceindata:
201-
sentence_words=sentence['words']
202-
iflen(sentence_words)<self.batch_length:
203-
sentence_words+= [self.dictionary[BATCH_PAD]]* (self.batch_length-len(sentence_words))
204-
else:
205-
sentence_words=sentence_words[:self.batch_length]
206-
words.append(sentence_words)
207-
primary.append(np.arange(self.batch_length)-sentence['primary']+self.batch_length-1)
208-
secondary.append(np.arange(self.batch_length)-sentence['secondary']+self.batch_length-1)
209-
sentence_labels=np.zeros([self.relation_count])
210-
sentence_labels[sentence['type']]=1
211-
labels.append(sentence_labels)
212-
returnnp.array(words,np.int32),np.array(primary,np.int32),np.array(secondary,np.int32),np.array(labels,
213-
np.float32)
213+
214214

215215
def__weight_variable(self,shape,name):
216216
initial=tf.truncated_normal(shape,stddev=0.1,dtype=self.dtype)

‎python/dnlp/core/re_cnn_base.py‎

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
# -*- coding:utf-8 -*-
2+
importnumpyasnp
3+
importpickle
24
fromdnlp.configimportRECNNConfig
5+
fromdnlp.utils.constantimportBATCH_PAD,BATCH_PAD_VAL
6+
37
classRECNNBase(object):
4-
def__init__(self,config:RECNNConfig,dict_path:str):
8+
def__init__(self,config:RECNNConfig,dict_path:str,data_path:str=''):
59
self.window_size=config.window_size
610
self.filter_size=config.filter_size
711
self.learning_rate=config.learning_rate
@@ -24,3 +28,24 @@ def read_dictionary(self,dict_path):
2428

2529
returndictionary
2630

31+
defload_data(self):
32+
primary= []
33+
secondary= []
34+
words= []
35+
labels= []
36+
withopen(self.data_path,'rb')asf:
37+
data=pickle.load(f)
38+
forsentenceindata:
39+
sentence_words=sentence['words']
40+
iflen(sentence_words)<self.batch_length:
41+
sentence_words+= [self.dictionary[BATCH_PAD]]* (self.batch_length-len(sentence_words))
42+
else:
43+
sentence_words=sentence_words[:self.batch_length]
44+
words.append(sentence_words)
45+
primary.append(np.arange(self.batch_length)-sentence['primary']+self.batch_length-1)
46+
secondary.append(np.arange(self.batch_length)-sentence['secondary']+self.batch_length-1)
47+
sentence_labels=np.zeros([self.relation_count])
48+
sentence_labels[sentence['type']]=1
49+
labels.append(sentence_labels)
50+
returnnp.array(words,np.int32),np.array(primary,np.int32),np.array(secondary,np.int32),np.array(labels,
51+
np.float32)

‎python/scripts/cws_ner.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def export_cws(data, filename):
405405
elifargs.emr:
406406
# train_emr_old_method()
407407
# train_emr_cws()
408-
#train_emr_word_skipgram()
408+
train_emr_word_skipgram()
409409
train_emr_word_cbow()
410410
# train_emr_with_embeddings()
411411
# train_emr_ngram('mlp')

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp