- Notifications
You must be signed in to change notification settings - Fork4
Second-order differentiable PyTorch GRUs in JIT with TorchScript
License
Maghoumi/JitGRU
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
A simple implementation of GRUs using PyTorch's JIT (TorchScript). The API follows that oftorch.nn.GRU
. Should run reasonably fast.
At the time of writing, PyTorch does not support second order derivatives for GRUs with CUDA (seethis issue). As a result, any loss function that depends on computing the second derivatives of GRUs doesn't work on out of the box. I needed doublebackward()
calls for a project, so here it is!
The main implementation is available injit_gru.py.I've implemented equivalents oftorch.nn.GRUCell
andtorch.nn.GRU
in that file. Look at the test cases that I've included in the implementation. Those should help you get started.
Support for bi-directional GRUs with variable input lengths was recently added (credits go to@elixir-code). This implementation is available separately injit_bigru.py. See the included test cases in that file for example usage.
CheckoutDeepNAG, which contains a GAN-based sequence generation model, as well as a non-adversarial sequence generator.The GAN-based sequence generator in the aforementioned repository is trained with theimproved Wasserstein GAN loss function, and relies on the code from this repository.
If you find our work useful, please consider starring this repository and citing our work:
@phdthesis{maghoumi2020dissertation, title={{Deep Recurrent Networks for Gesture Recognition and Synthesis}}, author={Mehran Maghoumi}, year={2020}, school={University of Central Florida Orlando, Florida}}@misc{maghoumi2020deepnag, title={{DeepNAG: Deep Non-Adversarial Gesture Generation}}, author={Mehran Maghoumi and Eugene M. Taranta II and Joseph J. LaViola Jr}, year={2020}, eprint={2011.09149}, archivePrefix={arXiv}, primaryClass={cs.CV}}
I'm actively using this implementation, so contributions are greatly welcome as they help my work too. If you think you can improve this project, or implement something more efficiently, then feel free to submit pull requests!
This project is licensed under the MIT License.