Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit63fdb13

Browse files
remove former train method , instead with tensorflow's log likehood, add bilstm
1 parent73001ee commit63fdb13

File tree

1 file changed

+41
-131
lines changed

1 file changed

+41
-131
lines changed

‎python/dnlp/core/dnn_crf.py‎

Lines changed: 41 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
classDnnCrf(DnnCrfBase):
1010
def__init__(self,*,config:DnnCrfConfig=None,data_path:str='',dtype:type=tf.float32,mode:str='train',
11-
train:str='ll',nn:str,model_path:str=''):
11+
predict:str='ll',nn:str,model_path:str=''):
1212
ifmodenotin ['train','predict']:
1313
raiseException('mode error')
1414
ifnnnotin ['mlp','rnn','lstm','bilstm','gru']:
@@ -27,113 +27,42 @@ def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: t
2727
ifmode=='train':
2828
self.input=tf.placeholder(tf.int32, [self.batch_size,self.batch_length,self.windows_size])
2929
self.real_indices=tf.placeholder(tf.int32, [self.batch_size,self.batch_length])
30-
self.seq_length=tf.placeholder(tf.int32, [self.batch_size])
3130
else:
3231
self.input=tf.placeholder(tf.int32, [None,self.windows_size])
3332

33+
self.seq_length=tf.placeholder(tf.int32, [None])
34+
3435
# 查找表层
3536
self.embedding_layer=self.get_embedding_layer()
3637
# 隐藏层
3738
ifnn=='mlp':
3839
self.hidden_layer=self.get_mlp_layer(tf.transpose(self.embedding_layer))
3940
elifnn=='lstm':
4041
self.hidden_layer=self.get_lstm_layer(self.embedding_layer)
42+
elifnn=='bilstm':
43+
self.hidden_layer=self.get_bilstm_layer(self.embedding_layer)
4144
elifnn=='gru':
42-
self.hidden_layer=self.get_gru_layer(tf.transpose(self.embedding_layer))
45+
self.hidden_layer=self.get_gru_layer(self.embedding_layer)
4346
else:
44-
self.hidden_layer=self.get_rnn_layer(tf.transpose(self.embedding_layer))
47+
self.hidden_layer=self.get_rnn_layer(self.embedding_layer)
4548
# 输出层
4649
self.output=self.get_output_layer(self.hidden_layer)
4750

4851
ifmode=='predict':
49-
self.output=tf.squeeze(tf.transpose(self.output),axis=2)
52+
ifpredict!='ll':
53+
self.output=tf.squeeze(tf.transpose(self.output),axis=2)
54+
self.seq,self.best_score=tf.contrib.crf.crf_decode(self.output,self.transition,self.seq_length)
5055
self.sess=tf.Session()
5156
self.sess.run(tf.global_variables_initializer())
5257
tf.train.Saver().restore(save_path=self.model_path,sess=self.sess)
53-
eliftrain=='ll':
54-
self.ll_loss,_=tf.contrib.crf.crf_log_likelihood(self.output,self.real_indices,self.seq_length,
55-
self.transition)
56-
self.optimizer=tf.train.AdagradOptimizer(self.learning_rate)
57-
self.train_ll=self.optimizer.minimize(-self.ll_loss)
5858
else:
59-
# 构建训练函数
60-
# 训练用placeholder
61-
self.ll_corr=tf.placeholder(tf.int32,shape=[None,3])
62-
self.ll_curr=tf.placeholder(tf.int32,shape=[None,3])
63-
self.trans_corr=tf.placeholder(tf.int32, [None,2])
64-
self.trans_curr=tf.placeholder(tf.int32, [None,2])
65-
self.trans_init_corr=tf.placeholder(tf.int32, [None,1])
66-
self.trans_init_curr=tf.placeholder(tf.int32, [None,1])
67-
# 损失函数
68-
self.loss,self.loss_with_init=self.get_loss()
59+
self.loss,_=tf.contrib.crf.crf_log_likelihood(self.output,self.real_indices,self.seq_length,
60+
self.transition)
6961
self.optimizer=tf.train.AdagradOptimizer(self.learning_rate)
70-
self.train=self.optimizer.minimize(self.loss)
71-
self.train_with_init=self.optimizer.minimize(self.loss_with_init)
62+
self.new_optimizer=tf.train.AdamOptimizer()
63+
self.train=self.optimizer.minimize(-self.loss)
7264

7365
deffit(self,epochs:int=100,interval:int=20):
74-
withtf.Session()assess:
75-
tf.global_variables_initializer().run()
76-
saver=tf.train.Saver(max_to_keep=100)
77-
forepochinrange(1,epochs+1):
78-
print('epoch:',epoch)
79-
for_inrange(self.batch_count):
80-
characters,labels,lengths=self.get_batch()
81-
self.fit_batch(characters,labels,lengths,sess)
82-
# if epoch % interval == 0:
83-
model_path='../dnlp/models/cws{0}.ckpt'.format(epoch)
84-
saver.save(sess,model_path)
85-
self.save_config(model_path)
86-
87-
deffit_batch(self,characters,labels,lengths,sess):
88-
scores=sess.run(self.output,feed_dict={self.input:characters})
89-
transition=self.transition.eval(session=sess)
90-
transition_init=self.transition_init.eval(session=sess)
91-
update_labels_pos=None
92-
update_labels_neg=None
93-
current_labels= []
94-
trans_pos_indices= []
95-
trans_neg_indices= []
96-
trans_init_pos_indices= []
97-
trans_init_neg_indices= []
98-
foriinrange(self.batch_size):
99-
current_label=self.viterbi(scores[:, :lengths[i],i],transition,transition_init)
100-
current_labels.append(current_label)
101-
diff_tag=np.subtract(labels[i, :lengths[i]],current_label)
102-
update_index=np.where(diff_tag!=0)[0]
103-
update_length=len(update_index)
104-
ifupdate_length==0:
105-
continue
106-
update_label_pos=np.stack([labels[i,update_index],update_index,i*np.ones([update_length])],axis=-1)
107-
update_label_neg=np.stack([current_label[update_index],update_index,i*np.ones([update_length])],axis=-1)
108-
ifupdate_labels_posisnotNone:
109-
np.concatenate((update_labels_pos,update_label_pos))
110-
np.concatenate((update_labels_neg,update_label_neg))
111-
else:
112-
update_labels_pos=update_label_pos
113-
update_labels_neg=update_label_neg
114-
115-
trans_pos_index,trans_neg_index,trans_init_pos,trans_init_neg,update_init=self.generate_transition_update_index(
116-
labels[i, :lengths[i]],current_labels[i])
117-
118-
trans_pos_indices.extend(trans_pos_index)
119-
trans_neg_indices.extend(trans_neg_index)
120-
121-
ifupdate_init:
122-
trans_init_pos_indices.append(trans_init_pos)
123-
trans_init_neg_indices.append(trans_init_neg)
124-
125-
ifupdate_labels_posisnotNoneandupdate_labels_negisnotNone:
126-
feed_dict= {self.input:characters,self.ll_curr:update_labels_neg,self.ll_corr:update_labels_pos,
127-
self.trans_curr:trans_neg_indices,self.trans_corr:trans_pos_indices}
128-
129-
ifnottrans_init_pos_indices:
130-
sess.run(self.train,feed_dict)
131-
else:
132-
feed_dict[self.trans_init_corr]=trans_init_pos_indices
133-
feed_dict[self.trans_init_curr]=trans_init_neg_indices
134-
sess.run(self.train_with_init,feed_dict)
135-
136-
deffit_ll(self,epochs:int=100,interval:int=20):
13766
withtf.Session()assess:
13867
tf.global_variables_initializer().run()
13968
saver=tf.train.Saver(max_to_keep=epochs)
@@ -143,44 +72,13 @@ def fit_ll(self, epochs: int = 100, interval: int = 20):
14372
characters,labels,lengths=self.get_batch()
14473
# scores = sess.run(self.output, feed_dict={self.input: characters})
14574
feed_dict= {self.input:characters,self.real_indices:labels,self.seq_length:lengths}
146-
sess.run(self.train_ll,feed_dict=feed_dict)
75+
sess.run(self.train,feed_dict=feed_dict)
14776
# self.fit_batch(characters, labels, lengths, sess)
14877
# if epoch % interval == 0:
14978
model_path='../dnlp/models/cws{0}.ckpt'.format(epoch)
15079
saver.save(sess,model_path)
15180
self.save_config(model_path)
15281

153-
deffit_batch_ll(self):
154-
pass
155-
156-
defgenerate_transition_update_index(self,correct_labels,current_labels):
157-
ifcorrect_labels.shape!=current_labels.shape:
158-
print('sequence length is not equal')
159-
returnNone
160-
161-
before_corr=correct_labels[0]
162-
before_curr=current_labels[0]
163-
update_init=False
164-
165-
trans_init_pos=None
166-
trans_init_neg=None
167-
trans_pos= []
168-
trans_neg= []
169-
170-
ifbefore_corr!=before_curr:
171-
trans_init_pos= [before_corr]
172-
trans_init_neg= [before_curr]
173-
update_init=True
174-
175-
for_, (corr_label,curr_label)inenumerate(zip(correct_labels[1:],current_labels[1:])):
176-
ifcorr_label!=curr_labelorbefore_corr!=before_curr:
177-
trans_pos.append([before_corr,corr_label])
178-
trans_neg.append([before_curr,curr_label])
179-
before_corr=corr_label
180-
before_curr=curr_label
181-
182-
returntrans_pos,trans_neg,trans_init_pos,trans_init_neg,update_init
183-
18482
defpredict(self,sentence:str,return_labels=False):
18583
ifself.mode!='predict':
18684
raiseException('mode is not allowed to predict')
@@ -194,6 +92,22 @@ def predict(self, sentence: str, return_labels=False):
19492
else:
19593
returnself.tags2words(sentence,labels),self.tag2sequences(labels)
19694

95+
defpredict_ll(self,sentence:str,return_labels=False):
96+
ifself.mode!='predict':
97+
raiseException('mode is not allowed to predict')
98+
99+
input=self.indices2input(self.sentence2indices(sentence))
100+
runner= [self.seq,self.best_score,self.output,self.transition]
101+
labels,best_score,output,trans=self.sess.run(runner,
102+
feed_dict={self.input:input,self.seq_length: [len(sentence)]})
103+
# print(output)
104+
# print(trans)
105+
labels=np.squeeze(labels)
106+
ifreturn_labels:
107+
returnself.tags2words(sentence,labels),self.tag2sequences(labels)
108+
else:
109+
returnself.tags2words(sentence,labels)
110+
197111
defget_embedding_layer(self)->tf.Tensor:
198112
embeddings=self.__get_variable([self.dict_size,self.embed_size],'embeddings')
199113
self.params.append(embeddings)
@@ -215,19 +129,27 @@ def get_rnn_layer(self, layer: tf.Tensor) -> tf.Tensor:
215129
rnn=tf.nn.rnn_cell.RNNCell(self.hidden_units)
216130
rnn_output,rnn_out_state=tf.nn.dynamic_rnn(rnn,layer,dtype=self.dtype)
217131
self.params+= [vforvintf.global_variables()ifv.name.startswith('rnn')]
218-
returntf.transpose(rnn_output)
132+
returnrnn_output
219133

220134
defget_lstm_layer(self,layer:tf.Tensor)->tf.Tensor:
221135
lstm=tf.nn.rnn_cell.LSTMCell(self.hidden_units)
222136
lstm_output,lstm_out_state=tf.nn.dynamic_rnn(lstm,layer,dtype=self.dtype)
223137
self.params+= [vforvintf.global_variables()ifv.name.startswith('rnn')]
224138
returnlstm_output
225139

140+
defget_bilstm_layer(self,layer:tf.Tensor)->tf.Tensor:
141+
lstm_fw=tf.nn.rnn_cell.LSTMCell(self.hidden_units//2)
142+
lstm_bw=tf.nn.rnn_cell.LSTMCell(self.hidden_units//2)
143+
bilstm_output,bilstm_output_state=tf.nn.bidirectional_dynamic_rnn(lstm_fw,lstm_bw,layer,self.seq_length,
144+
dtype=self.dtype)
145+
self.params+= [vforvintf.global_variables()ifv.name.startswith('rnn')]
146+
returntf.concat([bilstm_output[0],bilstm_output[1]],-1)
147+
226148
defget_gru_layer(self,layer:tf.Tensor)->tf.Tensor:
227149
gru=tf.nn.rnn_cell.GRUCell(self.hidden_units)
228150
gru_output,gru_out_state=tf.nn.dynamic_rnn(gru,layer,dtype=self.dtype)
229151
self.params+= [vforvintf.global_variables()ifv.name.startswith('rnn')]
230-
returntf.transpose(gru_output)
152+
returngru_output
231153

232154
defget_dropout_layer(self,layer:tf.Tensor)->tf.Tensor:
233155
returntf.layers.dropout(layer,self.dropout_rate)
@@ -238,17 +160,5 @@ def get_output_layer(self, layer: tf.Tensor) -> tf.Tensor:
238160
self.params+= [output_weight,output_bias]
239161
returntf.tensordot(layer,output_weight, [[2], [0]])+output_bias
240162

241-
defget_loss(self)-> (tf.Tensor,tf.Tensor):
242-
output_loss=tf.reduce_sum(tf.gather_nd(self.output,self.ll_curr)-tf.gather_nd(self.output,self.ll_corr))
243-
trans_loss=tf.gather_nd(self.transition,self.trans_curr)-tf.gather_nd(self.transition,self.trans_corr)
244-
trans_i_curr=tf.gather_nd(self.transition_init,self.trans_init_curr)
245-
trans_i_corr=tf.gather_nd(self.transition_init,self.trans_init_corr)
246-
trans_init_loss=tf.reduce_sum(trans_i_curr-trans_i_corr)
247-
loss=output_loss+trans_loss
248-
regu=tf.contrib.layers.apply_regularization(tf.contrib.layers.l2_regularizer(self.lam),self.params)
249-
l1=loss+regu
250-
l2=l1+trans_init_loss
251-
returnl1,l2
252-
253163
def__get_variable(self,size,name)->tf.Variable:
254164
returntf.Variable(tf.truncated_normal(size,stddev=1.0/math.sqrt(size[-1]),dtype=self.dtype),name=name)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp