11# -*- coding: UTF-8 -*-
22import sys
33import argparse
4- from dnlp .config . config import DnnCrfConfig
4+ from dnlp .config import DnnCrfConfig
55from dnlp .core .dnn_crf import DnnCrf
66from dnlp .core .skip_gram import SkipGram
77from dnlp .utils .evaluation import evaluate_cws ,evaluate_ner
@@ -44,10 +44,26 @@ def train_emr_ngram():
4444
4545
4646def test_emr_ngram ():
47- model_path = '../dnlp/models/ner-lstm-bi_bigram-10 .ckpt'
47+ bi_bigram_model_path = '../dnlp/models/ner-lstm-bi_bigram-50 .ckpt'
4848config_bi_bigram = DnnCrfConfig (skip_left = 1 ,skip_right = 1 )
49- mlpcrf_bi_bigram = DnnCrf (model_path = model_path ,config = config_bi_bigram ,mode = 'predict' ,task = 'ner' ,nn = 'lstm' )
49+ mlpcrf_bi_bigram = DnnCrf (model_path = bi_bigram_model_path ,config = config_bi_bigram ,mode = 'predict' ,task = 'ner' ,
50+ nn = 'lstm' )
5051evaluate_ner (mlpcrf_bi_bigram ,'../dnlp/data/emr/emr_test.pickle' )
52+ left_bigram_model_path = '../dnlp/models/ner-lstm-left_bigram-50.ckpt'
53+ config_left_bigram = DnnCrfConfig (skip_left = 1 ,skip_right = 0 )
54+ mlpcrf_left_bigram = DnnCrf (model_path = left_bigram_model_path ,config = config_left_bigram ,mode = 'predict' ,task = 'ner' ,
55+ nn = 'lstm' )
56+ evaluate_ner (mlpcrf_left_bigram ,'../dnlp/data/emr/emr_test.pickle' )
57+ right_bigram_model_path = '../dnlp/models/ner-lstm-right_bigram-50.ckpt'
58+ config_right_bigram = DnnCrfConfig (skip_left = 0 ,skip_right = 1 )
59+ mlpcrf_right_bigram = DnnCrf (model_path = right_bigram_model_path ,config = config_right_bigram ,mode = 'predict' ,
60+ task = 'ner' ,nn = 'lstm' )
61+ evaluate_ner (mlpcrf_right_bigram ,'../dnlp/data/emr/emr_test.pickle' )
62+ unigram_model_path = '../dnlp/models/ner-lstm-unigram-50.ckpt'
63+ config_unigram = DnnCrfConfig (skip_left = 0 ,skip_right = 0 )
64+ mlpcrf_unigram = DnnCrf (model_path = unigram_model_path ,config = config_unigram ,mode = 'predict' ,task = 'ner' ,
65+ nn = 'lstm' )
66+ evaluate_ner (mlpcrf_unigram ,'../dnlp/data/emr/emr_test.pickle' )
5167
5268
5369def train_emr_random_init ():
@@ -64,42 +80,56 @@ def train_emr_random_init():
6480grulstmcrf = DnnCrf (config = config ,task = 'ner' ,data_path = data_path ,nn = 'gru' )
6581grulstmcrf .fit ()
6682
83+ def test_emr_random_init ():
84+ config = DnnCrfConfig ()
85+ mlp_model_path = '../dnlp/models/ner-mlp-50.ckpt'
86+ rnn_model_path = '../dnlp/models/ner-rnn-50.ckpt'
87+ lstm_model_path = '../dnlp/models/ner-lstm-50.ckpt'
88+ bilstm_model_path = '../dnlp/models/ner-bilstm-50.ckpt'
89+ gru_model_path = '../dnlp/models/ner-gru-50.ckpt'
90+ mlpcrf = DnnCrf (config = config ,task = 'ner' ,mode = 'predict' ,model_path = mlp_model_path ,nn = 'mlp' )
91+ rnncrf = DnnCrf (config = config ,task = 'ner' ,mode = 'predict' ,model_path = rnn_model_path ,nn = 'rnn' )
92+ lstmcrf = DnnCrf (config = config ,task = 'ner' ,mode = 'predict' ,model_path = lstm_model_path ,nn = 'lstm' )
93+ bilstmcrf = DnnCrf (config = config ,task = 'ner' ,mode = 'predict' ,model_path = bilstm_model_path ,nn = 'bilstm' )
94+ grucrf = DnnCrf (config = config ,task = 'ner' ,mode = 'predict' ,model_path = gru_model_path ,nn = 'gru' )
95+ evaluate_ner (mlpcrf ,'../dnlp/data/emr/emr_test.pickle' )
96+ evaluate_ner (rnncrf ,'../dnlp/data/emr/emr_test.pickle' )
97+ evaluate_ner (lstmcrf ,'../dnlp/data/emr/emr_test.pickle' )
98+ evaluate_ner (bilstmcrf ,'../dnlp/data/emr/emr_test.pickle' )
99+ evaluate_ner (grucrf ,'../dnlp/data/emr/emr_test.pickle' )
67100
68101def train_emr_with_embeddings ():
69102data_path = '../dnlp/data/emr/emr_training.pickle'
70103embedding_path = '../dnlp/data/emr/emr_skip_gram.npy'
71104config = DnnCrfConfig ()
72105mlpcrf = DnnCrf (config = config ,task = 'ner' ,data_path = data_path ,nn = 'mlp' ,embedding_path = embedding_path )
73- mlpcrf .fit ()
106+ # mlpcrf.fit()
74107rnncrf = DnnCrf (config = config ,task = 'ner' ,data_path = data_path ,nn = 'rnn' ,embedding_path = embedding_path )
75- rnncrf .fit ()
108+ # rnncrf.fit()
76109lstmcrf = DnnCrf (config = config ,task = 'ner' ,data_path = data_path ,nn = 'lstm' ,embedding_path = embedding_path )
77- lstmcrf .fit ()
110+ # lstmcrf.fit()
78111bilstmcrf = DnnCrf (config = config ,task = 'ner' ,data_path = data_path ,nn = 'bilstm' ,embedding_path = embedding_path )
79112bilstmcrf .fit ()
80113grulstmcrf = DnnCrf (config = config ,task = 'ner' ,data_path = data_path ,nn = 'gru' ,embedding_path = embedding_path )
81114grulstmcrf .fit ()
82115
83116
84117def test_emr_with_embeddings ():
85- sentence = '多饮多尿多食'
86118config = DnnCrfConfig ()
87- # dnncrf = DnnCrf(config=config, task='ner', mode='predict', model_path=model_path, nn='lstm')
88- # res = dnncrf.predict_ll(sentence)
89- # print(res)
90119embedding_path = '../dnlp/data/emr/emr_skip_gram.npy'
91- mlp_model_path = '../dnlp/models/ner-mlp-50.ckpt'
92- rnn_model_path = '../dnlp/models/ner-rnn-50.ckpt'
93- lstm_model_path = '../dnlp/models/ner-lstm-50.ckpt'
94- bilstm_model_path = '../dnlp/models/ner-bilstm-50.ckpt'
95- gru_model_path = '../dnlp/models/ner-gru-50.ckpt'
120+ mlp_model_path = '../dnlp/models/ner-mlp-embedding- 50.ckpt'
121+ rnn_model_path = '../dnlp/models/ner-rnn-embedding- 50.ckpt'
122+ lstm_model_path = '../dnlp/models/ner-lstm-embedding- 50.ckpt'
123+ bilstm_model_path = '../dnlp/models/ner-bilstm-embedding- 50.ckpt'
124+ gru_model_path = '../dnlp/models/ner-gru-embedding- 50.ckpt'
96125mlpcrf = DnnCrf (config = config ,task = 'ner' ,mode = 'predict' ,model_path = mlp_model_path ,nn = 'mlp' ,
97126embedding_path = embedding_path )
98127rnncrf = DnnCrf (config = config ,task = 'ner' ,mode = 'predict' ,model_path = rnn_model_path ,nn = 'rnn' ,
99128embedding_path = embedding_path )
100- lstmcrf = DnnCrf (config = config ,task = 'ner' ,mode = 'predict' ,model_path = lstm_model_path ,nn = 'lstm' )
101- bilstmcrf = DnnCrf (config = config ,task = 'ner' ,mode = 'predict' ,model_path = bilstm_model_path ,nn = 'bilstm' )
102- grucrf = DnnCrf (config = config ,task = 'ner' ,mode = 'predict' ,model_path = gru_model_path ,nn = 'gru' )
129+ lstmcrf = DnnCrf (config = config ,task = 'ner' ,mode = 'predict' ,model_path = lstm_model_path ,nn = 'lstm' ,embedding_path = embedding_path )
130+ bilstmcrf = DnnCrf (config = config ,task = 'ner' ,mode = 'predict' ,model_path = bilstm_model_path ,nn = 'bilstm' ,
131+ embedding_path = embedding_path )
132+ grucrf = DnnCrf (config = config ,task = 'ner' ,mode = 'predict' ,model_path = gru_model_path ,nn = 'gru' ,embedding_path = embedding_path )
103133evaluate_ner (mlpcrf ,'../dnlp/data/emr/emr_test.pickle' )
104134evaluate_ner (rnncrf ,'../dnlp/data/emr/emr_test.pickle' )
105135evaluate_ner (lstmcrf ,'../dnlp/data/emr/emr_test.pickle' )
@@ -117,8 +147,8 @@ def train_emr_skipgram():
117147parser = argparse .ArgumentParser ()
118148parser .add_argument ('-t' ,'--t' ,dest = 'train' ,action = 'store_true' ,default = False )
119149parser .add_argument ('-p' ,'--p' ,dest = 'predict' ,action = 'store_true' ,default = False )
120- parser .add_argument ('-c' ,'--c' ,dest = 'cws' ,action = 'store_true' ,default = False )
121- parser .add_argument ('-e' ,'--e' ,dest = 'emr' ,action = 'store_true' ,default = False )
150+ parser .add_argument ('-c' ,'--c' ,dest = 'cws' ,action = 'store_true' ,default = False )
151+ parser .add_argument ('-e' ,'--e' ,dest = 'emr' ,action = 'store_true' ,default = False )
122152args = parser .parse_args (sys .argv [1 :])
123153train = args .train
124154predict = args .predict
@@ -133,11 +163,15 @@ def train_emr_skipgram():
133163if args .cws :
134164train_cws ()
135165elif args .emr :
136- train_emr_ngram ()
166+ train_emr_with_embeddings ()
167+ # train_emr_ngram()
137168# train_emr_random_init()
138169# train_emr_skipgram()
139170else :
140171if args .cws :
141172test_cws ()
142173elif args .emr :
143- test_emr_ngram ()
174+ # test_emr_ngram()
175+ test_emr_random_init ()
176+ print ('embedding' )
177+ test_emr_with_embeddings ()