- Notifications
You must be signed in to change notification settings - Fork8
jayparks/quasi-rnn
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Pytorch implementation of Neural Machine Translation using"Quasi-Recurrent Neural Networks", ICLR 2017
- NumPy >= 1.11.1
- Pytorch >= 0.2.0
- layer.py : Implementation of the quasi-recurrent layer
- model.py: Implementation of the Encoder-Decoder model using qrnn layer
- train.py: Code to train a NMT model
- decode.py: Code to translate a source file using a trained model
To train a quasi-rnn NMT model,
$pythontrain.py --kernel_size3 \ --hidden_size640 \ --emb_size500 \ --num_enc_symbols30000 \ --num_dec_symbols30000 ...
To run the trained model for translation,
$pythoneval.py --model_path $path_to_model \ --decode_input $path_to_source \ --decode_output $path_to_output --max_decode_step300 \ --batch_size30 ...
For simplicity, we used greedy decoding at each time step, not the beam search decoding.
For more in-depth exploration, QRNN API for Pytorch is available:https://github.com/salesforce/pytorch-qrnn
For any comments and feedbacks, please email me atpjh0308@gmail.com or open an issue here.