
#include "libs/trackers/byte_tracker.h"

namespace waytous{
namespace deepinfer{
namespace tracker{



bool ByteTracker::Init(YAML::Node& node, YAML::Node& globalParamNode)
{
	if(!BaseUnit::Init(node, globalParamNode)){
        LOG_WARN << "Init resize_norm error";
        return false;
    };
	int frame_rate = node["frame_rate"].as<int>();
	int track_buffer = node["track_buffer"].as<int>();
	track_thresh = node["track_thresh"].as<float>();
	high_thresh = node["high_thresh"].as<float>();
	match_thresh = node["match_thresh"].as<float>();
	init_frames = node["init_frames"].as<int>();

	frame_id = 0;
	max_time_lost = int(frame_rate / 30.0 * track_buffer);
	LOG_INFO << "Init ByteTracker!";
	return true;
}


bool ByteTracker::Exec()
{
	if(interfaces::GetIOPtr(inputNames[0]) == nullptr){
		LOG_ERROR << "ByteTracker input" << inputNames[0] << " haven't init";
		return false;
	}
	auto input = std::dynamic_pointer_cast<ios::Detection2Ds>(interfaces::GetIOPtr(inputNames[0]));
	auto objects = input->detections;

	////////////////// Step 1: Get detections //////////////////
	this->frame_id++;
	std::vector<Track> activated_tracks;
	std::vector<Track> refind_tracks;
	std::vector<Track> removed_tracks;
	std::vector<Track> lost_tracks;
	std::vector<Track> detections;
	std::vector<Track> detections_low;

	std::vector<Track> detections_cp;
	std::vector<Track> tracked_tracks_swap;
	std::vector<Track> resa, resb;
	// std::vector<Track> output_tracks;

	std::vector<Track*> unconfirmed;
	std::vector<Track*> tracked_tracks;
	std::vector<Track*> track_pool;
	std::vector<Track*> r_tracked_tracks;

	if (objects.size() > 0)
	{
		for (int i = 0; i < objects.size(); i++)
		{
			Track track(objects[i]);
			if (objects[i]->confidence >= track_thresh)
			{
				detections.push_back(track);
			}else
			{
				detections_low.push_back(track);
			}
			
		}
	}

	// Add newly detected tracklets to tracked_tracks
	for (int i = 0; i < this->tracked_tracks.size(); i++)
	{
		if (!this->tracked_tracks[i].is_activated)
			unconfirmed.push_back(&this->tracked_tracks[i]);
		else
			tracked_tracks.push_back(&this->tracked_tracks[i]);
	}

	////////////////// Step 2: First association, with IoU //////////////////
	track_pool = joint_tracks(tracked_tracks, this->lost_tracks);
	Track::multi_predict(track_pool, this->kalman_filter);

	std::vector<std::vector<float> > dists;
	int dist_size = 0, dist_size_size = 0;
	dists = iou_distance(track_pool, detections, dist_size, dist_size_size);

	std::vector<std::vector<int> > matches;
	std::vector<int> u_track, u_detection;
	linear_assignment(dists, dist_size, dist_size_size, match_thresh, matches, u_track, u_detection);

	for (int i = 0; i < matches.size(); i++)
	{
		Track *track = track_pool[matches[i][0]];
		Track *det = &detections[matches[i][1]];
		if (track->state == TrackState::Tracked)
		{
			track->update(*det, this->frame_id);
			activated_tracks.push_back(*track);
		}
		else
		{
			track->re_activate(*det, this->frame_id, false);
			refind_tracks.push_back(*track);
		}
	}

	////////////////// Step 3: Second association, using low score dets //////////////////
	for (int i = 0; i < u_detection.size(); i++)
	{
		detections_cp.push_back(detections[u_detection[i]]);
	}
	detections.clear();
	detections.assign(detections_low.begin(), detections_low.end());
	
	for (int i = 0; i < u_track.size(); i++)
	{
		if (track_pool[u_track[i]]->state == TrackState::Tracked)
		{
			r_tracked_tracks.push_back(track_pool[u_track[i]]);
		}
	}

	dists.clear();
	dists = iou_distance(r_tracked_tracks, detections, dist_size, dist_size_size);

	matches.clear();
	u_track.clear();
	u_detection.clear();
	linear_assignment(dists, dist_size, dist_size_size, 0.5, matches, u_track, u_detection);

	for (int i = 0; i < matches.size(); i++)
	{
		Track *track = r_tracked_tracks[matches[i][0]];
		Track *det = &detections[matches[i][1]];
		if (track->state == TrackState::Tracked)
		{
			track->update(*det, this->frame_id);
			activated_tracks.push_back(*track);
		}
		else
		{
			track->re_activate(*det, this->frame_id, false);
			refind_tracks.push_back(*track);
		}
	}

	for (int i = 0; i < u_track.size(); i++)
	{
		Track *track = r_tracked_tracks[u_track[i]];
		if (track->state != TrackState::Lost)
		{
			track->mark_lost();
			lost_tracks.push_back(*track);
		}
	}

	// Deal with unconfirmed tracks, usually tracks with only one beginning frame
	detections.clear();
	detections.assign(detections_cp.begin(), detections_cp.end());

	dists.clear();
	dists = iou_distance(unconfirmed, detections, dist_size, dist_size_size);

	matches.clear();
	std::vector<int> u_unconfirmed;
	u_detection.clear();
	linear_assignment(dists, dist_size, dist_size_size, 0.7, matches, u_unconfirmed, u_detection);

	for (int i = 0; i < matches.size(); i++)
	{
		unconfirmed[matches[i][0]]->update(detections[matches[i][1]], this->frame_id);
		activated_tracks.push_back(*unconfirmed[matches[i][0]]);
	}

	for (int i = 0; i < u_unconfirmed.size(); i++)
	{
		Track *track = unconfirmed[u_unconfirmed[i]];
		track->mark_removed();
		removed_tracks.push_back(*track);
	}

	////////////////// Step 4: Init new tracks //////////////////
	for (int i = 0; i < u_detection.size(); i++)
	{
		Track *track = &detections[u_detection[i]];
		if (track->obj_->confidence < this->high_thresh)
			continue;
		track->activate(this->kalman_filter, this->frame_id);
		activated_tracks.push_back(*track);
	}

	////////////////// Step 5: Update state //////////////////
	for (int i = 0; i < this->lost_tracks.size(); i++)
	{
		if (this->frame_id - this->lost_tracks[i].end_frame() > this->max_time_lost)
		{
			this->lost_tracks[i].mark_removed();
			removed_tracks.push_back(this->lost_tracks[i]);
		}
	}
	
	for (int i = 0; i < this->tracked_tracks.size(); i++)
	{
		if (this->tracked_tracks[i].state == TrackState::Tracked)
		{
			tracked_tracks_swap.push_back(this->tracked_tracks[i]);
		}
	}
	this->tracked_tracks.clear();
	this->tracked_tracks.assign(tracked_tracks_swap.begin(), tracked_tracks_swap.end());

	this->tracked_tracks = joint_tracks(this->tracked_tracks, activated_tracks);
	this->tracked_tracks = joint_tracks(this->tracked_tracks, refind_tracks);

	//std::cout << activated_tracks.size() << std::endl;

	this->lost_tracks = sub_tracks(this->lost_tracks, this->tracked_tracks);
	for (int i = 0; i < lost_tracks.size(); i++)
	{
		this->lost_tracks.push_back(lost_tracks[i]);
	}

	this->lost_tracks = sub_tracks(this->lost_tracks, this->removed_tracks);
	for (int i = 0; i < removed_tracks.size(); i++)
	{
		this->removed_tracks.push_back(removed_tracks[i]);
	}
	
	remove_duplicate_tracks(resa, resb, this->tracked_tracks, this->lost_tracks);

	this->tracked_tracks.clear();
	this->tracked_tracks.assign(resa.begin(), resa.end());
	this->lost_tracks.clear();
	this->lost_tracks.assign(resb.begin(), resb.end());
	
	auto tracked_bboxes = std::make_shared<ios::Detection2Ds>(ios::Detection2Ds());
	for (int i = 0; i < this->tracked_tracks.size(); i++)
	{
		if (this->tracked_tracks[i].is_activated && (this->tracked_tracks[i].tracklet_len) >= this->init_frames)
		{
			// output_tracks.push_back(this->tracked_tracks[i]);
			auto t = this->tracked_tracks[i].obj_->copy();
			t->validCoordinate();
            tracked_bboxes->detections.push_back(t);
		}
	}
	interfaces::SetIOPtr(outputNames[0], tracked_bboxes);
    LOG_INFO << "Get " << tracked_bboxes->detections.size() << " tracked objs.";
    return true;
}



std::vector<Track*> ByteTracker::joint_tracks(std::vector<Track*> &tlista, std::vector<Track> &tlistb)
{
	std::map<int, int> exists;
	std::vector<Track*> res;
	for (int i = 0; i < tlista.size(); i++)
	{
		exists.insert(std::pair<int, int>(tlista[i]->track_id, 1));
		res.push_back(tlista[i]);
	}
	for (int i = 0; i < tlistb.size(); i++)
	{
		int tid = tlistb[i].track_id;
		if (!exists[tid] || exists.count(tid) == 0)
		{
			exists[tid] = 1;
			res.push_back(&tlistb[i]);
		}
	}
	return res;
}

std::vector<Track> ByteTracker::joint_tracks(std::vector<Track> &tlista, std::vector<Track> &tlistb)
{
	std::map<int, int> exists;
	std::vector<Track> res;
	for (int i = 0; i < tlista.size(); i++)
	{
		exists.insert(std::pair<int, int>(tlista[i].track_id, 1));
		res.push_back(tlista[i]);
	}
	for (int i = 0; i < tlistb.size(); i++)
	{
		int tid = tlistb[i].track_id;
		if (!exists[tid] || exists.count(tid) == 0)
		{
			exists[tid] = 1;
			res.push_back(tlistb[i]);
		}
	}
	return res;
}

std::vector<Track> ByteTracker::sub_tracks(std::vector<Track> &tlista, std::vector<Track> &tlistb)
{
	std::map<int, Track> tracks;
	for (int i = 0; i < tlista.size(); i++)
	{
		tracks.insert(std::pair<int, Track>(tlista[i].track_id, tlista[i]));
	}
	for (int i = 0; i < tlistb.size(); i++)
	{
		int tid = tlistb[i].track_id;
		if (tracks.count(tid) != 0)
		{
			tracks.erase(tid);
		}
	}

	std::vector<Track> res;
	std::map<int, Track>::iterator  it;
	for (it = tracks.begin(); it != tracks.end(); ++it)
	{
		res.push_back(it->second);
	}

	return res;
}

void ByteTracker::remove_duplicate_tracks(std::vector<Track> &resa, std::vector<Track> &resb, std::vector<Track> &tracksa, std::vector<Track> &tracksb)
{
	std::vector<std::vector<float> > pdist = iou_distance(tracksa, tracksb);
	std::vector<std::pair<int, int> > pairs;
	for (int i = 0; i < pdist.size(); i++)
	{
		for (int j = 0; j < pdist[i].size(); j++)
		{
			if (pdist[i][j] < 0.15)
			{
				pairs.push_back(std::pair<int, int>(i, j));
			}
		}
	}

	std::vector<int> dupa, dupb;
	for (int i = 0; i < pairs.size(); i++)
	{
		int timep = tracksa[pairs[i].first].frame_id - tracksa[pairs[i].first].start_frame;
		int timeq = tracksb[pairs[i].second].frame_id - tracksb[pairs[i].second].start_frame;
		if (timep > timeq)
			dupb.push_back(pairs[i].second);
		else
			dupa.push_back(pairs[i].first);
	}

	for (int i = 0; i < tracksa.size(); i++)
	{
		std::vector<int>::iterator iter = find(dupa.begin(), dupa.end(), i);
		if (iter == dupa.end())
		{
			resa.push_back(tracksa[i]);
		}
	}

	for (int i = 0; i < tracksb.size(); i++)
	{
		std::vector<int>::iterator iter = find(dupb.begin(), dupb.end(), i);
		if (iter == dupb.end())
		{
			resb.push_back(tracksb[i]);
		}
	}
}

void ByteTracker::linear_assignment(std::vector<std::vector<float> > &cost_matrix, int cost_matrix_size, int cost_matrix_size_size, float thresh,
	std::vector<std::vector<int> > &matches, std::vector<int> &unmatched_a, std::vector<int> &unmatched_b)
{
	if (cost_matrix.size() == 0)
	{
		for (int i = 0; i < cost_matrix_size; i++)
		{
			unmatched_a.push_back(i);
		}
		for (int i = 0; i < cost_matrix_size_size; i++)
		{
			unmatched_b.push_back(i);
		}
		return;
	}

	std::vector<int> rowsol; std::vector<int> colsol;
	float c = lapjv(cost_matrix, rowsol, colsol, true, thresh);
	for (int i = 0; i < rowsol.size(); i++)
	{
		if (rowsol[i] >= 0)
		{
			std::vector<int> match;
			match.push_back(i);
			match.push_back(rowsol[i]);
			matches.push_back(match);
		}
		else
		{
			unmatched_a.push_back(i);
		}
	}

	for (int i = 0; i < colsol.size(); i++)
	{
		if (colsol[i] < 0)
		{
			unmatched_b.push_back(i);
		}
	}
}

std::vector<std::vector<float> > ByteTracker::ious(std::vector<std::vector<float> > &atlbrs, std::vector<std::vector<float> > &btlbrs)
{
	std::vector<std::vector<float> > ious;
	if (atlbrs.size()*btlbrs.size() == 0)
		return ious;

	ious.resize(atlbrs.size());
	for (int i = 0; i < ious.size(); i++)
	{
		ious[i].resize(btlbrs.size());
	}

	//bbox_ious
	for (int k = 0; k < btlbrs.size(); k++)
	{
		std::vector<float> ious_tmp;
		float box_area = (btlbrs[k][2] - btlbrs[k][0] + 1)*(btlbrs[k][3] - btlbrs[k][1] + 1);
		for (int n = 0; n < atlbrs.size(); n++)
		{
			float iw = std::min(atlbrs[n][2], btlbrs[k][2]) - std::max(atlbrs[n][0], btlbrs[k][0]) + 1;
			if (iw > 0)
			{
				float ih = std::min(atlbrs[n][3], btlbrs[k][3]) - std::max(atlbrs[n][1], btlbrs[k][1]) + 1;
				if(ih > 0)
				{
					float ua = (atlbrs[n][2] - atlbrs[n][0] + 1)*(atlbrs[n][3] - atlbrs[n][1] + 1) + box_area - iw * ih;
					ious[n][k] = iw * ih / ua;
				}
				else
				{
					ious[n][k] = 0.0;
				}
			}
			else
			{
				ious[n][k] = 0.0;
			}
		}
	}

	return ious;
}

std::vector<std::vector<float> > ByteTracker::iou_distance(std::vector<Track*> &atracks, std::vector<Track> &btracks, int &dist_size, int &dist_size_size)
{
	std::vector<std::vector<float> > cost_matrix;
	if (atracks.size() * btracks.size() == 0)
	{
		dist_size = atracks.size();
		dist_size_size = btracks.size();
		return cost_matrix;
	}
	std::vector<std::vector<float> > atlbrs, btlbrs;
	for (int i = 0; i < atracks.size(); i++)
	{
		atlbrs.push_back(atracks[i]->tlbr);
	}
	for (int i = 0; i < btracks.size(); i++)
	{
		btlbrs.push_back(btracks[i].tlbr);
	}

	dist_size = atracks.size();
	dist_size_size = btracks.size();

	std::vector<std::vector<float> > _ious = ious(atlbrs, btlbrs);
	
	for (int i = 0; i < _ious.size();i++)
	{
		std::vector<float> _iou;
		for (int j = 0; j < _ious[i].size(); j++)
		{
			_iou.push_back(1 - _ious[i][j]);
		}
		cost_matrix.push_back(_iou);
	}

	return cost_matrix;
}

std::vector<std::vector<float> > ByteTracker::iou_distance(std::vector<Track> &atracks, std::vector<Track> &btracks)
{
	std::vector<std::vector<float> > atlbrs, btlbrs;
	for (int i = 0; i < atracks.size(); i++)
	{
		atlbrs.push_back(atracks[i].tlbr);
	}
	for (int i = 0; i < btracks.size(); i++)
	{
		btlbrs.push_back(btracks[i].tlbr);
	}

	std::vector<std::vector<float> > _ious = ious(atlbrs, btlbrs);
	std::vector<std::vector<float> > cost_matrix;
	for (int i = 0; i < _ious.size(); i++)
	{
		std::vector<float> _iou;
		for (int j = 0; j < _ious[i].size(); j++)
		{
			_iou.push_back(1 - _ious[i][j]);
		}
		cost_matrix.push_back(_iou);
	}

	return cost_matrix;
}

double ByteTracker::lapjv(const std::vector<std::vector<float> > &cost, std::vector<int> &rowsol, std::vector<int> &colsol,
	bool extend_cost, float cost_limit, bool return_cost)
{
	std::vector<std::vector<float> > cost_c;
	cost_c.assign(cost.begin(), cost.end());

	std::vector<std::vector<float> > cost_c_extended;

	int n_rows = cost.size();
	int n_cols = cost[0].size();
	rowsol.resize(n_rows);
	colsol.resize(n_cols);

	int n = 0;
	if (n_rows == n_cols)
	{
		n = n_rows;
	}
	else
	{
		if (!extend_cost)
		{
			LOG_ERROR << "lapjv set extend_cost=True.";
			// system("pause");
			// exit(0);
		}
	}
		
	if (extend_cost || cost_limit < LONG_MAX)
	{
		n = n_rows + n_cols;
		cost_c_extended.resize(n);
		for (int i = 0; i < cost_c_extended.size(); i++)
			cost_c_extended[i].resize(n);

		if (cost_limit < LONG_MAX)
		{
			for (int i = 0; i < cost_c_extended.size(); i++)
			{
				for (int j = 0; j < cost_c_extended[i].size(); j++)
				{
					cost_c_extended[i][j] = cost_limit / 2.0;
				}
			}
		}
		else
		{
			float cost_max = -1;
			for (int i = 0; i < cost_c.size(); i++)
			{
				for (int j = 0; j < cost_c[i].size(); j++)
				{
					if (cost_c[i][j] > cost_max)
						cost_max = cost_c[i][j];
				}
			}
			for (int i = 0; i < cost_c_extended.size(); i++)
			{
				for (int j = 0; j < cost_c_extended[i].size(); j++)
				{
					cost_c_extended[i][j] = cost_max + 1;
				}
			}
		}

		for (int i = n_rows; i < cost_c_extended.size(); i++)
		{
			for (int j = n_cols; j < cost_c_extended[i].size(); j++)
			{
				cost_c_extended[i][j] = 0;
			}
		}
		for (int i = 0; i < n_rows; i++)
		{
			for (int j = 0; j < n_cols; j++)
			{
				cost_c_extended[i][j] = cost_c[i][j];
			}
		}

		cost_c.clear();
		cost_c.assign(cost_c_extended.begin(), cost_c_extended.end());
	}

	double **cost_ptr;
	cost_ptr = new double *[sizeof(double *) * n];
	for (int i = 0; i < n; i++)
		cost_ptr[i] = new double[sizeof(double) * n];

	for (int i = 0; i < n; i++)
	{
		for (int j = 0; j < n; j++)
		{
			cost_ptr[i][j] = cost_c[i][j];
		}
	}

	int* x_c = new int[sizeof(int) * n];
	int *y_c = new int[sizeof(int) * n];

	int ret = lapjv_internal(n, cost_ptr, x_c, y_c);
	if (ret != 0)
	{
		LOG_ERROR << "lapjv_internal Calculate Wrong!";
		// system("pause");
		// exit(0);
	}

	double opt = 0.0;

	if (n != n_rows)
	{
		for (int i = 0; i < n; i++)
		{
			if (x_c[i] >= n_cols)
				x_c[i] = -1;
			if (y_c[i] >= n_rows)
				y_c[i] = -1;
		}
		for (int i = 0; i < n_rows; i++)
		{
			rowsol[i] = x_c[i];
		}
		for (int i = 0; i < n_cols; i++)
		{
			colsol[i] = y_c[i];
		}

		if (return_cost)
		{
			for (int i = 0; i < rowsol.size(); i++)
			{
				if (rowsol[i] != -1)
				{
					//cout << i << "\t" << rowsol[i] << "\t" << cost_ptr[i][rowsol[i]] << endl;
					opt += cost_ptr[i][rowsol[i]];
				}
			}
		}
	}
	else if (return_cost)
	{
		for (int i = 0; i < rowsol.size(); i++)
		{
			opt += cost_ptr[i][rowsol[i]];
		}
	}

	for (int i = 0; i < n; i++)
	{
		delete[]cost_ptr[i];
	}
	delete[]cost_ptr;
	delete[]x_c;
	delete[]y_c;

	return opt;
}


std::string ByteTracker::Name() {
	return "ByteTracker";
};



} //namespace tracker
} //namspace deepinfer
} //namespace waytous


