#!E:\Pycharm Projects\Waytous
# -*- coding: utf-8 -*-
# @Time : 2021/7/26 14:35
# @Author : Opfer
# @Site :
# @File : path_plannner.py    
# @Software: PyCharm

from equipment.truck import TruckInfo, DumpInfo, ExcavatorInfo
# from path_plan.priority_control import PriorityController
from path_plan.priority_control import PriorityController
from para_config import *
from tables import *

M = 1000000


class PathPlanner(WalkManage):
    def __init__(self, dump, excavator, truck):
        # 路段类
        self.lane = LaneInfo()
        self.lane.lane_speed_generate()
        # 设备类
        # self.dump = DumpInfo()
        # self.excavator = ExcavatorInfo()
        # self.truck = TruckInfo(self.dump, self.excavator)
        self.dump = dump
        self.excavator = excavator
        self.truck = truck
        # self.truck.update_truck_size()
        # 控制类
        self.controller = PriorityController(self.dump, self.excavator, self.truck)

        # 路线行驶成本
        self.rout_cost = np.array((get_value("unload_area_num"), get_value("load_area_num")))
        # 路段集合
        self.lane_set = {}
        # 车辆长度(暂)
        try:
            self.truck_length = float(sum(self.truck.get_length().values())) / len(self.truck.get_length())
        except Exception as es:
            logger.error("无矿卡数据")
            logger.error(es)
            self.truck_length = 3
        # 装载区数量
        self.num_of_load_area = len(set(update_load_area()))
        # 卸载区数量
        self.num_of_unload_area = len(set(update_unload_area()))
        # 备停区数量
        self.num_of_park_area = len(set(update_park_area()))
        # 路网行驶成本信息
        self.cost_to_load_area = np.full((self.num_of_unload_area, self.num_of_load_area), M)
        self.cost_to_unload_area = np.full((self.num_of_unload_area, self.num_of_load_area), M)
        # 路网信息（备停区）
        self.cost_park_to_load_area = np.full((self.num_of_park_area, self.num_of_load_area), M)
        # 设备路网成本信息
        self.cost_to_excavator = np.zeros_like(get_value("distance_to_excavator"))
        self.cost_to_dump = np.zeros_like(get_value("distance_to_dump"))
        self.cost_park_to_excavator = np.zeros_like(get_value("distance_park_to_excavator"))

    def path_cost_generate(self, load_area_id, unload_area_id, is_park, lane_cost_memory, alpha, beta):

        # 卸载道路阻塞成本初始化
        cost_to_unload_blockage = 0
        # 装载道路阻塞成本初始化
        cost_to_load_blockage = 0
        # 卸载道路总成本初始化
        to_unload_cost = 0
        # 装载道路总成本初始化
        to_load_cost = 0
        # # 道路权重
        # weighted_distance = weighted_walk_cost()

        # 修正因子
        weight = 10

        session_mysql.commit()

        if alpha > 0:
            beta /= (beta + 0.001)
            alpha = alpha / beta * weight
        else:
            beta = 1

        try:

            # 备停区处理
            if is_park:
                # 提取指定道路记录
                try:
                    path = session_postgre.query(WalkTimePark).filter_by(park_area_id=unload_area_id,
                                                                     load_area_id=load_area_id).first()
                except Exception as es:
                    session_postgre.rollback()
                    session_mysql.rollback()
                # 读取道路路段信息
                for lane_id in path.park_load_lanes:
                    if lane_id in self.lane.used_lane_set:
                        if lane_id in lane_cost_memory:
                            lane_cost = lane_cost_memory[lane_id]
                        else:
                            lane_cost = self.lane_cost_generate(lane_id)
                            lane_cost_memory[lane_id] = lane_cost
                        # 各路段阻塞成本累加
                        cost_to_load_blockage = cost_to_load_blockage + beta * lane_cost
                # 道路总成本=道路距离成本+道路阻塞成本
                to_load_cost = alpha * cost_to_load_blockage + beta * path.park_load_distance

            else:
                try:
                    path = session_postgre.query(WalkTime).filter_by(load_area_id=load_area_id,
                                                                     unload_area_id=unload_area_id).first()
                except Exception as es:
                    session_postgre.rollback()
                    session_mysql.rollback()
                for lane_id in path.to_unload_lanes:
                    if lane_id in self.lane.used_lane_set:
                        if lane_id in lane_cost_memory:
                            lane_cost = lane_cost_memory[lane_id]
                        else:
                            lane_cost = self.lane_cost_generate(lane_id)
                            lane_cost_memory[lane_id] = lane_cost

                        cost_to_unload_blockage = cost_to_unload_blockage + lane_cost

                for lane_id in path.to_load_lanes:
                    if lane_id in self.lane.used_lane_set:
                        if lane_id in lane_cost_memory:
                            lane_cost = lane_cost_memory[lane_id]
                        else:
                            lane_cost = self.lane_cost_generate(lane_id)
                            lane_cost_memory[lane_id] = lane_cost

                        cost_to_load_blockage = cost_to_load_blockage + lane_cost

                to_unload_cost = alpha * cost_to_unload_blockage + beta * path.to_unload_distance
                to_load_cost = alpha * cost_to_load_blockage + beta * path.to_load_distance

        except Exception as es:
            logger.error(f'道路{load_area_id + "-" + unload_area_id}行驶成本计算异常')
            logger.error(es)

        return to_load_cost, to_unload_cost

    def lane_cost_generate(self, lane_id):
        """ 计算路段拥堵成本
        :param lane_id: (uuid) 路段id
        :return:
            lane_blockage: (float) 路段拥堵度
        """

        lane_blockage = 0  # 路段拥堵度默认为0

        try:
            try:
                lane_rec = session_postgre.query(Lane).filter_by(Id=lane_id).first()  # 读取路段记录
            except Exception as es:
                session_postgre.rollback()
                session_mysql.rollback()

            lane_length = lane_rec.Length  # 道路长度

            clear_speed = lane_rec.MaxSpeed  # 车辆自由行驶时的速度

            # 1. 计算阻塞时车辆密度=路段长度/车辆长度
            truck_density = lane_length / self.truck_length
            # 2. 读取实际车流速度
            actual_speed = self.lane.lane_speed_dict[lane_id]
            # 3. 计算路段阻塞程度=(1-实际路段速度)/路段最高速度
            lane_blockage = (1 - actual_speed / clear_speed) * truck_density
        except Exception as es:
            logger.error('路段拥堵成本计算异常')
            logger.error(es)

        return lane_blockage

    def walk_cost_cal(self):
        """
        计算路网行驶成本
        :return:
            cost_to_excavator:  (Matrix[int]) 卸载区驶往装载区行驶成本
            cost_to_dump:  (Matrix[int]) 装载区驶往卸载区行驶成本
            cost_park_to_excavator:  (Matrix[int]) 备停区驶往装载区行驶成本
        """

        alpha = 0  # 阻塞成本权重

        beta = 1  # 距离成本权重

        lane_cost_memory = {}  # 路段拥堵度列表, 记忆化搜索

        # 距离成本启用
        try:
            rule1 = session_mysql.query(DispatchRule).filter_by(id=1).first()
        except Exception as es:
            session_postgre.rollback()
            session_mysql.rollback()
        if rule1.disabled == 0:
            beta = rule1.rule_weight

        # 拥堵成本启用
        try:
            rule2 = session_mysql.query(DispatchRule).filter_by(id=2).first()
        except Exception as es:
            session_postgre.rollback()
            session_mysql.rollback()
        if rule2.disabled == 0:
            alpha = rule2.rule_weight

        try:
            unload_area_uuid_to_index_dict = get_value("unload_area_uuid_to_index_dict")
            load_area_uuid_to_index_dict = get_value("load_area_uuid_to_index_dict")
            park_uuid_to_index_dict = get_value("park_uuid_to_index_dict")
            # 读取路网成本
            try:
                for walk_time in session_postgre.query(WalkTime).all():
                    load_area_id, unload_area_id = str(walk_time.load_area_id), str(walk_time.unload_area_id)
                    unload_area_index = unload_area_uuid_to_index_dict[unload_area_id]
                    load_area_index = load_area_uuid_to_index_dict[load_area_id]
                    self.cost_to_load_area[unload_area_index][load_area_index], \
                    self.cost_to_unload_area[unload_area_index][load_area_index] = \
                        self.path_cost_generate(load_area_id, unload_area_id, False, lane_cost_memory, alpha, beta)
                    # self.cost_to_load_area[unload_area_index][load_area_index] = \
                    #     WalkManage.distance_to_load_area[unload_area_index][load_area_index]
                    # self.cost_to_unload_area[unload_area_index][load_area_index] = \
                    #     WalkManage.distance_to_unload_area[unload_area_index][load_area_index]
            except Exception as es:
                session_postgre.rollback()
                session_mysql.rollback()

            # 读取备停区路网成本
            try:
                for walk_time_park in session_postgre.query(WalkTimePark).all():
                    park_area_index = park_uuid_to_index_dict[str(walk_time_park.park_area_id)]
                    load_area_index = load_area_uuid_to_index_dict[str(walk_time_park.load_area_id)]
                    self.cost_park_to_load_area[park_area_index][load_area_index], _ = \
                        self.path_cost_generate(str(walk_time_park.load_area_id), str(walk_time_park.park_area_id), True, lane_cost_memory, alpha, beta)
                    # self.cost_park_to_load_area[park_area_index][load_area_index] = \
                    #     WalkManage.distance_park_to_load_area[park_area_index][load_area_index]
            except Exception as es:
                session_postgre.rollback()
                session_mysql.rollback()

            logger.info(self.cost_park_to_load_area)
            logger.info(self.distance_park_to_excavator)
        except Exception as es:
            logger.error('路网信息计成本计算异常')
            logger.error(es)

        # 基于工作区路网行驶成本, 更新设备间逻辑路网行驶成本
        # 调度算法以设备间逻辑路网作为调度的依据

        self.cost_to_excavator = np.zeros_like(get_value("distance_to_excavator"))
        self.cost_to_dump = np.zeros_like(get_value("distance_to_dump"))
        self.cost_park_to_excavator = np.zeros_like(get_value("distance_park_to_excavator"))

        logger.info("distance_park_to_excavator")
        logger.info(self.distance_park_to_excavator)

        # try:
        walk_to_excavator_weight, walk_to_dump_weight, park_walk_weight = self.controller.weighted_walk_calc()

        # walk_available = self.controller.walk_available_calc()

        # group_walk_available = self.controller.update_group_walk_available()

        logger.info("path_weight")
        logger.info(walk_to_excavator_weight)
        logger.info(walk_to_dump_weight)
        logger.info(park_walk_weight)

        # logger.info("walk_avail")
        # logger.info(walk_available)

        # except Exception as es:
        #     logger.error("无派车计划可用")
        try:

            for i in range(get_value("dynamic_dump_num")):
                for j in range(get_value("dynamic_excavator_num")):
                    load_area_index = self.excavator_index_to_load_area_index_dict[j]
                    unload_area_index = self.dump_index_to_unload_area_index_dict[i]
                    # self.cost_to_excavator[i][j] = self.cost_to_load_area[unload_area_index][load_area_index] / walk_weight[i][j] + group_walk_available[i][j]
                    self.cost_to_excavator[i][j] = self.cost_to_load_area[unload_area_index][load_area_index] / \
                                                   walk_to_excavator_weight[i][j]
                    # self.cost_to_dump[i][j] = self.cost_to_unload_area[unload_area_index][load_area_index] / walk_weight[i][j] + walk_available[i][j] + group_walk_available[i][j]
                    self.cost_to_dump[i][j] = self.cost_to_unload_area[unload_area_index][load_area_index] / \
                                              walk_to_dump_weight[j][i]

            for j in range(get_value("dynamic_excavator_num")):
                load_area_index = self.excavator_index_to_load_area_index_dict[j]
                self.cost_park_to_excavator[0][j] = self.cost_park_to_load_area[0][load_area_index] / park_walk_weight[0][j]

        except Exception as es:
            logger.error(es)
            logger.error("路网映射异常")

        logger.info("真实路网距离-驶往挖机:")
        logger.info(self.distance_to_excavator)

        logger.info("真实路网距离-驶往卸点:")
        logger.info(self.distance_to_dump)

        logger.info("真实备停区路网距离-驶往挖机:")
        logger.info(self.distance_park_to_excavator)

        logger.info("阻塞路网距离-驶往挖机:")
        logger.info(self.cost_to_excavator)

        logger.info("阻塞路网距离-驶往卸点:")
        logger.info(self.cost_to_dump)

        logger.info("阻塞备停区路网距离-驶往挖机：")
        logger.info(self.cost_park_to_excavator)

        return self.cost_to_excavator, self.cost_to_dump, self.cost_park_to_excavator


class LaneInfo:
    def __init__(self):
        self.lane_speed_dict = {}  # 路段速度表
        self.used_lane_set = []  # 存在矿卡路段表

    def update_truck_speed(self):
        """读取矿卡实时速度信息
        :return:
            truck_speed_dict:  (Dict{key:truck_id, value:speed}) 矿卡速度表
        """

        truck_speed_dict = {}
        try:
            truck_name_to_uuid_dict = get_value("truck_name_to_uuid_dict")
            device_name_set = redis2.keys()
            for item in device_name_set:
                item = item.decode(encoding='utf-8')
                # json_value = json.loads(redis2.get(item))
                key_value_dict = redis2.hgetall(item)
                device_type = key_value_dict[str_to_byte('type')]
                if device_type == str_to_byte("1") and str_to_byte('speed') in key_value_dict.keys():
                    truck_speed = float(key_value_dict[str_to_byte('speed')])
                    truck_speed_dict[truck_name_to_uuid_dict[item]] = truck_speed
        except Exception as es:
            logger.error(f'矿卡{item}实时速度读取异常')
            logger.error(es)

        return truck_speed_dict

    def update_truck_loacate(self):
        """读取矿卡所在路段信息
        :return:
            truck_locate_dict:  (Dict{key:truck_id, value:lane_id}) 矿卡所在路段表
        """

        truck_name_to_uuid_dict = get_value("truck_name_to_uuid_dict")
        self.used_lane_set = []
        # try:
        truck_locate_dict = {}
        device_name_set = redis2.keys()
        for item in device_name_set:
            item = item.decode(encoding='utf-8')
            # json_value = json.loads(redis2.get(item))
            key_value_dict = redis2.hgetall(item)
            device_type = key_value_dict[str_to_byte('type')]
            is_online = key_value_dict[str_to_byte('online')]
            key_set = key_value_dict.keys()
            if (device_type == str_to_byte("1")) \
                    and (str_to_byte('online') in key_set) \
                    and (bytes.decode(is_online) in ["true" or "True"]) \
                    and (str_to_byte('laneId') in key_set):
                truck_locate = key_value_dict[str_to_byte('laneId')]
                truck_locate_dict[truck_name_to_uuid_dict[item]] = eval(truck_locate)
                self.used_lane_set.append(eval(truck_locate))
        # except Exception as es:
        #     logger.error(f'矿卡{item}所在路段信息读取异常')
        #     logger.error(es)

        logger.info("truck_locate_dict")
        logger.info(truck_locate_dict)

        return truck_locate_dict

    def lane_speed_generate(self):
        """计算存在矿卡路段实时速度
        :return:
            lane_speed_dict: (Dict{key:lane_id, value:avg_speed}) 各路段矿卡平均速度
        """

        truck_locate_dict = self.update_truck_loacate()  # truck -> lane

        logger.info("矿卡位于路段:")
        logger.info(truck_locate_dict)

        truck_speed_dict = self.update_truck_speed()  # truck -> speed

        logger.info("矿卡当前速度:")
        logger.info(truck_speed_dict)

        # try:
        #     # lane_set, 用到的路段集合
        #     lane_set = []
        #     for walk_time in session_postgre.query(WalkTime).all():
        #         for lane in walk_time.to_load_lanes:
        #             lane_set.append(lane)
        #         for lane in walk_time.to_unload_lanes:
        #             lane_set.append(lane)
        #     for walk_time_park in session_postgre.query(WalkTimePark).all():
        #         for lane in walk_time_park.park_load_lanes:
        #             lane_set.append(lane)
        #     lane_set = set(lane_set)
        # except Exception as es:
        #     logger.error('所用路网路段集合读取异常')
        #     logger.info(es)

        self.lane_speed_dict = {}  # lane -> avg_speed

        lane_trucks_dict = {}  # lane -> truck_num

        try:
            # 初始化
            for lane_id in self.used_lane_set:
                self.lane_speed_dict[str(lane_id)] = 0
                lane_trucks_dict[str(lane_id)] = 0

            # 对于各路段信息
            for truck in truck_locate_dict.keys():
                lane_id = truck_locate_dict[truck]
                logger.info("lane_speed_generate-lane_id")
                logger.info(lane_id)
                # 矿卡速度字段和位置字段存在
                if truck in truck_speed_dict and truck in truck_locate_dict:
                    self.lane_speed_dict[truck_locate_dict[truck]] += truck_speed_dict[truck]
                    # 该路段矿卡数量加一
                    lane_trucks_dict[truck_locate_dict[truck]] += 1
                    # # 记录存在行驶矿卡的路段
                    # self.used_lane_set.append(lane_id)

            # 存在矿卡的路段
            logger.info("存在矿卡的路段:")
            logger.info(self.used_lane_set)

            # # 对不存在的矿卡路段，实时速度设置为最高
            # for lane_id in lane_set:
            #     if lane_id not in self.used_lane_set:
            #         self.lane_speed_dict[str(lane_id)] = session_postgre.query(Lane).filter_by(
            #             Id=lane_id).first().MaxSpeed
            #         lane_trucks_dict[str(lane_id)] = 1

            # 各路段实时速度取平均
            for lane in lane_trucks_dict:
                self.lane_speed_dict[lane] = self.lane_speed_dict[lane] / lane_trucks_dict[lane]

        except Exception as es:
            logger.error("路段实时速度计算异常")
            logger.error(es)

        return self.lane_speed_dict

# path_planner = PathPlanner()
#
# path_planner.walk_cost()
