Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit73001ee

Browse files
fix dnn-crf prediction bug
1 parent2e755ad commit73001ee

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

‎python/dnlp/core/dnn_crf_base.py‎

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
importnumpyasnp
33
importpickle
44
fromdnlp.config.configimportDnnCrfConfig
5-
fromdnlp.utils.constantimportBATCH_PAD,STRT_VAL,END_VAL,TAG_PAD,TAG_BEGIN,TAG_INSIDE,TAG_SINGLE
5+
fromdnlp.utils.constantimportBATCH_PAD,UNK,STRT_VAL,END_VAL,TAG_PAD,TAG_BEGIN,TAG_INSIDE,TAG_SINGLE
66

77

88
classDnnCrfBase(object):
9-
def__init__(self,config:DnnCrfConfig=None,data_path:str='',mode:str='train',model_path:str=''):
9+
def__init__(self,config:DnnCrfConfig=None,data_path:str='',mode:str='train',model_path:str=''):
1010
# 加载数据
1111
self.data_path=data_path
1212
self.config_suffix='.config.pickle'
@@ -18,7 +18,7 @@ def __init__(self, config: DnnCrfConfig=None, data_path: str = '', mode: str = '
1818
self.dictionary,self.tags=self.__load_config()
1919
self.tags_count=len(self.tags)-1# 忽略TAG_PAD
2020
self.tags_map=self.__generate_tag_map()
21-
self.reversed_tags_map=dict(zip(self.tags_map.values(),self.tags_map.keys()))
21+
self.reversed_tags_map=dict(zip(self.tags_map.values(),self.tags_map.keys()))
2222
self.dict_size=len(self.dictionary)
2323
# 初始化超参数
2424
self.skip_left=config.skip_left
@@ -82,7 +82,7 @@ def get_batch(self) -> (np.ndarray, np.ndarray, np.ndarray):
8282
else:
8383
ext_size=self.batch_length-len(chs)
8484
chs_batch[i]=chs+ext_size* [self.dictionary[BATCH_PAD]]
85-
lls_batch[i]=list(map(lambdat:self.tags_map[t],lls))+ext_size* [0]#[self.tags_map[TAG_PAD]]
85+
lls_batch[i]=list(map(lambdat:self.tags_map[t],lls))+ext_size* [0]#[self.tags_map[TAG_PAD]]
8686

8787
self.batch_start=new_start
8888
returnself.indices2input(chs_batch),np.array(lls_batch,dtype=np.int32),np.array(len_batch,dtype=np.int32)
@@ -111,7 +111,8 @@ def viterbi(self, emission: np.ndarray, transition: np.ndarray, transition_init:
111111
returncorr_path
112112

113113
defsentence2indices(self,sentence:str)->list:
114-
returnlist(map(lambdach:self.dictionary[ch],sentence))
114+
expr=lambdach:self.dictionary[ch]ifchinself.dictionaryelseself.dictionary[UNK]
115+
returnlist(map(expr,sentence))
115116

116117
defindices2input(self,indices:list)->np.ndarray:
117118
res= []
@@ -173,10 +174,10 @@ def tags2entities(self, sentence: str, tags_seq: np.ndarray, return_start: bool
173174
else:
174175
returnentities
175176

176-
deftag2sequences(self,tags_seq:np.ndarray):
177+
deftag2sequences(self,tags_seq:np.ndarray):
177178
seq= []
178179

179180
fortagintags_seq:
180181
seq.append(self.reversed_tags_map[tag])
181182

182-
returnseq
183+
returnseq

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp