#!E:\Pycharm Projects\Waytous
# -*- coding: utf-8 -*-
# @Time : 2021/7/26 14:35
# @Author : Opfer
# @Site :
# @File : path_plannner.py    
# @Software: PyCharm

from equipment.truck import TruckInfo, DumpInfo, ExcavatorInfo
# from path_plan.priority_control import PriorityController
from path_plan.priority_control import PriorityController
from path_plan.topo_graph import *
from para_config import *
from tables import *

M = 1000000


class PathPlanner(WalkManage):
    def __init__(self, topo):
        # 路段类
        self.lane = LaneInfo()
        self.lane.lane_speed_generate()
        self.topo = topo

        # 生成的拓扑图
        self.unload_G = self.topo.get_unload_G()
        self.load_G = self.topo.get_load_G()

        # 设备类
        self.dump = DumpInfo()
        self.excavator = ExcavatorInfo()
        self.truck = TruckInfo(self.dump, self.excavator)
        self.truck.update_truck_size()
        # 控制类
        self.controller = PriorityController(self.dump, self.excavator, self.truck)

        # 路线行驶成本
        self.rout_cost = np.array((unload_area_num, load_area_num))
        # 路段集合
        self.lane_set = {}

        self.logger = get_logger("zxt.path_planner")
        try:
            self.truck_length = float(sum(self.truck.get_length().values())) / len(self.truck.get_length())
        except Exception as es:
            self.logger.error("无矿卡数据")
            self.logger.error(es)
            self.truck_length = 3
        # 装载区数量
        self.num_of_load_area = len(set(update_load_area()))
        # 卸载区数量
        self.num_of_unload_area = len(set(update_unload_area()))
        # 备停区数量
        self.num_of_park_area = len(set(update_park_area()))
        # 路网行驶成本信息
        self.cost_to_load_area = np.full((self.num_of_unload_area, self.num_of_load_area), M)
        self.cost_to_unload_area = np.full((self.num_of_unload_area, self.num_of_load_area), M)
        # 路网信息（备停区）
        self.cost_park_to_load_area = np.full((self.num_of_park_area, self.num_of_load_area), M)
        # 设备路网成本信息
        self.cost_to_excavator = np.zeros_like(get_value("distance_to_excavator"))
        self.cost_to_dump = np.zeros_like(get_value("distance_to_dump"))
        self.cost_park_to_excavator = np.zeros_like(get_value("distance_park_to_excavator"))


    def update_load_G_locked_distance(self, source_node, target_node, alpha, beta):
        _, park_path = nx.single_source_dijkstra(self.load_G, source=source_node, target=target_node,weight="real_distance")
        for i in range(len(park_path) - 1):
            data = dict(self.load_G[park_path[i]][park_path[i + 1]]['lane'])
            for u, v in data.items():
                u_lane_cost = self.lane_cost_generate(u)
                data[u] = [v[0], beta * u_lane_cost]
            self.load_G[park_path[i]][park_path[i + 1]]['locked_distance'] = sum([beta * sum(i[0] for i in list(data.values())) , alpha * sum(i[-1] for i in list(data.values()))])

    def update_unload_G_locked_distance(self, source_node, target_node, alpha, beta):
        _, park_path = nx.single_source_dijkstra(self.unload_G, source=source_node, target=target_node,weight="real_distance")
        for i in range(len(park_path) - 1):
            data = dict(self.unload_G[park_path[i]][park_path[i + 1]]['lane'])
            for u, v in data.items():
                u_lane_cost = self.lane_cost_generate(u)
                data[u] = [v[0], beta * u_lane_cost]
            self.unload_G[park_path[i]][park_path[i + 1]]['locked_distance'] = sum([beta * sum(i[0] for i in list(data.values())) , alpha * sum(i[-1] for i in list(data.values()))])

    def path_cost_generate(self, load_area_id, unload_area_id, is_park):

        # self.logger.info("path_拓扑图输出")
        # self.logger.info(self.load_G.edges(data=True))
        # self.logger.info(self.unload_G.edges(data=True))
        # 卸载道路阻塞成本初始化
        cost_to_unload_blockage = 0
        # 装载道路阻塞成本初始化
        cost_to_load_blockage = 0
        # 卸载道路总成本初始化
        to_unload_cost = 0
        # 装载道路总成本初始化
        to_load_cost = 0

        # 修正因子
        weight = 60
        # 阻塞成本权重
        alpha = 1
        # 距离成本权重
        beta = 1
        session_mysql.commit()

        # 距离成本启用
        rule1 = session_mysql.query(DispatchRule).filter_by(id=1).first()
        if rule1.disabled == 0:
            beta = rule1.rule_weight

        # 拥堵成本启用
        rule2 = session_mysql.query(DispatchRule).filter_by(id=2).first()
        if rule2.disabled == 0:
            alpha = rule2.rule_weight

        beta /= beta
        alpha = alpha / beta * weight

        # 备停区处理
        if is_park:
            try:
                # 获取从指定备停区到装载区的指定路径并完成cost更新
                park_source_node = str(session_postgre.query(Node).filter_by(Name="测试备停区出场点").first().Id)
                park_end_node = str(session_postgre.query(DiggingWorkArea).filter_by(Id=load_area_id).first().EntranceNodeId)
                self.update_load_G_locked_distance(park_source_node, park_end_node, alpha, beta)
                to_load_cost = nx.dijkstra_path_length(self.load_G, source=park_source_node, target=park_end_node,
                                                 weight="locked_distance")
            except Exception as es:
                self.logger.error(f"从备停区到装载区{load_area_id}cost更新异常")
                self.logger.error(es)
        else:
            try:
                # from load_G part
                load_source_node = str(session_postgre.query(DumpArea).filter_by(Id=unload_area_id).first().ExitNodeId)
                load_end_node =str(session_postgre.query(DiggingWorkArea).filter_by(Id=load_area_id).first().EntranceNodeId)
                self.update_load_G_locked_distance(load_source_node, load_end_node, alpha, beta)
                to_load_cost = nx.dijkstra_path_length(self.load_G, source=load_source_node, target=load_end_node, weight="locked_distance")
            except Exception as es:
                self.logger.error(f"从卸载区{unload_area_id}到装载区{load_area_id}cost更新异常")
                self.logger.error(es)

            try:
                # from unload_G part
                unload_source_node = str(session_postgre.query(DiggingWorkArea).filter_by(Id=load_area_id).first().ExitNodeId)
                unload_end_node = str(session_postgre.query(DumpArea).filter_by(Id=unload_area_id).first().EntranceNodeId)
                self.update_unload_G_locked_distance(unload_source_node, unload_end_node, alpha, beta)
                to_unload_cost = nx.dijkstra_path_length(self.unload_G, source=unload_source_node, target=unload_end_node, weight="locked_distance")
            except Exception as es:
                self.logger.error(f"从装载区{load_area_id}到卸载区{unload_area_id}cost更新异常")
                self.logger.error(es)

        # print("拥堵因子-挖机")
        # print(alpha, cost_to_load_blockage)
        # print("拥堵因子-卸点")
        # print(alpha, cost_to_unload_blockage)

        return to_load_cost, to_unload_cost

    def lane_cost_generate(self, lane_id):
        try:
            # 读取路段记录
            lane_rec = session_postgre.query(Lane).filter_by(Id=lane_id).first()

            # 道路长度
            lane_length = lane_rec.Length
            # 车辆自由行驶时的速度
            clear_speed = lane_rec.MaxSpeed

            # 1. 计算阻塞时车辆密度=路段长度/车辆长度
            truck_density = lane_length / self.truck_length
            # 2. 读取实际车流速度
            actual_speed = self.lane.lane_speed_dict[lane_id]
            # 3. 计算路段阻塞程度=(1-实际路段速度)/路段最高速度
            lane_blockage = (1 - actual_speed / clear_speed) * truck_density

        except Exception as es:
            self.logger.error('路段拥堵成本计算异常')
            self.logger.error(es)

        return lane_blockage

    def walk_cost_cal(self):

        self.excavator.para_period_update()

        self.dump.para_period_update()

        self.truck.para_period_update(self.dump, self.excavator)

        self.truck.state_period_update()

        self.period_walk_para_load()

        self.period_map_para_load()

        # self.controller.period_update(self.dump, self.excavator, self.truck)

        # 计算行驶成本前，更新路网速度信息
        self.lane.lane_speed_generate()

        try:
            # 读取路网成本
            for walk_time in session_postgre.query(WalkTime).all():
                load_area_id, unload_area_id = str(walk_time.load_area_id), str(walk_time.unload_area_id)
                unload_area_index = unload_area_uuid_to_index_dict[unload_area_id]
                load_area_index = load_area_uuid_to_index_dict[load_area_id]

                self.cost_to_load_area[unload_area_index][load_area_index], self.cost_to_unload_area[unload_area_index][load_area_index] =   \
                    self.path_cost_generate(load_area_id, unload_area_id, False)

            # 读取备停区路网成本
            for walk_time_park in session_postgre.query(WalkTimePark).all():
                park_area_index = park_uuid_to_index_dict[str(walk_time_park.park_area_id)]
                load_area_index = load_area_uuid_to_index_dict[str(walk_time_park.load_area_id)]
                self.cost_park_to_load_area[park_area_index][load_area_index], _ = \
                    self.path_cost_generate(str(walk_time_park.load_area_id), str(walk_time_park.park_area_id), True)
        except Exception as es:
            self.logger.error('路网信息计成本计算异常')
            self.logger.error(es)

        self.cost_to_excavator = np.zeros_like(get_value("distance_to_excavator"))
        self.cost_to_dump = np.zeros_like(get_value("distance_to_dump"))
        self.cost_park_to_excavator = np.zeros_like(get_value("distance_park_to_excavator"))

        self.logger.info("distance_park_to_excavator")
        self.logger.info(self.distance_park_to_excavator)

        # try:
        # 路网权重
        walk_to_excavator_weight, walk_to_dump_weight, park_walk_weight = self.controller.weighted_walk_calc()

        # # 路网禁用关系
        # walk_available = self.controller.walk_available_calc()

        # group_walk_available = self.controller.update_group_walk_available()

        self.logger.info("path_weight")
        self.logger.info(walk_to_excavator_weight)
        self.logger.info(walk_to_dump_weight)
        self.logger.info(park_walk_weight)

        # self.logger.info("walk_avail")
        # self.logger.info(walk_available)

        # except Exception as es:
        #     self.logger.error("无派车计划可用")

        for i in range(get_value("dynamic_dump_num")):
            for j in range(get_value("dynamic_excavator_num")):
                load_area_index = self.excavator_index_to_load_area_index_dict[j]
                unload_area_index = self.dump_index_to_unload_area_index_dict[i]
                self.logger.info("cost_to_excavator")
                self.logger.info(self.cost_to_excavator)
                # self.cost_to_excavator[i][j] = self.cost_to_load_area[unload_area_index][load_area_index] / walk_weight[i][j] + group_walk_available[i][j]
                self.cost_to_excavator[i][j] = self.cost_to_load_area[unload_area_index][load_area_index] / \
                                               walk_to_excavator_weight[i][j]
                # self.cost_to_dump[i][j] = self.cost_to_unload_area[unload_area_index][load_area_index] / walk_weight[i][j] + walk_available[i][j] + group_walk_available[i][j]
                self.cost_to_dump[i][j] = self.cost_to_unload_area[unload_area_index][load_area_index] / \
                                          walk_to_dump_weight[j][i]

        for j in range(get_value("dynamic_excavator_num")):
            load_area_index = self.excavator_index_to_load_area_index_dict[j]
            self.cost_park_to_excavator[0][j] = self.cost_park_to_load_area[0][load_area_index] / park_walk_weight[0][j]

        self.logger.info("真实路网距离-驶往挖机:")
        self.logger.info(self.distance_to_excavator)

        self.logger.info("真实路网距离-驶往卸点:")
        self.logger.info(self.distance_to_dump)

        self.logger.info("真实备停区路网距离-驶往挖机:")
        self.logger.info(self.distance_park_to_excavator)

        self.logger.info("阻塞路网距离-驶往挖机:")
        self.logger.info(self.cost_to_excavator)

        self.logger.info("阻塞路网距离-驶往卸点:")
        self.logger.info(self.cost_to_dump)

        self.logger.info("阻塞备停区路网距离-驶往挖机：")
        self.logger.info(self.cost_park_to_excavator)

        return self.cost_to_excavator, self.cost_to_dump, self.cost_park_to_excavator


class LaneInfo:
    def __init__(self):
        self.lane_speed_dict = {}

        self.logger = get_logger("zxt.laneinfo")

    def update_truck_speed(self):
        # 读取矿卡实时速度信息
        try:
            truck_speed_dict = {}
            device_name_set = redis2.keys()
            for item in device_name_set:
                item = item.decode(encoding='utf-8')
                # json_value = json.loads(redis2.get(item))
                key_value_dict = redis2.hgetall(item)
                device_type = key_value_dict[str_to_byte('type')]
                if device_type == str_to_byte("1") and str_to_byte('speed') in key_value_dict.keys():
                    truck_speed = float(key_value_dict[str_to_byte('speed')])
                    truck_speed_dict[truck_name_to_uuid_dict[item]] = truck_speed
        except Exception as es:
            self.logger.error(f'矿卡{item}实时速度读取异常')
            self.logger.error(es)

        print("truck_speed_dict")
        print(truck_speed_dict)

        return truck_speed_dict

    def update_truck_loacate(self):
        # 读取矿卡所在路段信息
        try:
            truck_locate_dict = {}
            device_name_set = redis2.keys()
            for item in device_name_set:
                item = item.decode(encoding='utf-8')
                # json_value = json.loads(redis2.get(item))
                key_value_dict = redis2.hgetall(item)
                device_type = key_value_dict[str_to_byte('type')]
                is_online = key_value_dict[str_to_byte('online')]
                key_set = key_value_dict.keys()
                if (device_type == str_to_byte("1")) \
                        and (str_to_byte('online') in key_set) \
                        and (bytes.decode(is_online) in ["true" or "True"]) \
                        and (str_to_byte('laneId') in key_set):
                    truck_locate = key_value_dict[str_to_byte('laneId')]
                    truck_locate_dict[truck_name_to_uuid_dict[item]] = eval(truck_locate)

        except Exception as es:
            self.logger.error(f'矿卡{item}所在路段信息读取异常')
            self.logger.error(es)

        print("truck_locate_dict")
        print(truck_locate_dict)

        return truck_locate_dict

    def lane_speed_generate(self):

        # truck -> lane
        truck_locate_dict = self.update_truck_loacate()

        self.logger.info("矿卡位于路段:")
        self.logger.info(truck_locate_dict)

        # truck -> speed
        truck_speed_dict = self.update_truck_speed()

        self.logger.info("矿卡当前速度:")
        self.logger.info(truck_speed_dict)

        try:
            # lane_set, 用到的路段集合
            lane_set = []
            for walk_time in session_postgre.query(WalkTime).all():
                for lane in walk_time.to_load_lanes:
                    lane_set.append(lane)
                for lane in walk_time.to_unload_lanes:
                    lane_set.append(lane)
            for walk_time_park in session_postgre.query(WalkTimePark).all():
                for lane in walk_time_park.park_load_lanes:
                    lane_set.append(lane)
            lane_set = set(lane_set)
        except Exception as es:
            self.logger.error('所用路网路段集合读取异常')
            self.logger.info(es)

        # lane -> speed, 各路段平均行驶速度
        self.lane_speed_dict = {}

        # lane -> num, 各路段行驶车辆
        lane_trucks_dict = {}

        # used lane, 存在行驶矿卡的路段
        tmp_lane_set = []

        try:
            # 初始化
            for lane_id in lane_set:
                self.lane_speed_dict[str(lane_id)] = 0
                lane_trucks_dict[str(lane_id)] = 0

            # 对于各路段信息
            print("truck_locate_dict")
            print(truck_locate_dict.keys())
            for truck in truck_locate_dict.keys():
                lane_id = truck_locate_dict[truck]
                self.logger.info("lane_speed_generate-lane_id")
                self.logger.info(lane_id)
                if lane_id in lane_set:
                    self.lane_speed_dict[truck_locate_dict[truck]] = self.lane_speed_dict[truck_locate_dict[truck]] + \
                                                                     truck_speed_dict[truck]
                    # 该路段矿卡数量加一
                    lane_trucks_dict[truck_locate_dict[truck]] = lane_trucks_dict[truck_locate_dict[truck]] + 1
                    # 记录存在行驶矿卡的路段
                    tmp_lane_set.append(lane_id)

            # 存在矿卡的路段
            print("存在矿卡的路段:")
            print(tmp_lane_set)
            self.logger.info("存在矿卡的路段:")
            self.logger.info(tmp_lane_set)

            # 对不存在的矿卡路段，实时速度设置为最高
            for lane_id in lane_set:
                if lane_id not in tmp_lane_set:
                    self.lane_speed_dict[str(lane_id)] = session_postgre.query(Lane).filter_by(
                        Id=lane_id).first().MaxSpeed
                    lane_trucks_dict[str(lane_id)] = 1

            # 各路段实时速度取平均
            for lane in lane_trucks_dict:
                self.lane_speed_dict[lane] = self.lane_speed_dict[lane] / lane_trucks_dict[lane]

        except Exception as es:
            self.logger.error("路段实时速度计算异常")
            self.logger.error(es)

        return self.lane_speed_dict

# path_planner = PathPlanner()
#
# path_planner.walk_cost()
