#include "libs/fusioner/bayesian_fusioner.h"

namespace waytous {
namespace deepinfer {
namespace fusioner {


bool BayesianFusioner::Init(YAML::Node& node) {
    if(!BaseUnit::Init(node)){
        LOG_WARN << "Init BayesianFusioner error";
        return false;
    };
    NMSThreshold = node["NMSThreshold"].as<float>();
    std::string methodStr = node["matchMethod"].as<std::string>();
    auto method = MethodName2Type.find(methodStr);
    if(method == MethodName2Type.end()){
        LOG_WARN << "BayesianFusioner not supported match method " << methodStr;
        return false;
    }
    matchMethod = method->second;
    return true;
}


bool BayesianFusioner::Exec(){
    std::vector<ios::Det2DPtr> all_objects; 
    for(int i=0; i<inputNames.size(); i++){
        auto inputName = inputNames[i];
        auto ptr = std::dynamic_pointer_cast<ios::Detection2Ds>(interfaces::GetIOPtr(inputName));
        if (ptr == nullptr){
            LOG_ERROR << "BayesianFusioner input " << inputName << " haven't been init or doesn't exist.";
            return false;
        }
        for(auto d: ptr->detections){
            d->camera_id = i;
            all_objects.push_back(d);
        }
    }
    // nms
    std::vector<int> indices;
    std::vector<bool> deleted(all_objects.size(), false);
    for(int i=0; i<all_objects.size(); i++){
        indices.push_back(i);
    }
    std::sort(indices.begin(), indices.end(), 
        [&](int a, int b){ return (all_objects[a]->confidence - float(all_objects[a]->truncated)) > (all_objects[b]->confidence - float(all_objects[b]->truncated));}
    );

    auto fusioned_objects = std::make_shared<ios::Detection2Ds>(ios::Detection2Ds());
    interfaces::SetIOPtr(outputNames[0], fusioned_objects);
    for(size_t i=0; i < indices.size(); i++){
        if(!deleted[indices[i]]){
            std::vector<ios::Det2DPtr> matched_objs;
            auto main_obj = all_objects[indices[i]];
            matched_objs.push_back(main_obj);
            for(size_t j=i+1; j < indices.size(); j++){
                // only merge obj from different sensor
                if(!deleted[indices[j]] && (main_obj->camera_id != all_objects[indices[j]]->camera_id)
                && Measure(main_obj, all_objects[indices[j]]) > NMSThreshold){
                        deleted[indices[j]] = true;
                        matched_objs.push_back(all_objects[indices[j]]);
                }
            }
            /* merge score, useless
            float pos_sum = 0;
            float neg_sum = 0;
            for(auto& obj : matched_objs){
                // merge matched mask
                if(main_obj->mask_ptr == nullptr){
                    main_obj->mask_ptr = obj->mask_ptr;
                }
                pos_sum += std::log(obj->confidence);
                neg_sum += std::log(1 - obj->confidence);
            }
            pos_sum = std::exp(pos_sum);
            neg_sum = std::exp(neg_sum);
            main_obj->confidence = pos_sum / (pos_sum + neg_sum);
            */
            // if()
            fusioned_objects->detections.push_back(main_obj);
        }
    }
    return true;
}


float BayesianFusioner::Measure(ios::Det2DPtr obja, ios::Det2DPtr objb){
    if(obja->class_name != objb->class_name){
        return 0;
    }
    float insection_area = obja->insectionArea(objb);
    float measurement = 0;
    switch (matchMethod)
    {
    case BayesianFusionMatchMethod::IOU:{
        measurement = insection_area / (obja->area() + objb->area() - insection_area);
        break;
    }
    case BayesianFusionMatchMethod::IOA:{
        measurement = insection_area / (objb->truncated ? objb->area() : obja->area());
        break;
    }
    case BayesianFusionMatchMethod::GIOU:{
        float rect_area_sum = (obja->area() + objb->area() - insection_area);
        float out_area = obja->uniArea(objb);
        measurement = insection_area / rect_area_sum - (out_area - rect_area_sum) / out_area;
        break;
    }
    case BayesianFusionMatchMethod::CIOU:{
        float iou = insection_area / (obja->area() + objb->area() - insection_area);
        float a = std::atan((obja->x2 - obja->x1) / (obja->y2 - obja->y1)) - 
                  std::atan((objb->x2 - objb->x1) / (objb->y2 - objb->y1));
        float pi = 3.1415926;
        float v = (4.0 / (pi * pi)) * (a * a);
        float alpha = v / (1 - iou + v);
        float xc_a = obja->x1 + (obja->x2 - obja->x1) / 2;
        float yc_a = obja->y1 + (obja->y2 - obja->y1) / 2;
        float xc_b = objb->x1 + (objb->x2 - objb->x1) / 2;
        float yc_b = objb->y1 + (objb->y2 - objb->y1) / 2;
        float p2 = std::pow(xc_a - xc_b, 2) + std::pow(yc_a - yc_b, 2);
        float out_width = std::max(obja->x2, objb->x2) - std::min(obja->x1, objb->x1);
        float out_height = std::max(obja->y2, objb->y2) - std::min(obja->y1, objb->y1);
        float c2 = std::pow(out_width, 2) + std::pow(out_height, 2);
        measurement = iou - p2 / c2 - alpha * v;
        break;
    }
    default:
        break;
    }

    return measurement;
};

std::string BayesianFusioner::Name() {
    return "BayesianFusioner";
};


}  // namespace postprocess
}  // namespace deepinfer
}  // namespace waytous


