Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

[CVPR 2022 Oral] Balanced MSE for Imbalanced Visual Regressionhttps://arxiv.org/abs/2203.16427

License

NotificationsYou must be signed in to change notification settings

jiawei-ren/BalancedMSE

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

27 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Code for the paper:

Balanced MSE for Imbalanced Visual Regression
Jiawei Ren, Mingyuan Zhang, Cunjun Yu, Ziwei Liu

CVPR 2022 (Oral)

News

Live Demo

Check out ourlive demo in the Hugging Face 🤗 space!

Tutorial

We provide a minimal working example of Balanced MSE using the BMC implementation on a small-scale dataset,Boston Housing dataset.

Open In Colab

The notebook is developed on top ofDeep Imbalanced Regression (DIR) Tutorial,we thank the authors for their amazing tutorial!

Quick Preview

A code snippet of the Balanced MSE loss is shown below. We use the BMC implementation for demonstration,BMC does not require any label prior beforehand.

One-dimensional Balanced MSE

defbmc_loss(pred,target,noise_var):"""Compute the Balanced MSE Loss (BMC) between `pred` and the ground truth `targets`.    Args:      pred: A float tensor of size [batch, 1].      target: A float tensor of size [batch, 1].      noise_var: A float number or tensor.    Returns:      loss: A float tensor. Balanced MSE Loss.    """logits=- (pred-target.T).pow(2)/ (2*noise_var)# logit size: [batch, batch]loss=F.cross_entropy(logits,torch.arange(pred.shape[0]))# contrastive-like lossloss=loss* (2*noise_var).detach()# optional: restore the loss scale, 'detach' when noise is learnablereturnloss

noise_var is a one-dimensional hyper-parameter.noise_var can be optionally optimized in training:

classBMCLoss(_Loss):def__init__(self,init_noise_sigma):super(BMCLoss,self).__init__()self.noise_sigma=torch.nn.Parameter(torch.tensor(init_noise_sigma))defforward(self,pred,target):noise_var=self.noise_sigma**2returnbmc_loss(pred,target,noise_var)criterion=BMCLoss(init_noise_sigma)optimizer.add_param_group({'params':criterion.noise_sigma,'lr':sigma_lr,'name':'noise_sigma'})

Multi-dimensional Balanced MSE

The multi-dimensional implementation is compatible with the 1-D version.

fromtorch.distributionsimportMultivariateNormalasMVNdefbmc_loss_md(pred,target,noise_var):"""Compute the Multidimensional Balanced MSE Loss (BMC) between `pred` and the ground truth `targets`.    Args:      pred: A float tensor of size [batch, d].      target: A float tensor of size [batch, d].      noise_var: A float number or tensor.    Returns:      loss: A float tensor. Balanced MSE Loss.    """I=torch.eye(pred.shape[-1])logits=MVN(pred.unsqueeze(1),noise_var*I).log_prob(target.unsqueeze(0))# logit size: [batch, batch]loss=F.cross_entropy(logits,torch.arange(pred.shape[0]))# contrastive-like lossloss=loss* (2*noise_var).detach()# optional: restore the loss scale, 'detach' when noise is learnablereturnloss

noise_var is still a one-dimensional hyper-parameter and can be optionally learned in training.

Run Experiments

Please go into the sub-folder to run experiments.

As for IHMR, we have released our code and pretrained models inMMHuman3d

Citation

@inproceedings{ren2021bmse,title={Balanced MSE for Imbalanced Visual Regression},author={Ren, Jiawei and Zhang, Mingyuan and Yu, Cunjun and Liu, Ziwei},booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},year={2022}}

Acknowledgment

This work is supported by NTU NAP, MOE AcRF Tier 2 (T2EP20221-0033), the National Research Foundation, Singapore under its AI Singapore Programme, and under the RIE2020 Industry Alignment Fund – Industry Collabo- ration Projects (IAF-ICP) Funding Initiative, as well as cash and in-kind contribution from the industry partner(s).

The code is developed on top ofDelving into Deep Imbalanced Regression.


[8]ページ先頭

©2009-2025 Movatter.jp