PyTorch 搭建 ResNet 模型对 CIFAR10 数据集分类总结
ResNet 模型搭建
结构分析
通用框架
根据 Deep Residual Learning for Image Recognition 论文中的信息,可以得到常规 ResNet 模型的通用框架如下
layer_name | out_size | (18/34 layers) out_channel | (50/101/152 layers) out_channel | kernel_size | stride | padding |
---|---|---|---|---|---|---|
Input | 224*224 | 3 | 3 | None | None | None |
Conv1 | 112*112 | 64 | 64 | 7 | 2 | 3 |
Maxpool | 56*56 | 64 | 64 | 3 | 2 | 1 |
Conv2_x | 56*56 | 64 | 64*4=256 | - | - | - |
Conv3_x | 28*28 | 128 | 128*4=512 | - | - | - |
Conv4_x | 14*14 | 256 | 256*4=1024 | - | - | - |
Conv5_x | 7*7 | 512 | 512*4=2048 | - | - | - |
Avgpool | 1*1 | 512 | 2048 | None | None | None |
Flatten | 2048 | 1 | 1 | None | None | None |
FC | 1000 | 1 | 1 | None | None | None |
其中 Conv2_x 、 Conv3_x 、Conv4_x 、 Conv5_x 层可由 BasicBlock 和 Bottleneck 两种基本模型组合而成
layer_name | ResNet18 | ResNet34 | ResNet50 | ResNet101 | ResNet152 |
---|---|---|---|---|---|
Conv2_x | BasicBlock*2 | BasicBlock*3 | Bottleneck*3 | Bottleneck*3 | Bottleneck*3 |
Conv3_x | BasicBlock*2 | BasicBlock*4 | Bottleneck*4 | Bottleneck*4 | Bottleneck*8 |
Conv4_x | BasicBlock*2 | BasicBlock*6 | Bottleneck*6 | Bottleneck*23 | Bottleneck*36 |
Conv5_x | BasicBlock*2 | BasicBlock*3 | Bottleneck*3 | Bottleneck*3 | Bottleneck*3 |
基础结构
ResNet 18/34 使用的 BasicBlock 结构如下
layer_name | in_size | out_size | out_channel | kernel_size | stride | padding |
---|---|---|---|---|---|---|
Conv1 | x*x | (x/stride)*(x/stride) | out_channel | 3 | stride | 1 |
Conv2 | x'*x' | x'*x' | out_channel | 3 | 1 | 1 |
identity |
ResNet 50/101/152 使用的 Bottleneck 结构如下
layer_name | in_size | out_size | out_channel | kernel_size | stride | padding |
---|---|---|---|---|---|---|
Conv1 | x*x | (x/stride)*(x/stride) | out_channel | 1 | 1 | 0 |
Conv2 | x'*x' | x'*x' | out_channel | 3 | stride | 1 |
Conv3 | x'*x' | x'*x' | out_channel | 1 | 1 | 0 |
identity |
stride 和 identity
当基础结构是 Conv3_x 、Conv4_x 、 Conv5_x 的第一层时, stride=2
且 identity 为下采样后的输入
1 | nn.Sequential( |
当基础结构是 Conv3_x 、Conv4_x 、 Conv5_x 的其他层或在 Conv2_x 层时, stride=1
且 identity 为输入本身
1 | nn.Sequential() |
网络实现
BasicBlock
1 | class BasicBlock(nn.Module): |
Bottleneck
1 | class Bottleneck(nn.Module): |
通用框架
1 | class ResNet(nn.Module): |
构造网络
1 | def ResNet18() -> ResNet: |
参考网页
PyTorch - SOURCE CODE FOR TORCHVISION.MODELS.RESNET
明素 - ResNet详解
回顾
nn.Conv2d()
函数 bias
参数的设置
当 nn.Conv2d()
后接 nn.BatchNorm2d()
时,可以把 bias
参数设置为 False
因为在 BN 层中,输入是否存在偏置不影响输出结果
不添加偏置还可以减少显卡内存的占用
参考网页
7s记忆的鱼 - 【pytorch】Conv2d()里面的参数bias什么时候加,什么时候不加?
nn.AdaptiveAvgPool2d
函数
参考网页
*参数
的作用
*参数
可以解压参数
1 | a = (0,1,2,3,4,5,6,7,8,9) |
将 List 和 Tuple 中的元素逐一解压出来
1 | 0 1 2 3 4 5 6 7 8 9 |
参考网页
TEDxPY - Python *args 用法笔记
pip install
默认安装在 base
环境
使用 pip install
时改用如下指令即可安装到当前虚拟环境中
1 | python -m pip install ** |
参考网页
timertimer - 在conda虚拟环境中用pip安装包总是在base环境中的解决办法
CIFAR10 特化模型
layer_name | out_size | out_channel | kernel_size | stride | padding | |
---|---|---|---|---|---|---|
Input | 32*32 | 3 | None | None | None | |
Conv1 | 32*32 | 16 | 1 | 1 | 31 | |
Conv2_x | 32*32 | 16 | 64*4=256 | - | - | - |
Conv3_x | 16*16 | 32 | 128*4=512 | - | - | - |
Conv4_x | 8*8 | 64 | 256*4=1024 | - | - | - |
Avgpool | 1*1 | 64 | None | None | None | |
Flatten | 64 | 1 | None | None | None | |
FC | 10 | 1 | None | None | None |
其中 Conv2_x 、 Conv3_x 、Conv4_x 层由 Block 组成
layer_name | ResNet_CIFAR10 |
---|---|
Conv2_x | Block*n |
Conv3_x | Block*n |
Conv4_x | Block*n |
Block 结构如下
layer_name | in_size | out_size | out_channel | kernel_size | stride | padding |
---|---|---|---|---|---|---|
Conv1 | x*x | (x/stride)*(x/stride) | out_channel | 3 | stride | 1 |
Conv2 | x'*x' | x'*x' | out_channel | 3 | 1 | 1 |
identity |
当基础结构是 Conv3_x 、Conv4_x 、 Conv5_x 的第一层时, stride=2
且 identity 为下采样后的输入
1 | nn.Sequential( |
当基础结构是 Conv3_x 、Conv4_x 、 Conv5_x 的其他层或在 Conv2_x 层时, stride=1
且 identity 为输入本身
1 | nn.Sequential() |
模型实现
1 | class Block(nn.Module): |
1 | class ResNet(nn.Module): |
1 | def ResNet20() -> ResNet: |
模型训练
模型训练内容与 AlexNet_CIFAR10 项目相似,相同之处不再赘述
封装自定义 Python 库
此次实验中将 AlexNet_CIFAR10 项目中计算数据集均值和方差封装在 utils
文件夹下
需要在 utils
文件夹下生成空的 __init__.py
文件,声明 utils
文件夹为封装好的 Python 库