#include "libs/inferences/tensorrt/trt_calibrator.h"


namespace waytous {
namespace deepinfer {
namespace inference {


void resizeImageCPU(cv::Mat& img, float* hostImage, int inputWidth, int inputHeight,
    const float* inputMean, const float* inputStd, bool fixAspect, bool bgr)
{   
    cv::Mat resized;
    if (fixAspect){
        float scale = cv::min(float(inputWidth) / img.cols, float(inputHeight) / img.rows);
        auto scaleSize = cv::Size(img.cols * scale, img.rows * scale);

        cv::resize(img, resized, scaleSize);
        cv::Mat cropped = cv::Mat::zeros(inputHeight, inputWidth, CV_8UC3);
        cv::Rect rect(0, 0, scaleSize.width, scaleSize.height);
        resized.copyTo(cropped(rect));
        resized = cropped;
    }else{
        cv::resize(img, resized, cv::Size(inputWidth, inputHeight));
    }

    for(int i=0; i < inputWidth * inputHeight; i++){
        hostImage[i + inputWidth * inputHeight] = (resized.at<cv::Vec3b>(i)[1] / 255. - inputMean[1]) / inputStd[1];
        if(bgr){
            hostImage[i] = (resized.at<cv::Vec3b>(i)[0] / 255. - inputMean[0]) / inputStd[0] ;
            hostImage[i + 2 * inputWidth * inputHeight] = (resized.at<cv::Vec3b>(i)[2] / 255. - inputMean[2]) / inputStd[2];
        }else{
            hostImage[i] = (resized.at<cv::Vec3b>(i)[2] / 255. - inputMean[2]) / inputStd[2] ;
            hostImage[i + 2 * inputWidth * inputHeight] = (resized.at<cv::Vec3b>(i)[0] / 255. - inputMean[0]) / inputStd[0];
        }
    }
}

int8EntroyCalibrator::int8EntroyCalibrator(const std::string &imgPath_, const std::string &calibTablePath_, 
    int batchSize_, int inputWidth_, int inputHeight_, std::vector<float>& inputMean_, std::vector<float>& inputStd_, bool useBGR_, bool fixAspectRatio_):
    batchSize(batchSize_), inputWidth(inputWidth_), inputHeight(inputHeight_), imageIndex(0), useBGR(useBGR_), fixAspectRatio(fixAspectRatio_)
    {   
        calibTablePath = common::GetAbsolutePath(common::ConfigRoot::GetRootPath(), calibTablePath_);
        for(int ii =0; ii<3; ii++){
            inputMean[ii] = inputMean_[ii];
            inputStd[ii] = inputStd_[ii];
        }
        inputCount =  3 * inputWidth * inputHeight;
        std::fstream f(common::GetAbsolutePath(common::ConfigRoot::GetRootPath(), imgPath_));
        if(f.is_open()){
            std::string temp;
            while (std::getline(f,temp)) imgPaths.push_back(common::GetAbsolutePath(common::ConfigRoot::GetRootPath(), temp));
        }
        batchData = new float[batchSize * inputCount];
        CUDA_CHECK(cudaMalloc(&deviceInput, batchSize * inputCount * sizeof(float)));
    };

int8EntroyCalibrator::~int8EntroyCalibrator(){
    CUDA_CHECK(cudaFree(deviceInput));
    if(batchData)
        delete[] batchData;
};

int int8EntroyCalibrator::getBatchSize() const { return batchSize; };

bool int8EntroyCalibrator::getBatch(void *bindings[], const char *names[], int nbBindings){
    if (imageIndex + batchSize > int(imgPaths.size()))
        return false;
    // load batch
    float* ptr = batchData;
    for (size_t j = imageIndex; j < imageIndex + batchSize; ++j)
    {
        auto img = cv::imread(imgPaths[j]);
        LOG_INFO << "load image " << imgPaths[j] << "  " << (j+1)*100./imgPaths.size() << "%";
        resizeImageCPU(img, ptr, inputWidth, inputHeight, inputMean, inputStd, fixAspectRatio, useBGR);
        ptr += inputCount;
    }
    imageIndex += batchSize;
    CUDA_CHECK(cudaMemcpy(deviceInput, batchData, batchSize * inputCount * sizeof(float), cudaMemcpyHostToDevice));
    bindings[0] = deviceInput;
    return true;
};

const void * int8EntroyCalibrator::readCalibrationCache(std::size_t &length){
    calibrationCache.clear();
    std::ifstream input(calibTablePath, std::ios::binary);
    input >> std::noskipws;
    if (readCache && input.good())
        std::copy(std::istream_iterator<char>(input), std::istream_iterator<char>(),
                std::back_inserter(calibrationCache));

    length = calibrationCache.size();
    return length ? &calibrationCache[0] : nullptr;
};

void int8EntroyCalibrator::writeCalibrationCache(const void *cache, std::size_t length){
    std::ofstream output(calibTablePath, std::ios::binary);
    output.write(reinterpret_cast<const char*>(cache), length);
};


}  // namespace inference
}  // namespace deepinfer
}  // namespace waytous

