Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

Deep Learning Template with JAX

License

NotificationsYou must be signed in to change notification settings

renardyreveur/jax-nn-template

Repository files navigation

JAX Logo

Neural Network Development and Training template with JAX.

Introduction|Structure|Usage|Test Envs|Acknowledgements

Introduction

JAX is awesome ❤️; this project aims to provide a starting template for training neural networks built with JAX.

Some features that this template have:

  • Clear structure📚 to differentiate between the multiple parts involved in training a neural network (Data-loading, Modelling, Training, etc.)
  • Custom(izable) neural network library🤔 with custom(izable) optimizers and update functions
  • JSON configuration file✏️ to easily define/switch between training hyper-parameters
  • Checkpoint saving✅ and Training logging
  • WandB Integration 📉
  • Experimentaltflite andONNX conversion support 🔃

The template is preloaded with aMNIST training sample, so it can be run straight away to see how the cookie 🍪 crumbles (in a good way!)

The project sample doesn't useFlax orHaiku for modelling the neural network (or any of the optimizers or the update function, etc.), but I think(not tested) it would be possible to combine the libraries with the custom definitions in the template. This is intentional as I wanted to practice designing and understanding the layers involved in building a neural network, and to also have maximum customization powers. The way JAX handlespytrees is very helpful in creating user-defined trees to represent various parts of the training process(parameters, etc.).

Structure

Directory/Root FileDescription
dataA place where your data resides
data_loaderbase.py contains a base class for data loaders
Implement aDataset andDataloader indataloader.py
inferenceconvert_tflite.py is a script that converts your JAX model and weights
into atflite file and if possible, anonnx file
(Check out theUsage section below)
modelmodules contains basic building blocks used in creating neural networks, very raw at this stage, extend as necessary
Usemodel.py to define neural networks to use in training
savedEmpty folder that will be populated with trained weight checkpoints
trainingThis is where you define yourloss function,training/validation metrics, andoptimizers
Theupdate function is inoptimizers.py
config.jsonJSON file containing hyper-parameters for the training runs
logger_config.jsonlogger configuration file
test.pyEmpty file, to be populated with test scripts, etc.
train.pyTHE MAIN FILE - Controls the entire logic of training
utils.pyUtility file containing various small functions that have been refactored out
requirements.txtPython package requirements file

Usage

The usage of this template is typically separated into three parts:Development,Training,Inference

Development

Data Loading

  • In order to use this template, you need to first add training/testing data to thedata folder
  • The data will be fed to the trainer using the data-loader defined indata_loader/dataloader.py.
  • Create a dataset class that carries at least the__getitem__ and__len__ functions
  • The dataset object is fed into a dataloader that inherits theBaseDataLoader
  • Define acustom batch collate function if needed and add it as a parameter to the dataloader (if you want to control this from config.json, create agetattr parser as well!)

Model Creation

  • Try and adopt the functional programming paradigm, especially for JIT.
  • Create neural networks to train inmodel/model.py using the base layers in themodel/module module / custom layers / Flax, Haiku, etc. asPython functions
  • The 'made-from-scratch' models in this template carries the model parameters in a convoluted tree of list and dictionaries. This is fed into the model as a function parameter
  • Base layers require a parameter initialization part where it creates initial parameters when none are given
  • To make it work with the experimental onnx converter, you might need to use approximations or alternative forms for the same layer (as shown in the template)

Define Loss, Metric, Optimizer

  • Again as python functions, define your loss functions, metric functions and optimizers inside thetraining folder.
  • vmap comes in handy here for easily making batch-ful computations
  • optimizers can carry parameters as well, the same way as model parameters are handled
  • Theupdate function in thetraining/optimizers.py can probably handle multiple optimizers and multiple models at once, but fix as necessary

Tune the main training script

  • Tune thetrain.py main script as necessary for your experimentation

Training

Once the development phase is over, you are ready to train.The following are controlled by the configuration JSON fileconfig.json, edit to your need

Hyper-parameters in JSON:

  • WandB logging configurations
  • Which model to use (by name), and its arguments
  • Which data/test-loader to use (by name), and its arguments
  • Which optimizer to use (by name), and its arguments
  • Which loss function to use (by name)
  • Which metrics to track (by name)
  • How many epochs to train for
  • Whether to load pre-trained weights before training
  • Where to save checkpoints and the interval of doing so

Add more options and change the codebase accordingly for extended use cases.

Once you are set to go, run the following command to start training with your configuration

python train.py -c config.json

Hopefully, it will produce a console log similar to the following:

[main INFO]: 2022-04-25 02:36:14,020 - Start Training! [MNIST_Sample][main INFO]: 2022-04-25 02:36:14,053 - Dataloader, Model, Optimizer, Lossfunctionloaded![absl INFO]: 2022-04-25 02:36:14,058 - Remote TPU is not linked into jax; skipping remote TPU.[absl INFO]: 2022-04-25 02:36:14,058 - Unable to initialize backend'tpu_driver': Could not initialize backend'tpu_driver'[absl INFO]: 2022-04-25 02:36:14,138 - Unable to initialize backend'tpu': INVALID_ARGUMENT: TpuPlatform is not available.[main INFO]: 2022-04-25 02:36:18,151 - The model has 13810 parameters[main INFO]: 2022-04-25 02:36:19,369 - ---- STARTING EPOCH 1 ----[main INFO]: 2022-04-25 02:36:24,354 - Epoch 1 [0/118 (0%)] -- Loss: 5.887911796569824[main INFO]: 2022-04-25 02:36:30,918 - Epoch 1 [22/118 (19%)] -- Loss: 0.8348695635795593[main INFO]: 2022-04-25 02:36:37,899 - Epoch 1 [44/118 (37%)] -- Loss: 0.5723302960395813[main INFO]: 2022-04-25 02:36:44,871 - Epoch 1 [66/118 (56%)] -- Loss: 0.3618883192539215[main INFO]: 2022-04-25 02:36:51,849 - Epoch 1 [88/118 (75%)] -- Loss: 0.36900243163108826[main INFO]: 2022-04-25 02:36:58,628 - Epoch 1 [110/118 (93%)] -- Loss: 0.33672940731048584[main INFO]: 2022-04-25 02:37:43,137 - Epoch 1 [train] accuracy @ 0.9055989980697632 [main INFO]: 2022-04-25 02:37:54,400 - Epoch 1 [test] avg Loss: 0.2996721863746643[main INFO]: 2022-04-25 02:37:54,401 - Epoch 1 [test] accuracy @ 0.9099782109260559 [main INFO]: 2022-04-25 02:37:54,401 - ---- Epoch 1 took 95.03 seconds to complete! ----[main INFO]: 2022-04-25 02:37:54,401 - ---- STARTING EPOCH 2 ----...

As the training progresses, weight checkpoints will be saved under the designated 'save folder' specified in the JSON configuration file.

saved/│├── 'title' given in config.json ...   ├── datetime / wandb_id depending on configuration       ...      ├── model/       -   copy of model/ directory when training script was invoked                ├── checkpoint-epoch_n.params   # n being multiplies of save_period                ├── ...                ├── train.log    -   copy of console log as file                └── config.json  -   copy of config.json when training script was invoked

Inference

This last part deals with exporting the model and weights into something more standardized such asONNX.

The scriptinference/convert_tflite.py converts a given model definition and weights into a tensorflow-lite model file. If it is possible, it also tries to convert the tflite file into an onnx file usingtf2onnx

The syntax for the command is:

python inference/convert_tflite.py -c config.json --input (1, 1, 28, 28) --name jax_mnist -w saved/MNIST_Sample/../checkpoint-epoch_10.params

It uses the model and model args given in theconfig.json file to prepare the model

This depends on experimental features such asJAX-tflite conversion andtf2onnx,so making it work might take some close examinations and pedantic model poking

If ran correctly, the script creates a folder with the name provided by the--name parameter under theinference folder.

inference/│├── <name>       ├── model/        -   copy of model/ directory when training script was invoked       ├── config.json   -   copy of config.json when training script was invoked       ├── <name>.tflite - converted tflite model+weights file       └── <name>.onnx   - converted onnx model+weights file

The script prints logs to the console while it runs and provides information such as test outputs of each model, and input and output node names.

JAX model is run normally, tflite with tensorflow Interpreter, onnx with onnxruntime

There's still a long way to go for this part, such as adding dynamic axes options, etc.

Now the world is your oyster, deploy your neural network into various services and projects!

Test Envs

MNIST Training Sample and TFLITE/ONNX export tested with:

  • Conf 1

    • Windows 11
    • Python 3.9.12
    • jax/jaxlib==0.3.2 from the jax-windows-builder pre-built wheels
    • CUDA 11.6, cuDNN 8.3
    • IMPORTANT: use the option--use-deprecated legacy-resolver when installing therequirements.txt packages with pip under Windows

  • Conf 2

    • Ubuntu 20.04
    • Python 3.8.10
    • jax/jaxlib==0.3.7 from JAX official linux CUDA/cuDNN build
    • CUDA 11.6, cuDNN 8.3
    • For some reason, I needed to set XLA_PYTHON_CLIENT_MEM_FRACTION=0.80 or else it would produce a cuDNN conv error.The GPU used for Conf 2 was a mobile one with a very small VRAM.

Acknowledgements

As a PyTorch guy, I've been utilizing the project:pytorch-template for a while, adding functionality as needed. This project stems from my experience using it!

The windows build for 'jaxlib' with CUDA and cuDNN is provided bycloudhan with the repojax-windows-builder, this saved me a lot of hassle when working in a Windows environment!

If you have any questions about the project or would like to contribute, please feel free to submit an issue, PR, or send me a DM on Twitter(@jeehoonlerenard)

Releases

No releases published

Packages

No packages published

Languages


[8]ページ先頭

©2009-2025 Movatter.jp