- Notifications
You must be signed in to change notification settings - Fork0
Deep Learning Template with JAX
License
renardyreveur/jax-nn-template
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Neural Network Development and Training template with JAX.
Introduction|Structure|Usage|Test Envs|Acknowledgements
JAX is awesome ❤️; this project aims to provide a starting template for training neural networks built with JAX.
Some features that this template have:
- Clear structure📚 to differentiate between the multiple parts involved in training a neural network (Data-loading, Modelling, Training, etc.)
- Custom(izable) neural network library🤔 with custom(izable) optimizers and update functions
JSON
configuration file✏️ to easily define/switch between training hyper-parameters- Checkpoint saving✅ and Training logging
WandB
Integration 📉- Experimental
tflite
andONNX
conversion support 🔃
The template is preloaded with a
MNIST training sample
, so it can be run straight away to see how the cookie 🍪 crumbles (in a good way!)
The project sample doesn't useFlax
orHaiku
for modelling the neural network (or any of the optimizers or the update function, etc.), but I think(not tested) it would be possible to combine the libraries with the custom definitions in the template. This is intentional as I wanted to practice designing and understanding the layers involved in building a neural network, and to also have maximum customization powers. The way JAX handlespytrees
is very helpful in creating user-defined trees to represent various parts of the training process(parameters, etc.).
Directory/Root File | Description |
---|---|
data | A place where your data resides |
data_loader | base.py contains a base class for data loadersImplement a Dataset andDataloader indataloader.py |
inference | convert_tflite.py is a script that converts your JAX model and weightsinto a tflite file and if possible, anonnx file(Check out theUsage section below) |
model | modules contains basic building blocks used in creating neural networks, very raw at this stage, extend as necessaryUse model.py to define neural networks to use in training |
saved | Empty folder that will be populated with trained weight checkpoints |
training | This is where you define yourloss function ,training/validation metrics , andoptimizers Theupdate function is in optimizers.py |
config.json | JSON file containing hyper-parameters for the training runs |
logger_config.json | logger configuration file |
test.py | Empty file, to be populated with test scripts, etc. |
train.py | THE MAIN FILE - Controls the entire logic of training |
utils.py | Utility file containing various small functions that have been refactored out |
requirements.txt | Python package requirements file |
The usage of this template is typically separated into three parts:Development,Training,Inference
Data Loading
- In order to use this template, you need to first add training/testing data to the
data
folder - The data will be fed to the trainer using the data-loader defined in
data_loader/dataloader.py
. - Create a dataset class that carries at least the
__getitem__
and__len__
functions - The dataset object is fed into a dataloader that inherits the
BaseDataLoader
- Define acustom batch collate function if needed and add it as a parameter to the dataloader (if you want to control this from config.json, create a
getattr
parser as well!)
Model Creation
- Try and adopt the functional programming paradigm, especially for JIT.
- Create neural networks to train in
model/model.py
using the base layers in themodel/module
module / custom layers / Flax, Haiku, etc. asPython functions
- The 'made-from-scratch' models in this template carries the model parameters in a convoluted tree of list and dictionaries. This is fed into the model as a function parameter
- Base layers require a parameter initialization part where it creates initial parameters when none are given
- To make it work with the experimental onnx converter, you might need to use approximations or alternative forms for the same layer (as shown in the template)
Define Loss, Metric, Optimizer
- Again as python functions, define your loss functions, metric functions and optimizers inside the
training
folder. vmap
comes in handy here for easily making batch-ful computations- optimizers can carry parameters as well, the same way as model parameters are handled
- The
update
function in thetraining/optimizers.py
can probably handle multiple optimizers and multiple models at once, but fix as necessary
Tune the main training script
- Tune the
train.py
main script as necessary for your experimentation
Once the development phase is over, you are ready to train.The following are controlled by the configuration JSON fileconfig.json
, edit to your need
Hyper-parameters in JSON:
- WandB logging configurations
- Which model to use (by name), and its arguments
- Which data/test-loader to use (by name), and its arguments
- Which optimizer to use (by name), and its arguments
- Which loss function to use (by name)
- Which metrics to track (by name)
- How many epochs to train for
- Whether to load pre-trained weights before training
- Where to save checkpoints and the interval of doing so
Add more options and change the codebase accordingly for extended use cases.
Once you are set to go, run the following command to start training with your configuration
python train.py -c config.json
Hopefully, it will produce a console log similar to the following:
[main INFO]: 2022-04-25 02:36:14,020 - Start Training! [MNIST_Sample][main INFO]: 2022-04-25 02:36:14,053 - Dataloader, Model, Optimizer, Lossfunctionloaded![absl INFO]: 2022-04-25 02:36:14,058 - Remote TPU is not linked into jax; skipping remote TPU.[absl INFO]: 2022-04-25 02:36:14,058 - Unable to initialize backend'tpu_driver': Could not initialize backend'tpu_driver'[absl INFO]: 2022-04-25 02:36:14,138 - Unable to initialize backend'tpu': INVALID_ARGUMENT: TpuPlatform is not available.[main INFO]: 2022-04-25 02:36:18,151 - The model has 13810 parameters[main INFO]: 2022-04-25 02:36:19,369 - ---- STARTING EPOCH 1 ----[main INFO]: 2022-04-25 02:36:24,354 - Epoch 1 [0/118 (0%)] -- Loss: 5.887911796569824[main INFO]: 2022-04-25 02:36:30,918 - Epoch 1 [22/118 (19%)] -- Loss: 0.8348695635795593[main INFO]: 2022-04-25 02:36:37,899 - Epoch 1 [44/118 (37%)] -- Loss: 0.5723302960395813[main INFO]: 2022-04-25 02:36:44,871 - Epoch 1 [66/118 (56%)] -- Loss: 0.3618883192539215[main INFO]: 2022-04-25 02:36:51,849 - Epoch 1 [88/118 (75%)] -- Loss: 0.36900243163108826[main INFO]: 2022-04-25 02:36:58,628 - Epoch 1 [110/118 (93%)] -- Loss: 0.33672940731048584[main INFO]: 2022-04-25 02:37:43,137 - Epoch 1 [train] accuracy @ 0.9055989980697632 [main INFO]: 2022-04-25 02:37:54,400 - Epoch 1 [test] avg Loss: 0.2996721863746643[main INFO]: 2022-04-25 02:37:54,401 - Epoch 1 [test] accuracy @ 0.9099782109260559 [main INFO]: 2022-04-25 02:37:54,401 - ---- Epoch 1 took 95.03 seconds to complete! ----[main INFO]: 2022-04-25 02:37:54,401 - ---- STARTING EPOCH 2 ----...
As the training progresses, weight checkpoints will be saved under the designated 'save folder' specified in the JSON configuration file.
saved/│├── 'title' given in config.json ... ├── datetime / wandb_id depending on configuration ... ├── model/ - copy of model/ directory when training script was invoked ├── checkpoint-epoch_n.params # n being multiplies of save_period ├── ... ├── train.log - copy of console log as file └── config.json - copy of config.json when training script was invoked
This last part deals with exporting the model and weights into something more standardized such asONNX.
The scriptinference/convert_tflite.py
converts a given model definition and weights into a tensorflow-lite model file. If it is possible, it also tries to convert the tflite file into an onnx file usingtf2onnx
The syntax for the command is:
python inference/convert_tflite.py -c config.json --input (1, 1, 28, 28) --name jax_mnist -w saved/MNIST_Sample/../checkpoint-epoch_10.params
It uses the model and model args given in theconfig.json
file to prepare the model
This depends on experimental features such asJAX-tflite conversion andtf2onnx,so making it work might take some close examinations and pedantic model poking
If ran correctly, the script creates a folder with the name provided by the--name
parameter under theinference
folder.
inference/│├── <name> ├── model/ - copy of model/ directory when training script was invoked ├── config.json - copy of config.json when training script was invoked ├── <name>.tflite - converted tflite model+weights file └── <name>.onnx - converted onnx model+weights file
The script prints logs to the console while it runs and provides information such as test outputs of each model, and input and output node names.
JAX model is run normally, tflite with tensorflow Interpreter, onnx with onnxruntime
There's still a long way to go for this part, such as adding dynamic axes options, etc.
Now the world is your oyster, deploy your neural network into various services and projects!
MNIST Training Sample and TFLITE/ONNX export tested with:
Conf 1
- Windows 11
- Python 3.9.12
- jax/jaxlib==0.3.2 from the jax-windows-builder pre-built wheels
- CUDA 11.6, cuDNN 8.3
IMPORTANT: use the option
--use-deprecated legacy-resolver
when installing therequirements.txt
packages with pip under Windows
Conf 2
- Ubuntu 20.04
- Python 3.8.10
- jax/jaxlib==0.3.7 from JAX official linux CUDA/cuDNN build
- CUDA 11.6, cuDNN 8.3
For some reason, I needed to set XLA_PYTHON_CLIENT_MEM_FRACTION=0.80 or else it would produce a cuDNN conv error.The GPU used for Conf 2 was a mobile one with a very small VRAM.
As a PyTorch guy, I've been utilizing the project:pytorch-template for a while, adding functionality as needed. This project stems from my experience using it!
The windows build for 'jaxlib' with CUDA and cuDNN is provided bycloudhan with the repojax-windows-builder, this saved me a lot of hassle when working in a Windows environment!
If you have any questions about the project or would like to contribute, please feel free to submit an issue, PR, or send me a DM on Twitter(@jeehoonlerenard)