- Notifications
You must be signed in to change notification settings - Fork126
Text summarization using seq2seq in Keras
License
chen0040/keras-text-summarization
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Text summarization using seq2seq and encoder-decoder recurrent networks in Keras
The follow neural network models are implemented and studied for text summarization:
The seq2seq models encodes the content of an article (encoder input) and one character (decoder input) from the summarized text to predict the next character in the summarized text
The implementation can be found inkeras_text_summarization/library/seq2seq.py
There are three variants of seq2seq model implemented for the text summarization
- Seq2SeqSummarizer (one hot encoding)
- training: rundemo/seq2seq_train.py
- prediction: demo code is available indemo/seq2seq_predict.py
- Seq2SeqGloVeSummarizer (GloVe encoding for encoder input)
- training: rundemo/seq2seq_glove_train.py
- prediction: demo code is available indemo/seq2seq_glove_predict.py
- Seq2SeqGloVeSummarizerV2 (GloVe encoding for both encoder input and decoder input)
- training: rundemo/seq2seq_glove_v2_train.py
- prediction: demo code is available indemo/seq2seq_glove_v2_predict.py
There are currently 3 other encoder-decoder recurrent models based on some recommendationhere
The implementation can be found inkeras_text_summarization/library/rnn.py
- One-Shot RNN (OneShotRNN inrnn.py):The one-shot RNN is a very simple encoder-decoder recurrent network model which encodes the content of an article and decodes the entire content of the summarized text
- training: rundemo/one_hot_rnn_train.py
- prediction: rundemo/one_hot_rnn_predict.py
- Recursive RNN 1 (RecursiveRNN1 inrnn.py):The recursive RNN 1 takes the artcile content and the current built-up summarized text to predict the next character of the summarized text.
- training: rundemo/recursive_rnn_v1_train.py
- prediction: rundemo/recursive_rnn_v1_predict.py
- Recursive RNN 2 (RecursiveRNN2 inrnn.py):The recursive RNN 2 takes the article content and the current built-up summarized text to predict the next character of the summarized text + one layer of LSTM decoder.
- training: rundemo/recursive_rnn_v2_train.py
- prediction: rundemo/recursive_rnn_v2_predict.py
The trained models are available in the demo/models folder
The demo below shows how to use seq2seq to do training and prediction, but other models described above also followthe same process of training and prediction.
To train a deep learning model, say Seq2SeqSummarizer, run the following commands:
pip install requirements.txtcd demopython seq2seq_train.py
The training code in seq2seq_train.py is quite straightforward and illustrated below:
from __future__importprint_functionimportpandasaspdfromsklearn.model_selectionimporttrain_test_splitfromkeras_text_summarization.library.utility.plot_utilsimportplot_and_save_historyfromkeras_text_summarization.library.seq2seqimportSeq2SeqSummarizerfromkeras_text_summarization.library.applications.fake_news_loaderimportfit_textimportnumpyasnpLOAD_EXISTING_WEIGHTS=Truenp.random.seed(42)data_dir_path='./data'report_dir_path='./reports'model_dir_path='./models'print('loading csv file ...')df=pd.read_csv(data_dir_path+"/fake_or_real_news.csv")print('extract configuration from input texts ...')Y=df.titleX=df['text']config=fit_text(X,Y)summarizer=Seq2SeqSummarizer(config)ifLOAD_EXISTING_WEIGHTS:summarizer.load_weights(weight_file_path=Seq2SeqSummarizer.get_weight_file_path(model_dir_path=model_dir_path))Xtrain,Xtest,Ytrain,Ytest=train_test_split(X,Y,test_size=0.2,random_state=42)history=summarizer.fit(Xtrain,Ytrain,Xtest,Ytest,epochs=100)history_plot_file_path=report_dir_path+'/'+Seq2SeqSummarizer.model_name+'-history.png'ifLOAD_EXISTING_WEIGHTS:history_plot_file_path=report_dir_path+'/'+Seq2SeqSummarizer.model_name+'-history-v'+str(summarizer.version)+'.png'plot_and_save_history(history,summarizer.model_name,history_plot_file_path,metrics={'loss','acc'})
After the training is completed, the trained models will be saved as cf-v1-. in the video_classifier/demo/models.
To use the trained deep learning model to summarize an article, the following code demo how to do this:
from __future__importprint_functionimportpandasaspdfromkeras_text_summarization.library.seq2seqimportSeq2SeqSummarizerimportnumpyasnpnp.random.seed(42)data_dir_path='./data'# refers to the demo/data foldermodel_dir_path='./models'# refers to the demo/models folderprint('loading csv file ...')df=pd.read_csv(data_dir_path+"/fake_or_real_news.csv")X=df['text']Y=df.titleconfig=np.load(Seq2SeqSummarizer.get_config_file_path(model_dir_path=model_dir_path)).item()summarizer=Seq2SeqSummarizer(config)summarizer.load_weights(weight_file_path=Seq2SeqSummarizer.get_weight_file_path(model_dir_path=model_dir_path))print('start predicting ...')foriinrange(20):x=X[i]actual_headline=Y[i]headline=summarizer.summarize(x)print('Article: ',x)print('Generated Headline: ',headline)print('Original Headline: ',actual_headline)
- Step 1: Change tensorflow to tensorflow-gpu in requirements.txt and install tensorflow-gpu
- Step 2: Download and install theCUDA® Toolkit 9.0 (Please note thatcurrently CUDA® Toolkit 9.1 is not yet supported by tensorflow, therefore you should download CUDA® Toolkit 9.0)
- Step 3: Download and unzip thecuDNN 7.0.4 for CUDA@ Toolkit 9.0 and add thebin folder of the unzipped directory to the $PATH of your Windows environment
About
Text summarization using seq2seq in Keras
Topics
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Releases
Packages0
Uh oh!
There was an error while loading.Please reload this page.