Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit2aadfc6

Browse files
Merge branches 'develop' and 'emr' of github.com:supercoderhawk/DeepLearning_NLP into emr
# Conflicts:#python/dnlp/core/re_cnn.py#python/dnlp/core/re_cnn_base.pyadd old method
2 parents971f093 +fe60bba commit2aadfc6

File tree

11 files changed

+186
-41
lines changed

11 files changed

+186
-41
lines changed

‎python/dnlp/config/__init__.py‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
#-*- coding: UTF-8 -*-
1+
#-*- coding: UTF-8 -*-
2+
fromdnlp.config.sequence_labeling_configimportDnnCrfConfig
3+
fromdnlp.config.re_configimportRECNNConfig

‎python/dnlp/config/re_config.py‎

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# -*- coding:utf-8 -*-
2+
3+
classRECNNConfig(object):
4+
def__init__(self,window_size:tuple=(3,),filter_size:int=150,learning_rate:float=0.4,dropout_rate:float=0.5,
5+
lam:float=1e-4,word_embed_size:int=300,position_embed_size:int=50,batch_length:int=85,
6+
batch_size:int=50):
7+
self.__window_size=window_size
8+
self.__filter_size=filter_size
9+
self.__learning_rate=learning_rate
10+
self.__dropout_rate=dropout_rate
11+
self.__lam=lam
12+
self.__word_embed_size=word_embed_size
13+
self.__position_embed_size=position_embed_size
14+
self.__batch_length=batch_length
15+
self.__batch_size=batch_size
16+
17+
@property
18+
defwindow_size(self):
19+
returnself.__window_size
20+
21+
@property
22+
deffilter_size(self):
23+
returnself.__filter_size
24+
25+
@property
26+
deflearning_rate(self):
27+
returnself.__learning_rate
28+
29+
@property
30+
defdropout_rate(self):
31+
returnself.__dropout_rate
32+
33+
@property
34+
deflam(self):
35+
returnself.__lam
36+
37+
@property
38+
defword_embed_size(self):
39+
returnself.__word_embed_size
40+
41+
@property
42+
defposition_embed_size(self):
43+
returnself.__position_embed_size
44+
45+
@property
46+
defbatch_length(self):
47+
returnself.__batch_length
48+
49+
@property
50+
defbatch_size(self):
51+
returnself.__batch_size
File renamed without changes.

‎python/dnlp/core/dnn_crf.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
importmath
55
importos
66
fromdnlp.core.dnn_crf_baseimportDnnCrfBase
7-
fromdnlp.config.configimportDnnCrfConfig
7+
fromdnlp.configimportDnnCrfConfig
88

99

1010
classDnnCrf(DnnCrfBase):
@@ -152,7 +152,7 @@ def predict_ll(self, sentence: str, return_labels=False):
152152

153153
defget_embedding_layer(self)->tf.Tensor:
154154
ifself.embedding_path:
155-
embeddings=tf.Variable(np.load(self.embedding_path),trainable=False,name='embeddings')
155+
embeddings=tf.Variable(np.load(self.embedding_path),trainable=True,name='embeddings')
156156
else:
157157
embeddings=self.__get_variable([self.dict_size,self.embed_size],'embeddings')
158158
self.params.append(embeddings)

‎python/dnlp/core/dnn_crf_base.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: UTF-8 -*-
22
importnumpyasnp
33
importpickle
4-
fromdnlp.config.configimportDnnCrfConfig
4+
fromdnlp.configimportDnnCrfConfig
55
fromdnlp.utils.constantimportBATCH_PAD,UNK,STRT_VAL,END_VAL,TAG_OTHER,TAG_BEGIN,TAG_INSIDE,TAG_SINGLE
66

77

‎python/dnlp/core/dnn_crf_emr.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#-*- coding: UTF-8 -*-

‎python/dnlp/core/re_cnn.py‎

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
#-*- coding: UTF-8 -*-
22
importtensorflowastf
3-
importpickle
3+
fromdnlp.core.re_cnn_baseimportRECNNBase
4+
fromdnlp.configimportRECNNConfig
45

5-
classRECNN(object):
6-
def__init__(self):
7-
pass
6+
classRECNN(RECNNBase):
7+
def__init__(self,config:RECNNConfig,dtype:type=tf.float32,dict_path:str='',mode:str='train'):
8+
RECNNBase.__init__(self,config)
9+
self.dtype=dtype
10+
self.mode=mode
11+
self.dictionary=self.read_dictionary(dict_path)
12+
13+
14+
def__weight_variable(self,shape,name):
15+
initial=tf.truncated_normal(shape,stddev=0.1,dtype=self.dtype)
16+
returntf.Variable(initial,name=name)

‎python/dnlp/core/re_cnn_base.py‎

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,24 @@
1-
#-*- coding: UTF-8 -*-
1+
# -*- coding:utf-8 -*-
2+
fromdnlp.configimportRECNNConfig
3+
classRECNNBase(object):
4+
def__init__(self,config:RECNNConfig):
5+
self.window_size=config.window_size
6+
self.filter_size=config.filter_size
7+
self.learning_rate=config.learning_rate
8+
self.dropout_rate=config.dropout_rate
9+
self.lam=config.lam
10+
self.word_embed_size=config.word_embed_size
11+
self.position_embed_size=config.position_embed_size
12+
self.batch_length=config.batch_length
13+
self.batch_size=config.batch_size
14+
15+
defread_dictionary(self,dict_path):
16+
withopen(dict_path,encoding='utf-8')asf:
17+
content=f.read().splitlines()
18+
dictionary= {}
19+
dict_arr=map(lambdaitem:item.split(' '),content)
20+
for_,dict_iteminenumerate(dict_arr):
21+
dictionary[dict_item[0]]=int(dict_item[1])
22+
23+
returndictionary
24+
Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,14 @@
1-
#-*- coding: UTF-8 -*-
1+
#-*- coding: UTF-8 -*-
2+
importos
3+
classProcessEMR(object):
4+
def__init__(self,base_folder:str):
5+
self.base_folder=base_folder
6+
7+
defget_files(self):
8+
files=set()
9+
forlinos.listdir(self.base_folder):
10+
files.add(os.path.splitext(l)[0])
11+
returnfiles
12+
13+
defread_annotations(self):
14+
pass

‎python/scripts/cws_ner.py‎

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: UTF-8 -*-
22
importsys
33
importargparse
4-
fromdnlp.config.configimportDnnCrfConfig
4+
fromdnlp.configimportDnnCrfConfig
55
fromdnlp.core.dnn_crfimportDnnCrf
66
fromdnlp.core.skip_gramimportSkipGram
77
fromdnlp.utils.evaluationimportevaluate_cws,evaluate_ner
@@ -44,10 +44,26 @@ def train_emr_ngram():
4444

4545

4646
deftest_emr_ngram():
47-
model_path='../dnlp/models/ner-lstm-bi_bigram-10.ckpt'
47+
bi_bigram_model_path='../dnlp/models/ner-lstm-bi_bigram-50.ckpt'
4848
config_bi_bigram=DnnCrfConfig(skip_left=1,skip_right=1)
49-
mlpcrf_bi_bigram=DnnCrf(model_path=model_path,config=config_bi_bigram,mode='predict',task='ner',nn='lstm')
49+
mlpcrf_bi_bigram=DnnCrf(model_path=bi_bigram_model_path,config=config_bi_bigram,mode='predict',task='ner',
50+
nn='lstm')
5051
evaluate_ner(mlpcrf_bi_bigram,'../dnlp/data/emr/emr_test.pickle')
52+
left_bigram_model_path='../dnlp/models/ner-lstm-left_bigram-50.ckpt'
53+
config_left_bigram=DnnCrfConfig(skip_left=1,skip_right=0)
54+
mlpcrf_left_bigram=DnnCrf(model_path=left_bigram_model_path,config=config_left_bigram,mode='predict',task='ner',
55+
nn='lstm')
56+
evaluate_ner(mlpcrf_left_bigram,'../dnlp/data/emr/emr_test.pickle')
57+
right_bigram_model_path='../dnlp/models/ner-lstm-right_bigram-50.ckpt'
58+
config_right_bigram=DnnCrfConfig(skip_left=0,skip_right=1)
59+
mlpcrf_right_bigram=DnnCrf(model_path=right_bigram_model_path,config=config_right_bigram,mode='predict',
60+
task='ner',nn='lstm')
61+
evaluate_ner(mlpcrf_right_bigram,'../dnlp/data/emr/emr_test.pickle')
62+
unigram_model_path='../dnlp/models/ner-lstm-unigram-50.ckpt'
63+
config_unigram=DnnCrfConfig(skip_left=0,skip_right=0)
64+
mlpcrf_unigram=DnnCrf(model_path=unigram_model_path,config=config_unigram,mode='predict',task='ner',
65+
nn='lstm')
66+
evaluate_ner(mlpcrf_unigram,'../dnlp/data/emr/emr_test.pickle')
5167

5268

5369
deftrain_emr_random_init():
@@ -64,42 +80,56 @@ def train_emr_random_init():
6480
grulstmcrf=DnnCrf(config=config,task='ner',data_path=data_path,nn='gru')
6581
grulstmcrf.fit()
6682

83+
deftest_emr_random_init():
84+
config=DnnCrfConfig()
85+
mlp_model_path='../dnlp/models/ner-mlp-50.ckpt'
86+
rnn_model_path='../dnlp/models/ner-rnn-50.ckpt'
87+
lstm_model_path='../dnlp/models/ner-lstm-50.ckpt'
88+
bilstm_model_path='../dnlp/models/ner-bilstm-50.ckpt'
89+
gru_model_path='../dnlp/models/ner-gru-50.ckpt'
90+
mlpcrf=DnnCrf(config=config,task='ner',mode='predict',model_path=mlp_model_path,nn='mlp')
91+
rnncrf=DnnCrf(config=config,task='ner',mode='predict',model_path=rnn_model_path,nn='rnn')
92+
lstmcrf=DnnCrf(config=config,task='ner',mode='predict',model_path=lstm_model_path,nn='lstm')
93+
bilstmcrf=DnnCrf(config=config,task='ner',mode='predict',model_path=bilstm_model_path,nn='bilstm')
94+
grucrf=DnnCrf(config=config,task='ner',mode='predict',model_path=gru_model_path,nn='gru')
95+
evaluate_ner(mlpcrf,'../dnlp/data/emr/emr_test.pickle')
96+
evaluate_ner(rnncrf,'../dnlp/data/emr/emr_test.pickle')
97+
evaluate_ner(lstmcrf,'../dnlp/data/emr/emr_test.pickle')
98+
evaluate_ner(bilstmcrf,'../dnlp/data/emr/emr_test.pickle')
99+
evaluate_ner(grucrf,'../dnlp/data/emr/emr_test.pickle')
67100

68101
deftrain_emr_with_embeddings():
69102
data_path='../dnlp/data/emr/emr_training.pickle'
70103
embedding_path='../dnlp/data/emr/emr_skip_gram.npy'
71104
config=DnnCrfConfig()
72105
mlpcrf=DnnCrf(config=config,task='ner',data_path=data_path,nn='mlp',embedding_path=embedding_path)
73-
mlpcrf.fit()
106+
#mlpcrf.fit()
74107
rnncrf=DnnCrf(config=config,task='ner',data_path=data_path,nn='rnn',embedding_path=embedding_path)
75-
rnncrf.fit()
108+
#rnncrf.fit()
76109
lstmcrf=DnnCrf(config=config,task='ner',data_path=data_path,nn='lstm',embedding_path=embedding_path)
77-
lstmcrf.fit()
110+
#lstmcrf.fit()
78111
bilstmcrf=DnnCrf(config=config,task='ner',data_path=data_path,nn='bilstm',embedding_path=embedding_path)
79112
bilstmcrf.fit()
80113
grulstmcrf=DnnCrf(config=config,task='ner',data_path=data_path,nn='gru',embedding_path=embedding_path)
81114
grulstmcrf.fit()
82115

83116

84117
deftest_emr_with_embeddings():
85-
sentence='多饮多尿多食'
86118
config=DnnCrfConfig()
87-
# dnncrf = DnnCrf(config=config, task='ner', mode='predict', model_path=model_path, nn='lstm')
88-
# res = dnncrf.predict_ll(sentence)
89-
# print(res)
90119
embedding_path='../dnlp/data/emr/emr_skip_gram.npy'
91-
mlp_model_path='../dnlp/models/ner-mlp-50.ckpt'
92-
rnn_model_path='../dnlp/models/ner-rnn-50.ckpt'
93-
lstm_model_path='../dnlp/models/ner-lstm-50.ckpt'
94-
bilstm_model_path='../dnlp/models/ner-bilstm-50.ckpt'
95-
gru_model_path='../dnlp/models/ner-gru-50.ckpt'
120+
mlp_model_path='../dnlp/models/ner-mlp-embedding-50.ckpt'
121+
rnn_model_path='../dnlp/models/ner-rnn-embedding-50.ckpt'
122+
lstm_model_path='../dnlp/models/ner-lstm-embedding-50.ckpt'
123+
bilstm_model_path='../dnlp/models/ner-bilstm-embedding-50.ckpt'
124+
gru_model_path='../dnlp/models/ner-gru-embedding-50.ckpt'
96125
mlpcrf=DnnCrf(config=config,task='ner',mode='predict',model_path=mlp_model_path,nn='mlp',
97126
embedding_path=embedding_path)
98127
rnncrf=DnnCrf(config=config,task='ner',mode='predict',model_path=rnn_model_path,nn='rnn',
99128
embedding_path=embedding_path)
100-
lstmcrf=DnnCrf(config=config,task='ner',mode='predict',model_path=lstm_model_path,nn='lstm')
101-
bilstmcrf=DnnCrf(config=config,task='ner',mode='predict',model_path=bilstm_model_path,nn='bilstm')
102-
grucrf=DnnCrf(config=config,task='ner',mode='predict',model_path=gru_model_path,nn='gru')
129+
lstmcrf=DnnCrf(config=config,task='ner',mode='predict',model_path=lstm_model_path,nn='lstm',embedding_path=embedding_path)
130+
bilstmcrf=DnnCrf(config=config,task='ner',mode='predict',model_path=bilstm_model_path,nn='bilstm',
131+
embedding_path=embedding_path)
132+
grucrf=DnnCrf(config=config,task='ner',mode='predict',model_path=gru_model_path,nn='gru',embedding_path=embedding_path)
103133
evaluate_ner(mlpcrf,'../dnlp/data/emr/emr_test.pickle')
104134
evaluate_ner(rnncrf,'../dnlp/data/emr/emr_test.pickle')
105135
evaluate_ner(lstmcrf,'../dnlp/data/emr/emr_test.pickle')
@@ -117,8 +147,8 @@ def train_emr_skipgram():
117147
parser=argparse.ArgumentParser()
118148
parser.add_argument('-t','--t',dest='train',action='store_true',default=False)
119149
parser.add_argument('-p','--p',dest='predict',action='store_true',default=False)
120-
parser.add_argument('-c','--c',dest='cws',action='store_true',default=False)
121-
parser.add_argument('-e','--e',dest='emr',action='store_true',default=False)
150+
parser.add_argument('-c','--c',dest='cws',action='store_true',default=False)
151+
parser.add_argument('-e','--e',dest='emr',action='store_true',default=False)
122152
args=parser.parse_args(sys.argv[1:])
123153
train=args.train
124154
predict=args.predict
@@ -133,11 +163,15 @@ def train_emr_skipgram():
133163
ifargs.cws:
134164
train_cws()
135165
elifargs.emr:
136-
train_emr_ngram()
166+
train_emr_with_embeddings()
167+
# train_emr_ngram()
137168
# train_emr_random_init()
138169
# train_emr_skipgram()
139170
else:
140171
ifargs.cws:
141172
test_cws()
142173
elifargs.emr:
143-
test_emr_ngram()
174+
# test_emr_ngram()
175+
test_emr_random_init()
176+
print('embedding')
177+
test_emr_with_embeddings()

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp