SSD: Single Shot MultiBox Detector 训练KITTI数据集(2)

前言

博主在上篇中花了很大篇幅讲解如何一步步把KITTI原始数据做成了SSD可以训练的格式,接下来就可以使用相关caffe代码实现SSD的训练了。

下载VGG预训练模型

将 SSD 用于自己的检测任务,是需要 Fine-tuning a pretrained network,看过论文的朋友可能都知道,论文中的SSD框架是是由VGG网络为基底(base)的。除此之外,作者也提供了另外两种结构的网络:ZF-SSD和Resnet-SSD,可以在caffe/examples/ssd 文件夹中查看相应的python训练代码。

初次训练,还是用VGG网络吧,下载该预训练模型,将其放到/home/mx/caffe/models/VGGNet文件夹之下。

修改训练代码

这里先说下电脑硬件配置,这决定了一些训练参数的设定。博主使用的主机CPU为Intel i7 6700,GPU为TITAN X,运行内存16GB,搭载Ubuntu16.04系统。

一般的caffe训练都是使用train.prototxt和solver.prototxt文件,有同学可能想问了,为什么SSD项目下找不到这些文件呢?原因是SSD的模型很大,train.prototxt就有1000多行,直接修改参数的工作量太大,而且train.prototxt一旦改动,test.prototxt和solver.prototxt也要跟着改动。因此,作者使用了一个很有效的方法,利用python脚本,自动生成这些文件。初次训练,博主选择了ssd_pascal.py 脚本来训练SSD_300x300,那么将ssd_pascal.py复制一份,重命名为ssd_pascal_kitti.py,然后修改这个ython文件。

ssd_pascal.py脚本也有500多行,博主也不可能全部贴出来,这里就贴出可以修改的部分,作为一个参照。

PS.根据博友反馈,还是提供修改过的的训练脚本 ssd_pascal_kitti.py,以供参考。

自定义路径和常用参数

train_data = "examples/VOC0712/VOC0712_trainval_lmdb" # 训练数据路径,修改前
train_data = "examples/KITTI/KITTI_trainval_lmdb" # 修改后
------
test_data = "examples/VOC0712/VOC0712_test_lmdb" # 测试数据路径
test_data = "examples/KITTI/KITTI_test_lmdb"
------
model_name = "VGG_VOC0712_{}".format(job_name) # 模型名字
model_name = "KITTI_{}".format(job_name)
------
save_dir = "models/VGGNet/VOC0712/{}".format(job_name) # 模型保存路径
save_dir = "models/VGGNet/KITTI/{}".format(job_name)
------
snapshot_dir = "models/VGGNet/VOC0712/{}".format(job_name) # snapshot快照保存路径
snapshot_dir = "models/VGGNet/KITTI/{}".format(job_name)
------
job_dir = "jobs/VGGNet/VOC0712/{}".format(job_name) # job保存路径
job_dir = "jobs/VGGNet/KITTI/{}".format(job_name)
------
output_result_dir = "{}/data/VOCdevkit/results/VOC2007/{}/Main".format(os.environ['HOME'], job_name) # 测试结果txt保存路径
output_result_dir = "{}/data/KITTIdevkit/results/KITTI/{}/Main".format(os.environ['HOME'], job_name)
-------
name_size_file = "data/VOC0712/test_name_size.txt" # test_name_size.txt文件路径
name_size_file = "data/KITTI/test_name_size.txt"
-------
label_map_file = "data/VOC0712/labelmap_voc.prototxt" # label文件路径
label_map_file = "data/KITTI/labelmap_kitti.prototxt"
------
num_classes = 21 # 总类别数
num_classes = 4
------
gpus = "0,1,2,3" # 使用哪块GPU
gpus = "0" # 只有1块GPU
------
batch_size = 32 # 一次处理的图片数
batch_size = 自定义 # TITAN X刚好可以达到32,请根据显存大小调整 
------
num_test_image = 4952 # 测试图片数量
num_test_image = 自定义 # 这个数量应该和test_name_size.txt保持一致
------
run_soon = True # 生成文件后自动开始训练
run_soon = False # 手动挡
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41

自定义训练参数

SSD的训练参数比较多,具体含义之前也了解过。但是初次试验,不太清楚该怎么调,所以决定基本先不改动其他参数(等训练一次后根据结果再调整,反复训练),只对初始学习率做调整。事实证明,对于KITTI数据集,初始学习率0.001过大,会导致网络不收敛,此处应调整为0.0001,具体如下:

# If true, use batch norm for all newly added layers.
# Currently only the non batch norm version has been tested.
use_batchnorm = False
lr_mult = 1
# Use different initial learning rate.
if use_batchnorm:
    base_lr = 0.0004
else:
    # A learning rate for batch_size = 1, num_gpus = 1.
    base_lr = 0.00004
......
if normalization_mode == P.Loss.NONE:
  base_lr /= batch_size_per_device
elif normalization_mode == P.Loss.VALID:
  base_lr *= 25. / loc_weight
elif normalization_mode == P.Loss.FULL:
  # Roughly there are 2000 prior bboxes per image.
  # TODO(weiliu89): Estimate the exact # of priors.
  base_lr *= 2000.
# 从以上代码可以看出,训练脚本没有使用batchnorm,未修改前,初始学习率 = base_lr * 25 / loc_weight=0.001
# 将base_lr变量改为原来十分之一,也就是0.00004->0.000004,就能把学习率调整为0.0001
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

PS.学习率0.0001确实偏小,可以适当调大一些,我自己的试验表明,0.0005~0.0007都是可选的参数,高于0.0008就会发散。

训练模型

修改完脚本参数后,运行该脚本程序。

$ cd caffe/
$ python examples/ssd/ssd_pascal_kitti.py
  • 1
  • 2

然后模型就训练起来了,可能会遇到“out of memory”的问题,那是batch_size设的太高,导致显存不足,此时调低batch_size再重新运行即可解决。

本次训练迭代120000次,博主估算训练时间大约50多小时,由于该主机还有别的任务,也只能断断续续进行训练(直接导致训练日志不完整,没法画accuracy和loss曲线)。感觉直接套用人家的训练参数,loss收敛比较慢。

这里写图片描述

终于训练完成,截图留念。

这里写图片描述

测试模型

从训练结束的截图来看,本次训练结果是不太理想的,经过120000次迭代后,loss还在2左右,收敛性比较差;测试准确率也仅有57.3%,和论文中VOC训练结果差得蛮远。虽然模型有点渣,但还是拿出来用用,看到底是个什么程度。

这里使用ssd_detect.py来检测单张图片,部分检测效果如下:

这里写图片描述
这里写图片描述
这里写图片描述

由图可知,汽车和行人训练的比较充分,检测结果还差强人意;自行车不知道是标注样本太少还是训练的不好,基本很难检测出来。

总结和思考

本次训练,过程是跑通了,但是结果并不是很理想,总结下可能的原因和改进方向:

  • 训练参数上,直接照搬肯定不行,后面需要有选择性的调参;
  • VOC图片大小多为500x375,而KITTI图片大小为1242x375,不仅分辨率上去了,长宽比也达到了3:1,仍然使用SSD_300x300可能不太合适,之后将试着训练SSD_512x512或者其他大小的模型,准确率应该能上升,但是运行时间将变长;
  • 换底,看看ZF-SSD和Resnet-SSD的表现如何,一方面是准确率,另一方面是运行时间;
  • 为了更高的准确率,感觉训练图片仍显不够(尤其在自行车这一类上),可以试着补充类似的数据集。

更新:针对精度低的问题,博主阅读一些论文,发现数据集失衡应该是主要原因,确切来说,我制作的KITTI中car占比太大,而其他类别太少,mAP被拉低了。后来从VOC数据集中挑出上千个只含有person或bycicle的样本补充进去训练,形成KITTI plus数据集,目前的mAP已经提高到了78.5%,达到了可用的水平。

版权声明:本文为zouyu1746430162原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/zouyu1746430162/article/details/79607035

智能推荐

Python学习练习6----列表、字典的运用2

range 用法参见http://blog.csdn.net/chiclewu/article/details/50592368 直接在 在线编程工具中练习: https://www.tutorialspoint.com/execute_python_online.php 代码如下,增加range、列表的len()、字典的items()函数,for 函数也有了新变化 练习2: 2的运行结果,注意p...

PoolThreadCache

缓存构成   PoolThreadCache的缓存由三部分构成:tiny、small 和 normal。 tiny   缓存数据大小区间为[16B, 496B]数据,数组长度为32,根据数据大小计算索引的办法:数据大小除以16,如下代码所示: small   缓存数据大小区间为[512B, 4KB]数据,数组长度为4,根据数据大小计算索引的办法:数据大小除以512,然后log2得到指数,如下代码所...

Intellij IDEA 搭建Spring Boot项目(一)

Intellij IDEA 搭建Spring Boot项目 标签(空格分隔): SpringBoot JAVA后台 第一步 选择File –> New –> Project –>Spring Initialer –> 点击Next  第二步 自己修改 Group 和 Artif...

CentOS学习之路1-wget下载安装配置

参考1: https://blog.csdn.net/zhaoyanjun6/article/details/79108129 参考2: http://www.souvc.com/?p=1569 CentOS学习之路1-wget下载安装配置 1.wget的安装与基本使用 安装wget yum 安装软件 默认安装保存在/var/cache/yum ,用于所有用户使用。 帮助命令 基本用法 例子:下载...

深入浅出Spring的IOC容器,对Spring的IOC容器源码进行深入理解

文章目录 DispatcherServlet整体继承图 入口:DispatcherServlet.init() HttpServletBean.init() FrameworkServlet.initServletBean() 首先大家,去看Spring的源码入口,第一个就是DispatcherServlet DispatcherServlet整体继承图 入口:DispatcherServlet....

猜你喜欢

laravel框架的课堂知识点概总

1. MVC 1.1 概念理解 MVC全名是Model View Controller,是模型(model)-视图(view)-控制器(controller)的缩写,一种软件设计典范,用一种业务逻辑、数据、界面显示分离的方法组织代码,将业务逻辑聚集到一个部件里面,在改进和个性化定制界面及用户交互的同时,不需要重新编写业务逻辑 MVC 是一种使用 MVC(Model View Controller ...

Unity人物角色动画系统学习总结

使用动画系统控制人物行走、转向、翻墙、滑行、拾取木头 混合树用来混合多个动画 MatchTarget用来匹配翻墙贴合墙上的某一点,人物以此为支点翻墙跳跃 IK动画类似于MatchTarget,控制两只手上的两个点来指定手的旋转和位置,使得拾取木头时更逼真 创建AnimatorController: 首先创建一个混合树,然后双击 可以看到该混合树有五种状态机,分别是Idle、WalkForward、...

Composer 安装 ThinkPHP6 问题

Composer 安装 ThinkPHP6 问题 先说说问题 一.运行环境要求 二.配置 参考: ThinkPHP6.0完全开发手册 先说说问题 执行ThinkPHP6的安装命令 遇到问题汇总如下: 看提示是要更新版本,执行命令更新。 更新之后,再次安装ThinkPHP,之后遇到如下问题。 尝试了很多方法,依然不能解决。其中包括使用https://packagist.phpcomposer.com...

Spring Boot 整合JDBC

今天主要讲解一下SpringBoot如何整合JDBC,没啥理论好说的,直接上代码,看项目整体结构 看一下对应的pom.xml 定义User.java 定义数据源配置,这里使用druid,所以需要写一个配置类 上面指定druid的属性配置,和用户登录的账号信息以及对应的过滤规则: 下面定义数据访问接口和对应的实现: 数据访问层很简单,直接注入JdbcTemplate模板即可,下面再看对应的servi...

html鼠标悬停显示样式

1.显示小手:     在style中添加cursor:pointer 实现鼠标悬停变成小手样式     实例:         其他参数: cursor语法: cursor : auto | crosshair | default | hand | move | help | wait | tex...