#include "libs/postprocessors/yolov5_post.h"

namespace waytous {
namespace deepinfer {
namespace postprocess {


bool YoloV5PostProcess::Init(YAML::Node& node, YAML::Node& globalParamNode) {
    if(!BaseUnit::Init(node, globalParamNode)){
        LOG_WARN << "Init trades postprocess error";
        return false;
    };
    classNames = node["classNames"].as<std::vector<std::string>>();
    truncatedThreshold = node["truncatedThreshold"].as<float>(); // default 5%
    scoreThreshold = node["scoreThreshold"].as<float>();
    keepTopK = node["keepTopK"].as<int>();
    return true;
}


bool YoloV5PostProcess::Exec(){
    if (inputNames.size() != 4 + outputNames.size()){
        LOG_ERROR << "yolov5 postprocess, inputsize != 4 + ouputsize.";
        return false;
    }
    std::vector<base::BlobPtr<float>> inputs; // model outputs
    for(int i=0; i<4; i++){
        auto inputName = inputNames[i];
        auto ptr = std::dynamic_pointer_cast<ios::NormalIO>(interfaces::GetIOPtr(inputName));
        if (ptr == nullptr){
            LOG_ERROR << "YoloV5 postprocess input " << inputName << " haven't been init or doesn't exist.";
            return false;
        }
        inputs.push_back(ptr->data_);
    }

    std::vector<base::Image8UPtr> inputImages;
    for(int j=4; j<inputNames.size(); j++){
        auto iName = inputNames[j];
        auto iptr = std::dynamic_pointer_cast<ios::CameraSrcOut>(interfaces::GetIOPtr(iName));
        if (iptr == nullptr){
            LOG_ERROR << "YoloV5 postprocess input " << iName << " haven't been init or doesn't exist.";
            return false;
        }
        inputImages.push_back(iptr->img_ptr_);
    }

    //post
    int* num_detections = static_cast<int*>(static_cast<void*>(inputs[0]->mutable_cpu_data()));// it's type is int32
    const float* nmsed_boxes = inputs[1]->cpu_data();
    const float* nmsed_scores = inputs[2]->cpu_data();
    const float* nmsed_classes = inputs[3]->cpu_data();

    for(int b=0; b<inputImages.size(); b++){
        auto outName = outputNames[b];
        float img_width = float(inputImages[b]->cols());
        float img_height = float(inputImages[b]->rows());
        float scalex = inputWidth / img_width;
        float scaley = inputHeight / img_height;
        if(fixAspectRatio){
            scalex = scaley = std::min(scalex, scaley);
        }
        auto dets = std::make_shared<ios::Detection2Ds>(ios::Detection2Ds());
        for(int i = 0; i < num_detections[b]; i++){
            if(nmsed_scores[b * keepTopK * 1 + i] < scoreThreshold){
                continue;
            }
            ios::Det2DPtr obj = std::make_shared<ios::Det2D>(ios::Det2D());
            obj->confidence = nmsed_scores[b * keepTopK * 1 + i];
            obj->class_label = int(nmsed_classes[b * keepTopK * 1 + i]);
            obj->class_name = classNames[obj->class_label];
            obj->x1= nmsed_boxes[b * keepTopK * 4 + i * 4 + 0] / scalex;
            obj->y1 = nmsed_boxes[b * keepTopK * 4 + i * 4 + 1] / scaley;
            obj->x2 = nmsed_boxes[b * keepTopK * 4 + i * 4 + 2] / scalex;
            obj->y2 = nmsed_boxes[b * keepTopK * 4 + i * 4 + 3] / scaley;
            obj->image_height = img_height;
            obj->image_width = img_width;
            obj->validCoordinate(); //
            // LOG_INFO << "box:" << obj->x1 << ","<< obj->y1 << ","<< obj->x2 << ","<< obj->y2;
            if((obj->x1 / img_width  < truncatedThreshold) || (obj->y1 / img_height  < truncatedThreshold) ||
            (obj->x2 / img_width  > (1 - truncatedThreshold)) || (obj->y2 / img_height  > (1 - truncatedThreshold))){
                obj->truncated = true;
            }
            dets->detections.push_back(obj);
        }
        interfaces::SetIOPtr(outName, dets);
    }
    return true;
}


std::string YoloV5PostProcess::Name() {
    return "YoloV5PostProcess";
};


}  // namespace postprocess
}  // namespace deepinfer
}  // namespace waytous


