#include "libs/postprocessors/mobilefacenet_post.h"

namespace waytous {
namespace deepinfer {
namespace postprocess {


bool MobileFaceNetPostProcess::Init(YAML::Node& node, YAML::Node& globalParamNode) {
    if(!BaseUnit::Init(node, globalParamNode)){
        LOG_WARN << "Init trades postprocess error";
        return false;
    };
    landmarkNumber = node["landmarkNumber"].as<int>();
    return true;
}


bool MobileFaceNetPostProcess::Exec(){
    if (inputNames.size() != 1 + outputNames.size()){
        LOG_ERROR << "yolov5 postprocess, inputsize != 4 + ouputsize.";
        return false;
    }

    auto input = std::dynamic_pointer_cast<ios::NormalIO>(interfaces::GetIOPtr(inputNames[0]));
    if (input == nullptr){
        LOG_ERROR << "WHENet postprocess input " << inputNames[0] << " haven't been init or doesn't exist.";
        return false;
    }

    std::vector<base::Image8UPtr> inputImages;
    for(int j=1; j<inputNames.size(); j++){
        auto iName = inputNames[j];
        auto iptr = std::dynamic_pointer_cast<ios::CameraSrcOut>(interfaces::GetIOPtr(iName));
        if (iptr == nullptr){
            LOG_ERROR << "YoloV5 postprocess input " << iName << " haven't been init or doesn't exist.";
            return false;
        }
        inputImages.push_back(iptr->img_ptr_);
    }

    const float* points = input->data_->cpu_data();
    for(int b=0; b<inputImages.size(); b++){
        auto outName = outputNames[b];
        float img_width = float(inputImages[b]->cols());
        float img_height = float(inputImages[b]->rows());
        float scalex = inputWidth / img_width;
        float scaley = inputHeight / img_height;
        if(fixAspectRatio){
            scalex = scaley = std::min(scalex, scaley);
        }
        auto face_landmarks = std::make_shared<ios::Landmarks>(ios::Landmarks());
        for(int l=0; l < landmarkNumber; l++){
            ios::Point2DPtr lptr = std::make_shared<ios::Point2D>(ios::Point2D());
            lptr->x = points[l * 2] * inputWidth / scalex;
            lptr->y = points[l * 2 + 1] * inputHeight / scaley;
            face_landmarks->landmarks.push_back(lptr);
        }
        interfaces::SetIOPtr(outputNames[b], face_landmarks);
    }
    
    return true;
}


std::string MobileFaceNetPostProcess::Name() {
    return "MobileFaceNetPostProcess";
};


}  // namespace postprocess
}  // namespace deepinfer
}  // namespace waytous


