#include "libs/postprocessors/whenet_post.h"

namespace waytous {
namespace deepinfer {
namespace postprocess {


bool WHENetPostProcess::Init(YAML::Node& node) {
    if(!BaseUnit::Init(node)){
        LOG_WARN << "Init WHENet postprocess error";
        return false;
    };
    outputYawLength = node["outputYawLength"].as<int>();
    outputPitchLength = node["outputPitchLength"].as<int>();
    outputRollLength = node["outputRollLength"].as<int>();
    return true;
}


bool WHENetPostProcess::Exec(){
    std::vector<base::BlobPtr<float>> inputs; // model outputs
    for(int i=0; i<inputNames.size(); i++){
        auto inputName = inputNames[i];
        auto ptr = std::dynamic_pointer_cast<ios::NormalIO>(interfaces::GetIOPtr(inputName));
        if (ptr == nullptr){
            LOG_ERROR << "WHENet postprocess input " << inputName << " haven't been init or doesn't exist.";
            return false;
        }
        inputs.push_back(ptr->data_);
    }

    //post
    float* yaw = inputs[0]->mutable_cpu_data();
    float* roll = inputs[1]->mutable_cpu_data();
    float* pitch = inputs[2]->mutable_cpu_data();

    for(int b=0; b<outputNames.size(); b++){
        auto headpose = std::make_shared<ios::HeadPose>(ios::HeadPose());
        interfaces::SetIOPtr(outputNames[b], headpose);
        
        float yawSum = 0, pitchSum=0, rollSum = 0;
        // yaw
        common::softmax(yaw + b * outputYawLength, outputYawLength);
        for(int i=0; i<outputYawLength; i++){
            yawSum += yaw[i + b * outputYawLength] * i;
        }
        yawSum *= 3.0;
        yawSum -= 180;
        headpose->yaw = yawSum;

        // roll
        common::softmax(roll + b * outputRollLength, outputRollLength);
        for(int i=0; i<outputRollLength; i++){
            rollSum += roll[i + b * outputRollLength] * i;
        }
        rollSum *= 3.0;
        rollSum -= 99;
        headpose->roll = rollSum;

        // pitch
        common::softmax(pitch + b * outputPitchLength, outputPitchLength);
        for(int i=0; i<outputPitchLength; i++){
            pitchSum += pitch[i + b * outputPitchLength] * i;
        }
        pitchSum *= 3.0;
        pitchSum -= 99;
        headpose->pitch = pitchSum;
    }
    return true;
}


std::string WHENetPostProcess::Name() {
    return "WHENetPostProcess";
};


}  // namespace postprocess
}  // namespace deepinfer
}  // namespace waytous


