#ifndef WAYTOUS_DEEPINFER_INFERENCE_TRT_INFER_H_
#define WAYTOUS_DEEPINFER_INFERENCE_TRT_INFER_H_
#pragma once

#include <string>
#include <yaml-cpp/yaml.h>
// GPU TensorRT
#include <cuda_runtime_api.h>
#include "NvInfer.h"
#include "NvInferPlugin.h"
#include "NvOnnxParser.h"

#include "base/blob.h"
#include "common/file.h"
#include "common/register.h"

#include "interfaces/base_unit.h"
#include "libs/ios/normal_ios.h"
#include "libs/inferences/tensorrt/trt_utils.h"
#include "libs/inferences/tensorrt/trt_calibrator.h"

namespace waytous {
namespace deepinfer {
namespace inference {

static Logger gLogger;
typedef std::map<std::string,
                 std::shared_ptr<base::Blob<float>>>
    BlobMap;

class TRTInference: public interfaces::BaseUnit {
public:
    ~TRTInference();

    bool Init(YAML::Node& node) override;
    virtual bool BuildEngine(YAML::Node& node);
    bool Exec() override;
    virtual std::string Name() override;

    std::shared_ptr<base::Blob<float>> get_blob(const std::string &name);

public:
    nvinfer1::IExecutionContext* mContext = nullptr;
    nvinfer1::ICudaEngine* mEngine = nullptr;
    nvinfer1::IRuntime* mRunTime = nullptr;
    Profiler mProfiler;

    cudaStream_t mCudaStream;
    std::vector<void*> mBindings;
    BlobMap blobs_;

    std::string engineFile;
    bool inferDynamic = false;
    int inferBatchSize = 1;
    int inputWidth, inputHeight;
};  // class TRTInferenceEngine

DEEPINFER_REGISTER_UNIT(TRTInference);

}  // namespace inference
}  // namespace deepinfer
}  // namespace waytous


#endif // header

