#ifndef DEEPINFER_BASE_BLOB_H_
#define DEEPINFER_BASE_BLOB_H_

#include <memory>
#include <limits>
#include "common/common.h"
#include "base/syncmem.h"

namespace waytous{
namespace deepinfer{
namespace base{

constexpr size_t kMaxBlobAxes = 32;

template <typename Dtype>
class Blob {
public:
    Blob() : data_(), count_(0), capacity_(0), use_cuda_host_malloc_(false) {}
    explicit Blob(bool use_cuda_host_malloc)
        : data_(),
        count_(0),
        capacity_(0),
        use_cuda_host_malloc_(use_cuda_host_malloc) {}
    explicit Blob(const std::vector<int>& shape,
                  const bool use_cuda_host_malloc = false): 
        capacity_(0), use_cuda_host_malloc_(use_cuda_host_malloc) {
        Reshape(shape);
    };

    explicit Blob(const std::vector<int>& shape,
                  Dtype* cpu_d_,
                  const bool use_cuda_host_malloc = false
                  ): 
        capacity_(0), use_cuda_host_malloc_(use_cuda_host_malloc) {
        Reshape(shape);
        auto d_ = data_->mutable_cpu_data();
        memcpy(d_, cpu_d_, data_->size());
    };

    Blob(const Blob&) = delete;
    void operator=(const Blob&) = delete;

    void Reshape(const std::vector<int>& shape){
        CHECK_LE(shape.size(), kMaxBlobAxes);
        count_ = 1;
        shape_.resize(shape.size());
        if (!shape_data_ || shape_data_->size() < shape.size() * sizeof(int)) {
            shape_data_.reset(
                new SyncedMemory(shape.size() * sizeof(int), use_cuda_host_malloc_));
        }
        int* shape_data = static_cast<int*>(shape_data_->mutable_cpu_data());
        for (size_t i = 0; i < shape.size(); ++i) {
            CHECK_GE(shape[i], 0);
            if (count_ != 0) {
            CHECK_LE(shape[i], std::numeric_limits<int>::max() / count_)
                << "blob size exceeds std::numeric_limits<int>::max()";
            }
            count_ *= shape[i];
            shape_[i] = shape[i];
            shape_data[i] = shape[i];
        }
        if (count_ > capacity_) {
            capacity_ = count_;
            data_.reset(
                new SyncedMemory(capacity_ * sizeof(Dtype), use_cuda_host_malloc_));
        }
    };

    void ReshapeLike(const Blob& other){
        Reshape(other.shape());
    };

    inline const std::vector<int>& shape() const { return shape_; }
    inline int shape(int index) const {
        return shape_[CanonicalAxisIndex(index)];
    }
    inline int CanonicalAxisIndex(int axis_index) const {
        CHECK_GE(axis_index, -num_axes())
            << "axis " << axis_index << " out of range for " << num_axes()
            << "-D Blob with shape ";
        CHECK_LT(axis_index, num_axes())
            << "axis " << axis_index << " out of range for " << num_axes()
            << "-D Blob with shape ";
        if (axis_index < 0) {
        return axis_index + num_axes();
        }
        return axis_index;
    }
    inline int num_axes() const { return static_cast<int>(shape_.size()); }
    inline int count() const { return count_; }
    inline const std::shared_ptr<SyncedMemory>& data() const {
        CHECK(data_);
        return data_;
    }
    /// @brief Deprecated legacy shape accessor num: use shape(0) instead.
    inline int num() const { return LegacyShape(0); }
    /// @brief Deprecated legacy shape accessor channels: use shape(1) instead.
    inline int channels() const { return LegacyShape(1); }
    /// @brief Deprecated legacy shape accessor height: use shape(2) instead.
    inline int height() const { return LegacyShape(2); }
    /// @brief Deprecated legacy shape accessor width: use shape(3) instead.
    inline int width() const { return LegacyShape(3); }
    inline int LegacyShape(int index) const {
        CHECK_LE(num_axes(), 4)
            << "Cannot use legacy accessors on Blobs with > 4 axes.";
        CHECK_LT(index, 4);
        CHECK_GE(index, -4);
        if (index >= num_axes() || index < -num_axes()) {
        // Axis is out of range, but still in [0, 3] (or [-4, -1] for reverse
        // indexing) -- this special case simulates the one-padding used to fill
        // extraneous axes of legacy blobs.
        return 1;
        }
        return shape(index);
    }
    inline int offset(const int n, const int c = 0, const int h = 0,
                    const int w = 0) const {
        CHECK_GE(n, 0);
        CHECK_LE(n, num());
        CHECK_GE(channels(), 0);
        CHECK_LE(c, channels());
        CHECK_GE(height(), 0);
        CHECK_LE(h, height());
        CHECK_GE(width(), 0);
        CHECK_LE(w, width());
        return ((n * channels() + c) * height() + h) * width() + w;
    }
    inline int offset(const std::vector<int>& indices) const {
        CHECK_LE(indices.size(), static_cast<size_t>(num_axes()));
        int offset = 0;
        for (int i = 0; i < num_axes(); ++i) {
            offset *= shape(i);
            if (static_cast<int>(indices.size()) > i) {
                CHECK_GE(indices[i], 0);
                CHECK_LT(indices[i], shape(i));
                offset += indices[i];
            }
        }
        return offset;
    }

    const Dtype* cpu_data() const {
        CHECK(data_);
        return (const Dtype*)data_->cpu_data();
    };

    void set_cpu_data(Dtype* data){
        CHECK(data);
        // Make sure CPU and GPU sizes remain equal
        size_t size = count_ * sizeof(Dtype);
        if (data_->size() != size) {
            data_.reset(new SyncedMemory(size, use_cuda_host_malloc_));
        }
        data_->set_cpu_data(data);
    };

    const int* gpu_shape() const{
        CHECK(shape_data_);
        return (const int*)shape_data_->gpu_data();
    };

    const Dtype* gpu_data() const {
        CHECK(data_);
        return (const Dtype*)data_->gpu_data();
    };
    
    void set_gpu_data(Dtype* data){
        CHECK(data);
        // Make sure CPU and GPU sizes remain equal
        size_t size = count_ * sizeof(Dtype);
        if (data_->size() != size) {
            data_.reset(new SyncedMemory(size, use_cuda_host_malloc_));
        }
        data_->set_gpu_data(data);
    };

    Dtype* mutable_cpu_data(){
        CHECK(data_);
        return static_cast<Dtype*>(data_->mutable_cpu_data());
    };

    Dtype* mutable_gpu_data(){
        CHECK(data_);
        return static_cast<Dtype*>(data_->mutable_gpu_data());
    };

    void set_head_gpu() { data_->set_head_gpu(); }
    void set_head_cpu() { data_->set_head_cpu(); }
    SyncedMemory::SyncedHead head() const { return data_->head(); }

    // void CopyFrom(const Blob<Dtype>& source, bool reshape = false);
    void ShareData(const Blob& other){
        CHECK_EQ(count_, other.count());
        data_ = other.data();
    };

protected:
    std::shared_ptr<SyncedMemory> data_;
    std::shared_ptr<SyncedMemory> shape_data_;
    std::vector<int> shape_;
    int count_;
    int capacity_;
    bool use_cuda_host_malloc_;
};

template <typename Dtype>
using BlobPtr = std::shared_ptr<Blob<Dtype>>;
template <typename Dtype>
using BlobConstPtr = std::shared_ptr<const Blob<Dtype>>;


} //namespace base
} //namspace deepinfer
} //namespace waytous

#endif


