#ifndef WAYTOUS_DEEPINFER_INFERENCE_TRT_YOLOV5_INFER_H_
#define WAYTOUS_DEEPINFER_INFERENCE_TRT_YOLOV5_INFER_H_

#include "libs/inferences/tensorrt/trt_infer.h"
#include "libs/inferences/tensorrt/trt_yolov5_layer.h"

using namespace nvinfer1;


namespace waytous {
namespace deepinfer {
namespace inference {


// TensorRT weight files have a simple space delimited format:
// [type] [size] <data x size in hex>
inline std::map<std::string, nvinfer1::Weights> loadWeights(const std::string file) {
    LOG_INFO << "Loading weights: " << file ;
    std::map<std::string, nvinfer1::Weights> weightMap;

    // Open weights file
    std::ifstream input(file);
    assert(input.is_open() && "Unable to load weight file. please check if the .wts file path is right!!!!!!");

    // Read number of weight blobs
    int32_t count;
    input >> count;
    assert(count > 0 && "Invalid weight map file.");

    while (count--)
    {
        nvinfer1::Weights wt{ nvinfer1::DataType::kFLOAT, nullptr, 0 };
        uint32_t size;

        // Read name and type of blob
        std::string name;
        input >> name >> std::dec >> size;
        wt.type = nvinfer1::DataType::kFLOAT;

        // Load blob
        uint32_t* val = reinterpret_cast<uint32_t*>(malloc(sizeof(val) * size));
        for (uint32_t x = 0, y = size; x < y; ++x)
        {
            input >> std::hex >> val[x];
        }
        wt.values = val;

        wt.count = size;
        weightMap[name] = wt;
    }

    return weightMap;
}


inline static int get_width(int x, float gw, int divisor = 8) {
    return int(ceil((x * gw) / divisor)) * divisor;
}


inline static int get_depth(int x, float gd) {
    if (x == 1) return 1;
    int r = round(x * gd);
    if (x * gd - int(x * gd) == 0.5 && (int(x * gd) % 2) == 0) {
        --r;
    }
    return std::max<int>(r, 1);
}



class YoloV5TRTInference: public TRTInference{

public:
    bool BuildEngine(YAML::Node& node) override;
    std::string Name() override;

private:
    IScaleLayer* addBatchNorm2d(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input, std::string lname, float eps);

    ILayer* convBlock(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input, int outch, int ksize, int s, int g, std::string lname);

    ILayer* focus(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input, int inch, int outch, int ksize, std::string lname, int inputWidth, int inputHeight);

    ILayer* bottleneck(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input, int c1, int c2, bool shortcut, int g, float e, std::string lname);

    ILayer* bottleneckCSP(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input, int c1, int c2, int n, bool shortcut, int g, float e, std::string lname);

    ILayer* C3(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input, int c1, int c2, int n, bool shortcut, int g, float e, std::string lname);

    ILayer* SPP(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input, int c1, int c2, int k1, int k2, int k3, std::string lname);

    ILayer* SPPF(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, ITensor& input, int c1, int c2, int k, std::string lname);

    std::vector<std::vector<float>> getAnchors(std::map<std::string, Weights>& weightMap, std::string lname, int numAnchorsPerLevel);

    IPluginV2Layer* addYoLoLayer(INetworkDefinition *network, std::map<std::string, Weights>& weightMap, std::string lname, std::vector<IConvolutionLayer*> dets,
        int classNumber, int inputWidth, int inputHeight, int maxOutputBBoxCount, int numAnchorsPerLevel, int startScale=8);

    IPluginV2Layer *addBatchedNMSLayer(INetworkDefinition *network, IPluginV2Layer *yolo, int num_classes, int top_k, int keep_top_k, float score_thresh, 
        float iou_thresh, bool is_normalized = false, bool clip_boxes = false);

};



}  // namespace inference
}  // namespace deepinfer
}  // namespace waytous


#endif // header
