Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

Test Time image Augmentation (TTA) wrapper for Keras model.

NotificationsYou must be signed in to change notification settings

qubvel/tta_wrapper

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

19 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyPI version

TTA wrapper

Test time augmnentation wrapper for keras image segmentation and classification models.

Description

How it works?

Wrapper add augmentation layers to your Keras model like this:

          Input            |           # input image; shape 1, H, W, C       / / / \ \ \      # duplicate image for augmentation; shape N, H, W, C      | | |   | | |     # apply augmentations (flips, rotation, shifts)     your Keras model      | | |   | | |     # reverse transformations       \ \ \ / / /      # merge predictions (mean, max, gmean)            |           # output mask; shape 1, H, W, C          Output

Arguments

  • h_flip - bool, horizontal flip augmentation
  • v_flip - bool, vertical flip augmentation
  • rotataion - list, allowable angles - 90, 180, 270
  • h_shift - list of int, horizontal shift augmentation in pixels
  • v_shift - list of int, vertical shift augmentation in pixels
  • add - list of int/float, additive factor (aug_image = image + factor)
  • mul - list of int/float, additive factor (aug_image = image * factor)
  • contrast - list of int/float, contrast adjustment factor (aug_image = (image - mean) * factor + mean)
  • merge - one of 'mean', 'gmean' and 'max' - mode of merging augmented predictions together

Constraints

  1. model has to have 1input and 1output
  2. inferencebatch_size == 1
  3. imageheight == width ifrotation augmentation is used

Installation

  1. PyPI package:
$ pip install tta-wrapper
  1. Latest version:
$ pip install git+https://github.com/qubvel/tta_wrapper/

Example

fromkeras.modelsimportload_modelfromtta_wrapperimporttta_segmentationmodel=load_model('path/to/model.h5')tta_model=tta_segmentation(model,h_flip=True,rotation=(90,270),h_shift=(-5,5),merge='mean')y=tta_model.predict(x)

About

Test Time image Augmentation (TTA) wrapper for Keras model.

Topics

Resources

Stars

Watchers

Forks

Packages

No packages published

Languages


[8]ページ先頭

©2009-2025 Movatter.jp