
#include "tasks/task_mots.h"

namespace waytous{
namespace deepinfer{
namespace task{


bool TaskMOTS::Init(std::string& taskConfigPath){
    if(!interfaces::BaseTask::Init(taskConfigPath)){
        LOG_ERROR << "Init task detect error";
        return false;
    };
    std::string segmentorName = taskNode["segmentorName"].as<std::string>();
    std::string segmentorConfigPath = taskNode["segmentorConfigPath"].as<std::string>();

    segmentor.reset(interfaces::BaseModelRegisterer::GetInstanceByName(segmentorName));
    if(!segmentor->Init(segmentorConfigPath)){
        LOG_ERROR << segmentorName << " segmentor init problem";
        return false;
    };
    return true;
}


bool TaskMOTS::Exec(std::vector<cv::Mat*> inputs, std::vector<interfaces::BaseIOPtr>& outputs){
    if(!segmentor->Exec(inputs, outputs)){
        LOG_ERROR << "Task Detect Exec error";
        return false;
    };
    return true;
}


void TaskMOTS::Visualize(cv::Mat* image, interfaces::BaseIOPtr outs){
    auto detections = std::dynamic_pointer_cast<ios::Detection2Ds>(outs)->detections;
    cv::Mat realInstanceMask, colorInstance;
    for(auto& obj: detections){
        cv::Scalar color = get_color(obj->class_label * 100 + obj->track_id);
        cv::putText(*image, std::to_string(obj->track_id) + ":" + common::formatValue(obj->confidence, 2), 
                    cv::Point(int(obj->x1), int(obj->y1) - 5), 
                    0, 0.6, cv::Scalar(0, 0, 255), 2, cv::LINE_AA);
        
        cv::rectangle(*image, cv::Rect(int(obj->x1), int(obj->y1), 
                    int(obj->x2 - obj->x1), int(obj->y2 - obj->y1)), color, 2);
        // LOG_INFO << obj->mask_ptr->width << ", " << obj->mask_ptr->height << ", " << obj->mask_ptr->rle_string;
        cv::Mat instance_mask = obj->mask_ptr->decode();
        instance_mask.convertTo(instance_mask, CV_32FC1);
        cv::resize(instance_mask, realInstanceMask, image->size());
        (*image).copyTo(colorInstance);
        for(int i=0; i<realInstanceMask.rows; i++){
            for(int j=0; j<realInstanceMask.cols; j++){
                if(realInstanceMask.at<float>(i, j) >= 0.5){
                    colorInstance.at<cv::Vec3b>(i, j) = cv::Vec3b(color[0], color[1], color[2]);
                }
            }
        }
        cv::addWeighted(*image, 0.5, colorInstance, 0.5, 0, *image);
    }
}


std::string TaskMOTS::Name(){
    return "TaskMOTS";
}


} //namespace task
} //namspace deepinfer
} //namespace waytous









