PyTorch 搭建 DenseNet 模型对 CIFAR10 数据集分类总结
DenseNet 模型搭建
结构分析
通用框架
根据论文中的信息,可以得到常规 DenseNet 模型(忽略 batch_size )的通用框架如下
layer_name | out_size | kernel_size | stride | padding |
---|---|---|---|---|
Input | 224 * 224 | None | None | None |
Conv | 112 * 112 | 7 | 2 | 3 |
Maxpool | 56 * 56 | 3 | 2 | 1 |
Dense Layer_1 | 56 * 56 | - | - | - |
Transition Layer_1 | 28 * 28 | - | - | - |
Dense Layer_2 | 28 * 28 | - | - | - |
Transition Layer_2 | 14 * 14 | - | - | - |
Dense Layer_3 | 14 * 14 | - | - | - |
Transition Layer_3 | 7 * 7 | - | - | - |
Dense Layer_4 | 7 * 7 | - | - | - |
Global Avgpool | 1 * 1 | 7 | 0 | 0 |
FC | 1000 | None | None | None |
其中 Dense Layer 由多个 Dense Block 组成
layer_name | ResNet18 | ResNet34 | ResNet50 |
---|---|---|---|
Dense Layer_1 | Dense Block * 6 | Dense Block * 6 | Dense Block * 6 |
Dense Layer_2 | Dense Block * 12 | Dense Block * 12 | Dense Block * 12 |
Dense Layer_3 | Dense Block * 24 | Dense Block * 32 | Dense Block * 48 |
Dense Layer_4 | Dense Block * 16 | Dense Block * 32 | Dense Block * 32 |
基础结构
DenseNet 使用的 Dense Block 结构如下
layer_name | in_size | out_size | out_channel | kernel_size | stride | padding |
---|---|---|---|---|---|---|
Conv1 | x * x | x * x | 4 * growth_rate | 1 | 1 | 0 |
Conv2 | x * x | x * x | growth_rate | 3 | 1 | 1 |
Concatence | None | None | in_channel + growth_rate | None | None | None |
DenseNet 使用的 Dense Block 结构如下
layer_name | in_size | out_size | out_channel | kernel_size | stride | padding |
---|---|---|---|---|---|---|
Conv | x * x | x * x | in_channel // 2 | 1 | 1 | 0 |
Avgpool | x * x | (x/2) * (x/2) | in_channel // 2 | 2 | 2 | 0 |
Conv
DenseNet 的 Conv 层和 ResNet 的 Conv 层不同
ResNet 的 Conv 层实际上是 Conv - BatchNorm - ReLU
而 DenseNet 中除了第一个 Conv 层以外的其他 Conv 层实际上是 ReLU - BatchNorm - Conv
网络实现
Dense Block
1 | class _DenseLayer(nn.Module): |
Transition
1 | class _Transition(nn.Module): |
通用框架
1 | class DenseNet(nn.Module): |
构造网络
1 | def DenseNet121() -> DenseNet: |
参考网页
PyTorch - SOURCE CODE FOR TORCHVISION.MODELS.DENSENET
Mayurji - Image-Classification-PyTorch/DenseNet.py
wmpscc - CNN-Series-Getting-Started-and-PyTorch-Implementation/DenseNet/DenseNet-Torch.py
pytorch - vision/torchvision/models/densenet.py
模型训练
模型训练内容与 AlexNet_CIFAR10 项目相似,相同之处不再赘述
总结
DenseNet 和 ResNet 很像, ResNet 是使用了 short cut,而 DenseNet 可以理解为将所有输出都进行 short cut 连接到了其后面的所有输出