@@ -10,16 +10,6 @@ def __init__(self, config: DnnCrfConfig = None, data_path: str = '', mode: str =
1010# 加载数据
1111self .data_path = data_path
1212self .config_suffix = '.config.pickle'
13- if mode == 'train' :
14- self .dictionary ,self .tags ,self .characters ,self .labels = self .__load_data ()
15- else :
16- self .model_path = model_path
17- self .config_path = self .model_path + self .config_suffix
18- self .dictionary ,self .tags = self .__load_config ()
19- self .tags_count = len (self .tags )- 1 # 忽略TAG_PAD
20- self .tags_map = self .__generate_tag_map ()
21- self .reversed_tags_map = dict (zip (self .tags_map .values (),self .tags_map .keys ()))
22- self .dict_size = len (self .dictionary )
2313# 初始化超参数
2414self .skip_left = config .skip_left
2515self .skip_right = config .skip_right
@@ -32,12 +22,25 @@ def __init__(self, config: DnnCrfConfig = None, data_path: str = '', mode: str =
3222self .concat_embed_size = self .embed_size * self .windows_size
3323self .batch_length = config .batch_length
3424self .batch_size = config .batch_size
35- # 数据
25+
3626if mode == 'train' :
37- self .sentences_length = list (map (lambda s :len (s ),self .characters ))
38- self .sentences_count = len (self .sentences_length )
27+ self .dictionary ,self .tags ,self .sentences ,self .labels = self .__load_data ()
28+ self .sentence_lengths = list (map (lambda s :len (s ),self .sentences ))
29+ self .sentences_count = len (self .sentence_lengths )
3930self .batch_count = self .sentences_count // self .batch_size
4031self .batch_start = 0
32+ self .dataset_start = 0
33+ else :
34+ self .model_path = model_path
35+ self .config_path = self .model_path + self .config_suffix
36+ self .dictionary ,self .tags = self .__load_config ()
37+ self .tags_count = len (self .tags )- 1 # 忽略TAG_PAD
38+ self .tags_map = self .__generate_tag_map ()
39+ self .reversed_tags_map = dict (zip (self .tags_map .values (),self .tags_map .keys ()))
40+ self .dict_size = len (self .dictionary )
41+ if mode == 'train' :
42+ self .preprocess ()
43+
4144
4245def __load_data (self )-> (dict ,tuple ,np .ndarray ,np .ndarray ):
4346with open (self .data_path ,'rb' )as f :
@@ -63,17 +66,29 @@ def __generate_tag_map(self):
6366tags_map [self .tags [i ]]= i
6467return tags_map
6568
69+ def preprocess (self ):
70+ for i ,(sentence ,labels ,length )in enumerate (zip (self .sentences ,self .labels ,self .sentence_lengths )):
71+ if length < self .batch_length :
72+ ext_size = self .batch_length - length
73+ sentence = self .__indices2input_single (sentence )
74+ self .sentences [i ]= sentence + [[self .dictionary [BATCH_PAD ]]* self .windows_size ]* ext_size
75+ self .labels [i ]= [self .tags_map [l ]for l in labels ]+ [0 ]* ext_size
76+ elif length > self .batch_length :
77+ self .sentences [i ]= self .__indices2input_single (sentence [:self .batch_length ])
78+ self .labels [i ]= [self .tags_map [l ]for l in labels [:self .batch_length ]]
79+
80+
6681def get_batch (self )-> (np .ndarray ,np .ndarray ,np .ndarray ):
6782if self .batch_start + self .batch_size > self .sentences_count :
6883new_start = self .batch_start + self .batch_size - self .sentences_count
69- chs_batch = self .characters [self .batch_start :]+ self .characters [:new_start ]
84+ chs_batch = self .sentences [self .batch_start :]+ self .sentences [:new_start ]
7085lls_batch = self .labels [self .batch_start :]+ self .labels [:new_start ]
71- len_batch = self .sentences_length [self .batch_start :]+ self .sentences_length [:new_start ]
86+ len_batch = self .sentence_lengths [self .batch_start :]+ self .sentence_lengths [:new_start ]
7287else :
7388new_start = self .batch_start + self .batch_size
74- chs_batch = self .characters [self .batch_start :new_start ]
89+ chs_batch = self .sentences [self .batch_start :new_start ]
7590lls_batch = self .labels [self .batch_start :new_start ]
76- len_batch = self .sentences_length [self .batch_start :new_start ]
91+ len_batch = self .sentence_lengths [self .batch_start :new_start ]
7792for i , (chs ,lls )in enumerate (zip (chs_batch ,lls_batch )):
7893if len (chs )> self .batch_length :
7994chs_batch [i ]= chs [:self .batch_length ]
@@ -162,13 +177,13 @@ def tags2entities(self, sentence: str, tags_seq: np.ndarray, return_start: bool
162177continue
163178elif tag == self .tags_map [TAG_BEGIN ]:
164179if entity :
165- entities .append (entity )
180+ entities .append (( entity , tag_index ) )
166181entity = sentence [tag_index ]
167182entity_starts .append (tag_index )
168183else :
169184entity += sentence [tag_index ]
170185if entity != '' :
171- entities .append (entity )
186+ entities .append (( entity , len ( sentence ) - len ( entity )) )
172187if return_start :
173188return entities ,entity_starts
174189else :