
#include "tasks/task_dms.h"

namespace waytous{
namespace deepinfer{
namespace task{


bool TaskDMS::Init(std::string& taskConfigPath){
    if(!interfaces::BaseTask::Init(taskConfigPath)){
        LOG_ERROR << "Init task detect error";
        return false;
    };
    std::string faceDetectorName = taskNode["faceDetectorName"].as<std::string>();
    std::string faceDetectorConfigPath = taskNode["faceDetectorConfigPath"].as<std::string>();
    std::string landmarkDetectorName = taskNode["landmarkDetectorName"].as<std::string>();
    std::string landmarkDetectorConfigPath = taskNode["landmarkDetectorConfigPath"].as<std::string>();
    std::string headposeDetectorName = taskNode["headposeDetectorName"].as<std::string>();
    std::string headposeDetectorConfigPath = taskNode["headposeDetectorConfigPath"].as<std::string>();

    mouthARThreshold = taskNode["mouthARThreshold"].as<float>();
    eyeARThreshold = taskNode["eyeARThreshold"].as<float>();
    eyeARConsecFrames = taskNode["eyeARConsecFrames"].as<int>();
    distractionConsecFrames = taskNode["distractionConsecFrames"].as<int>();
    unmaskedConsecFrames = taskNode["unmaskedConsecFrames"].as<int>();
    nodriverConsecFrames = taskNode["nodriverConsecFrames"].as<int>();
    sightYawThresholdLeft = taskNode["sightYawThresholdLeft"].as<int>();
    sightYawThresholdRight = taskNode["sightYawThresholdRight"].as<int>();
    sightPitchThresholdLeft = taskNode["sightPitchThresholdLeft"].as<int>();
    sightPitchThresholdRight = taskNode["sightPitchThresholdRight"].as<int>();

    faceDetector.reset(interfaces::BaseModelRegisterer::GetInstanceByName(faceDetectorName));
    if(!faceDetector->Init(faceDetectorConfigPath)){
        LOG_ERROR << faceDetectorName << " faceDetector init problem";
        return false;
    };
    landmarkDetector.reset(interfaces::BaseModelRegisterer::GetInstanceByName(landmarkDetectorName));
    if(!landmarkDetector->Init(landmarkDetectorConfigPath)){
        LOG_ERROR << landmarkDetectorName << " face landmarks detector init problem";
        return false;
    };
    headposeDetector.reset(interfaces::BaseModelRegisterer::GetInstanceByName(headposeDetectorName));
    if(!headposeDetector->Init(headposeDetectorConfigPath)){
        LOG_ERROR << headposeDetectorName << " face headpose detector init problem";
        return false;
    };

    res = std::make_shared<DMSResult>(DMSResult());
    return true;
}


bool TaskDMS::Exec(std::vector<cv::Mat*> inputs, std::vector<interfaces::BaseIOPtr>& outputs){
    if(inputs.size() != 1){
        LOG_ERROR << 'Now only support infer one image once.';
        return false;
    }
    std::vector<interfaces::BaseIOPtr> faceDets;
    if(!faceDetector->Exec(inputs, faceDets)){
        LOG_ERROR << "Task DMS face detector Exec error";
        return false;
    };
    outputs.push_back(res);
    
    auto faces = std::dynamic_pointer_cast<ios::Detection2Ds>(faceDets[0])->detections;
    // driver
    if(faces.size() < 1){
        NODRIVER_COUNTER++;
        if(NODRIVER_COUNTER >= nodriverConsecFrames){
            res->msg = DMSMsg::DMS_NODRIVER;
            LOG_INFO << "No driver!";
        }else
            res->msg = DMSMsg::DMS_NONE;
        return true;
    }
    NODRIVER_COUNTER = 0; // reinit nodriver counter
    auto faceObj = faces[0]; // only get the best face to analyse
    res->faceBBox = faceObj;

    // mask
    if(faceObj->class_label == 1){
        UNMASKED_COUNTER = 0;
    }else if (faceObj->class_label == 0){// driver didnot wear mask
        UNMASKED_COUNTER++;
        if(UNMASKED_COUNTER >= unmaskedConsecFrames){
            res->msg = DMSMsg::DMS_UNMASK; // TODO
        }
    }

    // scale box 1.2, to square
    int image_height = inputs[0]->rows;
    int image_width = inputs[0]->cols;
    int boxWidth = faceObj->x2 - faceObj->x1 + 1;
    int boxHeight = faceObj->y2 - faceObj->y1 + 1;
    int boxSquareSize = (int)std::min(boxWidth * 1.2, boxHeight * 1.2);
    int boxSquareX1 = faceObj->x1 + boxWidth / 2 - boxSquareSize / 2;
    int boxSquareX2 = boxSquareX1 + boxSquareSize;
    int boxSquareY1 = faceObj->y1 + boxHeight / 2 - boxSquareSize / 2;
    int boxSquareY2 = boxSquareY1 + boxSquareSize;
    // clip
    boxSquareX1 = std::max(0, boxSquareX1);
    boxSquareY1 = std::max(0, boxSquareY1);
    boxSquareX2 = std::min(image_width, boxSquareX2);
    boxSquareY2 = std::min(image_height, boxSquareY2);

    auto faceROI = cv::Rect2i(boxSquareX1, boxSquareY1, boxSquareX2 - boxSquareX1, boxSquareY2 - boxSquareY1);
    res->faceSquare = faceROI;

    // set landmarks and whenet input src
    std::vector<cv::Mat*> subInputs;
    auto faceImage = (*inputs[0])(faceROI);
    subInputs.push_back(&faceImage);

    // face landmarks
    std::vector<interfaces::BaseIOPtr> landmarks_vec;
    if(!landmarkDetector->Exec(subInputs, landmarks_vec)){
        LOG_ERROR << "face landmarks detect doesnot end normally!";
        return false;
    };
    res->face_landmarks = std::dynamic_pointer_cast<ios::Landmarks>(landmarks_vec[0]);
    for(auto& m: res->face_landmarks->landmarks){
        m->x += res->faceSquare.x;
        m->y += res->faceSquare.y;
    }
    // eyes
    eyeAR = (accOrganAspectRatio(eyeLandmarkIndices[0]) + accOrganAspectRatio(eyeLandmarkIndices[1])) / 2.0;
    if(eyeAR < eyeARThreshold){
        EYE_CLOSE_COUNTER++;
        if(EYE_CLOSE_COUNTER >= eyeARConsecFrames){
            res->msg = DMSMsg::DMS_NOD_EYE;
        }
    }else{
        EYE_CLOSE_COUNTER = 0;
    }
    // mouth
    mouthAR = accOrganAspectRatio(mouthLandmarkIndices);
    if(mouthAR > mouthARThreshold){
        res->msg = DMSMsg::DMS_NOD_MOUTH;
    }

    // headpose
    std::vector<interfaces::BaseIOPtr> headpose_vec;
    if(!headposeDetector->Exec(subInputs, headpose_vec)){
        LOG_ERROR << "face headpose detect doesnot end normally!";
        return false;
    };
    res->headPose = std::dynamic_pointer_cast<ios::HeadPose>(headpose_vec[0]);

    // sight
    if(res->headPose->yaw < sightYawThresholdLeft | res->headPose->yaw > sightYawThresholdRight |
       res->headPose->pitch < sightPitchThresholdLeft | res->headPose->pitch > sightPitchThresholdRight){
        DISTRACTION_COUNTER++;
        if(DISTRACTION_COUNTER >= distractionConsecFrames){
            res->msg = DMSMsg::DMS_INATTENTION;
        }
    }else{
        DISTRACTION_COUNTER = 0;
    }

    return true;
}


void TaskDMS::Visualize(cv::Mat* image, interfaces::BaseIOPtr outs){
    switch(res->msg){
        case DMSMsg::DMS_NODRIVER:
            cv::putText(*image, "NODRIVER!", cv::Point(0, 20), cv::FONT_HERSHEY_SIMPLEX, 0.4, cv::Scalar(255, 0, 0), 1, cv::LINE_AA);
            break;
        case DMSMsg::DMS_UNMASK:
            cv::putText(*image, "UNMASKED!", cv::Point(0, 40), cv::FONT_HERSHEY_SIMPLEX, 0.4, cv::Scalar(255, 0, 0), 1, cv::LINE_AA);
            break;
        case DMSMsg::DMS_NOD_EYE:
            cv::putText(*image, "WARNING: DROWSINESS ALERT BY EYE!", cv::Point(10, 60), cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 255), 1);
            break;
        case DMSMsg::DMS_NOD_MOUTH:
            cv::putText(*image, "WARNING: DROWSINESS ALERT BY MOUTH!", cv::Point(10, 80), cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 255), 1);
            break;
        case DMSMsg::DMS_INATTENTION:
            cv::putText(*image, "WARNING: DISTRACTION ALERT!", cv::Point(10, 100), cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 255, 255), 1);
            break;
    }

    if(NODRIVER_COUNTER == 0){
        cv::rectangle(*image, res->faceSquare, cv::Scalar(0, 0, 255), 2);
        cv::putText(*image, res->faceBBox->class_name + ":" + common::formatValue(res->faceBBox->confidence, 2), 
                    cv::Point(res->faceSquare.x, res->faceSquare.y), cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(255, 0, 255), 1);
        cv::putText(*image, "eyeCount:" + std::to_string(EYE_CLOSE_COUNTER) + "/" + std::to_string(eyeARConsecFrames), cv::Point(10, 120), 
                    cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(255, 255, 0), 1);
        cv::putText(*image, "distrCount:" + std::to_string(DISTRACTION_COUNTER) + "/" + std::to_string(distractionConsecFrames), cv::Point(10, 140), 
                    cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(255, 255, 0), 1);
        cv::putText(*image, "unmaskCount:" + std::to_string(UNMASKED_COUNTER) + "/" + std::to_string(unmaskedConsecFrames), cv::Point(10, 160), 
                    cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(255, 255, 0), 1);
        cv::putText(*image, "nodriverCount:" + std::to_string(NODRIVER_COUNTER) + "/" + std::to_string(nodriverConsecFrames), cv::Point(10, 180), 
                    cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(255, 255, 0), 1);

        cv::putText(*image, "EAR: " + common::formatValue(eyeAR, 2), cv::Point(400, 20), cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 255), 1);
        cv::putText(*image, "MAR: " + common::formatValue(mouthAR, 2), cv::Point(400, 40), cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 255), 1);
        cv::putText(*image, "yaw: " + common::formatValue(res->headPose->yaw, 2), cv::Point(res->faceSquare.x, res->faceSquare.y + res->faceSquare.height), 
                    cv::FONT_HERSHEY_SIMPLEX, 0.4, cv::Scalar(100, 255, 0), 1);
        cv::putText(*image, "pitch: " + common::formatValue(res->headPose->pitch, 2), cv::Point(res->faceSquare.x, res->faceSquare.y + res->faceSquare.height - 15), 
                    cv::FONT_HERSHEY_SIMPLEX, 0.4, cv::Scalar(100, 255, 0), 1);
        cv::putText(*image, "roll: " + common::formatValue(res->headPose->roll, 2), cv::Point(res->faceSquare.x, res->faceSquare.y + res->faceSquare.height - 30), 
                    cv::FONT_HERSHEY_SIMPLEX, 0.4, cv::Scalar(100, 255, 0), 1);

        for(auto& point: res->face_landmarks->landmarks){
            cv::circle(*image, cv::Point(point->x, point->y), 2, cv::Scalar(0, 255, 0), -1);
        }

        drawAxis(*image, res->headPose->pitch, res->headPose->yaw, res->headPose->roll, 
            res->faceSquare.x + res->faceSquare.width / 2, res->faceSquare.y + res->faceSquare.height / 2,
            res->faceSquare.width / 2);
    }
}


float TaskDMS::accOrganAspectRatio(int* indices){
    float A = common::euclidean(res->face_landmarks->landmarks[indices[1]]->x, res->face_landmarks->landmarks[indices[1]]->y, 
                                res->face_landmarks->landmarks[indices[5]]->x, res->face_landmarks->landmarks[indices[5]]->y);
    float B = common::euclidean(res->face_landmarks->landmarks[indices[2]]->x, res->face_landmarks->landmarks[indices[2]]->y, 
                                res->face_landmarks->landmarks[indices[4]]->x, res->face_landmarks->landmarks[indices[4]]->y);
    float C = common::euclidean(res->face_landmarks->landmarks[indices[0]]->x, res->face_landmarks->landmarks[indices[0]]->y, 
                                res->face_landmarks->landmarks[indices[3]]->x, res->face_landmarks->landmarks[indices[3]]->y);
    return (A + B) / (2.0 * C);

}


void TaskDMS::resetCounter(){
    EYE_CLOSE_COUNTER = 0;
    DISTRACTION_COUNTER = 0;
    UNMASKED_COUNTER = 0;
    NODRIVER_COUNTER = 0;
}


void TaskDMS::drawAxis(cv::Mat& img, float& pitch_, float& yaw_, float& roll_, int cx, int cy, int size){
    float pitch = pitch_ * PI / 180.;
    float yaw = -yaw_ * PI / 180.;
    float roll = roll_ * PI / 180.;

    // X-Axis pointing to right. drawn in red
    float x1 = size * (std::cos(yaw) * std::cos(roll)) + cx;
    float y1 = size * (std::cos(pitch) * std::sin(roll) + std::cos(roll) * std::sin(pitch) * std::sin(yaw)) + cy;

    // Y-Axis | drawn in green
    //        v
    float x2 = size * (-std::cos(yaw) * std::sin(roll)) + cx;
    float y2 = size * (std::cos(pitch) * std::cos(roll) - std::sin(pitch) * std::sin(yaw) * std::sin(roll)) + cy;

    // Z-Axis (out of the screen) drawn in blue
    float x3 = size * (std::sin(yaw)) + cx;
    float y3 = size * (-std::cos(yaw) * std::sin(pitch)) + cy;

    cv::line(img, cv::Point(cx, cy), cv::Point(int(x1), int(y1)), cv::Scalar(0,0,255), 2);
    cv::line(img, cv::Point(cx, cy), cv::Point(int(x2), int(y2)), cv::Scalar(0,255,0), 2);
    cv::line(img, cv::Point(cx, cy), cv::Point(int(x3), int(y3)), cv::Scalar(255,0,0), 2);

}


std::string TaskDMS::Name(){
    return "TaskDMS";
}


} //namespace task
} //namspace deepinfer
} //namespace waytous









