/*
 * @Descripttion: 
 * @version: 
 * @Author: wxin
 * @Date: 2023-09-03 08:07:19
 * @email: xin.wang@waytous.com
 * @LastEditors: wxin
 * @LastEditTime: 2023-09-03 11:37:24
 */
#include "undistort/undistort_v2.h"

namespace waytous {
namespace deepinfer {
namespace preprocess {


/**
 * @name: Init
 * @msg: 初始化畸变矫正节点
 * @param {Node&} node 默认接口是从node传参
 * @param {BaseIOMapPtr} pmap 保存模型pipeline的map
 * @return {bool} 是否初始化成功
 */
bool UndistortV2::Init(YAML::Node& node, interfaces::BaseIOMapPtr pmap) {
    if(!BaseUnit::Init(node, pmap)){
        LOG_WARN << "Init undistortv2 error";
        return false;
    };

    inited_ = InitParam();
    return true;
}


/**
 * @name: InitParam
 * @msg: 从pmap中获取参数用于初始化畸变矫正
 * @return {bool} 是否回去参数成功
 */  
bool UndistortV2::InitParam(){
    param_ = std::dynamic_pointer_cast<ios::CameraParam>(pMap->GetIOPtr(inputNames[0]));
    if (param_ == nullptr){
        LOG_INFO << "Undistort input " << inputNames[0] << ", camera param haven't been init or doesn't exist.";
        return false;
    }
    d_mapx_.Reshape({param_->dst_height_, param_->dst_width_});
    d_mapy_.Reshape({param_->dst_height_, param_->dst_width_});
    InitUndistortRectifyMap();
    
    dst_img = std::make_shared<base::Image8U>(param_->dst_height_, param_->dst_width_, base::Color::BGR);
    auto output = std::make_shared<ios::CameraSrcOut>(dst_img);
    pMap->SetIOPtr(outputNames[0], output);
    return true;
}


/**
 * @name: InitUndistortRectifyMap
 * @msg: 初始化畸变矫正映射map
 * @return {*}
 */  
void UndistortV2::InitUndistortRectifyMap() {
    float fx = param_->camera_intrinsic(0, 0);
    float fy = param_->camera_intrinsic(1, 1);
    float cx = param_->camera_intrinsic(0, 2);
    float cy = param_->camera_intrinsic(1, 2);
    float nfx = param_->new_camera_intrinsic(0, 0);
    float nfy = param_->new_camera_intrinsic(1, 1);
    float ncx = param_->new_camera_intrinsic(0, 2);
    float ncy = param_->new_camera_intrinsic(1, 2);
    float k1 = param_->distortion_coefficients(0, 0);
    float k2 = param_->distortion_coefficients(0, 1);
    float p1 = param_->distortion_coefficients(0, 2);
    float p2 = param_->distortion_coefficients(0, 3);
    float k3 = param_->distortion_coefficients(0, 4);
    float k4 = param_->distortion_coefficients(0, 5);
    float k5 = param_->distortion_coefficients(0, 6);
    float k6 = param_->distortion_coefficients(0, 7);
    float s1 = param_->distortion_coefficients(0, 8);
    float s2 = param_->distortion_coefficients(0, 9);
    float s3 = param_->distortion_coefficients(0, 10);
    float s4 = param_->distortion_coefficients(0, 11);

    Eigen::Matrix3f R;
    R << 1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f;
    Eigen::Matrix3f Rinv = R.inverse();

    for (int v = 0; v < param_->dst_height_; ++v) {
        float *x_ptr = d_mapx_.mutable_cpu_data() + d_mapx_.offset(v);
        float *y_ptr = d_mapy_.mutable_cpu_data() + d_mapy_.offset(v);
        for (int u = 0; u < param_->dst_width_; ++u) {
            Eigen::Matrix<float, 3, 1> xy1;
            xy1 << (static_cast<float>(u) - ncx) / nfx,
                (static_cast<float>(v) - ncy) / nfy, 1;
            auto XYW = Rinv * xy1;
            double nx = XYW(0, 0) / XYW(2, 0);
            double ny = XYW(1, 0) / XYW(2, 0);
            double r_square = nx * nx + ny * ny;
            // double kr = (1 + ((k3*r2 + k2)*r2 + k1)*r2)/(1 + ((k6*r2 + k5)*r2 + k4)*r2);
            // double xd = (x*kr + p1*_2xy + p2*(r2 + 2*x2) + s1*r2+s2*r2*r2);
            // double yd = (y*kr + p1*(r2 + 2*y2) + p2*_2xy + s3*r2+s4*r2*r2);
            double scale = (1 + r_square * (k1 + r_square * (k2 + r_square * k3)));
            scale = scale / (1 + r_square * (k4 + r_square * (k5 + r_square * k6)));
            double nnx =
                nx * scale + 2 * p1 * nx * ny + p2 * (r_square + 2 * nx * nx) + s1 * r_square + s2 * r_square * r_square;
            double nny =
                ny * scale + p1 * (r_square + 2 * ny * ny) + 2 * p2 * nx * ny + s3 * r_square + s4 * r_square * r_square;
            x_ptr[u] = static_cast<float>(nnx * fx + cx);
            y_ptr[u] = static_cast<float>(nny * fy + cy);
        }
    }
}


/**
 * @name: Exec
 * @msg: 推理，执行畸变矫正
 * @return {bool} 是否正常执行
 */ 
bool UndistortV2::Exec(){
    auto iptr = std::dynamic_pointer_cast<ios::CameraSrcOut>(pMap->GetIOPtr(inputNames[1]));
    if (iptr == nullptr){
        LOG_ERROR << "Undistort input " << inputNames[1] << " haven't been init or doesn't exist.";
        return false;
    }
    if (!inited_) {
        inited_ = InitParam();
        if(!inited_){
            LOG_WARN << "Undistortion param not init.";
            return false;
        }
    }

    auto src_img = iptr->img_ptr_;

    NppiInterpolationMode remap_mode = NPPI_INTER_LINEAR;
    NppiSize src_image_size, dst_image_size;
    src_image_size.width = param_->src_width_;
    src_image_size.height = param_->src_height_;
    dst_image_size.width = param_->dst_width_;
    dst_image_size.height = param_->dst_height_;
    NppiRect remap_roi = {0, 0, param_->src_width_, param_->src_height_};

    NppStatus status;
    int d_map_step = static_cast<int>(d_mapx_.shape(1) * sizeof(float));
    switch (src_img->channels()) {
        case 1:
            status = nppiRemap_8u_C1R(
                src_img->gpu_data(), src_image_size, src_img->width_step(), remap_roi,
                d_mapx_.gpu_data(), d_map_step, d_mapy_.gpu_data(), d_map_step,
                dst_img->mutable_gpu_data(), dst_img->width_step(), dst_image_size,
                remap_mode);
            break;
        case 3:
            status = nppiRemap_8u_C3R(
                src_img->gpu_data(), src_image_size, src_img->width_step(), remap_roi,
                d_mapx_.gpu_data(), d_map_step, d_mapy_.gpu_data(), d_map_step,
                dst_img->mutable_gpu_data(), dst_img->width_step(), dst_image_size,
                remap_mode);
            break;
        default:
            LOG_ERROR << "Invalid number of channels: " << src_img->channels();
            return false;
    }

    if (status != NPP_SUCCESS) {
        LOG_ERROR << "NPP_CHECK_NPP - status = " << status;
        return false;
    }
    
    return true;
}


/**
 * @name: Name
 * @msg: 节点名称
 * @return {std::string}
 */ 
std::string UndistortV2::Name() {
    return "UndistortV2";
};


}  // namespace preprocess
}  // namespace deepinfer
}  // namespace waytous


