Commit 9ac67037 authored by xin.wang.waytous's avatar xin.wang.waytous

head file change path

parent 76fa60c0
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
修改CMakeLists.txt中cuda、tensorrt、yaml、gflag和glog等库的路径 修改CMakeLists.txt中cuda、tensorrt、yaml、gflag和glog等库的路径
`mkdir build` `mkdir build && cd build`
`cmake ..` `cmake ..`
......
...@@ -345,6 +345,10 @@ bool YoloV5TRTInference::BuildEngine(YAML::Node& configNode){ ...@@ -345,6 +345,10 @@ bool YoloV5TRTInference::BuildEngine(YAML::Node& configNode){
nms->getOutput(3)->setName(outputNames[3].c_str()); nms->getOutput(3)->setName(outputNames[3].c_str());
network->markOutput(*nms->getOutput(3)); network->markOutput(*nms->getOutput(3));
// get box-prob
// yolo->getOutput(2)->setName(outputNames[4].c_str());
// network->markOutput(*yolo->getOutput(2));
// Build engine // Build engine
builder->setMaxBatchSize(maxBatchSize); builder->setMaxBatchSize(maxBatchSize);
config->setMaxWorkspaceSize((maxBatchSize * (1UL << 30))); // 2GB config->setMaxWorkspaceSize((maxBatchSize * (1UL << 30))); // 2GB
......
...@@ -105,6 +105,9 @@ namespace nvinfer1 ...@@ -105,6 +105,9 @@ namespace nvinfer1
return Dims3(mMaxOutObject, 1, 4); return Dims3(mMaxOutObject, 1, 4);
} }
return DimsHW(mMaxOutObject, mClassCount); return DimsHW(mMaxOutObject, mClassCount);
// else{
// return DimsHW(mMaxOutObject, 1);
// }
} }
// Set plugin namespace // Set plugin namespace
...@@ -193,6 +196,7 @@ namespace nvinfer1 ...@@ -193,6 +196,7 @@ namespace nvinfer1
int count = (int)atomicAdd(res_count, 1); int count = (int)atomicAdd(res_count, 1);
if (count >= maxoutobject) return; if (count >= maxoutobject) return;
// probData[bnIdx * maxoutobject + count] = box_prob;
float *curBbox = bboxData + bnIdx * maxoutobject * 4 + count * 4; float *curBbox = bboxData + bnIdx * maxoutobject * 4 + count * 4;
float *curScore = scoreData + bnIdx * maxoutobject * classes + count * classes; float *curScore = scoreData + bnIdx * maxoutobject * classes + count * classes;
...@@ -233,12 +237,14 @@ namespace nvinfer1 ...@@ -233,12 +237,14 @@ namespace nvinfer1
{ {
float *bboxData = (float *)outputs[0]; float *bboxData = (float *)outputs[0];
float *scoreData = (float *)outputs[1]; float *scoreData = (float *)outputs[1];
// float *probData = (float*)outputs[2];
int *countData = (int *)workspace; int *countData = (int *)workspace;
CUDA_CHECK(cudaMemset(countData, 0, sizeof(int) * batchSize)); CUDA_CHECK(cudaMemset(countData, 0, sizeof(int) * batchSize));
CUDA_CHECK(cudaMemset(bboxData, 0, sizeof(float) * mMaxOutObject * 4 * batchSize)); CUDA_CHECK(cudaMemset(bboxData, 0, sizeof(float) * mMaxOutObject * 4 * batchSize));
CUDA_CHECK(cudaMemset(scoreData, 0, sizeof(float) * mMaxOutObject * mClassCount * batchSize)); CUDA_CHECK(cudaMemset(scoreData, 0, sizeof(float) * mMaxOutObject * mClassCount * batchSize));
// CUDA_CHECK(cudaMemset(probData, 0, sizeof(float) * mMaxOutObject * 1 * batchSize));
int numElem = 0; int numElem = 0;
for (unsigned int i = 0; i < mYoloKernel.size(); ++i){ for (unsigned int i = 0; i < mYoloKernel.size(); ++i){
......
...@@ -22,6 +22,7 @@ namespace nvinfer1 ...@@ -22,6 +22,7 @@ namespace nvinfer1
int getNbOutputs() const TRT_NOEXCEPT override int getNbOutputs() const TRT_NOEXCEPT override
{ {
return 2; return 2;
// return 3; // get res of box-prob and class-prob
} }
Dims getOutputDimensions(int index, const Dims *inputs, int nbInputDims) TRT_NOEXCEPT override; Dims getOutputDimensions(int index, const Dims *inputs, int nbInputDims) TRT_NOEXCEPT override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment