Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

Test-Time Augmentation library for Pytorch

NotificationsYou must be signed in to change notification settings

lartpang/tta.pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Features

  • Support image segmentation task.
  • Support image classification task.

Requirements

  • torch
  • torchvision

Usage

More details can be found in the examples:

First, we need to import all necessary classes and define a TTA transform and a dummy image.

fromtta_pytorchimportTYPES,Chain,Compose,Flip,HFlip,Merger,Rescale,Resize,VFliptta_trans=Compose(    [Rescale(scales=[0.5],image_mode="bilinear",image_align_corners=False,mask_mode="bilinear",mask_align_corners=False,        ),Resize(sizes=[128],image_mode="bilinear",image_align_corners=False,mask_mode="bilinear",mask_align_corners=False,        ),Flip(),HFlip(),# VFlip(),    ],verbose=True,)image=torch.randn(3,1,50,50,dtype=torch.float32)

Next, we can use the TTA transform to augment the image and get the deaugmented results.

Here are some examples of how to use the TTA transform.

  1. Basic usage.
tta_results=Merger()fortransintta_trans:trans:Chainaug_image=trans.do_image(image)undo_image=trans.undo_image(aug_image)tta_results.append(undo_image)seg_results=tta_results.result
  1. Usage with the Merger class and the integrated augmentation and deaugmentation pipline.
tta_results=Merger()fortransintta_trans:trans:Chainaug_images:List[torch.Tensor]=trans.do_all(inputs=[image],input_types=[TYPES.IMAGE])undo_images=trans.undo_all(outputs=aug_images,output_types=[TYPES.MASK])tta_results.append(undo_images[0])seg_results=tta_results.result
  1. Usage with the Merger class for segmentation and classification, and the seperate augmentation and deaugmentation piplines.
tta_seg_merger=Merger(mode="mean")tta_cls_merger=Merger(mode="mean")tta_seg_merger.reset()tta_cls_merger.reset()fortranintta_trans:tran:Chainaug_tensor=tran.do_image(image)# simulate real datamask=aug_tensorlabel=torch.randn(3,1000,dtype=torch.float32)# for segmentation, [B,K,H,W]undo_mask=tran.undo_image(mask)tta_seg_merger.append(undo_mask)# for classification, [B,K]undo_label=tran.undo_label(label)tta_cls_merger.append(undo_label)seg_results=tta_seg_merger.resultseg_mask=seg_results.argmax(dim=1)# [B,H,W]cls_results=tta_cls_merger.resultcls_score,cls_index=cls_results.max(dim=1)# [B], [B]
  1. Usage with the built-in list and the seperate augmentation and deaugmentation piplines.
tta_seg_results= []tta_cls_results= []fortranintta_trans:tran:Chainaug_tensor=tran.do_image(image)# simulate real datamask=aug_tensorlabel=torch.randn(3,1000,dtype=torch.float32)# for segmentation, [B,K,H,W]undo_mask=tran.undo_image(mask)tta_seg_results.append(undo_mask)# for classification, [B,K]undo_label=tran.undo_label(label)tta_cls_results.append(undo_label)seg_results=sum(tta_seg_results)/len(tta_seg_results)seg_mask=seg_results.argmax(dim=1)# [B,H,W]cls_results=sum(tta_cls_results)/len(tta_cls_results)cls_score,cls_index=cls_results.max(dim=1)# [B], [B]
  1. Usage with the decorator.
@tta_trans.decorate(input_infos={"image":TYPES.IMAGE},output_infos={"mask":TYPES.MASK,"label":TYPES.LABEL},merge_mode="mean",)defdo_something(image=None):label=torch.randn(3,1000,dtype=torch.float32)return {"mask":image,"label":label}tta_results=do_something(image=image)

Cite

If you find this library useful, please cite our bibtex:

@online{tta.pytorch,author="lartpang",title="{Test-Time Augmentation library for Pytorch}",url="https://github.com/lartpang/tta.pytorch",note="(Dec 20, 2023)",}

Releases

No releases published

Languages


[8]ページ先頭

©2009-2025 Movatter.jp