-
Notifications
You must be signed in to change notification settings - Fork 11.6k
/
trt_sample.cpp
219 lines (203 loc) · 7.68 KB
/
trt_sample.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
#include <iostream>
#include <fstream>
#include <NvInfer.h>
#include <memory>
#include <NvOnnxParser.h>
#include <vector>
#include <cuda_runtime_api.h>
#include <opencv2/imgcodecs.hpp>
#include <opencv2/core/cuda.hpp>
#include <opencv2/cudawarping.hpp>
#include <opencv2/core.hpp>
#include <opencv2/cudaarithm.hpp>
#include <algorithm>
#include <numeric>
// utilities ----------------------------------------------------------------------------------------------------------
// class to log errors, warnings, and other information during the build and inference phases
class Logger : public nvinfer1::ILogger
{
public:
void log(Severity severity, const char* msg) override {
// remove this 'if' if you need more logged info
if ((severity == Severity::kERROR) || (severity == Severity::kINTERNAL_ERROR)) {
std::cout << msg << "\n";
}
}
} gLogger;
// destroy TensorRT objects if something goes wrong
struct TRTDestroy
{
template <class T>
void operator()(T* obj) const
{
if (obj)
{
obj->destroy();
}
}
};
template <class T>
using TRTUniquePtr = std::unique_ptr<T, TRTDestroy>;
// calculate size of tensor
size_t getSizeByDim(const nvinfer1::Dims& dims)
{
size_t size = 1;
for (size_t i = 0; i < dims.nbDims; ++i)
{
size *= dims.d[i];
}
return size;
}
// get classes names
std::vector<std::string> getClassNames(const std::string& imagenet_classes)
{
std::ifstream classes_file(imagenet_classes);
std::vector<std::string> classes;
if (!classes_file.good())
{
std::cerr << "ERROR: can't read file with classes names.\n";
return classes;
}
std::string class_name;
while (std::getline(classes_file, class_name))
{
classes.push_back(class_name);
}
return classes;
}
// preprocessing stage ------------------------------------------------------------------------------------------------
void preprocessImage(const std::string& image_path, float* gpu_input, const nvinfer1::Dims& dims)
{
// read input image
cv::Mat frame = cv::imread(image_path);
if (frame.empty())
{
std::cerr << "Input image " << image_path << " load failed\n";
return;
}
cv::cuda::GpuMat gpu_frame;
// upload image to GPU
gpu_frame.upload(frame);
auto input_width = dims.d[2];
auto input_height = dims.d[1];
auto channels = dims.d[0];
auto input_size = cv::Size(input_width, input_height);
// resize
cv::cuda::GpuMat resized;
cv::cuda::resize(gpu_frame, resized, input_size, 0, 0, cv::INTER_NEAREST);
// normalize
cv::cuda::GpuMat flt_image;
resized.convertTo(flt_image, CV_32FC3, 1.f / 255.f);
cv::cuda::subtract(flt_image, cv::Scalar(0.485f, 0.456f, 0.406f), flt_image, cv::noArray(), -1);
cv::cuda::divide(flt_image, cv::Scalar(0.229f, 0.224f, 0.225f), flt_image, 1, -1);
// to tensor
std::vector<cv::cuda::GpuMat> chw;
for (size_t i = 0; i < channels; ++i)
{
chw.emplace_back(cv::cuda::GpuMat(input_size, CV_32FC1, gpu_input + i * input_width * input_height));
}
cv::cuda::split(flt_image, chw);
}
// post-processing stage ----------------------------------------------------------------------------------------------
void postprocessResults(float *gpu_output, const nvinfer1::Dims &dims, int batch_size)
{
// get class names
auto classes = getClassNames("imagenet_classes.txt");
// copy results from GPU to CPU
std::vector<float> cpu_output(getSizeByDim(dims) * batch_size);
cudaMemcpy(cpu_output.data(), gpu_output, cpu_output.size() * sizeof(float), cudaMemcpyDeviceToHost);
// calculate softmax
std::transform(cpu_output.begin(), cpu_output.end(), cpu_output.begin(), [](float val) {return std::exp(val);});
auto sum = std::accumulate(cpu_output.begin(), cpu_output.end(), 0.0);
// find top classes predicted by the model
std::vector<int> indices(getSizeByDim(dims) * batch_size);
std::iota(indices.begin(), indices.end(), 0); // generate sequence 0, 1, 2, 3, ..., 999
std::sort(indices.begin(), indices.end(), [&cpu_output](int i1, int i2) {return cpu_output[i1] > cpu_output[i2];});
// print results
int i = 0;
while (cpu_output[indices[i]] / sum > 0.005)
{
if (classes.size() > indices[i])
{
std::cout << "class: " << classes[indices[i]] << " | ";
}
std::cout << "confidence: " << 100 * cpu_output[indices[i]] / sum << "% | index: " << indices[i] << "\n";
++i;
}
}
// initialize TensorRT engine and parse ONNX model --------------------------------------------------------------------
void parseOnnxModel(const std::string& model_path, TRTUniquePtr<nvinfer1::ICudaEngine>& engine,
TRTUniquePtr<nvinfer1::IExecutionContext>& context)
{
TRTUniquePtr<nvinfer1::IBuilder> builder{nvinfer1::createInferBuilder(gLogger)};
TRTUniquePtr<nvinfer1::INetworkDefinition> network{builder->createNetwork()};
TRTUniquePtr<nvonnxparser::IParser> parser{nvonnxparser::createParser(*network, gLogger)};
TRTUniquePtr<nvinfer1::IBuilderConfig> config{builder->createBuilderConfig()};
// parse ONNX
if (!parser->parseFromFile(model_path.c_str(), static_cast<int>(nvinfer1::ILogger::Severity::kINFO)))
{
std::cerr << "ERROR: could not parse the model.\n";
return;
}
// allow TensorRT to use up to 1GB of GPU memory for tactic selection.
config->setMaxWorkspaceSize(1ULL << 30);
// use FP16 mode if possible
if (builder->platformHasFastFp16())
{
config->setFlag(nvinfer1::BuilderFlag::kFP16);
}
// we have only one image in batch
builder->setMaxBatchSize(1);
// generate TensorRT engine optimized for the target platform
engine.reset(builder->buildEngineWithConfig(*network, *config));
context.reset(engine->createExecutionContext());
}
// main pipeline ------------------------------------------------------------------------------------------------------
int main(int argc, char* argv[])
{
if (argc < 3)
{
std::cerr << "usage: " << argv[0] << " model.onnx image.jpg\n";
return -1;
}
std::string model_path(argv[1]);
std::string image_path(argv[2]);
int batch_size = 1;
// initialize TensorRT engine and parse ONNX model
TRTUniquePtr<nvinfer1::ICudaEngine> engine{nullptr};
TRTUniquePtr<nvinfer1::IExecutionContext> context{nullptr};
parseOnnxModel(model_path, engine, context);
// get sizes of input and output and allocate memory required for input data and for output data
std::vector<nvinfer1::Dims> input_dims; // we expect only one input
std::vector<nvinfer1::Dims> output_dims; // and one output
std::vector<void*> buffers(engine->getNbBindings()); // buffers for input and output data
for (size_t i = 0; i < engine->getNbBindings(); ++i)
{
auto binding_size = getSizeByDim(engine->getBindingDimensions(i)) * batch_size * sizeof(float);
cudaMalloc(&buffers[i], binding_size);
if (engine->bindingIsInput(i))
{
input_dims.emplace_back(engine->getBindingDimensions(i));
}
else
{
output_dims.emplace_back(engine->getBindingDimensions(i));
}
}
if (input_dims.empty() || output_dims.empty())
{
std::cerr << "Expect at least one input and one output for network\n";
return -1;
}
// preprocess input data
preprocessImage(image_path, (float *) buffers[0], input_dims[0]);
// inference
context->enqueue(batch_size, buffers.data(), 0, nullptr);
// postprocess results
postprocessResults((float *) buffers[1], output_dims[0], batch_size);
for (void* buf : buffers)
{
cudaFree(buf);
}
return 0;
}