Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit1dabce4

Browse files
modify to dataset api
1 parented380f9 commit1dabce4

File tree

2 files changed

+52
-21
lines changed

2 files changed

+52
-21
lines changed

‎python/dnlp/core/dnn_crf.py‎

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,18 @@ def __init__(self, *, config: DnnCrfConfig = None, task='cws', data_path: str =
3535
ifmode=='train':
3636
self.input=tf.placeholder(tf.int32, [self.batch_size,self.batch_length,self.windows_size])
3737
self.real_indices=tf.placeholder(tf.int32, [self.batch_size,self.batch_length])
38+
self.sentence_inputs=tf.data.Dataset.from_tensor_slices(self.sentences).repeat(-1).batch(self.batch_size)
39+
self.label_inputs=tf.data.Dataset.from_tensor_slices(self.labels).repeat(-1).batch(self.batch_size)
40+
self.length_inputs=tf.data.Dataset.from_tensor_slices(self.sentence_lengths).repeat(-1).batch(self.batch_size)
41+
self.sentence_iterator=self.sentence_inputs.make_initializable_iterator()
42+
self.label_iterator=self.label_inputs.make_initializable_iterator()
43+
self.length_iterator=self.length_inputs.make_initializable_iterator()
3844
else:
3945
self.input=tf.placeholder(tf.int32, [None,self.windows_size])
4046

4147
self.seq_length=tf.placeholder(tf.int32, [None])
4248

49+
4350
# 查找表层
4451
self.embedding_layer=self.get_embedding_layer()
4552
# 执行drpout
@@ -91,13 +98,22 @@ def __init__(self, *, config: DnnCrfConfig = None, task='cws', data_path: str =
9198
deffit(self,epochs:int=50,interval:int=10):
9299
withtf.Session(graph=self.graph)assess:
93100
tf.global_variables_initializer().run()
101+
sess.run(self.sentence_iterator.initializer)
102+
sess.run(self.label_iterator.initializer)
103+
sess.run(self.length_iterator.initializer)
104+
sentence=self.sentence_iterator.get_next()
105+
label=self.label_iterator.get_next()
106+
length=self.length_iterator.get_next()
94107
saver=tf.train.Saver(max_to_keep=epochs)
95108
forepochinrange(1,epochs+1):
96109
print('epoch:',epoch)
97110
j=0
98111
foriinrange(self.batch_count):
99-
characters,labels,lengths=self.get_batch()
100-
feed_dict= {self.input:characters,self.real_indices:labels,self.seq_length:lengths}
112+
# sentences, labels, lengths = self.get_batch()
113+
sentences=sess.run(sentence)
114+
labels=sess.run(label)
115+
lengths=sess.run(length)
116+
feed_dict= {self.input:sentences,self.real_indices:labels,self.seq_length:lengths}
101117
_,summary,loss=sess.run([self.train,self.merged,self.mean_loss],feed_dict=feed_dict)
102118
self.train_writer.add_summary(summary,j)
103119
j+=1

‎python/dnlp/core/dnn_crf_base.py‎

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,6 @@ def __init__(self, config: DnnCrfConfig = None, data_path: str = '', mode: str =
1010
# 加载数据
1111
self.data_path=data_path
1212
self.config_suffix='.config.pickle'
13-
ifmode=='train':
14-
self.dictionary,self.tags,self.characters,self.labels=self.__load_data()
15-
else:
16-
self.model_path=model_path
17-
self.config_path=self.model_path+self.config_suffix
18-
self.dictionary,self.tags=self.__load_config()
19-
self.tags_count=len(self.tags)-1# 忽略TAG_PAD
20-
self.tags_map=self.__generate_tag_map()
21-
self.reversed_tags_map=dict(zip(self.tags_map.values(),self.tags_map.keys()))
22-
self.dict_size=len(self.dictionary)
2313
# 初始化超参数
2414
self.skip_left=config.skip_left
2515
self.skip_right=config.skip_right
@@ -32,12 +22,25 @@ def __init__(self, config: DnnCrfConfig = None, data_path: str = '', mode: str =
3222
self.concat_embed_size=self.embed_size*self.windows_size
3323
self.batch_length=config.batch_length
3424
self.batch_size=config.batch_size
35-
# 数据
25+
3626
ifmode=='train':
37-
self.sentences_length=list(map(lambdas:len(s),self.characters))
38-
self.sentences_count=len(self.sentences_length)
27+
self.dictionary,self.tags,self.sentences,self.labels=self.__load_data()
28+
self.sentence_lengths=list(map(lambdas:len(s),self.sentences))
29+
self.sentences_count=len(self.sentence_lengths)
3930
self.batch_count=self.sentences_count//self.batch_size
4031
self.batch_start=0
32+
self.dataset_start=0
33+
else:
34+
self.model_path=model_path
35+
self.config_path=self.model_path+self.config_suffix
36+
self.dictionary,self.tags=self.__load_config()
37+
self.tags_count=len(self.tags)-1# 忽略TAG_PAD
38+
self.tags_map=self.__generate_tag_map()
39+
self.reversed_tags_map=dict(zip(self.tags_map.values(),self.tags_map.keys()))
40+
self.dict_size=len(self.dictionary)
41+
ifmode=='train':
42+
self.preprocess()
43+
4144

4245
def__load_data(self)-> (dict,tuple,np.ndarray,np.ndarray):
4346
withopen(self.data_path,'rb')asf:
@@ -63,17 +66,29 @@ def __generate_tag_map(self):
6366
tags_map[self.tags[i]]=i
6467
returntags_map
6568

69+
defpreprocess(self):
70+
fori,(sentence,labels,length)inenumerate(zip(self.sentences,self.labels,self.sentence_lengths)):
71+
iflength<self.batch_length:
72+
ext_size=self.batch_length-length
73+
sentence=self.__indices2input_single(sentence)
74+
self.sentences[i]=sentence+ [[self.dictionary[BATCH_PAD]]*self.windows_size]*ext_size
75+
self.labels[i]= [self.tags_map[l]forlinlabels]+[0]*ext_size
76+
eliflength>self.batch_length:
77+
self.sentences[i]=self.__indices2input_single(sentence[:self.batch_length])
78+
self.labels[i]= [self.tags_map[l]forlinlabels[:self.batch_length]]
79+
80+
6681
defget_batch(self)-> (np.ndarray,np.ndarray,np.ndarray):
6782
ifself.batch_start+self.batch_size>self.sentences_count:
6883
new_start=self.batch_start+self.batch_size-self.sentences_count
69-
chs_batch=self.characters[self.batch_start:]+self.characters[:new_start]
84+
chs_batch=self.sentences[self.batch_start:]+self.sentences[:new_start]
7085
lls_batch=self.labels[self.batch_start:]+self.labels[:new_start]
71-
len_batch=self.sentences_length[self.batch_start:]+self.sentences_length[:new_start]
86+
len_batch=self.sentence_lengths[self.batch_start:]+self.sentence_lengths[:new_start]
7287
else:
7388
new_start=self.batch_start+self.batch_size
74-
chs_batch=self.characters[self.batch_start:new_start]
89+
chs_batch=self.sentences[self.batch_start:new_start]
7590
lls_batch=self.labels[self.batch_start:new_start]
76-
len_batch=self.sentences_length[self.batch_start:new_start]
91+
len_batch=self.sentence_lengths[self.batch_start:new_start]
7792
fori, (chs,lls)inenumerate(zip(chs_batch,lls_batch)):
7893
iflen(chs)>self.batch_length:
7994
chs_batch[i]=chs[:self.batch_length]
@@ -162,13 +177,13 @@ def tags2entities(self, sentence: str, tags_seq: np.ndarray, return_start: bool
162177
continue
163178
eliftag==self.tags_map[TAG_BEGIN]:
164179
ifentity:
165-
entities.append(entity)
180+
entities.append((entity,tag_index))
166181
entity=sentence[tag_index]
167182
entity_starts.append(tag_index)
168183
else:
169184
entity+=sentence[tag_index]
170185
ifentity!='':
171-
entities.append(entity)
186+
entities.append((entity,len(sentence)-len(entity)))
172187
ifreturn_start:
173188
returnentities,entity_starts
174189
else:

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp