
#include "base/syncmem.h"


namespace waytous{
namespace deepinfer{
namespace base{


SyncedMemory::SyncedMemory(bool use_cuda)
    : cpu_ptr_(NULL),
      gpu_ptr_(NULL),
      size_(0),
      head_(UNINITIALIZED),
      own_cpu_data_(false),
      cpu_malloc_use_cuda_(use_cuda),
      own_gpu_data_(false),
      device_(-1) {
#ifdef PERCEPTION_DEBUG
  CUDA_CHECK(cudaGetDevice(&device_));
#endif
}

SyncedMemory::SyncedMemory(size_t size, bool use_cuda)
    : cpu_ptr_(NULL),
      gpu_ptr_(NULL),
      size_(size),
      head_(UNINITIALIZED),
      own_cpu_data_(false),
      cpu_malloc_use_cuda_(use_cuda),
      own_gpu_data_(false),
      device_(-1) {
#ifdef PERCEPTION_DEBUG
  CUDA_CHECK(cudaGetDevice(&device_));
#endif
}

SyncedMemory::~SyncedMemory() {
    check_device();
    if (cpu_ptr_ && own_cpu_data_) {
        PerceptionFreeHost(cpu_ptr_, cpu_malloc_use_cuda_);
    }

    if (gpu_ptr_ && own_gpu_data_) {
        CUDA_CHECK(cudaFree(gpu_ptr_));
    }
}

inline void SyncedMemory::to_cpu() {
    check_device();
    switch (head_) {
        case UNINITIALIZED:
            PerceptionMallocHost(&cpu_ptr_, size_, cpu_malloc_use_cuda_);
            if (cpu_ptr_ == nullptr) {
                LOG_ERROR << "cpu_ptr_ is null";
                return;
            }
            memset(cpu_ptr_, 0, size_);
            head_ = HEAD_AT_CPU;
            own_cpu_data_ = true;
            break;
        case HEAD_AT_GPU:

            if (cpu_ptr_ == nullptr) {
                PerceptionMallocHost(&cpu_ptr_, size_, cpu_malloc_use_cuda_);
                own_cpu_data_ = true;
            }
            CUDA_CHECK(cudaMemcpy(cpu_ptr_, gpu_ptr_, size_, cudaMemcpyDefault));
            head_ = SYNCED;
            break;
        case HEAD_AT_CPU:
        case SYNCED:
            break;
    }
}

inline void SyncedMemory::to_gpu() {
    check_device();
    switch (head_) {
        case UNINITIALIZED:
            CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
            CUDA_CHECK(cudaMemset(gpu_ptr_, 0, size_));
            head_ = HEAD_AT_GPU;
            own_gpu_data_ = true;
            // LOG_INFO << "gpu init malloc: " << size_;
            break;
        case HEAD_AT_CPU:
            if (gpu_ptr_ == nullptr) {
                CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
                own_gpu_data_ = true;
                // LOG_INFO << "gpu copy malloc: " << size_;
            }
            CUDA_CHECK(cudaMemcpy(gpu_ptr_, cpu_ptr_, size_, cudaMemcpyDefault));
            // LOG_INFO << "gpu copy done";
            head_ = SYNCED;
            break;
        case HEAD_AT_GPU:
        case SYNCED:
            break;
    }
}

const void* SyncedMemory::cpu_data() {
    check_device();
    to_cpu();
    return (const void*)cpu_ptr_;
}

void SyncedMemory::set_cpu_data(void* data) {
    check_device();
    CHECK(data);
    if (own_cpu_data_) {
        PerceptionFreeHost(cpu_ptr_, cpu_malloc_use_cuda_);
    }
    cpu_ptr_ = data;
    head_ = HEAD_AT_CPU;
    own_cpu_data_ = false;
}

const void* SyncedMemory::gpu_data() {
    check_device();
    to_gpu();
    return (const void*)gpu_ptr_;
}

void SyncedMemory::set_gpu_data(void* data) {
    check_device();
    CHECK(data);
    if (own_gpu_data_) {
        CUDA_CHECK(cudaFree(gpu_ptr_));
    }
    gpu_ptr_ = data;
    head_ = HEAD_AT_GPU;
    own_gpu_data_ = false;
}

void* SyncedMemory::mutable_cpu_data() {
    check_device();
    to_cpu();
    head_ = HEAD_AT_CPU;
    return cpu_ptr_;
}

void* SyncedMemory::mutable_gpu_data() {
    check_device();
    to_gpu();
    head_ = HEAD_AT_GPU;
    return gpu_ptr_;
}

void SyncedMemory::async_gpu_push(const cudaStream_t& stream) {
    check_device();
    CHECK_EQ(head_, HEAD_AT_CPU);
    if (gpu_ptr_ == nullptr) {
        CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
        own_gpu_data_ = true;
        LOG_INFO << "gpu async malloc: " << size_;
    }
    const cudaMemcpyKind put = cudaMemcpyHostToDevice;
    CUDA_CHECK(cudaMemcpyAsync(gpu_ptr_, cpu_ptr_, size_, put, stream));
    // Assume caller will synchronize on the stream before use
    head_ = SYNCED;
}

void SyncedMemory::check_device() {
#ifdef PERCEPTION_DEBUG
    int device;
    cudaGetDevice(&device);
    CHECK_EQ(device, device_);
    if (gpu_ptr_ && own_gpu_data_) {
        cudaPointerAttributes attributes;
        CUDA_CHECK(cudaPointerGetAttributes(&attributes, gpu_ptr_));
        CHECK_EQ(attributes.device, device_);
    }
#endif
}




} //namespace base
} //namspace deepinfer
} //namespace waytous
