1- #-*- coding: UTF-8 -*-
1+ # -*- coding: UTF-8 -*-
22import tensorflow as tf
3+ import numpy as np
4+ import pickle
35from dnlp .core .re_cnn_base import RECNNBase
46from dnlp .config import RECNNConfig
7+ from dnlp .utils .constant import BATCH_PAD ,BATCH_PAD_VAL
8+
59
610class RECNN (RECNNBase ):
7- def __init__ (self ,config :RECNNConfig ,dtype :type = tf .float32 ,dict_path :str = '' ,mode :str = 'train' ):
8- RECNNBase .__init__ (self ,config )
11+ def __init__ (self ,config :RECNNConfig ,dtype :type = tf .float32 ,dict_path :str = '' ,mode :str = 'train' ,
12+ data_path :str = '' ,relation_count :int = 2 ):
13+ RECNNBase .__init__ (self ,config ,dict_path )
914self .dtype = dtype
1015self .mode = mode
11- self .dictionary = self .read_dictionary (dict_path )
16+ self .data_path = data_path
17+ self .relation_count = relation_count
18+
19+ self .concat_embed_size = self .word_embed_size + 2 * self .position_embed_size
20+ self .input_characters = tf .placeholder (tf .int32 , [None ,self .batch_length ])
21+ self .input_position = tf .placeholder (tf .int32 , [None ,self .batch_length ])
22+ self .input = tf .placeholder (self .dtype , [None ,self .batch_length ,self .concat_embed_size ,1 ])
23+ self .input_relation = tf .placeholder (self .dtype , [None ,self .relation_count ])
24+ self .position_embedding = self .__weight_variable ([2 * self .batch_length - 1 ,self .position_embed_size ],
25+ name = 'position_embedding' )
26+ self .word_embedding = self .__weight_variable ([self .words_size ,self .word_embed_size ],name = 'word_embedding' )
27+ self .conv_kernel = self .get_conv_kernel ()
28+ self .bias = [self .__weight_variable ([self .filter_size ],name = 'conv_bias' )]* len (self .window_size )
29+ self .full_connected_weight = self .__weight_variable ([self .filter_size * len (self .window_size ),self .relation_count ],
30+ name = 'full_connected_weight' )
31+ self .full_connected_bias = self .__weight_variable ([self .relation_count ],name = 'full_connected_bias' )
32+ self .position_lookup = tf .nn .embedding_lookup (self .position_embedding ,self .input_position )
33+ self .character_lookup = tf .nn .embedding_lookup (self .word_embedding ,self .input_characters )
34+ self .character_embed_holder = tf .placeholder (self .dtype ,
35+ [None ,self .batch_length ,self .word_embed_size ])
36+ self .primary_embed_holder = tf .placeholder (self .dtype ,
37+ [None ,self .batch_length ,self .position_embed_size ])
38+ self .secondary_embed_holder = tf .placeholder (self .dtype ,
39+ [None ,self .batch_length ,self .position_embed_size ])
40+ self .emebd_concat = tf .expand_dims (
41+ tf .concat ([self .character_embed_holder ,self .primary_embed_holder ,self .secondary_embed_holder ],2 ),3 )
42+ if self .mode == 'train' :
43+ self .hidden_layer = tf .layers .dropout (self .get_hidden (),self .dropout_rate )
44+ self .words ,self .primary ,self .secondary ,self .labels = self .load_data ()
45+ self .start = 0
46+ self .data_count = len (self .words )
47+ else :
48+ self .hidden_layer = tf .expand_dims (tf .layers .dropout (self .get_hidden (),self .dropout_rate ),0 )
49+ self .output_no_softmax = tf .matmul (self .hidden_layer ,self .full_connected_weight )+ self .full_connected_bias
50+ self .output = tf .nn .softmax (tf .matmul (self .hidden_layer ,self .full_connected_weight )+ self .full_connected_bias )
51+ self .params = [self .position_embedding ,self .word_embedding ,self .full_connected_weight ,
52+ self .full_connected_bias ]+ self .conv_kernel + self .bias
53+ self .regularization = tf .contrib .layers .apply_regularization (tf .contrib .layers .l2_regularizer (self .lam ),
54+ self .params )
55+ self .loss = tf .reduce_sum (tf .square (self .output - self .input_relation ))/ self .batch_size + self .regularization
56+ self .cross_entropy = tf .nn .softmax_cross_entropy_with_logits (labels = self .input_relation ,
57+ logits = self .output_no_softmax )+ self .regularization
58+ self .optimizer = tf .train .GradientDescentOptimizer (self .learning_rate )
59+ # self.optimizer = tf.train.AdagradOptimizer(self.learning_rate)
60+ self .train_model = self .optimizer .minimize (self .loss )
61+ self .train_cross_entropy_model = self .optimizer .minimize (self .cross_entropy )
62+ self .saver = tf .train .Saver (max_to_keep = 100 )
63+
64+ def get_conv_kernel (self ):
65+ conv_kernel = []
66+ for w in self .window_size :
67+ conv_kernel .append (self .__weight_variable ([w ,self .concat_embed_size ,1 ,self .filter_size ],name = 'conv_kernel' ))
68+ return conv_kernel
69+
70+ def get_max_pooling (self ,x ):
71+ max_pooling = []
72+ for w in self .window_size :
73+ max_pooling .append (self .max_pooling (x ,w ))
74+ return max_pooling
75+
76+ def get_hidden (self ):
77+ h = None
78+ for w ,conv ,bias in zip (self .window_size ,self .conv_kernel ,self .bias ):
79+ if h is None :
80+ h = tf .squeeze (self .max_pooling (tf .nn .relu (self .conv (conv )+ bias ),w ))
81+ else :
82+ hh = tf .squeeze (self .max_pooling (tf .nn .relu (self .conv (conv )+ bias ),w ))
83+ if self .mode == 'train' :
84+ h = tf .concat ([h ,hh ],1 )
85+ else :
86+ h = tf .concat ([h ,hh ],0 )
87+ return h
88+
89+ def conv (self ,conv_kernel ):
90+ return tf .nn .conv2d (self .input ,conv_kernel ,strides = [1 ,1 ,1 ,1 ],padding = 'VALID' )
91+
92+ def max_pooling (self ,x ,window_size ):
93+ return tf .nn .max_pool (x ,ksize = [1 ,self .batch_length - window_size + 1 ,1 ,1 ],
94+ strides = [1 ,1 ,1 ,1 ],padding = 'VALID' )
95+
96+ def fit (self ,epochs = 100 ,interval = 20 ):
97+ with tf .Session ()as sess :
98+ tf .global_variables_initializer ().run ()
99+ sess .graph .finalize ()
100+ for i in range (1 ,epochs + 1 ):
101+ print ('epoch:' + str (i ))
102+ for _ in range (self .data_count // self .batch_size ):
103+ words ,primary ,secondary ,labels = self .load_batch ()
104+ character_embeds ,primary_embeds = sess .run ([self .character_lookup ,self .position_lookup ],
105+ feed_dict = {self .input_characters :words ,
106+ self .input_position :primary })
107+ secondary_embeds = sess .run (self .position_lookup ,feed_dict = {self .input_position :secondary })
108+ input = sess .run (self .emebd_concat ,feed_dict = {self .character_embed_holder :character_embeds ,
109+ self .primary_embed_holder :primary_embeds ,
110+ self .secondary_embed_holder :secondary_embeds })
111+ # sess.run(self.train_model, feed_dict={self.input: input, self.input_relation: batch['label']})
112+ sess .run (self .train_cross_entropy_model ,feed_dict = {self .input :input ,self .input_relation :labels })
113+ if i % interval == 0 :
114+ model_name = '../dnlp/models/re/{0}-{1}.ckpt' .format (i ,'_' .join (map (str ,self .window_size )))
115+ self .saver .save (sess ,model_name )
116+
117+ def load_batch (self ):
118+ if self .start + self .batch_size > self .data_count :
119+ new_start = self .start + self .batch_size - self .data_count
120+ words = np .concatenate ([self .words [self .start :],self .words [:new_start ]])
121+ primary = np .concatenate ([self .primary [self .start :],self .primary [:new_start ]])
122+ secondary = np .concatenate ([self .secondary [self .start :],self .secondary [:new_start ]])
123+ labels = np .concatenate ([self .labels [self .start :],self .labels [:new_start ]])
124+ self .start = new_start
125+ else :
126+ new_start = self .start + self .batch_size
127+ words = self .words [self .start :new_start ]
128+ primary = self .primary [self .start :new_start ]
129+ secondary = self .secondary [self .start :new_start ]
130+ labels = self .labels [self .start :new_start ]
131+ self .start = new_start
132+ return words ,primary ,secondary ,labels
12133
134+ def load_data (self ):
135+ primary = []
136+ secondary = []
137+ words = []
138+ labels = []
139+ with open (self .data_path ,'rb' )as f :
140+ data = pickle .load (f )
141+ for sentence in data :
142+ sentence_words = sentence ['words' ]
143+ if len (sentence_words )< self .batch_length :
144+ sentence_words += [self .dictionary [BATCH_PAD ]]* (self .batch_length - len (sentence_words ))
145+ else :
146+ sentence_words = sentence_words [:self .batch_length ]
147+ words .append (sentence_words )
148+ primary .append (np .arange (self .batch_length )- sentence ['primary' ]+ self .batch_length - 1 )
149+ secondary .append (np .arange (self .batch_length )- sentence ['secondary' ]+ self .batch_length - 1 )
150+ sentence_labels = np .zeros ([self .relation_count ])
151+ sentence_labels [sentence ['type' ]]= 1
152+ labels .append (sentence_labels )
153+ return np .array (words ,np .int32 ),np .array (primary ,np .int32 ),np .array (secondary ,np .int32 ),np .array (labels ,
154+ np .float32 )
13155
14- def __weight_variable (self ,shape ,name ):
156+ def __weight_variable (self ,shape ,name ):
15157initial = tf .truncated_normal (shape ,stddev = 0.1 ,dtype = self .dtype )
16- return tf .Variable (initial ,name = name )
158+ return tf .Variable (initial ,name = name )