- Notifications
You must be signed in to change notification settings - Fork87
Unofficial tensorflow implemention of "Attentive Generative Adversarial Network for Raindrop Removal from A Single Image (CVPR 2018) " modelhttps://maybeshewill-cv.github.io/attentive-gan-derainnet/
License
MaybeShewill-CV/attentive-gan-derainnet
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Use tensorflow to implement a Deep Convolution Generative Adversarial Network for image deraintask mainly based on the CVPR2018 paper "Attentive Generative Adversarial Network for RaindropRemoval from A Single Image".You can refer to their paper for detailshttps://arxiv.org/abs/1711.10098.This model consists of a attentive attentive-recurrent network, a contextual autoencodernetwork and a discriminative network. Using convolution lstm unit to generate attention mapwhich is used to help locating the rain drop, multi-scale losses and a perceptual loss totrain the context autoencoder network. Thanks for the origin authorRui Qian
The main network architecture is as follows:
This software has only been tested on ubuntu 16.04(x64), python3.5, cuda-9.0, cudnn-7.0 witha GTX-1070 GPU. To install this software you need tensorflow 1.15.0 and other version oftensorflow has not been tested but I think it will be able to work properly intensorflow above version 1.10. Other required package you may install them by
pip3 install -r requirements.txt
In this repo I uploaded a model trained on dataset provided by the origin authororigin_dataset.
The trained derain net model weights files are stored in folder weights/
You can test a single image on the trained model as follows
cd REPO_ROOT_DIRpython tools/test_model.py --weights_path ./weights/derain_gan/derain_gan.ckpt-100000--image_path ./data/test_data/test_1.png
The results are as follows:
Test Input Image
Test Derain result image
Test Attention Map at time 1
Test Attention Map at time 2
Test Attention Map at time 3
Test Attention Map at time 4
You need to organize your training examples. Put all of your rain images andclean images in two separate folders which are named afterSOURCE_DATA_ROOT_DIR/rain_image and SOURCE_DATA_ROOT_DIR/clean_image.The rest of the preparation work will be done by running following script
cd PROJECT_ROOT_DIRpython data_provider/data_feed_pipline.py --dataset_dir SOURCE_DATA_ROOT_DIR--tfrecords_dir TFRECORDS_SAVE_DIR
The training samples are consist of two components. A clean image freefrom rain drop label image and a origin image degraded by raindrops.
All your training image will be automatically scaled into the same scaleaccording to the config file and will be converted into tensorflow recordsfor efficient data feed pipline.
In my experiment the training epochs are 100010, batch size is 1, initialized learning rateis 0.002. About training parameters you can check the global_configuration/config.py fordetails.
You may call the following script to train your own model
cd REPO_ROOT_DIRpython tools/train_model.py --dataset_dir SOURCE_DATA_ROOT_DIR
You can also continue the training process from the snapshot by
cd REPO_ROOT_DIRpython tools/train_model.py --dataset_dir SOURCE_DATA_ROOT_DIR --weights_path path/to/your/last/checkpoint
You may monitor the training process using tensorboard tools
During my experiment theG loss
drops as follows:
TheImage SSIM between generated image and clean label image
raises as follows:
Please cite my repoattentive-gan-derainnetif you find it helps you.
The trained model can be convert into tensorflow saved model and tensorflow jsmodel for web useage. If you want to convert the ckpt model into tensorflowsaved model you may run following script
cd PROJECT_ROOT_DIRpython tools/export_tf_saved_model.py --export_dir ./weights/derain_gan_saved_model --ckpt_path ./weights/derain_gan/derain_gan.ckpt-100000
If you want to convert into tensorflow js model you can modified the bashscript and run it
cd PROJECT_ROOT_DIRbash tools/convert_tfjs_model.sh
Several users find out the nan loss problem may occasionally happen intraining process under tensorflow v1.3.0. I think it may be caused by the randomly parameterinitialization problem. My solution is to kill the training process andrestart it again to find a suitable initialized parameters. At themean time I have found out that if you use the model under tensorflowv1.10.0 the nan loss problem will not happen. The reason may be thedifference of parameter initialization function or the loss optimizerfunction between older tensorflow and newest tensorflow. If the nanloss problem still troubles you when training the model then upgradingyour local tensorflow may be a nice option. Good luck on training process!
Thanks for the issues byJay-Jia
Adjust the initialized learning rate and using exponential decaystrategy to adjust the learning rate during training process. Usingtraditional image augmentation function including random crop andrandom flip to augment the training dataset which protomed the newmodel performance. I have uploaded a new tensorboard record file andyou can check the image ssim to compare the two models. Newmodel weights can be found under weights/new_model folder.
The first row is the source test image in folder ./data/test_data, thesecond row is the derain result generated by the old model and the lastrow is the derain result generated by the new model. As you can see thenew model can recover more vivid details than the old model and I willupload a figure of ssim and psnr which will illustrate the new model'spromotion.
Since the batch size is 1 during the training process so the batchnormalization layer seems to be useless. All the bn layers were removedafter the new updates. I have trained a new model based on the newestcode and the new model will be placed in folder root_dir/weights/new_modeland the model updated on 2018.10.12 will be placed in folderroot_dir/weights/old_model. The new model can present more vivid detailscompared with the old model. The model's comparison result can be seenas follows.
The first row is the source test image in folder ./data/test_data, thesecond row is the derain result generated by the old model and the lastrow is the derain result generated by the new model. As you can see thenew model perform much better than the old model.
Since the bn layer will leads to a unstable result the deeper attentionmap of the old model will not catch valid information which is supposedto guide the model to focus on the rain drop. The attention map'scomparision result can be seen as follows.
Model attention map result comparision
The first row is the source test image in folder ./data/test_data, thesecond row is the attention map 4 generated by the old model and thelast row is the attention map 4 generated by the new model. As you cansee the new model catch much more valid attention information than theold model.
- Parameter adjustment
- Test different loss function design
- Add tensorflow service
About
Unofficial tensorflow implemention of "Attentive Generative Adversarial Network for Raindrop Removal from A Single Image (CVPR 2018) " modelhttps://maybeshewill-cv.github.io/attentive-gan-derainnet/
Topics
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Packages0
Uh oh!
There was an error while loading.Please reload this page.
Contributors3
Uh oh!
There was an error while loading.Please reload this page.