本文介绍 TensorFlow Object Detection API 1

TensorFlow Object Detection API

This article was original written by Jin Tian, welcome re-post, first come with https://jinfagang.github.io . but please keep this copyright info, thanks, any question could be asked via wechat: jintianiloveu

这篇博客主要记录一下TensorFlow最新的Object Detection API的使用和思考。Google团队还是非常的强大,他们将众多的目标检测算法都包装成了一个现成的API,可以组合不同的网络和检测算法,从而实现一个非常可定制话的PipeLine,不管是训练自己的数据集还是公开的数据集,Object Detection的API都是非常不错的可用资源,这些实现有时候往往比自己去实现好简单的多,因为它有着两个好处:

  • 它可以轻而易举的组合不同的算法,比如Faster-RCNN,SSD都可以通过一个简单的配置文件来组合;
  • 它可以到处为冰冻图,从而实现开箱即用的预测。

好了,说了这么多,我们分两部来研究他吧。

1. 安装

首先这个东西最好还是安装一下。安装tensorflow gpu版本以及将slim加载到.zshrc里面去,其次把protobuf-compiler安装一下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
sudo apt install protobuf-compiler
# all the following are run under research dir
cd /path/to/models/research
# generate protos
protoc object_detection/protos/*.proto --python_out=.
# echo pwd to PYTHONPATH
echo 'export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim' >> ~/.zshrc
source ~/.zshrc
# test installation
python object_detection/builders/model_builder_test.py

如果test显示success那么安装就成功了。其实我们进入下一步,来测试一下pets数据集的制作和训练。

2. 准备pets的数据集

官方教程使用的是google-cloud来训练,我们现在直接用本地图片来训练啦。其实也非常简单,我们先把数据下载一下,为了方便,我们将数据下载到object_detection/data这个目录下:

1
2
3
4
5
6
7
8
9
# you still under research
cd object_detection/data
mkdir pets_data
cd pets_data
wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz
tar -xvf images.tar.gz
tar -xvf annotations.tar.gz
pwd

现在你已经把数据下载到了object_detection/data/pets_data/ 这个目录下面。接下来我们用 object_detection/data_utils/create_tfrecords_pets.py 这个脚本来生成pets这个数据集的tfrecords文件。./object_detection/data/

1
2
3
# back to research dir
cd /path/to/research
python object_detection/dataset_tools/create_pet_tf_record.py --label_map_path ./object_detection/data/pet_label_map.pbtxt --data_dir ./object_detection/data/pets_data --output_dir ./object_detection/data/

这个时候如果一些没有错,你会成功生成 pets_train.tfrecord, 和 pets_val.tfrecord两个文件:

1
2
3
4
5
6
7
8
9
10
11
12
.
├── faster_rcnn_resnet101_pets.config
├── kitti_label_map.pbtxt
├── model.ckpt.data-00000-of-00001
├── model.ckpt.index
├── model.ckpt.meta
├── mscoco_label_map.pbtxt
├── oid_bbox_trainable_label_map.pbtxt
├── pascal_label_map.pbtxt
├── pet_label_map.pbtxt
├── pet_train.record
└── pet_val.record

其他的文件大家可以先不管,好啦,我们现在有自己的数据集啦。

3. 设置训练配置的PipeLine

这个时候只需要设置一下pipeline就可以了,在 object_detection/samples/config下面有几个配置的模板,我们看看:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
faster_rcnn_inception_resnet_v2_atrous_coco.config
faster_rcnn_inception_resnet_v2_atrous_pets.config
faster_rcnn_inception_v2_coco.config
faster_rcnn_inception_v2_pets.config
faster_rcnn_nas_coco.config
faster_rcnn_resnet101_coco.config
faster_rcnn_resnet101_kitti.config
faster_rcnn_resnet101_pets.config
faster_rcnn_resnet101_voc07.config
faster_rcnn_resnet152_coco.config
faster_rcnn_resnet152_pets.config
faster_rcnn_resnet50_coco.config
faster_rcnn_resnet50_pets.config
rfcn_resnet101_coco.config
rfcn_resnet101_pets.config
ssd_inception_v2_coco.config
ssd_inception_v2_pets.config
ssd_mobilenet_v1_coco.config
ssd_mobilenet_v1_pets.config

大家可以看到了,基本上没一个数据集都有一个不同的网络和算法的配置,甚至包括一个rfcn分割的模型。这些配置的不同之处只不过是不同的数据集的类别和文件路径不同其他都是一样的,那么这个时候我们就可以看到了,我们直接拷贝一个放到 object_detection/data下面去,其实要修改的部分主要有:

  • 设置一下预训练权重的保存路径;
  • 设置一下input的tfrecord的路径;
  • 设置一下val的tfrecord路径;

其实他们都是 object_detection/data 了,因为他们都在里面。好了,这个时候,我们要进一步的开始训练.

4. 开始训练

直接新建一个 train.sh 的启动文件吧,内容:

1
2
3
python train.py \
--train_dir=train_dir \
--pipeline_config_path=data/faster_rcnn_resnet101_pets.config

可以看到,我们只要制定一下train_dir 以及piple_line的配置路径,就可以训练不同的网络模型和方法,简直他妈的简单到想哭!!

5. 导出图和导入图做预测

最后一部了,导出图,并且导入图来做inference。这里我就不说了,关键部分。我们不得不说,我好像遇到了一个nms的问题,不知道你怎么看。我直接把推理部分的代码贴上来吧:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
this is the demo file, to predict on the model that exported using
TensorFlow object detection API.
We only need 3 things:
- the exported graph.pb file;
- the label_map.pbtxt file;
- a test image
the graph.pb file are default in `exported_graphs/frozen_inference_graph.pb`
"""
import numpy as np
import os
import tensorflow as tf
from PIL import Image
import argparse
import time
if tf.__version__ != '1.4.0':
raise ImportError('Please upgrade your TensorFlow version to v1.4.0!')
def parse_args():
arg_parser = argparse.ArgumentParser('demo for object detection API.')
arg_parser.add_argument('--pb', default='exported_graphs/frozen_inference_graph.pb', help='the pb file.')
arg_parser.add_argument('--label_map', default='data/pet_label_map.pbtxt', help='the label_map file')
arg_parser.add_argument('--image', help='image dir or image file.')
return arg_parser.parse_args()
class Detector(object):
def __init__(self, pb_file):
self.pb_file = pb_file
self._get_model()
self._init_sess()
self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
self.detection_boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
self.detection_scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
self.detection_classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
self.num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')
def _get_model(self):
self.detection_graph = tf.Graph()
with self.detection_graph.as_default():
self.od_graph_def = tf.GraphDef()
with tf.gfile.GFile(self.pb_file, 'rb') as fid:
serialized_graph = fid.read()
self.od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(self.od_graph_def, name='')
def _init_sess(self):
with self.detection_graph.as_default():
with tf.Session(graph=self.detection_graph) as sess:
self.sess = sess
print('# Session and model loaded!')
def detect_on_img(self, img_array):
assert isinstance(img_array, np.ndarray), 'img_array must be a numpy array object'
img_np_expanded = np.expand_dims(img_array, axis=0)
return self.sess.run(
[self.detection_boxes, self.detection_scores,
self.detection_classes, self.num_detections],
feed_dict={self.image_tensor: img_np_expanded}
)
if __name__ == '__main__':
args = parse_args()
pb_file_ = args.pb
label_map_file = args.label_map
images = args.image
if os.path.exists(pb_file_) and os.path.exists(label_map_file):
detector = Detector(pb_file=pb_file_)
if os.path.isfile(images):
img = Image.open(images)
img = np.array(img)
tic = time.clock()
boxes, scores, classes, num = detector.detect_on_img(img)
print('boxes: ', boxes)
print('scores: ', scores)
print('classes: ', classes)
print('num: ', num)
print('done in {} seconds!'.format(time.clock() - tic))
else:
pass
else:
print('pb file or label_map file not exist.')

为什么会有300多个box?我日了狗狗了,是不是没有加nms。我日。