22import numpy as np
33import pickle
44from dnlp .config .config import DnnCrfConfig
5- from dnlp .utils .constant import BATCH_PAD ,STRT_VAL ,END_VAL ,TAG_PAD ,TAG_BEGIN ,TAG_INSIDE ,TAG_SINGLE
5+ from dnlp .utils .constant import BATCH_PAD ,UNK , STRT_VAL ,END_VAL ,TAG_PAD ,TAG_BEGIN ,TAG_INSIDE ,TAG_SINGLE
66
77
88class DnnCrfBase (object ):
9- def __init__ (self ,config :DnnCrfConfig = None ,data_path :str = '' ,mode :str = 'train' ,model_path :str = '' ):
9+ def __init__ (self ,config :DnnCrfConfig = None ,data_path :str = '' ,mode :str = 'train' ,model_path :str = '' ):
1010# 加载数据
1111self .data_path = data_path
1212self .config_suffix = '.config.pickle'
@@ -18,7 +18,7 @@ def __init__(self, config: DnnCrfConfig=None, data_path: str = '', mode: str = '
1818self .dictionary ,self .tags = self .__load_config ()
1919self .tags_count = len (self .tags )- 1 # 忽略TAG_PAD
2020self .tags_map = self .__generate_tag_map ()
21- self .reversed_tags_map = dict (zip (self .tags_map .values (),self .tags_map .keys ()))
21+ self .reversed_tags_map = dict (zip (self .tags_map .values (),self .tags_map .keys ()))
2222self .dict_size = len (self .dictionary )
2323# 初始化超参数
2424self .skip_left = config .skip_left
@@ -82,7 +82,7 @@ def get_batch(self) -> (np.ndarray, np.ndarray, np.ndarray):
8282else :
8383ext_size = self .batch_length - len (chs )
8484chs_batch [i ]= chs + ext_size * [self .dictionary [BATCH_PAD ]]
85- lls_batch [i ]= list (map (lambda t :self .tags_map [t ],lls ))+ ext_size * [0 ]# [self.tags_map[TAG_PAD]]
85+ lls_batch [i ]= list (map (lambda t :self .tags_map [t ],lls ))+ ext_size * [0 ]# [self.tags_map[TAG_PAD]]
8686
8787self .batch_start = new_start
8888return self .indices2input (chs_batch ),np .array (lls_batch ,dtype = np .int32 ),np .array (len_batch ,dtype = np .int32 )
@@ -111,7 +111,8 @@ def viterbi(self, emission: np.ndarray, transition: np.ndarray, transition_init:
111111return corr_path
112112
113113def sentence2indices (self ,sentence :str )-> list :
114- return list (map (lambda ch :self .dictionary [ch ],sentence ))
114+ expr = lambda ch :self .dictionary [ch ]if ch in self .dictionary else self .dictionary [UNK ]
115+ return list (map (expr ,sentence ))
115116
116117def indices2input (self ,indices :list )-> np .ndarray :
117118res = []
@@ -173,10 +174,10 @@ def tags2entities(self, sentence: str, tags_seq: np.ndarray, return_start: bool
173174else :
174175return entities
175176
176- def tag2sequences (self ,tags_seq :np .ndarray ):
177+ def tag2sequences (self ,tags_seq :np .ndarray ):
177178seq = []
178179
179180for tag in tags_seq :
180181seq .append (self .reversed_tags_map [tag ])
181182
182- return seq
183+ return seq