#include "libs/postprocessors/trades_post.h"

namespace waytous {
namespace deepinfer {
namespace postprocess {


bool TraDesPostProcess::Init(YAML::Node& node, YAML::Node& globalParamNode) {
    if(!BaseUnit::Init(node, globalParamNode)){
        LOG_WARN << "Init trades postprocess error";
        return false;
    };

    scoreThreshold = node["scoreThreshold"].as<float>();
    truncatedThreshold = node["truncatedThreshold"].as<float>();
    classNumber = node["classNumber"].as<int>();
    classNames = node["classNames"].as<std::vector<std::string>>();
    topK = node["topK"].as<int>();
    downScale = node["downScale"].as<int>();
    segDims = node["segDims"].as<int>();
    maxCntsLength = node["maxCntsLength"].as<int>();

    output_length_ptr.reset(new base::Blob<int>({inferBatchSize, 1}));
    output_length_ptr->cpu_data();
    bboxes_ptr.reset(new base::Blob<float>({inferBatchSize, topK, 4 + 1 + 1})); // x1 y1 x2 y2 score label
    bboxes_ptr->cpu_data(); // init, cpu malloc
    maskCnts_lengths_ptr.reset(new base::Blob<int>({inferBatchSize, topK})); // mask cnts
    maskCnts_lengths_ptr->cpu_data();
    maskCnts_ptr.reset(new base::Blob<int>({inferBatchSize, topK, maxCntsLength})); // mask cnts
    maskCnts_ptr->cpu_data();
    return true;
};


bool TraDesPostProcess::Exec() {
    if (inputNames.size() != 5 + outputNames.size()){
        LOG_ERROR << "trades postprocess, inputsize != 5 + ouputsize.";
        return false;
    }

    std::vector<base::BlobPtr<float>> inputs;
    for(int i=0; i<5; i++){
        auto iName = inputNames[i];
        auto p = std::dynamic_pointer_cast<ios::NormalIO>(interfaces::GetIOPtr(iName));
        inputs.push_back(p->data_);
    }

    std::vector<base::Image8UPtr> inputImages;
    for(int j=4; j<inputNames.size(); j++){
        auto iName = inputNames[j];
        auto iptr = std::dynamic_pointer_cast<ios::CameraSrcOut>(interfaces::GetIOPtr(iName));
        inputImages.push_back(iptr->img_ptr_);
    }

    // reset output_length=0, otherwise, it will increase after every inference.
    output_length_ptr->mutable_cpu_data()[0] = 0;
    trades_postprocess(
        inputs[0]->gpu_data(),
        inputs[1]->gpu_data(),
        inputs[2]->gpu_data(),
        inputs[3]->gpu_data(),
        inputs[4]->gpu_data(),
        output_length_ptr->mutable_gpu_data(),
        bboxes_ptr->mutable_gpu_data(),
        maskCnts_lengths_ptr->mutable_gpu_data(),
        maskCnts_ptr->mutable_gpu_data(),
        topK, maxCntsLength, 
        inputWidth / downScale, inputHeight / downScale,
        classNumber, 3, segDims, scoreThreshold
    );

    auto outputLength = output_length_ptr->cpu_data();
    auto outputBoxes = bboxes_ptr->cpu_data();
    auto maskCntsLength = maskCnts_lengths_ptr->cpu_data();
    auto maskCnts = maskCnts_ptr->cpu_data();

    for(int b=0; b<inputImages.size(); b++){
        auto outName = outputNames[b];
        float img_width = float(inputImages[b]->cols());
        float img_height = float(inputImages[b]->rows());
        float scalex = inputWidth / img_width;
        float scaley = inputHeight / img_height;
        if(fixAspectRatio){
            scalex = scaley = std::min(scalex, scaley);
        }
        auto dets = std::make_shared<ios::Detection2Ds>(ios::Detection2Ds());
        for(int i = 0; i < outputLength[b]; i++){
            if(outputBoxes[b * topK * 6 + i * 6 + 4] < scoreThreshold){
                continue;
            }
            auto obj = std::make_shared<ios::Det2D>(ios::Det2D());
            obj->confidence = outputBoxes[b * topK * 6 + i * 6 + 4];
            obj->class_label = int(outputBoxes[b * topK * 6 + i * 6 + 5]);
            obj->class_name = classNames[obj->class_label];
            obj->x1= outputBoxes[b * topK * 6 + i * 6 + 0] / scalex;
            obj->y1 = outputBoxes[b * topK * 6 + i * 6 + 1] / scaley;
            obj->x2 = outputBoxes[b * topK * 6 + i * 6 + 2] / scalex;
            obj->y2 = outputBoxes[b * topK * 6 + i * 6 + 3] / scaley;
            obj->image_height = img_height;
            obj->image_width = img_width;
            obj->validCoordinate(); //
            // LOG_INFO << "box:" << obj->x1 << ","<< obj->y1 << ","<< obj->x2 << ","<< obj->y2;
            if((obj->x1 / img_width  < truncatedThreshold) || (obj->y1 / img_height  < truncatedThreshold) ||
            (obj->x2 / img_width  > (1 - truncatedThreshold)) || (obj->y2 / img_height  > (1 - truncatedThreshold))){
                obj->truncated = true;
            }
            obj->mask_ptr.reset(new ios::InstanceMask(
                inputWidth / downScale, inputHeight / downScale, 
                maskCnts + b * topK * maxCntsLength + i * maxCntsLength, 
                maskCntsLength[b * topK + i]
            ));
            dets->detections.push_back(obj);
        }
        interfaces::SetIOPtr(outName, dets);
    }
    return true;

};


__device__ float Logist(float data){ return 1./(1. + exp(-data)); }


__global__ void trades_postprocess_kernel(const float *hm, const float *reg, const float *wh, const float* seg_weights, const float* seg_feat,
        int* output_length, float *output_bboxes, int* output_mask_cnt_lengths, int* output_mask_cnts, const int topK, const int max_cnts_length, 
        const int w, const int h, const int classes, const int kernel_size, const int seg_dims, const float score_threshold) {
    int idx = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
    if (idx >= w * h * classes) return;
    int padding = (kernel_size - 1) / 2;
    int offset = -padding;
    int stride = w * h;
    int grid_x = idx % w;
    int grid_y = (idx / w) % h;
    int cls = idx / w / h ;
    int  l, m;
    int reg_index = idx - cls * stride;
    float c_x, c_y;
    // float objProb = Logist(hm[idx]);
    float objProb = hm[idx];
    if (objProb > score_threshold) {
        float max = -1;
        int max_index = 0;
        for (l = 0; l < kernel_size; ++l){
            for (m = 0; m < kernel_size; ++m) {
                int cur_x = offset + l + grid_x;
                int cur_y = offset + m + grid_y;
                int cur_index = cur_y * w + cur_x + stride * cls;
                int valid = (cur_x >= 0 && cur_x < w && cur_y >= 0 && cur_y < h);
                float val = (valid != 0) ? (hm[cur_index]) : -1;
                max_index = (val > max) ? cur_index : max_index;
                max = (val > max) ? val : max;
            }
        }
        if(idx == max_index){
            int resCount = (int) atomicAdd(output_length, 1);
            int newCount = resCount % topK;
            // printf("%d: %d, %f ", resCount, idx, max);
            if (resCount < topK){ // || det->score < objProb
                c_x = grid_x + reg[reg_index];
                c_y = grid_y + reg[reg_index + stride];
                float boxWidth = pow(Logist(wh[reg_index]), 2) * w;
                float boxHeight = pow(Logist(wh[reg_index + stride]), 2) * h;
                output_bboxes[newCount * 6 + 0] = (c_x -  boxWidth / 2);
                output_bboxes[newCount * 6 + 1] = (c_y - boxHeight / 2);
                output_bboxes[newCount * 6 + 2] = (c_x + boxWidth / 2);
                output_bboxes[newCount * 6 + 3] = (c_y + boxHeight / 2);
                output_bboxes[newCount * 6 + 4] = objProb;
                output_bboxes[newCount * 6 + 5] = cls;

                // get instance mask
                int x1 = (int)(c_x -  boxWidth / 2), y1 = (int)(c_y - boxHeight / 2);
                int x2 = int(ceil(c_x + boxWidth / 2)), y2 = int(ceil(c_y + boxHeight / 2));
                x1 = iMAX(x1, 0); y1 = iMAX(y1, 0);
                x2 = iMIN(x2, w - 1); y2 = iMIN(y2, h - 1);
                bool maskFlag = false, mask_j;
                int cnt = 0, num = 0;

                int row_left_num = w - (x2 - x1); // w-(x2-x1+1)
                cnt = y1 * w + x1;// init with number of zeros
                for(int i = y1; i < y2; i++){ // i<=y2
                    for(int j = x1; j < x2; j++){ // j<=x2
                        int pos = i * w + j;
                        float mask_value = 0;
                        for(int k = 0; k < seg_dims; k++){
                            mask_value += (seg_weights[k * stride + reg_index] * seg_feat[k * stride + pos]);
                        }
                        mask_j = mask_value >= 0;// sigmoid(mask_value) >= 0.5
                        if(mask_j != maskFlag && num < max_cnts_length){ 
                            output_mask_cnts[newCount * max_cnts_length + num] = cnt; 
                            num++; 
                            cnt = 0;
                            maskFlag = mask_j;
                        }
                        cnt++;
                    }
                    if(maskFlag && num < max_cnts_length){
                        output_mask_cnts[newCount * max_cnts_length + num] = cnt; 
                        num++; 
                        cnt = row_left_num;
                        maskFlag = false;
                    }else{
                        cnt += row_left_num;
                    }
                }
                output_mask_cnt_lengths[newCount] = num;
            }
        }
    }
}


void TraDesPostProcess::trades_postprocess(const float *hm, const float *reg, const float *wh, const float* seg_weights, const float* seg_feat, 
        int* output_length, float *output_bboxes, int* output_mask_cnt_lengths, int* output_mask_cnts, const int topK, const int max_cnts_length, 
        const int w, const int h, const int number_classes, const int kernerl_size,  const int seg_dims, const float score_threshold){
    uint num = w * h * number_classes;
    uint block = 512;
    dim3 threads = dim3(block, 1, 1);
    trades_postprocess_kernel<<<cudaGridSize(num, block), threads>>>(hm, reg, wh, seg_weights, seg_feat, 
             output_length, output_bboxes, output_mask_cnt_lengths, output_mask_cnts, topK, max_cnts_length,
             w, h, number_classes, kernerl_size, seg_dims, score_threshold);
    CUDA_CHECK(cudaDeviceSynchronize());
};


std::string TraDesPostProcess::Name() {
    return "TraDesPostProcess";
};

DEEPINFER_REGISTER_UNIT(TraDesPostProcess);

}  // namespace postprocess
}  // namespace deepinfer
}  // namespace waytous



