Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit24f1c45

Browse files
authored
Merge pull request#9 from luzai/master
fix depth of resnet/preresnet on cifar10/cifar100
2 parents6e1d46e +803ce40 commit24f1c45

File tree

3 files changed

+25
-10
lines changed

3 files changed

+25
-10
lines changed

‎cifar.py‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@
6464
' | '.join(model_names)+
6565
' (default: resnet18)')
6666
parser.add_argument('--depth',type=int,default=29,help='Model depth.')
67+
parser.add_argument('--block-name',type=str,default='BasicBlock',
68+
help='the building block for Resnet and Preresnet: BasicBlock, Bottleneck (default: Basicblock for cifar10/cifar100)')
6769
parser.add_argument('--cardinality',type=int,default=8,help='Model cardinality (group).')
6870
parser.add_argument('--widen-factor',type=int,default=4,help='Widen factor. 4 -> 64, 8 -> 128, ...')
6971
parser.add_argument('--growthRate',type=int,default=12,help='Growth rate for DenseNet.')
@@ -161,14 +163,14 @@ def main():
161163
model=models.__dict__[args.arch](
162164
num_classes=num_classes,
163165
depth=args.depth,
166+
block_name=args.block_name,
164167
)
165168
else:
166169
model=models.__dict__[args.arch](num_classes=num_classes)
167170

168171
model=torch.nn.DataParallel(model).cuda()
169172
cudnn.benchmark=True
170173
print(' Total params: %.2fM'% (sum(p.numel()forpinmodel.parameters())/1000000.0))
171-
172174
criterion=nn.CrossEntropyLoss()
173175
optimizer=optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay)
174176

‎models/cifar/preresnet.py‎

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,19 @@ def forward(self, x):
9292

9393
classPreResNet(nn.Module):
9494

95-
def__init__(self,depth,num_classes=1000):
95+
def__init__(self,depth,num_classes=1000,block_name='BasicBlock'):
9696
super(PreResNet,self).__init__()
9797
# Model type specifies number of layers for CIFAR-10 model
98-
assert (depth-2)%6==0,'depth should be 6n+2'
99-
n= (depth-2)//6
100-
101-
block=Bottleneckifdepth>=44elseBasicBlock
98+
ifblock_name.lower()=='basicblock':
99+
assert (depth-2)%6==0,'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202'
100+
n= (depth-2)//6
101+
block=BasicBlock
102+
elifblock_name.lower()=='bottleneck':
103+
assert (depth-2)%9==0,'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199'
104+
n= (depth-2)//9
105+
block=Bottleneck
106+
else:
107+
raiseValueError('block_name shoule be Basicblock or Bottleneck')
102108

103109
self.inplanes=16
104110
self.conv1=nn.Conv2d(3,16,kernel_size=3,padding=1,

‎models/cifar/resnet.py‎

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,20 @@ def forward(self, x):
9292

9393
classResNet(nn.Module):
9494

95-
def__init__(self,depth,num_classes=1000):
95+
def__init__(self,depth,num_classes=1000,block_name='BasicBlock'):
9696
super(ResNet,self).__init__()
9797
# Model type specifies number of layers for CIFAR-10 model
98-
assert (depth-2)%6==0,'depth should be 6n+2'
99-
n= (depth-2)//6
98+
ifblock_name.lower()=='basicblock':
99+
assert (depth-2)%6==0,'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202'
100+
n= (depth-2)//6
101+
block=BasicBlock
102+
elifblock_name.lower()=='bottleneck':
103+
assert (depth-2)%9==0,'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199'
104+
n= (depth-2)//9
105+
block=Bottleneck
106+
else:
107+
raiseValueError('block_name shoule be Basicblock or Bottleneck')
100108

101-
block=Bottleneckifdepth>=44elseBasicBlock
102109

103110
self.inplanes=16
104111
self.conv1=nn.Conv2d(3,16,kernel_size=3,padding=1,

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp