
#include "libs/preprocessors/resize_norm.h"

namespace waytous {
namespace deepinfer {
namespace preprocess {


bool ResizeNorm::Init(YAML::Node& node){
    CUDA_CHECK(cudaStreamCreate(&stream_));
    if(!BaseUnit::Init(node)){
        LOG_WARN << "Init resize_norm error";
        return false;
    };
    inputHeight = node["inputHeight"].as<int>();
    inputWidth = node["inputWidth"].as<int>();
    inferBatchSize = node["inferBatchSize"].as<int>();
    fixAspectRatio = node["fixAspectRatio"].as<bool>();

    if(inputNames.size() != inferBatchSize){
        LOG_ERROR << "Resize norm got wrong inputs number: " << inputNames.size() << " with infer batchsize: " << inferBatchSize;
        return false;
    }
    
    useBGR = node["useBGR"].as<bool>();
    std::vector<float> inputMean = node["inputMean"].as<std::vector<float>>();
    std::vector<float> inputStd = node["inputStd"].as<std::vector<float>>();

    dst.reset(new base::Blob<float>({inferBatchSize, 3, inputHeight, inputWidth}));
    auto dst_ptr = std::make_shared<ios::NormalIO>(ios::NormalIO(dst));
    interfaces::SetIOPtr(outputNames[0], dst_ptr);

    mean.reset(new base::Blob<float>({3, 1}, inputMean.data()));
    mean->mutable_gpu_data();
    std.reset(new base::Blob<float>({3, 1}, inputStd.data()));
    std->mutable_gpu_data();
    return true;
};


bool ResizeNorm::Exec(){
    for(int b=0; b < inputNames.size(); b++){
        auto inputName = inputNames[b];
        auto input = std::dynamic_pointer_cast<ios::CameraSrcOut>(interfaces::GetIOPtr(inputName));
        if(input == nullptr){
            LOG_ERROR << "resize norm input" << inputName << " haven't init";
            return false;
        }
        auto img = input->img_ptr_;
        resizeGPU(
            img->mutable_gpu_data(),
            img->cols(),
            img->rows(),
            img->width_step(),
            dst->mutable_gpu_data() + (b * 3 * inputHeight * inputWidth), // multi inputs
            inputWidth, inputHeight,
            mean->mutable_gpu_data(),
            std->mutable_gpu_data(),
            useBGR, fixAspectRatio, stream_
        );
    }
    // ios::NormalIOPtr dst_ptr = std::make_shared<ios::NormalIO>(ios::NormalIO(dst));
    return true;
};



std::string ResizeNorm::Name(){
    return "ResizeNorm";
};



}  // namespace preprocess
}  // namespace deepinfer
}  // namespace waytous



