#include "libs/postprocessors/whenet_post.h"

namespace waytous {
namespace deepinfer {
namespace postprocess {


bool WHENetPostProcess::Init(YAML::Node& node, YAML::Node& globalParamNode) {
    if(!BaseUnit::Init(node, globalParamNode)){
        LOG_WARN << "Init trades postprocess error";
        return false;
    };
    outputYawLength = node["outputYawLength"].as<int>();
    outputPitchLength = node["outputPitchLength"].as<int>();
    outputRollLength = node["outputRollLength"].as<int>();
    return true;
}


bool WHENetPostProcess::Exec(){
    std::vector<base::BlobPtr<float>> inputs; // model outputs
    for(int i=0; i<inputNames.size(); i++){
        auto inputName = inputNames[i];
        auto ptr = std::dynamic_pointer_cast<ios::NormalIO>(interfaces::GetIOPtr(inputName));
        if (ptr == nullptr){
            LOG_ERROR << "WHENet postprocess input " << inputName << " haven't been init or doesn't exist.";
            return false;
        }
        inputs.push_back(ptr->data_);
    }

    //post
    float* yaw = inputs[0]->mutable_cpu_data();
    float* roll = inputs[1]->mutable_cpu_data();
    float* pitch = inputs[2]->mutable_cpu_data();

    auto headpose = std::make_shared<ios::HeadPose>(ios::HeadPose());
    interfaces::SetIOPtr(outputNames[0], headpose);
    
    float yawSum = 0, pitchSum=0, rollSum = 0;
    // yaw
    common::softmax(yaw, outputYawLength);
    for(int i=0; i<outputYawLength; i++){
        yawSum += yaw[i] * i;
    }
    yawSum *= 3.0;
    yawSum -= 180;
    headpose->yaw = yawSum;

    // roll
    common::softmax(roll, outputRollLength);
    for(int i=0; i<outputRollLength; i++){
        rollSum += roll[i] * i;
    }
    rollSum *= 3.0;
    rollSum -= 99;
    headpose->roll = rollSum;

    // pitch
    common::softmax(pitch, outputPitchLength);
    for(int i=0; i<outputPitchLength; i++){
        pitchSum += pitch[i] * i;
    }
    pitchSum *= 3.0;
    pitchSum -= 99;
    headpose->pitch = pitchSum;
    return true;
}


std::string WHENetPostProcess::Name() {
    return "WHENetPostProcess";
};


}  // namespace postprocess
}  // namespace deepinfer
}  // namespace waytous


