Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit1f5e99e

Browse files
add ner task mode, add dropout layer
1 parent3527d9d commit1f5e99e

File tree

1 file changed

+31
-20
lines changed

1 file changed

+31
-20
lines changed

‎python/dnlp/core/dnn_crf.py‎

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88

99
classDnnCrf(DnnCrfBase):
10-
def__init__(self,*,config:DnnCrfConfig=None,data_path:str='',dtype:type=tf.float32,mode:str='train',
10+
def__init__(self,*,config:DnnCrfConfig=None,task='cws',data_path:str='',dtype:type=tf.float32,
11+
mode:str='train',
1112
predict:str='ll',nn:str,model_path:str=''):
1213
ifmodenotin ['train','predict']:
1314
raiseException('mode error')
@@ -17,6 +18,8 @@ def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: t
1718
DnnCrfBase.__init__(self,config,data_path,mode,model_path)
1819
self.dtype=dtype
1920
self.mode=mode
21+
self.task=task
22+
self.nn=nn
2023

2124
# 构建
2225
tf.reset_default_graph()
@@ -34,9 +37,11 @@ def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: t
3437

3538
# 查找表层
3639
self.embedding_layer=self.get_embedding_layer()
40+
# 执行drpout
41+
self.embedding_layer=self.get_dropout_layer(self.embedding_layer)
3742
# 隐藏层
3843
ifnn=='mlp':
39-
self.hidden_layer=self.get_mlp_layer(tf.transpose(self.embedding_layer))
44+
self.hidden_layer=self.get_mlp_layer(self.embedding_layer)
4045
elifnn=='lstm':
4146
self.hidden_layer=self.get_lstm_layer(self.embedding_layer)
4247
elifnn=='bilstm':
@@ -62,22 +67,20 @@ def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: t
6267
self.new_optimizer=tf.train.AdamOptimizer()
6368
self.train=self.optimizer.minimize(-self.loss)
6469

65-
deffit(self,epochs:int=100,interval:int=20):
70+
deffit(self,epochs:int=50,interval:int=10):
6671
withtf.Session()assess:
6772
tf.global_variables_initializer().run()
6873
saver=tf.train.Saver(max_to_keep=epochs)
6974
forepochinrange(1,epochs+1):
7075
print('epoch:',epoch)
7176
for_inrange(self.batch_count):
7277
characters,labels,lengths=self.get_batch()
73-
# scores = sess.run(self.output, feed_dict={self.input: characters})
7478
feed_dict= {self.input:characters,self.real_indices:labels,self.seq_length:lengths}
7579
sess.run(self.train,feed_dict=feed_dict)
76-
# self.fit_batch(characters, labels, lengths, sess)
77-
# if epoch % interval == 0:
78-
model_path='../dnlp/models/cws{0}.ckpt'.format(epoch)
79-
saver.save(sess,model_path)
80-
self.save_config(model_path)
80+
ifepoch%interval==0:
81+
model_path='../dnlp/models/{0}-{1}-{2}.ckpt'.format(self.task,self.nn,epoch)
82+
saver.save(sess,model_path)
83+
self.save_config(model_path)
8184

8285
defpredict(self,sentence:str,return_labels=False):
8386
ifself.mode!='predict':
@@ -87,10 +90,14 @@ def predict(self, sentence: str, return_labels=False):
8790
runner= [self.output,self.transition,self.transition_init]
8891
output,trans,trans_init=self.sess.run(runner,feed_dict={self.input:input})
8992
labels=self.viterbi(output,trans,trans_init)
93+
ifself.task=='cws':
94+
result=self.tags2words(sentence,labels)
95+
else:
96+
result=self.tags2entities(sentence,labels)
9097
ifnotreturn_labels:
91-
returnself.tags2words(sentence,labels)
98+
returnresult
9299
else:
93-
returnself.tags2words(sentence,labels),self.tag2sequences(labels)
100+
returnresult,self.tag2sequences(labels)
94101

95102
defpredict_ll(self,sentence:str,return_labels=False):
96103
ifself.mode!='predict':
@@ -103,10 +110,14 @@ def predict_ll(self, sentence: str, return_labels=False):
103110
# print(output)
104111
# print(trans)
105112
labels=np.squeeze(labels)
113+
ifself.task=='cws':
114+
result=self.tags2words(sentence,labels)
115+
else:
116+
result=self.tags2entities(sentence,labels)
106117
ifreturn_labels:
107-
returnself.tags2words(sentence,labels),self.tag2sequences(labels)
118+
returnresult,self.tag2sequences(labels)
108119
else:
109-
returnself.tags2words(sentence,labels)
120+
returnresult
110121

111122
defget_embedding_layer(self)->tf.Tensor:
112123
embeddings=self.__get_variable([self.dict_size,self.embed_size],'embeddings')
@@ -122,28 +133,28 @@ def get_mlp_layer(self, layer: tf.Tensor) -> tf.Tensor:
122133
hidden_weight=self.__get_variable([self.hidden_units,self.concat_embed_size],'hidden_weight')
123134
hidden_bias=self.__get_variable([self.hidden_units,1,1],'hidden_bias')
124135
self.params+= [hidden_weight,hidden_bias]
125-
layer=tf.sigmoid(tf.tensordot(hidden_weight,layer, [[1], [0]])+hidden_bias)
126-
returnlayer
136+
layer=tf.sigmoid(tf.tensordot(hidden_weight,tf.transpose(layer), [[1], [0]])+hidden_bias)
137+
returntf.transpose(layer)
127138

128139
defget_rnn_layer(self,layer:tf.Tensor)->tf.Tensor:
129-
rnn=tf.nn.rnn_cell.RNNCell(self.hidden_units)
140+
rnn=tf.nn.rnn_cell.BasicRNNCell(self.hidden_units)
130141
rnn_output,rnn_out_state=tf.nn.dynamic_rnn(rnn,layer,dtype=self.dtype)
131142
self.params+= [vforvintf.global_variables()ifv.name.startswith('rnn')]
132143
returnrnn_output
133144

134145
defget_lstm_layer(self,layer:tf.Tensor)->tf.Tensor:
135-
lstm=tf.nn.rnn_cell.LSTMCell(self.hidden_units)
146+
lstm=tf.nn.rnn_cell.BasicLSTMCell(self.hidden_units)
136147
lstm_output,lstm_out_state=tf.nn.dynamic_rnn(lstm,layer,dtype=self.dtype)
137148
self.params+= [vforvintf.global_variables()ifv.name.startswith('rnn')]
138149
returnlstm_output
139150

140151
defget_bilstm_layer(self,layer:tf.Tensor)->tf.Tensor:
141-
lstm_fw=tf.nn.rnn_cell.LSTMCell(self.hidden_units//2)
142-
lstm_bw=tf.nn.rnn_cell.LSTMCell(self.hidden_units//2)
152+
lstm_fw=tf.nn.rnn_cell.BasicLSTMCell(self.hidden_units//2)
153+
lstm_bw=tf.nn.rnn_cell.BasicLSTMCell(self.hidden_units//2)
143154
bilstm_output,bilstm_output_state=tf.nn.bidirectional_dynamic_rnn(lstm_fw,lstm_bw,layer,self.seq_length,
144155
dtype=self.dtype)
145156
self.params+= [vforvintf.global_variables()ifv.name.startswith('rnn')]
146-
returntf.concat([bilstm_output[0],bilstm_output[1]],-1)
157+
returntf.concat([bilstm_output[0],bilstm_output[1]],-1)
147158

148159
defget_gru_layer(self,layer:tf.Tensor)->tf.Tensor:
149160
gru=tf.nn.rnn_cell.GRUCell(self.hidden_units)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp