

#include "libs/inferences/tensorrt/trt_infer.h"

namespace waytous {
namespace deepinfer {
namespace inference {


TRTInference::~TRTInference(){
    if(mCudaStream){
        cudaStreamSynchronize(mCudaStream);
        cudaStreamDestroy(mCudaStream);
    }

    if (mRunTime)
        mRunTime->destroy();
    if (mContext)
        mContext->destroy();
    if (mEngine)
        mEngine->destroy();
}


bool TRTInference::Init(YAML::Node& configNode){
    if(!BaseUnit::Init(configNode)){
        LOG_WARN << "Init trt engine error";
        return false;
    };

    inputHeight = configNode["inputHeight"].as<int>();
    inputWidth = configNode["inputWidth"].as<int>();
    inferBatchSize = configNode["inferBatchSize"].as<int>();
    engineFile = configNode["engineFile"].as<std::string>();
    inferDynamic = configNode["inferDynamic"].as<bool>();
    engineFile = common::GetAbsolutePath(common::ConfigRoot::GetRootPath(), engineFile);
    if(!waytous::deepinfer::common::PathExists(engineFile)){
        LOG_INFO << "Tensorrt engine haven't been built, built from saved weights.";
        BuildEngine(configNode);
    }
    initLibNvInferPlugins(&gLogger, "");
    std::fstream file;

    file.open(engineFile, std::ios::binary | std::ios::in);
    if(!file.is_open())
    {
        LOG_WARN << "Read engine file" << engineFile <<" failed, please check.";
        return false;
    }
    file.seekg(0, std::ios::end);
    int length = file.tellg();
    file.seekg(0, std::ios::beg);
    std::unique_ptr<char[]> data(new char[length]);
    file.read(data.get(), length);

    file.close();

    LOG_INFO << "Deserializing tensorrt engine.";
    mRunTime = nvinfer1::createInferRuntime(gLogger);
    assert(mRunTime != nullptr);
    mEngine= mRunTime->deserializeCudaEngine(data.get(), length, NULL);
    assert(mEngine != nullptr);

    mContext = mEngine->createExecutionContext();
    assert(mContext != nullptr);
    mContext->setProfiler(&mProfiler);
    int nbBindings = mEngine->getNbBindings();

    if(nbBindings != inputNames.size() + outputNames.size()){
        LOG_ERROR << " model input(" << inputNames.size() << ")+output(" << 
            outputNames.size() << ") != nbBindings:" << nbBindings;
        return false;
    }

    for (auto inputName: inputNames){
        auto input = std::dynamic_pointer_cast<ios::NormalIO>(interfaces::GetIOPtr(inputName));
        if (input == nullptr){
            LOG_ERROR << "inference engine input " << inputName << " haven't been init or doesn't exist.";
            return false;
        }
        blobs_.insert(std::make_pair(inputName, input->data_));
        auto binding = input->data_->mutable_gpu_data();
        mBindings.emplace_back(static_cast<void*>(binding));
    }

    for (int i = 0; i < nbBindings; ++i)
    {   
        bool isInput = mEngine->bindingIsInput(i);
        if(isInput){
            continue;
        }
        nvinfer1::Dims dims = mEngine->getBindingDimensions(i);
        nvinfer1::DataType dtype = mEngine->getBindingDataType(i);
        std::string name = mEngine->getBindingName(i);
        LOG_INFO << "engine output name: " << name;
        
        std::vector<int> shape;
        shape.push_back(inferBatchSize);
        for(int dindex = 0; dindex < dims.nbDims; dindex++){
            shape.push_back(dims.d[dindex]);
        };
        //auto blob = std::make_shared<base::Blob<float>>(base::Blob<float>(shape)); // Blob(Blob)=delete
        base::BlobPtr<float> blob;
        blob.reset(new base::Blob<float>(shape));
        blobs_.insert(std::make_pair(name, blob));
        auto output = std::make_shared<ios::NormalIO>(ios::NormalIO(blob));
        interfaces::SetIOPtr(name, output);
        auto binding = blob->mutable_gpu_data();
        mBindings.emplace_back(static_cast<void*>(binding));
    }

    CUDA_CHECK(cudaStreamCreate(&mCudaStream));
    LOG_INFO << "Deserialized tensorrt engine.";
    return true;
}


bool TRTInference::BuildEngine(YAML::Node& configNode){
    // default convert onnx to engine

    int maxBatchSize = configNode["maxBatchSize"].as<int>();
    initLibNvInferPlugins(&gLogger, "");
    // int number;
    // auto lst = getPluginRegistry()->getPluginCreatorList(&number);
    // for(int i=0; i<number; i++){
    //     LOG_INFO << lst[i]->getPluginName()<<","<<lst[i]->getPluginVersion()<<","<<lst[i]->getPluginNamespace();
    // }
    nvinfer1::IHostMemory* modelStream{nullptr};
    // int verbosity = (int) nvinfer1::ILogger::Severity::kVERBOSE;
    int verbosity = (int) nvinfer1::ILogger::Severity::kINFO;
    nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(gLogger);
    nvinfer1::INetworkDefinition* network = builder->createNetworkV2(1U);
    auto parser = nvonnxparser::createParser(*network, gLogger);

    std::string weightsPath = configNode["weightsPath"].as<std::string>();
    weightsPath = common::GetAbsolutePath(common::ConfigRoot::GetRootPath(), weightsPath);
    if (!parser->parseFromFile(weightsPath.c_str(), verbosity))
    {
        LOG_ERROR << "Failed to parse weights file: " << weightsPath;
        return false;
    }

    builder->setMaxBatchSize(maxBatchSize);
    builder->setMaxWorkspaceSize(2UL << 30);// 2G

    int8EntroyCalibrator *calibrator = nullptr;
    int runMode = configNode["runMode"].as<int>();
    if (runMode == 1)//fp16
    {
        LOG_INFO << "Set FP16 Mode.";
        if (!builder->platformHasFastFp16())
            LOG_INFO << "Notice: the platform do not has fast for fp16" ;
        builder->setFp16Mode(true);
    }
    else if(runMode == 2){//int8
        LOG_ERROR << "No supported int8";
        /*LOG_INFO <<"Set Int8 Mode";
        if (!builder->platformHasFastInt8())
            LOG_INFO << "Notice: the platform do not has fast for int8.";
        builder->setInt8Mode(true);
        if(configNode["calibImgPathFile"].as<std::string>().size() > 0){
            std::vector<float> inputMean = configNode["inputMean"].as<std::vector<float>>();
            std::vector<float> inputStd = configNode["inputStd"].as<std::vector<float>>();
            calibrator = new int8EntroyCalibrator(
                configNode["calibImgPathFile"].as<std::string>(), 
                configNode["calibTableCache"].as<std::string>(), maxBatchSize, 
                configNode["inputWidth"].as<int>(), configNode["inputHeight"].as<int>(), 
                inputMean, inputStd, useBGR, fixAspectRatio
            );
            builder->setInt8Calibrator(calibrator);
        }
        else{
            LOG_ERROR << "Not imgs for calib int8. " << configNode["calibImgPathFile"].as<std::string>();
            return false;
        }
        */
    }

    LOG_INFO << "Begin building engine..." ;
    nvinfer1::ICudaEngine* engine = builder->buildCudaEngine(*network);
    if (!engine){
        std::string error_message ="Unable to create engine";
        gLogger.log(nvinfer1::ILogger::Severity::kERROR, error_message.c_str());
    }
    LOG_INFO << "End building engine..." ;

    // Serialize the engine, then close everything down.
    modelStream = engine->serialize();
    engine->destroy();
    network->destroy();
    builder->destroy();
    parser->destroy();
    assert(modelStream != nullptr);
    if(calibrator){
        delete calibrator;
        calibrator = nullptr;
    }
    // write
    std::ofstream file(engineFile, std::ios::binary);
    assert(file);
    file.write(static_cast<char*>(modelStream->data()), modelStream->size());
    assert(!file.fail());
    modelStream->destroy();
    CUDA_CHECK(cudaStreamCreate(&mCudaStream));
    LOG_INFO << "End writing engine";
    return true;
}


std::shared_ptr<base::Blob<float>> TRTInference::get_blob(const std::string &name) {
    auto iter = blobs_.find(name);
    if (iter == blobs_.end()) {
        return nullptr;
    }
    return iter->second;
}


bool TRTInference::Exec(){
    CUDA_CHECK(cudaStreamSynchronize(mCudaStream));
    for (auto name : inputNames) {
        auto blob = get_blob(name);
        if (blob != nullptr) {
            blob->gpu_data();
        }
    }
    // If `out_blob->mutable_cpu_data()` is invoked outside,
    // HEAD will be set to CPU, and `out_blob->mutable_gpu_data()`
    // after `enqueue` will copy data from CPU to GPU,
    // which will overwrite the `inference` results.
    // `out_blob->gpu_data()` will set HEAD to SYNCED,
    // then no copy happends after `enqueue`.
    for (auto name : outputNames) {
        auto blob = get_blob(name);
        if (blob != nullptr) {
            blob->gpu_data();
        }
    }
    if(inferDynamic){
        mContext->enqueue(inferBatchSize, &mBindings[0], mCudaStream, nullptr);
    }else{
        mContext->executeV2(&mBindings[0]);
    }
    CUDA_CHECK(cudaStreamSynchronize(mCudaStream));

    for (auto name : outputNames) {
        auto blob = get_blob(name);
        if (blob != nullptr) {
            // LOG_INFO << "output name: " << name;
            blob->mutable_gpu_data();
        }
    }
    return true;
}


std::string TRTInference::Name(){
    return "TRTInference";
}


}  // namespace inference
}  // namespace deepinfer
}  // namespace waytous



