# from para_config import *
from settings import *
from tables import *
import numpy as np
import sched
import time
import matplotlib.pyplot as plt
# from path_plan.path_plannner import *
import networkx as nx
from settings import *


class Topo():
    """ class for the topo.
    Description:
        根据设备状态计算调度价值
    Attribute:
        equipment class: None
    """

    def __init__(self):

        # to unload graph
        # self.unload_G = nx.DiGra/ph()
        self.unload_G = nx.Graph()

        self.unload_G_nodes = np.array(self.unload_G)
        self.unload_G_num_of_nodes = self.unload_G.number_of_nodes()
        self.unload_G_edges = np.array(self.unload_G)
        self.unload_G_num_of_edges = self.unload_G.number_of_edges()

        self.unload_G_all_nodes = []
        self.unload_G_digging_nodes = []
        self.unload_G_dump_nodes = []
        # self.unload_path_dict = {}

        # to load graph
        # self.load_G = nx.DiGraph()
        self.load_G = nx.Graph()

        self.load_G_nodes = np.array(self.load_G)
        self.load_G_num_of_nodes = self.load_G.number_of_nodes()
        self.load_G_edges = np.array(self.load_G)
        self.load_G_num_of_edges = self.load_G.number_of_edges()

        self.load_G_all_nodes = []
        self.load_G_digging_nodes = []
        self.load_G_dump_nodes = []
        # self.load_path_dict = {}

        self.work_area_distance_info = []
        self.park_distance_info = []
        # self.laneinfo = LaneInfo()

        # 获取日志器
        self.logger = get_logger("zxt.topo")

        # 车辆交叉路口点集合
        self.cross_nodes = []
        self.cross_bf_lanes = []

    """
    unload_G funcions
    """

    def get_unload_G(self):
        return self.unload_G

    def get_unload_G_nodes(self):
        return self.unload_G_nodes

    def get_unload_G_num_of_nodes(self):
        return self.unload_G_num_of_nodes

    def get_unload_G_edges(self):
        return self.unload_G_edges

    def get_unload_G_num_of_edges(self):
        return self.unload_G_num_of_edges

    """
    load_G functions
    """

    def get_load_G(self):
        return self.load_G

    def get_load_G_nodes(self):
        return self.load_G_nodes

    def get_load_G_num_of_nodes(self):
        return self.load_G_num_of_nodes

    def get_load_G_edges(self):
        return self.load_G_edges

    def get_load_G_num_of_edges(self):
        return self.load_G_num_of_edges

    def get_work_area_distance_info(self):
        self.work_area_distance_info = []
        try:
            for item in session_postgre.query(WalkTime).all():
                self.work_area_distance_info.append(
                    [[str(item.load_area_id), str(item.unload_area_id)], item.to_unload_lanes, item.to_load_lanes])
        except Exception as es:
            self.logger.error(es)
            self.logger.error("获取地图信息出错")
            session_mysql.rollback()
            session_postgre.rollback()

    def return_work_area_distance_info(self):
        self.get_work_area_distance_info()
        return self.work_area_distance_info

    def get_park_distance_info(self):
        self.park_distance_info = []
        try:
            for item in session_postgre.query(WalkTimePark).all():
                self.park_distance_info.append([[str(item.park_area_id), str(item.load_area_id)], item.park_load_lanes])
        except Exception as es:
            self.logger.error(es)
            self.logger.error("获取地图备停区信息出错")
            session_mysql.rollback()
            session_postgre.rollback()

    """
    generate_topo_graph: unload_G and load_G
    """

    def generate_topo_graph(self):
        self.get_work_area_distance_info()
        self.get_park_distance_info()
        trip = self.work_area_distance_info
        park_trip = self.park_distance_info

        for item in trip:
            # item = [[load_area_id, unload_area_id], to_unload_lanes, to_load_lanes]
            # path_node_for_trip= []

            # unload_G
            Exitnode_for_digging = str(
                session_postgre.query(DiggingWorkArea).filter_by(Id=item[0][0]).first().ExitNodeId)
            digging_name = str(
                session_postgre.query(DiggingWorkArea).filter_by(Id=item[0][0]).first().Name)
            Entrancenode_for_dump = str(session_postgre.query(DumpArea).filter_by(Id=item[0][1]).first().EntranceNodeId)
            dump_name = str(session_postgre.query(DumpArea).filter_by(Id=item[0][1]).first().Name)
            self.unload_G.add_node(Exitnode_for_digging, name='digging', type=digging_name)
            self.unload_G.add_node(Entrancenode_for_dump, name='group_dumps', type=dump_name)

            logger.info(f'unload route start digging name: %s' % digging_name)
            logger.info(f'unload route end dump name: %s' % dump_name)

            # path_node_for_trip.append(Exitnode_for_digging) # add start node for trip

            # find the exitnodeid for diggingarea, find the entrancenodeid for dumparea
            if Exitnode_for_digging not in self.unload_G_all_nodes:
                self.unload_G_all_nodes.append(Exitnode_for_digging)
                self.unload_G_digging_nodes.append(Exitnode_for_digging)
            if Entrancenode_for_dump not in self.unload_G_all_nodes:
                self.unload_G_all_nodes.append(Entrancenode_for_dump)
                self.unload_G_dump_nodes.append(Entrancenode_for_dump)

            try:
                unload_saved_lane = []
                for lane_id in item[1]:  # for each lane in a to_unload route
                    lane_startpoint = str(session_postgre.query(Lane).filter_by(Id=lane_id).first().StartNodeId)
                    lane_endpoint = str(session_postgre.query(Lane).filter_by(Id=lane_id).first().EndNodeId)

                    logger.info(f'unload route lane id: %s' % lane_id)

                    # add [start_point, end_point, length, lane_id] to unload_saved_lane
                    unload_saved_lane.append([str(lane_startpoint), str(lane_endpoint),
                                              float(session_postgre.query(Lane).filter_by(Id=lane_id).first().Length),
                                              lane_id])
                    # son_lane_num = sum(1 for truck_id in session_postgre.query(Lane).filter_by(StartNodeId = lane_endpoint).all())
                    map_version = session_postgre.query(Distribute_Library).filter_by(Status="1").first().Version
                    son_lane_num = len(session_postgre.query(Lane).filter_by(StartNodeId=lane_endpoint, MapVersion=map_version).all())

                    # 可以添加的节点：分叉口或终点
                    if son_lane_num > 1 or lane_endpoint in self.unload_G_dump_nodes:
                        # print("item",item[0])
                        # print(unload_saved_lane)
                        # print("\n")

                        # lanes dict {lane_id: length, length}
                        lanes = {}
                        for lane_info in unload_saved_lane:
                            # lane_info: [start_point, end_point, length, lane_id]
                            lanes[lane_info[-1]] = list([lane_info[2], lane_info[2]])

                        self.unload_G.add_edge(unload_saved_lane[0][0], unload_saved_lane[-1][1],
                                               real_distance=sum(value[0] for value in lanes.values()),
                                               locked_distance=sum(value[1] for value in lanes.values()), lane=lanes)
                        self.unload_G.add_node(unload_saved_lane[0][0])
                        self.unload_G.add_node(unload_saved_lane[-1][1])

                        logger.info(f'unload route cross node: %s' % unload_saved_lane[0][0])
                        logger.info(f'unload route cross node: %s' % unload_saved_lane[-1][1])

                        # self.unload_G_land_edges_map[lanes] = [unload_saved_lane[0][0], unload_saved_lane[-1][1]]
                        # if [unload_saved_lane[0][0], unload_saved_lane[-1][1]] not in self.unload_G_edges:
                        #     self.unload_G_edges.append([unload_saved_lane[0][0], unload_saved_lane[-1][1]])
                        unload_saved_lane = []
                        if lane_startpoint not in self.unload_G_all_nodes:
                            self.unload_G_all_nodes.append(lane_startpoint)
                        if lane_endpoint not in self.unload_G_all_nodes:
                            self.unload_G_all_nodes.append(lane_endpoint)
            except Exception as es:
                self.logger.error(es)
                self.logger.error("去卸载区拓扑图出错")
                session_mysql.rollback()
                session_postgre.rollback()
                # path_node_for_trip.append(Exitnode_for_digging)  # add node for trip
                # if lane_endpoint in self.unload_G_dump_nodes:

            # load_G
            Entrancenode_for_digging = str(
                session_postgre.query(DiggingWorkArea).filter_by(Id=item[0][0]).first().EntranceNodeId)
            digging_name = str(
                session_postgre.query(DiggingWorkArea).filter_by(Id=item[0][0]).first().Name)
            Exitnode_for_dump = str(session_postgre.query(DumpArea).filter_by(Id=item[0][1]).first().ExitNodeId)
            dump_name = str(session_postgre.query(DumpArea).filter_by(Id=item[0][1]).first().Name)

            logger.info(f'load route start dump name: %s' % dump_name)
            logger.info(f'load route end digging name: %s' % digging_name)

            # find the exit node_id for digging_area, entrance node_id for dumparea,
            if Exitnode_for_dump not in self.load_G_all_nodes:
                self.load_G_all_nodes.append(Exitnode_for_dump)
                self.load_G_dump_nodes.append(Exitnode_for_dump)
            if Entrancenode_for_digging not in self.load_G_all_nodes:
                self.load_G_all_nodes.append(Entrancenode_for_digging)
                self.load_G_digging_nodes.append(Entrancenode_for_digging)

            self.load_G.add_node(Exitnode_for_dump, name='group_dumps', type=dump_name)
            self.load_G.add_node(Entrancenode_for_digging, name='digging', type=digging_name)

            try:
                load_saved_lane = []

                for lane_id in item[
                    2]:  # [[str(item.load_area_id), str(item.unload_area_id)], item.to_unload_lanes, item.to_load_lanes]
                    lane_startpoint = str(session_postgre.query(Lane).filter_by(Id=lane_id).first().StartNodeId)
                    lane_endpoint = str(session_postgre.query(Lane).filter_by(Id=lane_id).first().EndNodeId)

                    load_saved_lane.append([str(lane_startpoint), str(lane_endpoint),
                                            float(session_postgre.query(Lane).filter_by(Id=lane_id).first().Length),
                                            lane_id])
                    # son_lane_num = sum(1 for truck_id in session_postgre.query(Lane).filter_by(StartNodeId = lane_endpoint).all())
                    map_version = session_postgre.query(Distribute_Library).filter_by(Status="1").first().Version
                    son_lane_num = len(session_postgre.query(Lane).filter_by(StartNodeId=lane_endpoint, MapVersion=map_version).all())

                    logger.info(f'load route lane id: %s' % lane_id)

                    nodes = list(self.load_G.nodes)
                    # 可以添加的节点：分叉口或终点
                    if son_lane_num > 1 or lane_endpoint in self.load_G_digging_nodes or lane_endpoint in nodes:
                        # print("item",item[0])
                        # print(load_saved_lane)
                        # print("\n")
                        lanes = {}
                        for lane_info in load_saved_lane:
                            lanes[lane_info[-1]] = list([lane_info[2], lane_info[2]])
                            # lanes.append(truck_id[0])
                            # lanes.append(truck_id[1])
                        # self.load_G.add_edge(load_saved_lane[0][0], load_saved_lane[-1][1],  real_distance = sum(n[2] for n in load_saved_lane), lane = lanes)
                        self.load_G.add_edge(load_saved_lane[0][0], load_saved_lane[-1][1],
                                             real_distance=sum(value[0] for value in lanes.values()),
                                             locked_distance=sum(value[-1] for value in lanes.values()), lane=lanes)
                        self.load_G.add_node(load_saved_lane[0][0])
                        self.load_G.add_node(load_saved_lane[-1][1])

                        logger.info(f'load route cross node: %s' % load_saved_lane[0][0])
                        logger.info(f'load route cross node: %s' % load_saved_lane[-1][1])

                        # self.load_G_land_edges_map[lanes] = [load_saved_lane[0][0], load_saved_lane[-1][1]]
                        # if [load_saved_lane[0][0], load_saved_lane[-1][1]] not in self.load_G_edges:
                        #     self.load_G_edges.append([load_saved_lane[0][0], load_saved_lane[-1][1]])
                        load_saved_lane = []
                        if lane_startpoint not in self.load_G_all_nodes:
                            self.load_G_all_nodes.append(lane_startpoint)
                        if lane_endpoint not in self.load_G_all_nodes:
                            self.load_G_all_nodes.append(lane_endpoint)
            except Exception as es:
                self.logger.error(es)
                self.logger.error("去装载区拓扑图生成失败")
                session_mysql.rollback()
                session_postgre.rollback()

        print(self.load_G.nodes.data())

        # park_to_load_G

        for k in park_trip:
            try:
                load_saved_lane = []
                for i in k[1]:
                    load_i_startpoint = str(session_postgre.query(Lane).filter_by(Id=i).first().StartNodeId)
                    load_i_endpoint = str(session_postgre.query(Lane).filter_by(Id=i).first().EndNodeId)

                    load_saved_lane.append([str(load_i_startpoint), str(load_i_endpoint),
                                            float(session_postgre.query(Lane).filter_by(Id=i).first().Length), i])
                    # son_lane_num = sum(1 for truck_id in session_postgre.query(Lane).filter_by(StartNodeId = load_i_endpoint).all())
                    map_version = session_postgre.query(Distribute_Library).filter_by(Status="1").first().Version
                    son_lane_num = len(session_postgre.query(Lane).filter_by(StartNodeId=load_i_endpoint, MapVersion=map_version).all())

                    logger.info(f'park route lane id: %s' % i)

                    nodes = list(self.load_G.nodes)
                    # 可以添加的节点：分叉口或终点
                    if son_lane_num > 1 or load_i_endpoint in self.load_G_digging_nodes or load_i_endpoint in nodes:
                        # print("item",item[0])
                        # print(load_saved_lane)
                        # print("\n")
                        lanes = {}
                        for lane_info in load_saved_lane:
                            lanes[lane_info[-1]] = list([lane_info[2], lane_info[2]])
                            # lanes.append(truck_id[0])
                            # lanes.append(truck_id[1])
                        # self.load_G.add_edge(load_saved_lane[0][0], load_saved_lane[-1][1],  real_distance = sum(n[2] for n in load_saved_lane), lane = lanes)
                        self.load_G.add_edge(load_saved_lane[0][0], load_saved_lane[-1][1],
                                             real_distance=sum(value[0] for value in lanes.values()),
                                             locked_distance=sum(value[-1] for value in lanes.values()), lane=lanes)
                        self.load_G.add_node(load_saved_lane[0][0])
                        self.load_G.add_node(load_saved_lane[-1][1])

                        logger.info(f'park route cross node: %s' % load_saved_lane[0][0])
                        logger.info(f'park route cross node: %s' % load_saved_lane[-1][1])

                        # self.load_G_land_edges_map[lanes] = [load_saved_lane[0][0], load_saved_lane[-1][1]]
                        # if [load_saved_lane[0][0], load_saved_lane[-1][1]] not in self.load_G_edges:
                        #     self.load_G_edges.append([load_saved_lane[0][0], load_saved_lane[-1][1]])
                        load_saved_lane = []
                        if load_i_startpoint not in self.load_G_all_nodes:
                            self.load_G_all_nodes.append(load_i_startpoint)
                        if load_i_endpoint not in self.load_G_all_nodes:
                            self.load_G_all_nodes.append(load_i_endpoint)
            except Exception as es:
                self.logger.error(es)
                self.logger.error("备停区部分装载拓扑图生成失败")

        # 更新交叉点
        self.update_cross_nodes()

        # 更新交叉点前路段
        self.update_cross_bf_lanes()

    """
    unload source node for reschedule
    """

    def get_unload_edge_node(self, truck_location_lane):

        try:
            # truck_location_lane = self.laneinfo.update_truck_loacate()[truck_id]
            for (startnode, endnode, lane) in self.unload_G.edges.data('lane'):
                if truck_location_lane in lane:
                    return startnode, endnode
        except Exception as es:
            self.logger.error(es)
            self.logger.error("卸载图，矿卡所在路段返回失败")

    """
    unload_G target
    """

    def get_unload_target_node_real(self, truck_location_lane, pre_target, allow):

        source_node = self.get_unload_edge_node(truck_location_lane)
        target_list = []

        for (u, wt) in self.unload_G.nodes.data('name'):
            # select next reachable target
            if wt == 'group_dumps':
                if allow:
                    target_list.append(u)
                else:
                    if u != pre_target:
                        target_list.append(u)

        if not len(target_list):
            self.logger.error("当前无可去卸载区！")

        unload_area_lane_dict = {}

        path_length_map = {}
        # build the path_length_map from source node

        for i in target_list:
            try:
                distance, path = nx.single_source_dijkstra(self.unload_G, source=source_node, target=i,
                                                           weight="real_distance")

                unload_area_id = session_postgre.query(DumpArea).filter_by(Id=unload_area_id).first(
                    EntranceNodeId=i).first().Id
                unload_area_lane_dict[str(unload_area_id)] = []
                for j in range(len(path) - 1):
                    unload_area_lane_dict[str(unload_area_id)] += self.load_G[path[j]][path[j + 1]]['lane']

                path_length_map[distance] = path
                # print(path)
            except Exception as es:
                self.logger.info(es)
                self.logger.info(f"卸载图中{source_node} 与 {i} 之间道路不通")

        # return the target area's entrance point and target area
        min_dis_path = path_length_map[sorted(list(path_length_map.keys()))[0]]
        entrance_point = min_dis_path[-1]
        map_version = session_postgre.query(Distribute_Library).filter_by(Status="1").first().Version
        target_dump_area = str(session_postgre.query(DumpArea).filter_by(EntranceNodeId=entrance_point, MapVersion=map_version).first().Id)
        # target_dump_area_name = str(session_postgre.query(DumpArea).filter_by(EntranceNodeId=entrance_point).first().Name)

        return min_dis_path, entrance_point, target_dump_area

    """
    load source node for reschedule
    """

    def get_load_edge_node(self, truck_location_lane):
        try:
            # truck_location_lane = self.laneinfo.update_truck_loacate()[truck_id]
            # print(self.load_G.edges.data())
            # for item in self.load_G.edges.data('lane'):
            #     print(len(item))
            #     print(item[0], item[1], item[2])
            for startnode, endnode, lane in self.load_G.edges.data('lane'):
                if truck_location_lane in lane:
                    return startnode, endnode
        except Exception as es:
            self.logger.error(es)
            self.logger.error("装载图，矿卡所在路段返回失败")

    """
    update load_G locked_distance
    """

    def update_load_G_locked_distance(self, path, alpha, beta):
        for i in range(len(path) - 1):
            data = dict(self.load_G[path[i]][path[i + 1]]['lane'])
            for u, v in data.items():
                data[u] = [v[0], v[0]]
            self.load_G[path[i]][path[i + 1]]['locked_distance'] = sum(i[-1] for i in list(data.values()))

    """
    return relative distance between node and lane (graph_type: 0=load, 1=unload)
    """

    def relative_distance(self, truck_location_lane, graph_type):
        distance_start_node = 0
        distance_end_node = 0
        if graph_type == 1:
            start_node, end_node = self.get_unload_edge_node(truck_location_lane)
            edge_length = self.unload_G[start_node][end_node]['real_distance']
            lane_dict = self.unload_G[start_node][end_node]['lane']
        else:
            start_node, end_node = self.get_load_edge_node(truck_location_lane)
            edge_length = self.load_G[start_node][end_node]['real_distance']
            lane_dict = self.load_G[start_node][end_node]['lane']

        for i, j in lane_dict.items():
            if i != truck_location_lane:
                distance_start_node += list(j)[0]
            else:
                distance_start_node += (list(j)[0]) / 2
                break
        distance_end_node = edge_length - distance_start_node

        return distance_start_node, distance_end_node

    """
    load_G target
    """

    def get_load_target_node_real(self, truck_location_lane, pre_target, allow):

        # source_node = self.get_load_edge_node(truck_location_lane)
        source_node, end_node = self.get_load_edge_node(truck_location_lane)
        distance_source_node, distance_end_node = self.relative_distance(truck_location_lane, graph_type=0)

        map_version = session_postgre.query(Distribute_Library).filter_by(Status="1").first().Version

        # update target_list
        pre_destination_node = str(
            session_postgre.query(DiggingWorkArea).filter_by(Id=pre_target).first().EntranceNodeId)
        target_list = []
        for (node, att) in self.load_G.nodes.data('name'):
            # select next reachable target
            if att == 'digging':
                if allow:
                    target_list.append(node)
                else:
                    if node != pre_destination_node:
                        target_list.append(node)

        if not len(target_list):
            self.logger.error("当前无可去装载区！")

        load_area_lane_dict = {}

        # print(target_list)
        path_length_map = {}
        # build the path_length_map from source node
        for i in target_list:
            try:
                distance, path = nx.single_source_dijkstra(self.load_G, source=source_node, target=i,
                                                           weight="real_distance")

                load_area_id = session_postgre.query(DiggingWorkArea).filter_by(EntranceNodeId=i, MapVersion=map_version).first().Id
                load_area_lane_dict[str(load_area_id)] = []
                for j in range(len(path) - 1):
                    load_area_lane_dict[str(load_area_id)] += self.load_G[path[j]][path[j + 1]]['lane']

                path_length_map[distance] = path
                # print(path)
            except Exception as es:
                self.logger.info(es)
                self.logger.info(f"装载图中{source_node} 与 {i} 之间道路不通")

        reachable_destinations = {}
        # 如果列表最后是标识符，0表示需要掉头的目的地，1表示不需要掉头的目的地
        for k, v in path_length_map.items():
            entrance_point = v[-1]
            target_digging_area = str(
                session_postgre.query(DiggingWorkArea).filter_by(EntranceNodeId=entrance_point, MapVersion=map_version).first().Id)
            if end_node not in v:
                reachable_destinations[target_digging_area] = [k + distance_source_node, 0]
            else:
                reachable_destinations[target_digging_area] = [k + distance_end_node, 1]

        return reachable_destinations, load_area_lane_dict

    def update_cross_nodes(self):
        """
        get cross nodes.
        :return:
        """
        self.cross_nodes = []
        for node_id, Type in self.load_G.nodes.data('name'):
            if Type == 'None':
                self.cross_nodes.append(node_id)

    def update_cross_bf_lanes(self):
        """
        get cross lanes.
        :return:
        """
        map_version = session_postgre.query(Distribute_Library).filter_by(Status="1").first().Version
        self.cross_bf_lanes = []
        for node_id, Type in self.load_G.nodes.data('name'):
            if Type is None or Type == 'None':
                try:
                    lane = session_postgre.query(Lane).filter_by(EndNodeId=node_id, MapVersion=map_version).first()
                    if lane is not None:
                        self.cross_bf_lanes.append(str(lane.Id))
                except Exception as es:
                    logger.error(es)

                try:
                    next_lane = session_postgre.query(Lane).filter_by(EndNodeId=lane.StartNodeId, MapVersion=map_version).first()
                    if next_lane is not None:
                        self.cross_bf_lanes.append(str(next_lane.Id))
                except Exception as es:
                    logger.error(es)

    def get_cross_nodes(self):
        return self.cross_nodes
