
#include "libs/postprocessors/trades_post_gpu.h"

namespace waytous {
namespace deepinfer {
namespace postprocess {

__device__ float Logist(float data){ return 1./(1. + exp(-data)); }


__global__ void trades_postprocess_kernel(const float *hm, const float *reg, const float *wh, const float* seg_weights, const float* seg_feat,
        int* output_length, float *output_bboxes, int* output_mask_cnt_lengths, int* output_mask_cnts, const int topK, const int max_cnts_length, 
        const int w, const int h, const int classes, const int kernel_size, const int seg_dims, const float score_threshold) {
    int idx = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
    if (idx >= w * h * classes) return;
    int padding = (kernel_size - 1) / 2;
    int offset = -padding;
    int stride = w * h;
    int grid_x = idx % w;
    int grid_y = (idx / w) % h;
    int cls = idx / w / h ;
    int  l, m;
    int reg_index = idx - cls * stride;
    float c_x, c_y;
    // float objProb = Logist(hm[idx]);
    float objProb = hm[idx];
    if (objProb > score_threshold) {
        float max = -1;
        int max_index = 0;
        for (l = 0; l < kernel_size; ++l){
            for (m = 0; m < kernel_size; ++m) {
                int cur_x = offset + l + grid_x;
                int cur_y = offset + m + grid_y;
                int cur_index = cur_y * w + cur_x + stride * cls;
                int valid = (cur_x >= 0 && cur_x < w && cur_y >= 0 && cur_y < h);
                float val = (valid != 0) ? (hm[cur_index]) : -1;
                max_index = (val > max) ? cur_index : max_index;
                max = (val > max) ? val : max;
            }
        }
        if(idx == max_index){
            int resCount = (int) atomicAdd(output_length, 1);
            int newCount = resCount % topK;
            // printf("%d: %d, %f ", resCount, idx, max);
            if (resCount < topK){ // || det->score < objProb
                c_x = grid_x + reg[reg_index];
                c_y = grid_y + reg[reg_index + stride];
                float boxWidth = pow(Logist(wh[reg_index]), 2) * w;
                float boxHeight = pow(Logist(wh[reg_index + stride]), 2) * h;
                output_bboxes[newCount * 6 + 0] = (c_x -  boxWidth / 2);
                output_bboxes[newCount * 6 + 1] = (c_y - boxHeight / 2);
                output_bboxes[newCount * 6 + 2] = (c_x + boxWidth / 2);
                output_bboxes[newCount * 6 + 3] = (c_y + boxHeight / 2);
                output_bboxes[newCount * 6 + 4] = objProb;
                output_bboxes[newCount * 6 + 5] = cls;

                // get instance mask
                int x1 = (int)(c_x -  boxWidth / 2), y1 = (int)(c_y - boxHeight / 2);
                int x2 = int(ceil(c_x + boxWidth / 2)), y2 = int(ceil(c_y + boxHeight / 2));
                x1 = iMAX(x1, 0); y1 = iMAX(y1, 0);
                x2 = iMIN(x2, w - 1); y2 = iMIN(y2, h - 1);
                bool maskFlag = false, mask_j;
                int cnt = 0, num = 0;

                int row_left_num = w - (x2 - x1); // w-(x2-x1+1)
                cnt = y1 * w + x1;// init with number of zeros
                for(int i = y1; i < y2; i++){ // i<=y2
                    for(int j = x1; j < x2; j++){ // j<=x2
                        int pos = i * w + j;
                        float mask_value = 0;
                        for(int k = 0; k < seg_dims; k++){
                            mask_value += (seg_weights[k * stride + reg_index] * seg_feat[k * stride + pos]);
                        }
                        mask_j = mask_value >= 0;// sigmoid(mask_value) >= 0.5
                        if(mask_j != maskFlag && num < max_cnts_length){ 
                            output_mask_cnts[newCount * max_cnts_length + num] = cnt; 
                            num++; 
                            cnt = 0;
                            maskFlag = mask_j;
                        }
                        cnt++;
                    }
                    if(maskFlag && num < max_cnts_length){
                        output_mask_cnts[newCount * max_cnts_length + num] = cnt; 
                        num++; 
                        cnt = row_left_num;
                        maskFlag = false;
                    }else{
                        cnt += row_left_num;
                    }
                }
                output_mask_cnt_lengths[newCount] = num;
            }
        }
    }
}


void trades_postprocess(const float *hm, const float *reg, const float *wh, const float* seg_weights, const float* seg_feat, 
        int* output_length, float *output_bboxes, int* output_mask_cnt_lengths, int* output_mask_cnts, const int topK, const int max_cnts_length, 
        const int w, const int h, const int number_classes, const int kernerl_size,  const int seg_dims, const float score_threshold){
    uint num = w * h * number_classes;
    uint block = 512;
    dim3 threads = dim3(block, 1, 1);
    trades_postprocess_kernel<<<common::cudaGridSize(num, block), threads>>>(hm, reg, wh, seg_weights, seg_feat, 
             output_length, output_bboxes, output_mask_cnt_lengths, output_mask_cnts, topK, max_cnts_length,
             w, h, number_classes, kernerl_size, seg_dims, score_threshold);
    CUDA_CHECK(cudaDeviceSynchronize());
};

}  // namespace postprocess
}  // namespace deepinfer
}  // namespace waytous
