本文介绍 嵌入式深度学习三:TensorRT Python API加速部署推理速度

嵌入式深度学习三:TensorRT Python API加速部署推理速度

This article was original written by Jin Tian, welcome re-post, first come with https://jinfagang.github.io . but please keep this copyright info, thanks, any question could be asked via wechat: jintianiloveu

将TensorFlow或者Keras的模型冷冻

这是部署加速的第一步,需要frozen一下模型,而这个frozen其实只需要保存model即可,model里面其实包含了网络信息。以Keras为例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from keras.applications.vgg19 import VGG19
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model, load_model
from keras.layers import Dense, GlobalAveragePooling2D
from keras import backend as K
import keras
import tensorflow as tf
from tensorflow.python.framework import graph_io
from tensorflow.python.tools import freeze_graph
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.training import saver as saver_lib
# Now, let's use the Tensorflow backend to get the TF graphdef and frozen graph
K.set_learning_phase(0)
sess = K.get_session()
saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)
# save model weights in TF checkpoint
checkpoint_path = saver.save(
sess, config['snapshot_dir'], global_step=0, latest_filename='checkpoint_state')
# remove nodes not needed for inference from graph def
train_graph = sess.graph
inference_graph = tf.graph_util.remove_training_nodes(
train_graph.as_graph_def())
# write the graph definition to a file.
# You can view this file to see your network structure and
# to determine the names of your network's input/output layers.
graph_io.write_graph(inference_graph, '.', config['graphdef_file'])
# specify which layer is the output layer for your graph.
# In this case, we want to specify the softmax layer after our
# last dense (fully connected) layer.
out_names = config['out_layer']
# freeze your inference graph and save it for later! (Tensorflow)
freeze_graph.freeze_graph(
config['graphdef_file'],
'',
False,
checkpoint_path,
out_names,
"save/restore_all",
"save/Const:0",
config['frozen_model_file'],
False,
""
)

这里只是部分代码,主要的API是freeze_graph, 这个来自tools包的方法可以将一个checkpoint里面的保存的权重,冷冻为一个model文件,不过要制定胰腺癌graphdef的保存路径,也就是事先要保存一下图。

将冷冻的模型转换成UFF

这里uff是tensorrt能够解析的网络定义框架。这个过程步骤如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Load your newly created Tensorflow frozen model and convert it to UFF
uff_model = uff.from_tensorflow_frozen_model(
config['frozen_model_file'], OUTPUT_LAYERS)
print('so we got the uff model.')
# Create a UFF parser to parse the UFF file created from your TF Frozen model
parser = uffparser.create_uff_parser()
parser.register_input(INPUT_LAYERS[0], (INPUT_C, INPUT_H, INPUT_W), 0)
parser.register_output(OUTPUT_LAYERS[0])
print('# here we failed.')
# Build your TensorRT inference engine
if config['precision'] == 'fp32':
engine = trt.utils.uff_to_trt_engine(
G_LOGGER,
uff_model,
parser,
INFERENCE_BATCH_SIZE,
1 << 20,
trt.infer.DataType.FLOAT
)
elif config['precision'] == 'fp16':
engine = trt.utils.uff_to_trt_engine(
G_LOGGER,
uff_model,
parser,
INFERENCE_BATCH_SIZE,
1 << 20,
trt.infer.DataType.HALF
)
# Serialize TensorRT engine to a file for when you are ready to deploy your model.
save_path = str(config['engine_save_dir']) + "keras_vgg19_b" + \
str(INFERENCE_BATCH_SIZE) + "_" + str(config['precision']) + ".engine"
trt.utils.write_engine_to_file(save_path, engine.serialize())
print("Saved TRT engine to {}".format(save_path))

通过 uff.from_tensorflow_frozen_model 这个方法就直接拿到了uff模型,其实这个步骤就是一个映射,只要冷冻模型里面用到的层能够被uff支持,那么uff就能够转换。接着要把uff转化成tensorrt的引擎,这个就需要实现一个parser,这个parser也非常简单,只要注册一下网络的输入和输出即可。

使用TensorRT引擎进行推理

把模型转换成TensorRT引擎之后,推理就变得非常简单了,当然也没有那么简单,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from tensorrt.lite import Engine
from PIL import Image
import numpy as np
import os
import functools
import time
import matplotlib.pyplot as plt
PLAN_single = '/tmp/model/keras_vgg19_b1_fp32.engine' # engine filename for batch size 1
PLAN_half = '/tmp/model/keras_vgg19_b1_fp16.engine'
IMAGE_DIR = '/tmp/data/val/roses'
BATCH_SIZE = 1
def analyze(output_data):
LABELS=["daisy", "dandelion", "roses", "sunflowers", "tulips"]
output = output_data.reshape(-1, len(LABELS))
top_classes = [LABELS[idx] for idx in np.argmax(output, axis=1)]
top_classes_prob = np.amax(output, axis=1)
return top_classes, top_classes_prob
def image_to_np_CHW(image):
return np.asarray(
image.resize(
(224, 224),
Image.ANTIALIAS
)).transpose([2,0,1]).astype(np.float32)
def load_and_preprocess_images():
file_list = [f for f in os.listdir(IMAGE_DIR) if os.path.isfile(os.path.join(IMAGE_DIR, f))]
images_trt = []
for f in file_list:
images_trt.append(image_to_np_CHW(Image.open(os.path.join(IMAGE_DIR, f))))
images_trt = np.stack(images_trt)
num_batches = int(len(images_trt) / BATCH_SIZE)
images_trt = np.reshape(images_trt[0:num_batches * BATCH_SIZE], [
num_batches,
BATCH_SIZE,
images_trt.shape[1],
images_trt.shape[2],
images_trt.shape[3]
])
return images_trt
def timeit(func):
@functools.wraps(func)
def newfunc(*args, **kwargs):
startTime = time.time()
retargs = func(*args, **kwargs)
elapsedTime = time.time() - startTime
print('function [{}] finished in {} ms'.format(
func.__name__, int(elapsedTime * 1000)))
return retargs
return newfunc
def load_TRT_engine(plan):
engine = Engine(PLAN=plan, postprocessors={"dense_2/Softmax":analyze})
return engine
engine_single = load_TRT_engine(PLAN_single)
engine_half = load_TRT_engine(PLAN_half)
images_trt = load_and_preprocess_images()
@timeit
def infer_all_images_trt(engine):
results = []
for image in images_trt:
result = engine.infer(image)
results.append(result)
return results
results_trt_single = infer_all_images_trt(engine_single)
results_trt_half = infer_all_images_trt(engine_half)
for i in range(len(results_trt_single)):
plt.imshow(images_trt[i, 0, 0], cmap='gray')
plt.show()
print(results_trt_single[i][0][0][0])
print(results_trt_half[i][0][0][0])