Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitbb93e22

Browse files
Merge remote-tracking branch 'origin/develop' into develop
2 parents641b54f +1f5e99e commitbb93e22

File tree

1 file changed

+26
-15
lines changed

1 file changed

+26
-15
lines changed

‎python/dnlp/core/dnn_crf.py‎

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88

99
classDnnCrf(DnnCrfBase):
10-
def__init__(self,*,config:DnnCrfConfig=None,data_path:str='',dtype:type=tf.float32,mode:str='train',
10+
def__init__(self,*,config:DnnCrfConfig=None,task='cws',data_path:str='',dtype:type=tf.float32,
11+
mode:str='train',
1112
predict:str='ll',nn:str,model_path:str=''):
1213
ifmodenotin ['train','predict']:
1314
raiseException('mode error')
@@ -17,6 +18,8 @@ def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: t
1718
DnnCrfBase.__init__(self,config,data_path,mode,model_path)
1819
self.dtype=dtype
1920
self.mode=mode
21+
self.task=task
22+
self.nn=nn
2023

2124
# 构建
2225
tf.reset_default_graph()
@@ -34,9 +37,11 @@ def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: t
3437

3538
# 查找表层
3639
self.embedding_layer=self.get_embedding_layer()
40+
# 执行drpout
41+
self.embedding_layer=self.get_dropout_layer(self.embedding_layer)
3742
# 隐藏层
3843
ifnn=='mlp':
39-
self.hidden_layer=self.get_mlp_layer(tf.transpose(self.embedding_layer))
44+
self.hidden_layer=self.get_mlp_layer(self.embedding_layer)
4045
elifnn=='lstm':
4146
self.hidden_layer=self.get_lstm_layer(self.embedding_layer)
4247
elifnn=='bilstm':
@@ -63,22 +68,20 @@ def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: t
6368
self.new_optimizer=tf.train.AdamOptimizer()
6469
self.train=self.optimizer.minimize(-self.loss)
6570

66-
deffit(self,epochs:int=100,interval:int=20):
71+
deffit(self,epochs:int=50,interval:int=10):
6772
withtf.Session()assess:
6873
tf.global_variables_initializer().run()
6974
saver=tf.train.Saver(max_to_keep=epochs)
7075
forepochinrange(1,epochs+1):
7176
print('epoch:',epoch)
7277
for_inrange(self.batch_count):
7378
characters,labels,lengths=self.get_batch()
74-
# scores = sess.run(self.output, feed_dict={self.input: characters})
7579
feed_dict= {self.input:characters,self.real_indices:labels,self.seq_length:lengths}
7680
sess.run(self.train,feed_dict=feed_dict)
77-
# self.fit_batch(characters, labels, lengths, sess)
78-
# if epoch % interval == 0:
79-
model_path='../dnlp/models/cws{0}.ckpt'.format(epoch)
80-
saver.save(sess,model_path)
81-
self.save_config(model_path)
81+
ifepoch%interval==0:
82+
model_path='../dnlp/models/{0}-{1}-{2}.ckpt'.format(self.task,self.nn,epoch)
83+
saver.save(sess,model_path)
84+
self.save_config(model_path)
8285

8386
defpredict(self,sentence:str,return_labels=False):
8487
ifself.mode!='predict':
@@ -88,10 +91,14 @@ def predict(self, sentence: str, return_labels=False):
8891
runner= [self.output,self.transition,self.transition_init]
8992
output,trans,trans_init=self.sess.run(runner,feed_dict={self.input:input})
9093
labels=self.viterbi(output,trans,trans_init)
94+
ifself.task=='cws':
95+
result=self.tags2words(sentence,labels)
96+
else:
97+
result=self.tags2entities(sentence,labels)
9198
ifnotreturn_labels:
92-
returnself.tags2words(sentence,labels)
99+
returnresult
93100
else:
94-
returnself.tags2words(sentence,labels),self.tag2sequences(labels)
101+
returnresult,self.tag2sequences(labels)
95102

96103
defpredict_ll(self,sentence:str,return_labels=False):
97104
ifself.mode!='predict':
@@ -104,10 +111,14 @@ def predict_ll(self, sentence: str, return_labels=False):
104111
# print(output)
105112
# print(trans)
106113
labels=np.squeeze(labels)
114+
ifself.task=='cws':
115+
result=self.tags2words(sentence,labels)
116+
else:
117+
result=self.tags2entities(sentence,labels)
107118
ifreturn_labels:
108-
returnself.tags2words(sentence,labels),self.tag2sequences(labels)
119+
returnresult,self.tag2sequences(labels)
109120
else:
110-
returnself.tags2words(sentence,labels)
121+
returnresult
111122

112123
defget_embedding_layer(self)->tf.Tensor:
113124
# embeddings = self.__get_variable([self.dict_size, self.embed_size], 'embeddings')
@@ -125,8 +136,8 @@ def get_mlp_layer(self, layer: tf.Tensor) -> tf.Tensor:
125136
hidden_weight=self.__get_variable([self.hidden_units,self.concat_embed_size],'hidden_weight')
126137
hidden_bias=self.__get_variable([self.hidden_units,1,1],'hidden_bias')
127138
self.params+= [hidden_weight,hidden_bias]
128-
layer=tf.sigmoid(tf.tensordot(hidden_weight,layer, [[1], [0]])+hidden_bias)
129-
returnlayer
139+
layer=tf.sigmoid(tf.tensordot(hidden_weight,tf.transpose(layer), [[1], [0]])+hidden_bias)
140+
returntf.transpose(layer)
130141

131142
defget_rnn_layer(self,layer:tf.Tensor)->tf.Tensor:
132143
rnn=tf.nn.rnn_cell.BasicRNNCell(self.hidden_units)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp