#ifndef WAYTOUS_DEEPINFER_INFERENCE_TRT_UTILS_H_
#define WAYTOUS_DEEPINFER_INFERENCE_TRT_UTILS_H_
#pragma once

#include <iostream>
#include <numeric>
#include <iterator>
#include <string>
#include <vector>
#include <map>

#include "NvInfer.h"
#include "cuda.h"
#include "cuda_runtime.h"
#include "common/common.h"

namespace waytous {
namespace deepinfer {
namespace inference {

#if NV_TENSORRT_MAJOR >= 8
#define TRT_NOEXCEPT noexcept
#define TRT_CONST_ENQUEUE const
#else
#define TRT_NOEXCEPT
#define TRT_CONST_ENQUEUE
#endif


namespace Tn
{
    template<typename T> 
    void write(char*& buffer, const T& val)
    {
        *reinterpret_cast<T*>(buffer) = val;
        buffer += sizeof(T);
    }

    template<typename T> 
    void read(const char*& buffer, T& val)
    {
        val = *reinterpret_cast<const T*>(buffer);
        buffer += sizeof(T);
    }
} // Tn

namespace Yolo{
    static const int maxNumAnchorsPerLevel = 10; // every level can have 10 anchors at most.
    struct YoloKernel{
        int width;
        int height;
        float anchors[maxNumAnchorsPerLevel * 2];
    };

}//Yolo kernel


class Profiler : public nvinfer1::IProfiler
{
public:
    struct Record
    {
        float time{0};
        int count{0};
    };
    void printTime(const int& runTimes)
    {
        //std::cout << "========== " << mName << " profile ==========" << std::endl;
        float totalTime = 0;
        std::string layerNameStr = "TensorRT layer name";
        int maxLayerNameLength = std::max(static_cast<int>(layerNameStr.size()), 70);
        for (const auto& elem : mProfile)
        {
            totalTime += elem.second.time;
            maxLayerNameLength = std::max(maxLayerNameLength, static_cast<int>(elem.first.size()));
        }
        std::cout<< " total runtime = " << totalTime/(runTimes + 1e-5) << " ms " << std::endl;
    }

    virtual void reportLayerTime(const char* layerName, float ms)
    {
        mProfile[layerName].count++;
        mProfile[layerName].time += ms;
    }
private:
    std::map<std::string, Record> mProfile;
};


class Logger : public nvinfer1::ILogger
{
public:
    Logger(Severity severity = Severity::kWARNING)
            : reportableSeverity(severity)
    {
    }

    void log(Severity severity, const char* msg) override
    {
        // suppress messages with severity enum value greater than the reportable
        if (severity > reportableSeverity)
            return;

        switch (severity)
        {
            case Severity::kINTERNAL_ERROR: std::cerr << "INTERNAL_ERROR: "; break;
            case Severity::kERROR: std::cerr << "ERROR: "; break;
            case Severity::kWARNING: std::cerr << "WARNING: "; break;
            case Severity::kINFO: std::cerr << "INFO: "; break;
            default: std::cerr << "UNKNOWN: "; break;
        }
        std::cerr << msg << std::endl;
    }
    Severity reportableSeverity;
};


inline int64_t volume(const nvinfer1::Dims& d)
{
    return std::accumulate(d.d, d.d + d.nbDims, 1, std::multiplies<int64_t>());
}

inline unsigned int getElementSize(nvinfer1::DataType t)
{
    switch (t)
    {
        case nvinfer1::DataType::kINT32: return 4;
        case nvinfer1::DataType::kFLOAT: return 4;
        case nvinfer1::DataType::kHALF: return 2;
        case nvinfer1::DataType::kINT8: return 1;
        // case nvinfer1::DataType::kBOOL: return 1;
    }
    throw std::runtime_error("Invalid DataType.");
    return 0;
}

inline void* safeCudaMalloc(size_t memSize)
{
    void* deviceMem;
    CUDA_CHECK(cudaMalloc(&deviceMem, memSize));
    if (deviceMem == nullptr)
    {
        std::cerr << "Out of memory" << std::endl;
        exit(1);
    }
    return deviceMem;
}



}  // namespace inference
}  // namespace deepinfer
}  // namespace waytous


#endif // header





