Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitc7dce73

Browse files
Merge remote-tracking branch 'origin/emr' into emr
2 parents145e2d7 +dfd4f58 commitc7dce73

File tree

6 files changed

+412
-42
lines changed

6 files changed

+412
-42
lines changed

‎python/dnlp/core/re_cnn.py‎

Lines changed: 148 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,158 @@
1-
#-*- coding: UTF-8 -*-
1+
#-*- coding: UTF-8 -*-
22
importtensorflowastf
3+
importnumpyasnp
4+
importpickle
35
fromdnlp.core.re_cnn_baseimportRECNNBase
46
fromdnlp.configimportRECNNConfig
7+
fromdnlp.utils.constantimportBATCH_PAD,BATCH_PAD_VAL
8+
59

610
classRECNN(RECNNBase):
7-
def__init__(self,config:RECNNConfig,dtype:type=tf.float32,dict_path:str='',mode:str='train'):
8-
RECNNBase.__init__(self,config)
11+
def__init__(self,config:RECNNConfig,dtype:type=tf.float32,dict_path:str='',mode:str='train',
12+
data_path:str='',relation_count:int=2):
13+
RECNNBase.__init__(self,config,dict_path)
914
self.dtype=dtype
1015
self.mode=mode
11-
self.dictionary=self.read_dictionary(dict_path)
16+
self.data_path=data_path
17+
self.relation_count=relation_count
18+
19+
self.concat_embed_size=self.word_embed_size+2*self.position_embed_size
20+
self.input_characters=tf.placeholder(tf.int32, [None,self.batch_length])
21+
self.input_position=tf.placeholder(tf.int32, [None,self.batch_length])
22+
self.input=tf.placeholder(self.dtype, [None,self.batch_length,self.concat_embed_size,1])
23+
self.input_relation=tf.placeholder(self.dtype, [None,self.relation_count])
24+
self.position_embedding=self.__weight_variable([2*self.batch_length-1,self.position_embed_size],
25+
name='position_embedding')
26+
self.word_embedding=self.__weight_variable([self.words_size,self.word_embed_size],name='word_embedding')
27+
self.conv_kernel=self.get_conv_kernel()
28+
self.bias= [self.__weight_variable([self.filter_size],name='conv_bias')]*len(self.window_size)
29+
self.full_connected_weight=self.__weight_variable([self.filter_size*len(self.window_size),self.relation_count],
30+
name='full_connected_weight')
31+
self.full_connected_bias=self.__weight_variable([self.relation_count],name='full_connected_bias')
32+
self.position_lookup=tf.nn.embedding_lookup(self.position_embedding,self.input_position)
33+
self.character_lookup=tf.nn.embedding_lookup(self.word_embedding,self.input_characters)
34+
self.character_embed_holder=tf.placeholder(self.dtype,
35+
[None,self.batch_length,self.word_embed_size])
36+
self.primary_embed_holder=tf.placeholder(self.dtype,
37+
[None,self.batch_length,self.position_embed_size])
38+
self.secondary_embed_holder=tf.placeholder(self.dtype,
39+
[None,self.batch_length,self.position_embed_size])
40+
self.emebd_concat=tf.expand_dims(
41+
tf.concat([self.character_embed_holder,self.primary_embed_holder,self.secondary_embed_holder],2),3)
42+
ifself.mode=='train':
43+
self.hidden_layer=tf.layers.dropout(self.get_hidden(),self.dropout_rate)
44+
self.words,self.primary,self.secondary,self.labels=self.load_data()
45+
self.start=0
46+
self.data_count=len(self.words)
47+
else:
48+
self.hidden_layer=tf.expand_dims(tf.layers.dropout(self.get_hidden(),self.dropout_rate),0)
49+
self.output_no_softmax=tf.matmul(self.hidden_layer,self.full_connected_weight)+self.full_connected_bias
50+
self.output=tf.nn.softmax(tf.matmul(self.hidden_layer,self.full_connected_weight)+self.full_connected_bias)
51+
self.params= [self.position_embedding,self.word_embedding,self.full_connected_weight,
52+
self.full_connected_bias]+self.conv_kernel+self.bias
53+
self.regularization=tf.contrib.layers.apply_regularization(tf.contrib.layers.l2_regularizer(self.lam),
54+
self.params)
55+
self.loss=tf.reduce_sum(tf.square(self.output-self.input_relation))/self.batch_size+self.regularization
56+
self.cross_entropy=tf.nn.softmax_cross_entropy_with_logits(labels=self.input_relation,
57+
logits=self.output_no_softmax)+self.regularization
58+
self.optimizer=tf.train.GradientDescentOptimizer(self.learning_rate)
59+
# self.optimizer = tf.train.AdagradOptimizer(self.learning_rate)
60+
self.train_model=self.optimizer.minimize(self.loss)
61+
self.train_cross_entropy_model=self.optimizer.minimize(self.cross_entropy)
62+
self.saver=tf.train.Saver(max_to_keep=100)
63+
64+
defget_conv_kernel(self):
65+
conv_kernel= []
66+
forwinself.window_size:
67+
conv_kernel.append(self.__weight_variable([w,self.concat_embed_size,1,self.filter_size],name='conv_kernel'))
68+
returnconv_kernel
69+
70+
defget_max_pooling(self,x):
71+
max_pooling= []
72+
forwinself.window_size:
73+
max_pooling.append(self.max_pooling(x,w))
74+
returnmax_pooling
75+
76+
defget_hidden(self):
77+
h=None
78+
forw,conv,biasinzip(self.window_size,self.conv_kernel,self.bias):
79+
ifhisNone:
80+
h=tf.squeeze(self.max_pooling(tf.nn.relu(self.conv(conv)+bias),w))
81+
else:
82+
hh=tf.squeeze(self.max_pooling(tf.nn.relu(self.conv(conv)+bias),w))
83+
ifself.mode=='train':
84+
h=tf.concat([h,hh],1)
85+
else:
86+
h=tf.concat([h,hh],0)
87+
returnh
88+
89+
defconv(self,conv_kernel):
90+
returntf.nn.conv2d(self.input,conv_kernel,strides=[1,1,1,1],padding='VALID')
91+
92+
defmax_pooling(self,x,window_size):
93+
returntf.nn.max_pool(x,ksize=[1,self.batch_length-window_size+1,1,1],
94+
strides=[1,1,1,1],padding='VALID')
95+
96+
deffit(self,epochs=100,interval=20):
97+
withtf.Session()assess:
98+
tf.global_variables_initializer().run()
99+
sess.graph.finalize()
100+
foriinrange(1,epochs+1):
101+
print('epoch:'+str(i))
102+
for_inrange(self.data_count//self.batch_size):
103+
words,primary,secondary,labels=self.load_batch()
104+
character_embeds,primary_embeds=sess.run([self.character_lookup,self.position_lookup],
105+
feed_dict={self.input_characters:words,
106+
self.input_position:primary})
107+
secondary_embeds=sess.run(self.position_lookup,feed_dict={self.input_position:secondary})
108+
input=sess.run(self.emebd_concat,feed_dict={self.character_embed_holder:character_embeds,
109+
self.primary_embed_holder:primary_embeds,
110+
self.secondary_embed_holder:secondary_embeds})
111+
# sess.run(self.train_model, feed_dict={self.input: input, self.input_relation: batch['label']})
112+
sess.run(self.train_cross_entropy_model,feed_dict={self.input:input,self.input_relation:labels})
113+
ifi%interval==0:
114+
model_name='../dnlp/models/re/{0}-{1}.ckpt'.format(i,'_'.join(map(str,self.window_size)))
115+
self.saver.save(sess,model_name)
116+
117+
defload_batch(self):
118+
ifself.start+self.batch_size>self.data_count:
119+
new_start=self.start+self.batch_size-self.data_count
120+
words=np.concatenate([self.words[self.start:],self.words[:new_start]])
121+
primary=np.concatenate([self.primary[self.start:],self.primary[:new_start]])
122+
secondary=np.concatenate([self.secondary[self.start:],self.secondary[:new_start]])
123+
labels=np.concatenate([self.labels[self.start:],self.labels[:new_start]])
124+
self.start=new_start
125+
else:
126+
new_start=self.start+self.batch_size
127+
words=self.words[self.start:new_start]
128+
primary=self.primary[self.start:new_start]
129+
secondary=self.secondary[self.start:new_start]
130+
labels=self.labels[self.start:new_start]
131+
self.start=new_start
132+
returnwords,primary,secondary,labels
12133

134+
defload_data(self):
135+
primary= []
136+
secondary= []
137+
words= []
138+
labels= []
139+
withopen(self.data_path,'rb')asf:
140+
data=pickle.load(f)
141+
forsentenceindata:
142+
sentence_words=sentence['words']
143+
iflen(sentence_words)<self.batch_length:
144+
sentence_words+= [self.dictionary[BATCH_PAD]]* (self.batch_length-len(sentence_words))
145+
else:
146+
sentence_words=sentence_words[:self.batch_length]
147+
words.append(sentence_words)
148+
primary.append(np.arange(self.batch_length)-sentence['primary']+self.batch_length-1)
149+
secondary.append(np.arange(self.batch_length)-sentence['secondary']+self.batch_length-1)
150+
sentence_labels=np.zeros([self.relation_count])
151+
sentence_labels[sentence['type']]=1
152+
labels.append(sentence_labels)
153+
returnnp.array(words,np.int32),np.array(primary,np.int32),np.array(secondary,np.int32),np.array(labels,
154+
np.float32)
13155

14-
def__weight_variable(self,shape,name):
156+
def__weight_variable(self,shape,name):
15157
initial=tf.truncated_normal(shape,stddev=0.1,dtype=self.dtype)
16-
returntf.Variable(initial,name=name)
158+
returntf.Variable(initial,name=name)

‎python/dnlp/core/re_cnn_base.py‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding:utf-8 -*-
22
fromdnlp.configimportRECNNConfig
33
classRECNNBase(object):
4-
def__init__(self,config:RECNNConfig):
4+
def__init__(self,config:RECNNConfig,dict_path:str):
55
self.window_size=config.window_size
66
self.filter_size=config.filter_size
77
self.learning_rate=config.learning_rate
@@ -11,6 +11,8 @@ def __init__(self, config:RECNNConfig):
1111
self.position_embed_size=config.position_embed_size
1212
self.batch_length=config.batch_length
1313
self.batch_size=config.batch_size
14+
self.dictionary=self.read_dictionary(dict_path)
15+
self.words_size=len(self.dictionary)
1416

1517
defread_dictionary(self,dict_path):
1618
withopen(dict_path,encoding='utf-8')asf:

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp