Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit417c93c

Browse files
add some codes
1 parent2aadfc6 commit417c93c

File tree

4 files changed

+398
-11
lines changed

4 files changed

+398
-11
lines changed

‎python/dnlp/core/dnn_crf_emr.py‎

Lines changed: 260 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,260 @@
1-
#-*- coding: UTF-8 -*-
1+
# -*- coding: UTF-8 -*-
2+
importtensorflowastf
3+
importnumpyasnp
4+
importmath
5+
fromdnlp.core.dnn_crf_baseimportDnnCrfBase
6+
fromdnlp.configimportDnnCrfConfig
7+
8+
9+
classDnnCrfEmr(DnnCrfBase):
10+
def__init__(self,*,config:DnnCrfConfig=None,data_path:str='',dtype:type=tf.float32,task:str='ner',mode:str='train',
11+
train:str='',nn:str,model_path:str=''):
12+
ifmodenotin ['train','predict']:
13+
raiseException('mode error')
14+
ifnnnotin ['mlp','rnn','lstm','bilstm','gru']:
15+
raiseException('name of neural network entered is not supported')
16+
17+
DnnCrfBase.__init__(self,config,data_path,mode,model_path)
18+
self.dtype=dtype
19+
self.mode=mode
20+
self.nn=nn
21+
self.task=task
22+
23+
# 构建
24+
tf.reset_default_graph()
25+
self.transition=self.__get_variable([self.tags_count,self.tags_count],'transition')
26+
self.transition_init=self.__get_variable([self.tags_count],'transition_init')
27+
self.params= [self.transition,self.transition_init]
28+
# 输入层
29+
ifmode=='train':
30+
self.input=tf.placeholder(tf.int32, [self.batch_size,self.batch_length,self.windows_size])
31+
self.real_indices=tf.placeholder(tf.int32, [self.batch_size,self.batch_length])
32+
self.seq_length=tf.placeholder(tf.int32, [self.batch_size])
33+
else:
34+
self.input=tf.placeholder(tf.int32, [None,self.windows_size])
35+
36+
# 查找表层
37+
self.embedding_layer=self.get_embedding_layer()
38+
# 隐藏层
39+
ifnn=='mlp':
40+
self.hidden_layer=self.get_mlp_layer(tf.transpose(self.embedding_layer))
41+
elifnn=='lstm':
42+
self.hidden_layer=self.get_lstm_layer(tf.transpose(self.embedding_layer))
43+
elifnn=='gru':
44+
self.hidden_layer=self.get_gru_layer(tf.transpose(self.embedding_layer))
45+
else:
46+
self.hidden_layer=self.get_rnn_layer(tf.transpose(self.embedding_layer))
47+
# 输出层
48+
self.output=self.get_output_layer(self.hidden_layer)
49+
50+
ifmode=='predict':
51+
self.output=tf.squeeze(self.output,axis=2)
52+
self.sess=tf.Session()
53+
self.sess.run(tf.global_variables_initializer())
54+
tf.train.Saver().restore(save_path=self.model_path,sess=self.sess)
55+
eliftrain=='ll':
56+
self.ll_loss,_=tf.contrib.crf.crf_log_likelihood(self.output,self.real_indices,self.seq_length,
57+
self.transition)
58+
self.optimizer=tf.train.AdagradOptimizer(self.learning_rate)
59+
self.train_ll=self.optimizer.minimize(-self.ll_loss)
60+
else:
61+
# 构建训练函数
62+
# 训练用placeholder
63+
self.ll_corr=tf.placeholder(tf.int32,shape=[None,3])
64+
self.ll_curr=tf.placeholder(tf.int32,shape=[None,3])
65+
self.trans_corr=tf.placeholder(tf.int32, [None,2])
66+
self.trans_curr=tf.placeholder(tf.int32, [None,2])
67+
self.trans_init_corr=tf.placeholder(tf.int32, [None,1])
68+
self.trans_init_curr=tf.placeholder(tf.int32, [None,1])
69+
# 损失函数
70+
self.loss,self.loss_with_init=self.get_loss()
71+
self.optimizer=tf.train.AdagradOptimizer(self.learning_rate)
72+
self.train=self.optimizer.minimize(self.loss)
73+
self.train_with_init=self.optimizer.minimize(self.loss_with_init)
74+
75+
deffit(self,epochs:int=100,interval:int=20):
76+
withtf.Session()assess:
77+
tf.global_variables_initializer().run()
78+
saver=tf.train.Saver(max_to_keep=100)
79+
forepochinrange(1,epochs+1):
80+
print('epoch:',epoch)
81+
for_inrange(self.batch_count):
82+
characters,labels,lengths=self.get_batch()
83+
self.fit_batch(characters,labels,lengths,sess)
84+
# if epoch % interval == 0:
85+
model_path='../dnlp/models/emr_old/{0}-{1}.ckpt'.format(self.nn,epoch)
86+
saver.save(sess,model_path)
87+
self.save_config(model_path)
88+
89+
deffit_batch(self,characters,labels,lengths,sess):
90+
scores=sess.run(self.output,feed_dict={self.input:characters})
91+
transition=self.transition.eval(session=sess)
92+
transition_init=self.transition_init.eval(session=sess)
93+
update_labels_pos=None
94+
update_labels_neg=None
95+
current_labels= []
96+
trans_pos_indices= []
97+
trans_neg_indices= []
98+
trans_init_pos_indices= []
99+
trans_init_neg_indices= []
100+
foriinrange(self.batch_size):
101+
current_label=self.viterbi(scores[:, :lengths[i],i],transition,transition_init)
102+
current_labels.append(current_label)
103+
diff_tag=np.subtract(labels[i, :lengths[i]],current_label)
104+
update_index=np.where(diff_tag!=0)[0]
105+
update_length=len(update_index)
106+
ifupdate_length==0:
107+
continue
108+
update_label_pos=np.stack([labels[i,update_index],update_index,i*np.ones([update_length])],axis=-1)
109+
update_label_neg=np.stack([current_label[update_index],update_index,i*np.ones([update_length])],axis=-1)
110+
ifupdate_labels_posisnotNone:
111+
np.concatenate((update_labels_pos,update_label_pos))
112+
np.concatenate((update_labels_neg,update_label_neg))
113+
else:
114+
update_labels_pos=update_label_pos
115+
update_labels_neg=update_label_neg
116+
117+
trans_pos_index,trans_neg_index,trans_init_pos,trans_init_neg,update_init=self.generate_transition_update_index(
118+
labels[i, :lengths[i]],current_labels[i])
119+
120+
trans_pos_indices.extend(trans_pos_index)
121+
trans_neg_indices.extend(trans_neg_index)
122+
123+
ifupdate_init:
124+
trans_init_pos_indices.append(trans_init_pos)
125+
trans_init_neg_indices.append(trans_init_neg)
126+
127+
ifupdate_labels_posisnotNoneandupdate_labels_negisnotNone:
128+
feed_dict= {self.input:characters,self.ll_curr:update_labels_neg,self.ll_corr:update_labels_pos,
129+
self.trans_curr:trans_neg_indices,self.trans_corr:trans_pos_indices}
130+
131+
ifnottrans_init_pos_indices:
132+
sess.run(self.train,feed_dict)
133+
else:
134+
feed_dict[self.trans_init_corr]=trans_init_pos_indices
135+
feed_dict[self.trans_init_curr]=trans_init_neg_indices
136+
sess.run(self.train_with_init,feed_dict)
137+
138+
deffit_ll(self,epochs:int=100,interval:int=20):
139+
withtf.Session()assess:
140+
tf.global_variables_initializer().run()
141+
saver=tf.train.Saver(max_to_keep=epochs)
142+
forepochinrange(1,epochs+1):
143+
print('epoch:',epoch)
144+
for_inrange(self.batch_count):
145+
characters,labels,lengths=self.get_batch()
146+
# scores = sess.run(self.output, feed_dict={self.input: characters})
147+
feed_dict= {self.input:characters,self.real_indices:labels,self.seq_length:lengths}
148+
sess.run(self.train_ll,feed_dict=feed_dict)
149+
# self.fit_batch(characters, labels, lengths, sess)
150+
ifepoch%interval==0:
151+
model_path='../dnlp/models/emr_old/{0}-{1}.ckpt'.format(self.nn,epoch)
152+
saver.save(sess,model_path)
153+
self.save_config(model_path)
154+
155+
deffit_batch_ll(self):
156+
pass
157+
158+
defgenerate_transition_update_index(self,correct_labels,current_labels):
159+
ifcorrect_labels.shape!=current_labels.shape:
160+
print('sequence length is not equal')
161+
returnNone
162+
163+
before_corr=correct_labels[0]
164+
before_curr=current_labels[0]
165+
update_init=False
166+
167+
trans_init_pos=None
168+
trans_init_neg=None
169+
trans_pos= []
170+
trans_neg= []
171+
172+
ifbefore_corr!=before_curr:
173+
trans_init_pos= [before_corr]
174+
trans_init_neg= [before_curr]
175+
update_init=True
176+
177+
for_, (corr_label,curr_label)inenumerate(zip(correct_labels[1:],current_labels[1:])):
178+
ifcorr_label!=curr_labelorbefore_corr!=before_curr:
179+
trans_pos.append([before_corr,corr_label])
180+
trans_neg.append([before_curr,curr_label])
181+
before_corr=corr_label
182+
before_curr=curr_label
183+
184+
returntrans_pos,trans_neg,trans_init_pos,trans_init_neg,update_init
185+
186+
defpredict_ll(self,sentence:str,return_labels=False):
187+
ifself.mode!='predict':
188+
raiseException('mode is not allowed to predict')
189+
190+
input=self.indices2input(self.sentence2indices(sentence))
191+
runner= [self.output,self.transition,self.transition_init]
192+
output,trans,trans_init=self.sess.run(runner,feed_dict={self.input:input})
193+
labels=self.viterbi(output,trans,trans_init)
194+
ifself.task=='cws':
195+
result=self.tags2words(sentence,labels)
196+
else:
197+
result=self.tags2entities(sentence,labels)
198+
ifnotreturn_labels:
199+
returnresult
200+
else:
201+
returnresult,self.tag2sequences(labels)
202+
203+
defget_embedding_layer(self)->tf.Tensor:
204+
embeddings=self.__get_variable([self.dict_size,self.embed_size],'embeddings')
205+
self.params.append(embeddings)
206+
ifself.mode=='train':
207+
input_size= [self.batch_size,self.batch_length,self.concat_embed_size]
208+
layer=tf.reshape(tf.nn.embedding_lookup(embeddings,self.input),input_size)
209+
else:
210+
layer=tf.reshape(tf.nn.embedding_lookup(embeddings,self.input), [1,-1,self.concat_embed_size])
211+
returnlayer
212+
213+
defget_mlp_layer(self,layer:tf.Tensor)->tf.Tensor:
214+
hidden_weight=self.__get_variable([self.hidden_units,self.concat_embed_size],'hidden_weight')
215+
hidden_bias=self.__get_variable([self.hidden_units,1,1],'hidden_bias')
216+
self.params+= [hidden_weight,hidden_bias]
217+
layer=tf.sigmoid(tf.tensordot(hidden_weight,layer, [[1], [0]])+hidden_bias)
218+
returnlayer
219+
220+
defget_rnn_layer(self,layer:tf.Tensor)->tf.Tensor:
221+
rnn=tf.nn.rnn_cell.RNNCell(self.hidden_units)
222+
rnn_output,rnn_out_state=tf.nn.dynamic_rnn(rnn,layer,dtype=self.dtype)
223+
self.params+= [vforvintf.global_variables()ifv.name.startswith('rnn')]
224+
returntf.transpose(rnn_output)
225+
226+
defget_lstm_layer(self,layer:tf.Tensor)->tf.Tensor:
227+
lstm=tf.nn.rnn_cell.LSTMCell(self.hidden_units)
228+
lstm_output,lstm_out_state=tf.nn.dynamic_rnn(lstm,layer,dtype=self.dtype)
229+
self.params+= [vforvintf.global_variables()ifv.name.startswith('rnn')]
230+
returntf.transpose(lstm_output)
231+
232+
defget_gru_layer(self,layer:tf.Tensor)->tf.Tensor:
233+
gru=tf.nn.rnn_cell.GRUCell(self.hidden_units)
234+
gru_output,gru_out_state=tf.nn.dynamic_rnn(gru,layer,dtype=self.dtype)
235+
self.params+= [vforvintf.global_variables()ifv.name.startswith('rnn')]
236+
returntf.transpose(gru_output)
237+
238+
defget_dropout_layer(self,layer:tf.Tensor)->tf.Tensor:
239+
returntf.layers.dropout(layer,self.dropout_rate)
240+
241+
defget_output_layer(self,layer:tf.Tensor)->tf.Tensor:
242+
output_weight=self.__get_variable([self.tags_count,self.hidden_units],'output_weight')
243+
output_bias=self.__get_variable([self.tags_count,1,1],'output_bias')
244+
self.params+= [output_weight,output_bias]
245+
returntf.tensordot(output_weight,layer, [[1], [0]])+output_bias
246+
247+
defget_loss(self)-> (tf.Tensor,tf.Tensor):
248+
output_loss=tf.reduce_sum(tf.gather_nd(self.output,self.ll_curr)-tf.gather_nd(self.output,self.ll_corr))
249+
trans_loss=tf.gather_nd(self.transition,self.trans_curr)-tf.gather_nd(self.transition,self.trans_corr)
250+
trans_i_curr=tf.gather_nd(self.transition_init,self.trans_init_curr)
251+
trans_i_corr=tf.gather_nd(self.transition_init,self.trans_init_corr)
252+
trans_init_loss=tf.reduce_sum(trans_i_curr-trans_i_corr)
253+
loss=output_loss+trans_loss
254+
regu=tf.contrib.layers.apply_regularization(tf.contrib.layers.l2_regularizer(self.lam),self.params)
255+
l1=loss+regu
256+
l2=l1+trans_init_loss
257+
returnl1,l2
258+
259+
def__get_variable(self,size,name)->tf.Variable:
260+
returntf.Variable(tf.truncated_normal(size,stddev=1.0/math.sqrt(size[-1]),dtype=self.dtype),name=name)
Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
#-*- coding: UTF-8 -*-
22
importos
3+
importre
4+
RE_SAPCE=re.compile('[ ]+')
35
classProcessEMR(object):
46
def__init__(self,base_folder:str):
57
self.base_folder=base_folder
8+
self.files=self.get_files()
9+
self.read_annotations()
610

711
defget_files(self):
812
files=set()
@@ -11,4 +15,107 @@ def get_files(self):
1115
returnfiles
1216

1317
defread_annotations(self):
14-
pass
18+
forfileinself.files:
19+
filename=self.base_folder+file
20+
sentence_dict,periods=self.read_entities_in_single_file(filename+'.txt',filename+'.ann')
21+
sentence_words=self.read_cws_file(filename+'.cws',periods)
22+
23+
24+
defread_cws_file(self,cws_file,periods):
25+
withopen(cws_file,encoding='utf-8')asf:
26+
words=f.read().replace('\n','').split(' ')
27+
word_lens= [len(w)forwinwords]
28+
start_point= [0]+[sum(word_lens[:l])forlinrange(1,len(words)+1) ][:-1]
29+
word_index= {c:wforw,cinzip(start_point)}
30+
sentence_words= []
31+
fors,einzip(periods[:-1],periods[1:]):
32+
sentence_words.append(words[word_index[s]:word_index[e]])
33+
returnsentence_words
34+
35+
36+
defread_entities_in_single_file(self,raw_file,ann_file):
37+
data= {}
38+
withopen(raw_file,encoding='utf-8')asr:
39+
sentence=r.read()
40+
rn_indices= [m.start()forminre.finditer('\n',sentence)]
41+
spans_diff= {}
42+
iflen(rn_indices):
43+
spans=zip([-1]+rn_indices,rn_indices+ [len(sentence)+len(rn_indices)])
44+
fori, (before,curr)inenumerate(spans):
45+
spans_diff[(before+2,curr)]=i*2
46+
raw_sentence=sentence
47+
sentence=sentence.replace('\n','')
48+
49+
# periods = [m.start() for m in re.finditer('。', sentence)]
50+
periods= []
51+
sentence_len=len(sentence)
52+
last=0
53+
sentences= []
54+
fori,chinenumerate(sentence):
55+
ifch=='。':
56+
ifi<sentence_len-1andsentence[i+1]=='”':
57+
continue
58+
else:
59+
periods.append(i)
60+
sentences.append(sentence[last:i+1])
61+
last=i+1
62+
iflast!=len(sentence):
63+
sentences.append(sentence[last:sentence_len])
64+
period_spans= {}
65+
sentence_dict= {k: {'text':k}forkinsentences}
66+
67+
iflen(periods):
68+
fors,einzip([-1]+periods,periods+ [len(sentence)]):
69+
period_spans[(s+1,e+1)]=s+1
70+
71+
withopen(ann_file,encoding='utf-8')asa:
72+
entries=map(lambdal:l.strip().split(' '),a.read().replace('\t',' ').splitlines())
73+
74+
forentryinentries:
75+
id=entry[0]
76+
ifid.startswith('T'):
77+
start=int(entry[2])
78+
end=int(entry[3])
79+
text=entry[4]
80+
iflen(rn_indices):
81+
flag=False
82+
fors,einspans_diff:
83+
ifs<=startandend<=e:
84+
diff=spans_diff[(s,e)]
85+
start-=diff
86+
end-=diff
87+
flag=True
88+
break
89+
ifnotflag:
90+
print('a fucked world')
91+
ifsentence[start:end]!=text:
92+
print('=========')
93+
# print(end - start)
94+
# print(id)
95+
# print(ann_file)
96+
# print(sentence[start:end])
97+
# print(text)
98+
# print('fuck world')
99+
continue
100+
101+
iflen(period_spans):
102+
fors,einperiod_spans:
103+
ifs<=startandend<=e:
104+
new_sentence=sentence[s:e]
105+
ifnew_sentencenotinsentence_dict:
106+
print(ann_file)
107+
print('fuck aa')
108+
new_diff=period_spans[(s,e)]
109+
start-=new_diff
110+
end-=new_diff
111+
ifnew_sentence[start:end]!=text:
112+
print('fuck')
113+
entity= {'id':id,'start':start,'length':end-start,'text':text}
114+
entities=sentence_dict[new_sentence].get('entities')
115+
ifentitiesisnotNone:
116+
entities.append(entity)
117+
else:
118+
sentence_dict[new_sentence]['entities']= [entity]
119+
break
120+
121+
returnsentence_dict,periods

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp