#include "libs/postprocessors/trades_post.h"

namespace waytous {
namespace deepinfer {
namespace postprocess {


bool TraDesPostProcess::Init(YAML::Node& node) {
    if(!BaseUnit::Init(node)){
        LOG_WARN << "Init trades postprocess error";
        return false;
    };
    
    inputHeight = node["inputHeight"].as<int>();
    inputWidth = node["inputWidth"].as<int>();
    inferBatchSize = node["inferBatchSize"].as<int>();
    fixAspectRatio = node["fixAspectRatio"].as<bool>();
    scoreThreshold = node["scoreThreshold"].as<float>();
    truncatedThreshold = node["truncatedThreshold"].as<float>();
    classNumber = node["classNumber"].as<int>();
    classNames = node["classNames"].as<std::vector<std::string>>();
    topK = node["topK"].as<int>();
    downScale = node["downScale"].as<int>();
    segDims = node["segDims"].as<int>();
    maxCntsLength = node["maxCntsLength"].as<int>();

    output_length_ptr.reset(new base::Blob<int>({inferBatchSize, 1}));
    output_length_ptr->cpu_data();
    bboxes_ptr.reset(new base::Blob<float>({inferBatchSize, topK, 4 + 1 + 1})); // x1 y1 x2 y2 score label
    bboxes_ptr->cpu_data(); // init, cpu malloc
    maskCnts_lengths_ptr.reset(new base::Blob<int>({inferBatchSize, topK})); // mask cnts
    maskCnts_lengths_ptr->cpu_data();
    maskCnts_ptr.reset(new base::Blob<int>({inferBatchSize, topK, maxCntsLength})); // mask cnts
    maskCnts_ptr->cpu_data();
    return true;
};


bool TraDesPostProcess::Exec() {
    if (inputNames.size() != 5 + outputNames.size()){
        LOG_ERROR << "trades postprocess, inputsize != 5 + ouputsize.";
        return false;
    }

    std::vector<base::BlobPtr<float>> inputs;
    for(int i=0; i<5; i++){
        auto iName = inputNames[i];
        auto p = std::dynamic_pointer_cast<ios::NormalIO>(interfaces::GetIOPtr(iName));
        if (p == nullptr){
            LOG_ERROR << "TraDeS postprocess input " << iName << " haven't been init or doesn't exist.";
            return false;
        }
        inputs.push_back(p->data_);
    }

    std::vector<base::Image8UPtr> inputImages;
    for(int j=5; j<inputNames.size(); j++){
        auto iName = inputNames[j];
        auto iptr = std::dynamic_pointer_cast<ios::CameraSrcOut>(interfaces::GetIOPtr(iName));
        if (iptr == nullptr){
            LOG_ERROR << "TraDeS postprocess input image " << iName << " haven't been init or doesn't exist.";
            return false;
        }
        inputImages.push_back(iptr->img_ptr_);
    }

    // reset output_length=0, otherwise, it will increase after every inference.
    output_length_ptr->mutable_cpu_data()[0] = 0;
    trades_postprocess(
        inputs[0]->gpu_data(),
        inputs[1]->gpu_data(),
        inputs[2]->gpu_data(),
        inputs[3]->gpu_data(),
        inputs[4]->gpu_data(),
        output_length_ptr->mutable_gpu_data(),
        bboxes_ptr->mutable_gpu_data(),
        maskCnts_lengths_ptr->mutable_gpu_data(),
        maskCnts_ptr->mutable_gpu_data(),
        topK, maxCntsLength, 
        inputWidth / downScale, inputHeight / downScale,
        classNumber, 3, segDims, scoreThreshold
    );

    auto outputLength = output_length_ptr->cpu_data();
    auto outputBoxes = bboxes_ptr->cpu_data();
    auto maskCntsLength = maskCnts_lengths_ptr->cpu_data();
    auto maskCnts = maskCnts_ptr->cpu_data();

    for(int b=0; b<inputImages.size(); b++){
        auto outName = outputNames[b];
        float img_width = float(inputImages[b]->cols());
        float img_height = float(inputImages[b]->rows());
        float scalex = inputWidth / img_width;
        float scaley = inputHeight / img_height;
        if(fixAspectRatio){
            scalex = scaley = std::min(scalex, scaley);
        }
        auto dets = std::make_shared<ios::Detection2Ds>(ios::Detection2Ds());
        for(int i = 0; i < outputLength[b]; i++){
            if(outputBoxes[b * topK * 6 + i * 6 + 4] < scoreThreshold){
                continue;
            }
            auto obj = std::make_shared<ios::Det2D>(ios::Det2D());
            obj->confidence = outputBoxes[b * topK * 6 + i * 6 + 4];
            obj->class_label = int(outputBoxes[b * topK * 6 + i * 6 + 5]);
            obj->class_name = classNames[obj->class_label];
            obj->x1= outputBoxes[b * topK * 6 + i * 6 + 0] * downScale / scalex;
            obj->y1 = outputBoxes[b * topK * 6 + i * 6 + 1] * downScale / scaley;
            obj->x2 = outputBoxes[b * topK * 6 + i * 6 + 2] * downScale/ scalex;
            obj->y2 = outputBoxes[b * topK * 6 + i * 6 + 3] * downScale/ scaley;
            obj->image_height = img_height;
            obj->image_width = img_width;
            obj->validCoordinate(); //
            // LOG_INFO << "box:" << obj->x1 << ","<< obj->y1 << ","<< obj->x2 << ","<< obj->y2;
            if((obj->x1 / img_width  < truncatedThreshold) || (obj->y1 / img_height  < truncatedThreshold) ||
            (obj->x2 / img_width  > (1 - truncatedThreshold)) || (obj->y2 / img_height  > (1 - truncatedThreshold))){
                obj->truncated = true;
            }
            obj->mask_ptr.reset(new ios::InstanceMask(
                inputWidth / downScale, inputHeight / downScale, 
                maskCnts + b * topK * maxCntsLength + i * maxCntsLength, 
                maskCntsLength[b * topK + i]
            ));
            dets->detections.push_back(obj);
        }
        interfaces::SetIOPtr(outName, dets);
    }
    return true;

};


std::string TraDesPostProcess::Name() {
    return "TraDesPostProcess";
};


}  // namespace postprocess
}  // namespace deepinfer
}  // namespace waytous



