11# -*- coding: UTF-8 -*-
22import tensorflow as tf
33import numpy as np
4- import pickle
54from collections import Counter
65from dnlp .core .re_cnn_base import RECNNBase
76from dnlp .config import RECNNConfig
8- from dnlp . utils . constant import BATCH_PAD , BATCH_PAD_VAL
7+
98
109
1110class RECNN (RECNNBase ):
@@ -23,6 +22,12 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
2322self .remark = remark
2423
2524self .concat_embed_size = self .word_embed_size + 2 * self .position_embed_size
25+ self .words ,self .primary ,self .secondary ,self .labels = self .load_data ()
26+ self .input_words = tf .constant (self .words )
27+ self .input_primary = tf .constant (self .primary )
28+ self .input_secondary = tf .constant (self .secondary )
29+ self .input_labels = tf .constant (self .labels )
30+ self .input_indices = tf .placeholder (tf .int32 , [self .batch_size ])
2631self .input_characters = tf .placeholder (tf .int32 , [None ,self .batch_length ])
2732self .input_position = tf .placeholder (tf .int32 , [None ,self .batch_length ])
2833self .input = tf .placeholder (self .dtype , [None ,self .batch_length ,self .concat_embed_size ,1 ])
@@ -39,6 +44,10 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
3944self .full_connected_weight = self .__weight_variable ([self .filter_size * len (self .window_size ),self .relation_count ],
4045name = 'full_connected_weight' )
4146self .full_connected_bias = self .__weight_variable ([self .relation_count ],name = 'full_connected_bias' )
47+ self .input_words_lookup = tf .nn .embedding_lookup (self .input_words ,self .input_indices )
48+ self .input_primary_lookup = tf .nn .embedding_lookup (self .input_primary ,self .input_indices )
49+ self .input_secondary_lookup = tf .nn .embedding_lookup (self .input_secondary ,self .input_indices )
50+ self .input_labels_lookup = tf .nn .embedding_lookup (self .input_labels ,self .input_indices )
4251self .position_lookup = tf .nn .embedding_lookup (self .position_embedding ,self .input_position )
4352self .character_lookup = tf .nn .embedding_lookup (self .word_embedding ,self .input_characters )
4453self .character_embed_holder = tf .placeholder (self .dtype ,
@@ -49,7 +58,8 @@ def __init__(self, config: RECNNConfig, dtype: type = tf.float32, dict_path: str
4958 [None ,self .batch_length ,self .position_embed_size ])
5059self .emebd_concat = tf .expand_dims (
5160tf .concat ([self .character_embed_holder ,self .primary_embed_holder ,self .secondary_embed_holder ],2 ),3 )
52- self .words ,self .primary ,self .secondary ,self .labels = self .load_data ()
61+
62+
5363
5464if self .mode == 'train' :
5565self .start = 0
@@ -111,10 +121,20 @@ def fit(self, epochs=50, interval=5):
111121with tf .Session ()as sess :
112122tf .global_variables_initializer ().run ()
113123sess .graph .finalize ()
124+ start = 0
114125for i in range (1 ,epochs + 1 ):
115126print ('epoch:' + str (i ))
116- for _ in range (self .data_count // self .batch_size ):
117- words ,primary ,secondary ,labels = self .load_batch ()
127+ for j in range (self .data_count // self .batch_size ):
128+ if start + self .batch_size < self .data_count :
129+ indices = list (range (start ,start + self .batch_size ))
130+ start += self .batch_size
131+ else :
132+ new_start = self .batch_size - self .data_count + start
133+ indices = list (range (start ,self .data_count ))+ list (range (0 ,new_start ))
134+ start = new_start
135+ words ,primary ,secondary ,labels = sess .run ([self .input_words ,self .input_primary ,self .input_secondary ,
136+ self .input_labels ],feed_dict = {self .input_indices :indices })
137+ # words, primary, secondary, labels = self.load_batch()
118138character_embeds ,primary_embeds = sess .run ([self .character_lookup ,self .position_lookup ],
119139feed_dict = {self .input_characters :words ,
120140self .input_position :primary })
@@ -190,27 +210,7 @@ def load_batch(self):
190210self .start = new_start
191211return words ,primary ,secondary ,labels
192212
193- def load_data (self ):
194- primary = []
195- secondary = []
196- words = []
197- labels = []
198- with open (self .data_path ,'rb' )as f :
199- data = pickle .load (f )
200- for sentence in data :
201- sentence_words = sentence ['words' ]
202- if len (sentence_words )< self .batch_length :
203- sentence_words += [self .dictionary [BATCH_PAD ]]* (self .batch_length - len (sentence_words ))
204- else :
205- sentence_words = sentence_words [:self .batch_length ]
206- words .append (sentence_words )
207- primary .append (np .arange (self .batch_length )- sentence ['primary' ]+ self .batch_length - 1 )
208- secondary .append (np .arange (self .batch_length )- sentence ['secondary' ]+ self .batch_length - 1 )
209- sentence_labels = np .zeros ([self .relation_count ])
210- sentence_labels [sentence ['type' ]]= 1
211- labels .append (sentence_labels )
212- return np .array (words ,np .int32 ),np .array (primary ,np .int32 ),np .array (secondary ,np .int32 ),np .array (labels ,
213- np .float32 )
213+
214214
215215def __weight_variable (self ,shape ,name ):
216216initial = tf .truncated_normal (shape ,stddev = 0.1 ,dtype = self .dtype )