- Notifications
You must be signed in to change notification settings - Fork25
Pretrained deep learning models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet, etc.
matthias-wright/flaxmodels
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
The goal of this project is to make current deep learning models more easily available for the awesomeJax/Flax ecosystem.
- GPT2 [model]
- StyleGAN2 [model] [training]
- ResNet{18, 34, 50, 101, 152} [model] [training]
- VGG{16, 19} [model] [training]
- FewShotGanAdaption [model] [training]
You will need Python 3.7 or later.
- For GPU usage, follow theJax installation with CUDA.
- Then install:
> pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git
For CPU-only you can skip step 1.
The documentation for the models can be foundhere.
The checkpoints are taken from the repositories that are referenced on the model pages. The processing steps and the format of the checkpoints are documentedhere.
To run the tests, pytest needs to be installed.
> git clone https://github.com/matthias-wright/flaxmodels.git>cd flaxmodels> python -m pytest tests/
Seehere for an explanation of the testing strategy.
Thank you to the developers of Jax and Flax. The title image is a photograph of a flax flower, kindly made available byMarta Matyszczyk.
Each model has an individual license.