Test time augmnentation wrapper for keras image segmentation and classification models.
Input | # input image; shape 1, H, W, C / / / \ \ \ # duplicate image for augmentation; shape N, H, W, C | | | | | | # apply augmentations (flips, rotation, shifts) your Keras model | | | | | | # reverse transformations \ \ \ / / / # merge predictions (mean, max, gmean) | # output mask; shape 1, H, W, C Output
fromkeras.modelsimportload_modelfromtta_wrapperimporttta_segmentationmodel=load_model('path/to/model.h5')tta_model=tta_segmentation(model,h_flip=True,rotation=(90,270),h_shift=(-5,5),merge='mean')y=tta_model.predict(x)