Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit145e2d7

Browse files
add some codes
1 parent417c93c commit145e2d7

File tree

3 files changed

+16
-12
lines changed

3 files changed

+16
-12
lines changed

‎python/dnlp/core/dnn_crf_emr.py‎

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: t
2929
ifmode=='train':
3030
self.input=tf.placeholder(tf.int32, [self.batch_size,self.batch_length,self.windows_size])
3131
self.real_indices=tf.placeholder(tf.int32, [self.batch_size,self.batch_length])
32-
self.seq_length=tf.placeholder(tf.int32, [self.batch_size])
32+
self.seq_length=tf.placeholder(tf.int32, [None])
3333
else:
3434
self.input=tf.placeholder(tf.int32, [None,self.windows_size])
3535

@@ -48,7 +48,7 @@ def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: t
4848
self.output=self.get_output_layer(self.hidden_layer)
4949

5050
ifmode=='predict':
51-
self.output=tf.squeeze(self.output,axis=2)
51+
self.output=tf.squeeze(self.output,axis=1)
5252
self.sess=tf.Session()
5353
self.sess.run(tf.global_variables_initializer())
5454
tf.train.Saver().restore(save_path=self.model_path,sess=self.sess)
@@ -81,10 +81,10 @@ def fit(self, epochs: int = 100, interval: int = 20):
8181
for_inrange(self.batch_count):
8282
characters,labels,lengths=self.get_batch()
8383
self.fit_batch(characters,labels,lengths,sess)
84-
#if epoch % interval == 0:
85-
model_path='../dnlp/models/emr_old/{0}-{1}.ckpt'.format(self.nn,epoch)
86-
saver.save(sess,model_path)
87-
self.save_config(model_path)
84+
ifepoch%interval==0:
85+
model_path='../dnlp/models/emr_old/{0}-{1}.ckpt'.format(self.nn,epoch)
86+
saver.save(sess,model_path)
87+
self.save_config(model_path)
8888

8989
deffit_batch(self,characters,labels,lengths,sess):
9090
scores=sess.run(self.output,feed_dict={self.input:characters})
@@ -224,7 +224,7 @@ def get_rnn_layer(self, layer: tf.Tensor) -> tf.Tensor:
224224
returntf.transpose(rnn_output)
225225

226226
defget_lstm_layer(self,layer:tf.Tensor)->tf.Tensor:
227-
lstm=tf.nn.rnn_cell.LSTMCell(self.hidden_units)
227+
lstm=tf.nn.rnn_cell.BasicLSTMCell(self.hidden_units)
228228
lstm_output,lstm_out_state=tf.nn.dynamic_rnn(lstm,layer,dtype=self.dtype)
229229
self.params+= [vforvintf.global_variables()ifv.name.startswith('rnn')]
230230
returntf.transpose(lstm_output)

‎python/scripts/cws_ner.py‎

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,13 @@ def test_emr_ngram():
6969
deftrain_emr_old_method():
7070
data_path='../dnlp/data/emr/emr_training.pickle'
7171
config=DnnCrfConfig()
72-
mlpcrf=DnnCrfEmr(config=config,task='ner',data_path=data_path,nn='mlp')
72+
mlpcrf=DnnCrfEmr(config=config,task='ner',data_path=data_path,nn='lstm')
7373
mlpcrf.fit(interval=1)
74+
7475
deftest_emr_old_method():
75-
model_path='../dnlp/models/emr_old/mlp-1.ckpt'
76+
model_path='../dnlp/models/emr_old/lstm-2.ckpt'
7677
config=DnnCrfConfig()
77-
mlpcrf=DnnCrfEmr(config=config,task='ner',mode='predict',model_path=model_path,nn='mlp')
78+
mlpcrf=DnnCrfEmr(config=config,task='ner',mode='predict',model_path=model_path,nn='lstm')
7879

7980
evaluate_ner(mlpcrf,'../dnlp/data/emr/emr_test.pickle')
8081

‎python/setup.py‎

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
author_email='supercoderhawk@gmail.com',
1212
description='deep learning-based natural language processing lib',
1313
install_requires=[
14-
"tensorflow >= 1.3.0"
15-
]
14+
],
15+
extra_require= {
16+
'tf':'tensorflow >= 1.3.0',
17+
'tf-gpu':'tensorflow_gpu >= 1.3.0'
18+
}
1619
)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp