Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit9a13f6d

Browse files
2 parents08352f3 +20ea280 commit9a13f6d

File tree

6 files changed

+43
-17
lines changed

6 files changed

+43
-17
lines changed

‎python/dnlp/core/dnn_crf.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def predict_ll(self, sentence: str, return_labels=False):
155155
input=self.indices2input(self.sentence2indices(sentence))
156156
runner= [self.seq,self.best_score,self.output,self.transition]
157157
labels,best_score,output,trans=self.sess.run(runner,
158-
feed_dict={self.input:input,self.seq_length: [len(sentence)]})
158+
feed_dict={self.input:input,self.seq_length: [len(sentence)]})
159159
# print(output)
160160
# print(trans)
161161
labels=np.squeeze(labels)

‎python/dnlp/core/re_cnn.py‎

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
1111
data_path:str='',relation_count:int=2,model_path:str='',embedding_path:str='',
1212
remark:str='',data_mode='prefetch'):
1313
tf.reset_default_graph()
14-
RECNNBase.__init__(self,config,dict_path)
14+
RECNNBase.__init__(self,config,dict_path,mode=mode)
1515
self.dtype=dtype
1616
self.mode=mode
1717
self.data_path=data_path
@@ -62,6 +62,8 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
6262
self.hidden_layer=tf.layers.dropout(self.get_hidden(),self.dropout_rate)
6363
self.saver=tf.train.Saver(max_to_keep=100)
6464
else:
65+
ifself.data_path:
66+
self.words,self.primary,self.secondary,self.labels=self.load_data()
6567
self.hidden_layer=self.get_hidden()
6668
self.sess=tf.Session()
6769
self.saver=tf.train.Saver().restore(self.sess,self.model_path)
@@ -152,6 +154,7 @@ def predict(self, words, primary, secondary):
152154
returnnp.argmax(output,1)
153155

154156
defevaluate(self):
157+
# self.sess.run(self.input_data)
155158
res=self.predict(self.words,self.primary,self.secondary)
156159
res_count=Counter(res)[1]
157160
target=np.argmax(self.labels,1)
@@ -173,6 +176,8 @@ def get_score(self, predict, true):
173176

174177
precs= [c/pforc,pinzip(corr_count,pred_count)ifp!=0andc!=0]
175178
recalls= [c/rforc,rinzip(corr_count,true_count)ifr!=0andc!=0]
179+
print(precs)
180+
print(recalls)
176181
prec=sum(precs)/len(precs)
177182
recall=sum(recalls)/len(recalls)
178183
f1=2*prec*recall/ (prec+recall)

‎python/dnlp/core/re_cnn_base.py‎

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
classRECNNBase(object):
9-
def__init__(self,config:RECNNConfig,dict_path:str,data_path:str=''):
9+
def__init__(self,config:RECNNConfig,dict_path:str,data_path:str='',mode='train'):
1010
self.window_size=config.window_size
1111
self.filter_size=config.filter_size
1212
self.learning_rate=config.learning_rate
@@ -15,7 +15,10 @@ def __init__(self, config: RECNNConfig, dict_path: str, data_path: str = ''):
1515
self.word_embed_size=config.word_embed_size
1616
self.position_embed_size=config.position_embed_size
1717
self.batch_length=config.batch_length
18-
self.batch_size=config.batch_size
18+
ifmode=='train':
19+
self.batch_size=config.batch_size
20+
else:
21+
self.batch_size=1
1922
self.dictionary=self.read_dictionary(dict_path)
2023
self.words_size=len(self.dictionary)
2124

‎python/scripts/init_datasets.py‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def build_emr_cws_files(base_folder):
8282
# copy()
8383
# build_cws_datasets()
8484
# build_emr_datasets()
85-
#build_emr_re()
86-
base_folder='../dnlp/data/emr/'
87-
dict_path=base_folder+'emr_merged_word_dict.utf8'
88-
ProcessEMR(base_folder=base_folder,dict_path=dict_path,directed=True)
85+
build_emr_re()
86+
#base_folder = '../dnlp/data/emr/'
87+
#dict_path = base_folder + 'emr_merged_word_dict.utf8'
88+
#ProcessEMR(base_folder=base_folder, dict_path=dict_path, directed=True)

‎python/scripts/pipeline.py‎

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# -*- coding:utf-8 -*-
22
importnumpyasnp
33
importpickle
4+
importjson
5+
fromoperatorimportitemgetter
46
fromitertoolsimportaccumulate,permutations
57
fromdnlp.config.sequence_labeling_configimportDnnCrfConfig
68
fromdnlp.core.dnn_crfimportDnnCrf
@@ -111,29 +113,44 @@ def rel_extract(sentences):
111113
rel_pairs.extend(pp)
112114
config_two=RECNNConfig(window_size=(2,3,4))
113115
config_mutli=RECNNConfig(window_size=(2,3,4))
114-
model_path_two='../dnlp/models/re_two/50-2_3_4_directed.ckpt'
116+
model_path_two='../dnlp/models/re_two/8-2_3_4_directed.ckpt'
115117
model_path_multi='../dnlp/models/re_multi/50-2_3_4_directed.ckpt'
116118
recnn2=RECNN(config=config_two,dict_path=DICT_PATH,mode='test',model_path=model_path_two,relation_count=2,data_mode='test')
117119
recnn=RECNN(config=config_two,dict_path=DICT_PATH,mode='test',model_path=model_path_multi,relation_count=28,data_mode='test')
118120
two_res=recnn2.predict(sentence_words,primary,secondary)
119-
true_words= [words[i]foriintwo_resifi]
120-
true_rel_pairs= [rel_pairs[i]foriintwo_resifi ]
121-
true_sentence_words= [sentence_words[i]foriintwo_resifi]
122-
true_primary= [primary[i]foriintwo_resifi]
123-
true_secondary= [secondary[i]foriintwo_resifi]
121+
get_true_rel=itemgetter(*[iiforii,iinenumerate(two_res)ifi])
122+
true_words=get_true_rel(words)
123+
true_rel_pairs=get_true_rel(rel_pairs)
124+
true_sentence_words=get_true_rel(sentence_words)
125+
true_primary=get_true_rel(primary)
126+
true_secondary=get_true_rel(secondary)
124127
multi_res=recnn.predict(true_sentence_words,true_primary,true_secondary)
125128
get_rel_result(true_words,true_rel_pairs,multi_res)
126129

127130
defget_rel_result(words,rel_pairs,rel_types):
128131
result= {}
132+
print(len(rel_pairs))
129133
forsentence_words, (primary_idx,secondary_idx),rel_typeinzip(words,rel_pairs,rel_types):
130134
rel_type_name=REL_NAME_LIST[rel_type]
131135
primary=sentence_words[primary_idx]
132136
secondary=sentence_words[secondary_idx]
133137
primary_type,secondary_type=REL_PAIR_NAMES[rel_type_name]
134138
primary_type=ENTITY_NAMES[primary_type]
135139
secondary_type=ENTITY_NAMES[secondary_type]
136-
# result[]
140+
rel= {'value':secondary,'entity_type':primary_type,'type':REL_NAMES[rel_type_name]}
141+
ifnotresult.get(primary):
142+
result[primary]= [rel]
143+
else:
144+
result[primary].append(rel)
145+
print(result)
146+
merged_result= {t:[]fortinset([rel[0]['entity_type']forrelinresult.values() ])}
147+
forprimary,valueinresult.items():
148+
res= {primary:{v['type']:v['value']forvinvalue}}
149+
primary_type=value[0]['entity_type']
150+
merged_result[primary_type].append(res)
151+
print(merged_result)
152+
153+
137154

138155

139156
defexport():

‎python/scripts/rel.py‎

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,9 @@ def test_re_cnn_with_embedding():
164164
# test_re_cnn_by_window((2,), epoch=5, embedding_path=CBOW_PATH, remark='_cbow_directed')
165165
# get_re_cnn_result()
166166
# get_re_cnn_result('multi')
167-
test_re_cnn_by_window((2,3,4),50,mode='two',relation_count=2,remark='_directed')
168-
test_re_cnn(remark='_directed')
167+
# test_re_cnn_by_window((3,4), 8, mode='two', relation_count=2, remark='_directed')
168+
test_re_cnn_by_window((3,4),8,mode='multi',relation_count=28,remark='_directed')
169+
# test_re_cnn(remark='_directed')
169170
# test_re_cnn('multi')
170171
# test_re_cnn_with_embedding()
171172
# test_single_model((2, 3, 4), 1)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp