PyTorch 搭建 ResNet 模型对 CIFAR10 数据集分类总结
ResNet 模型搭建
结构分析
通用框架
根据 Deep Residual Learning for Image Recognition 论文中的信息,可以得到常规 ResNet 模型的通用框架如下
| layer_name | out_size | (18/34 layers) out_channel | (50/101/152 layers) out_channel | kernel_size | stride | padding |
|---|---|---|---|---|---|---|
| Input | 224*224 | 3 | 3 | None | None | None |
| Conv1 | 112*112 | 64 | 64 | 7 | 2 | 3 |
| Maxpool | 56*56 | 64 | 64 | 3 | 2 | 1 |
| Conv2_x | 56*56 | 64 | 64*4=256 | - | - | - |
| Conv3_x | 28*28 | 128 | 128*4=512 | - | - | - |
| Conv4_x | 14*14 | 256 | 256*4=1024 | - | - | - |
| Conv5_x | 7*7 | 512 | 512*4=2048 | - | - | - |
| Avgpool | 1*1 | 512 | 2048 | None | None | None |
| Flatten | 2048 | 1 | 1 | None | None | None |
| FC | 1000 | 1 | 1 | None | None | None |
其中 Conv2_x 、 Conv3_x 、Conv4_x 、 Conv5_x 层可由 BasicBlock 和 Bottleneck 两种基本模型组合而成
| layer_name | ResNet18 | ResNet34 | ResNet50 | ResNet101 | ResNet152 |
|---|---|---|---|---|---|
| Conv2_x | BasicBlock*2 | BasicBlock*3 | Bottleneck*3 | Bottleneck*3 | Bottleneck*3 |
| Conv3_x | BasicBlock*2 | BasicBlock*4 | Bottleneck*4 | Bottleneck*4 | Bottleneck*8 |
| Conv4_x | BasicBlock*2 | BasicBlock*6 | Bottleneck*6 | Bottleneck*23 | Bottleneck*36 |
| Conv5_x | BasicBlock*2 | BasicBlock*3 | Bottleneck*3 | Bottleneck*3 | Bottleneck*3 |
基础结构
ResNet 18/34 使用的 BasicBlock 结构如下
| layer_name | in_size | out_size | out_channel | kernel_size | stride | padding |
|---|---|---|---|---|---|---|
| Conv1 | x*x | (x/stride)*(x/stride) | out_channel | 3 | stride | 1 |
| Conv2 | x’*x’ | x’*x’ | out_channel | 3 | 1 | 1 |
| identity |
ResNet 50/101/152 使用的 Bottleneck 结构如下
| layer_name | in_size | out_size | out_channel | kernel_size | stride | padding |
|---|---|---|---|---|---|---|
| Conv1 | x*x | (x/stride)*(x/stride) | out_channel | 1 | 1 | 0 |
| Conv2 | x’*x’ | x’*x’ | out_channel | 3 | stride | 1 |
| Conv3 | x’*x’ | x’*x’ | out_channel | 1 | 1 | 0 |
| identity |
stride 和 identity
当基础结构是 Conv3_x 、Conv4_x 、 Conv5_x 的第一层时, stride=2 且 identity 为下采样后的输入
1 | nn.Sequential( |
当基础结构是 Conv3_x 、Conv4_x 、 Conv5_x 的其他层或在 Conv2_x 层时, stride=1 且 identity 为输入本身
1 | nn.Sequential() |
网络实现
BasicBlock
1 | class BasicBlock(nn.Module): |
Bottleneck
1 | class Bottleneck(nn.Module): |
通用框架
1 | class ResNet(nn.Module): |
构造网络
1 | def ResNet18() -> ResNet: |
参考网页
PyTorch - SOURCE CODE FOR TORCHVISION.MODELS.RESNET
明素 - ResNet详解
回顾
nn.Conv2d() 函数 bias 参数的设置
当 nn.Conv2d() 后接 nn.BatchNorm2d() 时,可以把 bias 参数设置为 False
因为在 BN 层中,输入是否存在偏置不影响输出结果
不添加偏置还可以减少显卡内存的占用
参考网页
7s记忆的鱼 - 【pytorch】Conv2d()里面的参数bias什么时候加,什么时候不加?
nn.AdaptiveAvgPool2d 函数
参考网页
*参数 的作用
*参数 可以解压参数
1 | a = (0,1,2,3,4,5,6,7,8,9) |
将 List 和 Tuple 中的元素逐一解压出来
1 | 0 1 2 3 4 5 6 7 8 9 |
参考网页
TEDxPY - Python *args 用法笔记
pip install 默认安装在 base 环境
使用 pip install 时改用如下指令即可安装到当前虚拟环境中
1 | python -m pip install ** |
参考网页
timertimer - 在conda虚拟环境中用pip安装包总是在base环境中的解决办法
CIFAR10 特化模型
| layer_name | out_size | out_channel | kernel_size | stride | padding | |
|---|---|---|---|---|---|---|
| Input | 32*32 | 3 | None | None | None | |
| Conv1 | 32*32 | 16 | 1 | 1 | 31 | |
| Conv2_x | 32*32 | 16 | 64*4=256 | - | - | - |
| Conv3_x | 16*16 | 32 | 128*4=512 | - | - | - |
| Conv4_x | 8*8 | 64 | 256*4=1024 | - | - | - |
| Avgpool | 1*1 | 64 | None | None | None | |
| Flatten | 64 | 1 | None | None | None | |
| FC | 10 | 1 | None | None | None |
其中 Conv2_x 、 Conv3_x 、Conv4_x 层由 Block 组成
| layer_name | ResNet_CIFAR10 |
|---|---|
| Conv2_x | Block*n |
| Conv3_x | Block*n |
| Conv4_x | Block*n |
Block 结构如下
| layer_name | in_size | out_size | out_channel | kernel_size | stride | padding |
|---|---|---|---|---|---|---|
| Conv1 | x*x | (x/stride)*(x/stride) | out_channel | 3 | stride | 1 |
| Conv2 | x’*x’ | x’*x’ | out_channel | 3 | 1 | 1 |
| identity |
当基础结构是 Conv3_x 、Conv4_x 、 Conv5_x 的第一层时, stride=2 且 identity 为下采样后的输入
1 | nn.Sequential( |
当基础结构是 Conv3_x 、Conv4_x 、 Conv5_x 的其他层或在 Conv2_x 层时, stride=1 且 identity 为输入本身
1 | nn.Sequential() |
模型实现
1 | class Block(nn.Module): |
1 | class ResNet(nn.Module): |
1 | def ResNet20() -> ResNet: |
模型训练
模型训练内容与 AlexNet_CIFAR10 项目相似,相同之处不再赘述
封装自定义 Python 库
此次实验中将 AlexNet_CIFAR10 项目中计算数据集均值和方差封装在 utils 文件夹下
需要在 utils 文件夹下生成空的 __init__.py 文件,声明 utils 文件夹为封装好的 Python 库