本文介绍 TensorFlow Object Detection API 1
TensorFlow Object Detection API
This article was original written by Jin Tian, welcome re-post, first come with https://jinfagang.github.io . but please keep this copyright info, thanks, any question could be asked via wechat:
jintianiloveu
这篇博客主要记录一下TensorFlow最新的Object Detection API的使用和思考。Google团队还是非常的强大,他们将众多的目标检测算法都包装成了一个现成的API,可以组合不同的网络和检测算法,从而实现一个非常可定制话的PipeLine,不管是训练自己的数据集还是公开的数据集,Object Detection的API都是非常不错的可用资源,这些实现有时候往往比自己去实现好简单的多,因为它有着两个好处:
- 它可以轻而易举的组合不同的算法,比如Faster-RCNN,SSD都可以通过一个简单的配置文件来组合;
- 它可以到处为冰冻图,从而实现开箱即用的预测。
好了,说了这么多,我们分两部来研究他吧。
1. 安装
首先这个东西最好还是安装一下。安装tensorflow gpu版本以及将slim加载到.zshrc里面去,其次把protobuf-compiler安装一下:
|
|
如果test显示success那么安装就成功了。其实我们进入下一步,来测试一下pets数据集的制作和训练。
2. 准备pets的数据集
官方教程使用的是google-cloud来训练,我们现在直接用本地图片来训练啦。其实也非常简单,我们先把数据下载一下,为了方便,我们将数据下载到object_detection/data
这个目录下:
|
|
现在你已经把数据下载到了object_detection/data/pets_data/
这个目录下面。接下来我们用 object_detection/data_utils/create_tfrecords_pets.py
这个脚本来生成pets这个数据集的tfrecords文件。./object_detection/data/
|
|
这个时候如果一些没有错,你会成功生成 pets_train.tfrecord
, 和 pets_val.tfrecord
两个文件:
|
|
其他的文件大家可以先不管,好啦,我们现在有自己的数据集啦。
3. 设置训练配置的PipeLine
这个时候只需要设置一下pipeline就可以了,在 object_detection/samples/config
下面有几个配置的模板,我们看看:
|
|
大家可以看到了,基本上没一个数据集都有一个不同的网络和算法的配置,甚至包括一个rfcn分割的模型。这些配置的不同之处只不过是不同的数据集的类别和文件路径不同其他都是一样的,那么这个时候我们就可以看到了,我们直接拷贝一个放到 object_detection/data
下面去,其实要修改的部分主要有:
- 设置一下预训练权重的保存路径;
- 设置一下input的tfrecord的路径;
- 设置一下val的tfrecord路径;
其实他们都是 object_detection/data
了,因为他们都在里面。好了,这个时候,我们要进一步的开始训练.
4. 开始训练
直接新建一个 train.sh 的启动文件吧,内容:
|
|
可以看到,我们只要制定一下train_dir 以及piple_line的配置路径,就可以训练不同的网络模型和方法,简直他妈的简单到想哭!!
5. 导出图和导入图做预测
最后一部了,导出图,并且导入图来做inference。这里我就不说了,关键部分。我们不得不说,我好像遇到了一个nms的问题,不知道你怎么看。我直接把推理部分的代码贴上来吧:
|
|
为什么会有300多个box?我日了狗狗了,是不是没有加nms。我日。