
# from para_config import *
from settings import *
import numpy as np
import sched
import time
import matplotlib.pyplot as plt
# from path_plan.path_plannner import *
import networkx as nx

class Topo():
    """ class for the topo.
    Description:
        根据设备状态计算调度价值
    Attribute:
        equipment class: None
    """
    def __init__(self):

        # to unload graph
        self.unload_G = nx.Graph()

        self.unload_G_nodes = np.array(self.unload_G)
        self.unload_G_num_of_nodes = self.unload_G.number_of_nodes()
        self.unload_G_edges = np.array(self.unload_G)
        self.unload_G_num_of_edges = self.unload_G.number_of_edges()

        self.unload_G_all_nodes = []
        self.unload_G_digging_nodes = []
        self.unload_G_dump_nodes = []
        # self.unload_path_dict = {}

        # to load graph
        self.load_G = nx.Graph()

        self.load_G_nodes = np.array(self.load_G)
        self.load_G_num_of_nodes = self.load_G.number_of_nodes()
        self.load_G_edges = np.array(self.load_G)
        self.load_G_num_of_edges = self.load_G.number_of_edges()

        self.load_G_all_nodes = []
        self.load_G_digging_nodes = []
        self.load_G_dump_nodes = []
        # self.load_path_dict = {}

        self.work_area_distance_info = []
        # self.laneinfo = LaneInfo()


        # 获取日志器
        self.logger = get_logger("zxt.topo")

    """
    unload_G funcions
    """
    def get_unload_G(self):
        return self.unload_G

    def get_unload_G_nodes(self):
        return self.unload_G_nodes

    def get_unload_G_num_of_nodes(self):
        return self.unload_G_num_of_nodes

    def get_unload_G_edges(self):
        return self.unload_G_edges

    def get_unload_G_num_of_edges(self):
        return self.unload_G_num_of_edges

    """
    load_G functions
    """
    def get_load_G(self):
        return self.load_G

    def get_load_G_nodes(self):
        return self.load_G_nodes

    def get_load_G_num_of_nodes(self):
        return self.load_G_num_of_nodes

    def get_load_G_edges(self):
        return self.load_G_edges

    def get_load_G_num_of_edges(self):
        return self.load_G_num_of_edges


    def get_work_area_distance_info(self):
        try:
            for item in session_postgre.query(WalkTime).all():
                self.work_area_distance_info.append([[str(item.load_area_id), str(item.unload_area_id)], item.to_unload_lanes, item.to_load_lanes])
        except Exception as es:
            self.logger.error(es)
            self.logger.error("获取地图信息出错")

    """
    generate_topo_graph: unload_G and load_G
    """
    def generate_topo_graph(self):
        self.get_work_area_distance_info()
        trip = self.work_area_distance_info

        for item in trip:
            # path_node_for_trip= []

            # unload_G
            Exitnode_for_digging = str(session_postgre.query(DiggingWorkArea).filter_by(Id =item[0][0]).first().ExitNodeId)
            Entrancenode_for_dump = str(session_postgre.query(DumpArea).filter_by(Id=item[0][1]).first().EntranceNodeId)

            # path_node_for_trip.append(Exitnode_for_digging) # add start node for trip

            # find the exitnodeid for diggingarea, find the entrancenodeid for dumparea
            if Exitnode_for_digging not in self.unload_G_all_nodes:
                self.unload_G_all_nodes.append(Exitnode_for_digging)
                self.unload_G_digging_nodes.append(Exitnode_for_digging)
            if Entrancenode_for_dump not in self.unload_G_all_nodes:
                self.unload_G_all_nodes.append(Entrancenode_for_dump)
                self.unload_G_dump_nodes.append(Entrancenode_for_dump)

            self.unload_G.add_node(Exitnode_for_digging, name='digging')
            self.unload_G.add_node(Entrancenode_for_dump, name='dump')
            try:
                unload_saved_lane = []
                for i in item[1]:
                    i_startpoint = str(session_postgre.query(Lane).filter_by(Id=i).first().StartNodeId)
                    i_endpoint = str(session_postgre.query(Lane).filter_by(Id=i).first().EndNodeId)

                    unload_saved_lane.append([str(i_startpoint), str(i_endpoint), float(session_postgre.query(Lane).filter_by(Id=i).first().Length), i])
                    # son_lane_num = sum(1 for i in session_postgre.query(Lane).filter_by(StartNodeId = i_endpoint).all())
                    son_lane_num = len(session_postgre.query(Lane).filter_by(StartNodeId=i_endpoint).all())

                    # 可以添加的节点：分叉口或终点
                    if son_lane_num > 1 or i_endpoint in self.unload_G_dump_nodes:
                        # print("item",item[0])
                        # print(unload_saved_lane)
                        # print("\n")
                        lanes = {}
                        for lane_info in unload_saved_lane:
                            lanes[lane_info[-1]] = list([lane_info[2],lane_info[2]])

                        self.unload_G.add_edge(unload_saved_lane[0][0], unload_saved_lane[-1][1],  real_distance = sum(value[0] for value in lanes.values()),
                                               locked_distance = sum(value[1] for value in lanes.values()),lane = lanes)
                        self.unload_G.add_node(unload_saved_lane[0][0])
                        self.unload_G.add_node(unload_saved_lane[-1][1])
                        # self.unload_G_land_edges_map[lanes] = [unload_saved_lane[0][0], unload_saved_lane[-1][1]]
                        # if [unload_saved_lane[0][0], unload_saved_lane[-1][1]] not in self.unload_G_edges:
                        #     self.unload_G_edges.append([unload_saved_lane[0][0], unload_saved_lane[-1][1]])
                        unload_saved_lane = []
                        if i_startpoint not in self.unload_G_all_nodes:
                            self.unload_G_all_nodes.append(i_startpoint)
                        if i_endpoint not in self.unload_G_all_nodes:
                            self.unload_G_all_nodes.append(i_endpoint)
            except Exception as es:
                self.logger.error(es)
                self.logger.error("去卸载区拓扑图出错")
                    # path_node_for_trip.append(Exitnode_for_digging)  # add node for trip
                    # if i_endpoint in self.unload_G_dump_nodes:

            # load_G
            Entrancenode_for_digging = str(session_postgre.query(DiggingWorkArea).filter_by(Id =item[0][0]).first().EntranceNodeId)
            Exitnode_for_dump = str(session_postgre.query(DumpArea).filter_by(Id=item[0][1]).first().ExitNodeId)

            # find the exitnodeid for dumparea, entrancenodeid for diggingarea,
            if Exitnode_for_dump not in self.load_G_all_nodes:
                self.load_G_all_nodes.append(Exitnode_for_dump)
                self.load_G_dump_nodes.append(Exitnode_for_dump)
            if Entrancenode_for_digging not in self.load_G_all_nodes:
                self.load_G_all_nodes.append(Entrancenode_for_digging)
                self.load_G_digging_nodes.append(Entrancenode_for_digging)

            self.load_G.add_node(Exitnode_for_dump, name='dump')
            self.load_G.add_node(Entrancenode_for_digging, name='digging')

            try:
                load_saved_lane = []

                # 处理特殊点：备停区出场点
                lane_p = ["7b9c8e89-7134-63ac-9c44-f183674b090c", "7e120229-b426-461b-91a4-05e01fb37ed8"]
                lane_park_dict = {}
                lane_parks= []
                for i in lane_p:
                    i_startpoint = str(session_postgre.query(Lane).filter_by(Id=i).first().StartNodeId)
                    i_endpoint = str(session_postgre.query(Lane).filter_by(Id=i).first().EndNodeId)
                    lane_parks.append([str(i_startpoint), str(i_endpoint),
                                            float(session_postgre.query(Lane).filter_by(Id=i).first().Length), i])
                for l in lane_parks:
                    lane_park_dict[l[-1]] = list([l[2], l[2]])

                self.load_G.add_edge(lane_parks[0][0], lane_parks[-1][1],
                                     real_distance=sum(value[0] for value in lane_park_dict.values()),
                                     locked_distance=sum(value[-1] for value in lane_park_dict.values()), lane=lane_park_dict)
                self.load_G.add_node(lane_parks[0][0])
                self.load_G.add_node(lane_parks[-1][1])
                # lane_park = session_postgre.query(Lane).filter_by(Id="7b9c8e89-7134-63ac-9c44-f183674b090c").first()
                # lane_park_startnode = str(lane_park.StartNodeId)
                # lane_park_endnode = str(lane_park.EndNodeId)
                # lane_park_length = lane_park.Length

                # self.load_G.add_edge("d346992e-d134-63ac-9b08-76812d6e3b1d", "1c85bc2c-9134-6281-1265-fd19d0dcbd52",
                #                      real_distance=lane_park_length,
                #                      locked_distance=lane_park_length, lane={"7b9c8e89-7134-63ac-9c44-f183674b090c": [lane_park_length,lane_park_length]})
                # self.load_G.add_node("d346992e-d134-63ac-9b08-76812d6e3b1d")
                # self.load_G.add_node("1c85bc2c-9134-6281-1265-fd19d0dcbd52")

                for i in item[2]:
                    i_startpoint = str(session_postgre.query(Lane).filter_by(Id=i).first().StartNodeId)
                    i_endpoint = str(session_postgre.query(Lane).filter_by(Id=i).first().EndNodeId)

                    load_saved_lane.append([str(i_startpoint), str(i_endpoint), float(session_postgre.query(Lane).filter_by(Id=i).first().Length), i])
                    # son_lane_num = sum(1 for i in session_postgre.query(Lane).filter_by(StartNodeId = i_endpoint).all())
                    son_lane_num = len(session_postgre.query(Lane).filter_by(StartNodeId=i_endpoint).all())

                    # 可以添加的节点：分叉口或终点
                    if son_lane_num > 1 or i_endpoint in self.load_G_digging_nodes:
                        # print("item",item[0])
                        # print(load_saved_lane)
                        # print("\n")
                        lanes = {}
                        for lane_info in load_saved_lane:
                            lanes[lane_info[-1]] = list([lane_info[2],lane_info[2]])
                            # lanes.append(i[0])
                            # lanes.append(i[1])
                        # self.load_G.add_edge(load_saved_lane[0][0], load_saved_lane[-1][1],  real_distance = sum(n[2] for n in load_saved_lane), lane = lanes)
                        self.load_G.add_edge(load_saved_lane[0][0], load_saved_lane[-1][1],
                                             real_distance=sum(value[0] for value in lanes.values()), locked_distance = sum(value[-1] for value in lanes.values()), lane=lanes)
                        self.load_G.add_node(load_saved_lane[0][0])
                        self.load_G.add_node(load_saved_lane[-1][1])
                        # self.load_G_land_edges_map[lanes] = [load_saved_lane[0][0], load_saved_lane[-1][1]]
                        # if [load_saved_lane[0][0], load_saved_lane[-1][1]] not in self.load_G_edges:
                        #     self.load_G_edges.append([load_saved_lane[0][0], load_saved_lane[-1][1]])
                        load_saved_lane = []
                        if i_startpoint not in self.load_G_all_nodes:
                            self.load_G_all_nodes.append(i_startpoint)
                        if i_endpoint not in self.load_G_all_nodes:
                            self.load_G_all_nodes.append(i_endpoint)
            except Exception as es:
                self.logger.error(es)
                self.logger.error("去装载区拓扑图生成失败")
    # """
    # update blocked distance for graph
    # """
    # def update_blocked_distance(self):
    #

    def return_node_path(self, source_node, target_node):
        try:
            distance, path = nx.single_source_dijkstra(self.unload_G, source=source_node, target=target_node, weight="real_distance")
        except Exception as es:
            self.logger.error(es)
            self.logger.error("dijkstra最短路径生成失败")
        return list(path)

    """
    unload source node for reschedule
    """
    def get_unload_source_node(self, truck_location_lane):

        try:
            # truck_location_lane = self.laneinfo.update_truck_loacate()[truck_id]
            for (startnode,endnode,lane) in self.unload_G.edges.data('lane'):
                if truck_location_lane in lane:
                    return startnode
        except Exception as es:
            self.logger.error(es)
            self.logger.error("卸载图，矿卡所在路段返回失败")

    """
    unload_G target
    """
    def get_unload_target_node_real(self, truck_location_lane,pre_target):

        source_node = self.get_unload_source_node(truck_location_lane)
        target_list = []

        for (u, wt) in self.unload_G.nodes.data('name'):
            # select next reachable target
            if wt == 'dump' and u != pre_target :
                target_list.append(u)

        if not len(target_list):
            self.logger.error("当前无可去卸载区！")

        path_length_map = {}
        # build the path_length_map from source node

        for i in target_list:
            try:
                distance, path = nx.single_source_dijkstra(self.unload_G, source=source_node, target=i, weight="real_distance")
                path_length_map[distance] = path
                # print(path)
            except Exception as es:
                self.logger.info(es)
                self.logger.info(f"卸载图中{source_node} 与 {i} 之间道路不通")

        # return the target area's entrance point and target area
        min_dis_path = path_length_map[sorted(list(path_length_map.keys()))[0]]
        entrance_point = min_dis_path[-1]
        target_dump_area = str(session_postgre.query(DumpArea).filter_by(EntranceNodeId=entrance_point).first().Id)
        # target_dump_area_name = str(session_postgre.query(DumpArea).filter_by(EntranceNodeId=entrance_point).first().Name)

        return min_dis_path, entrance_point, target_dump_area

    """
    load source node for reschedule
    """
    def get_load_source_node (self, truck_location_lane):
        try:
            # truck_location_lane = self.laneinfo.update_truck_loacate()[truck_id]
            for (startnode,endnode,lane) in self.load_G.edges.data('lane'):
                if truck_location_lane in lane:
                    return startnode
        except Exception as es:
            self.logger.error(es)
            self.logger.error("装载图，矿卡所在路段返回失败")

    """
    load_G target
    """
    def get_load_target_node_real(self, truck_location_lane,pre_target ):

        source_node = self.get_load_source_node(truck_location_lane)
        target_list = []

        for (node, wt) in self.load_G.nodes.data('name'):
            # select next reachable target
            if wt == 'digging' and node != pre_target :
                target_list.append(node)

        if not len(target_list):
            self.logger.error("当前无可去装载区！")

        # print(target_list)
        path_length_map = {}
        # build the path_length_map from source node
        for i in target_list:
            try:
                distance, path = nx.single_source_dijkstra(self.load_G, source=source_node, target= i, weight="real_distance")
                path_length_map[distance] = path
                # print(path)
            except Exception as es:
                self.logger.info(es)
                self.logger.info(f"装载图中{source_node} 与 {i} 之间道路不通")

        # return the target area's entrance point and target area
        min_dis_path = path_length_map[sorted(list(path_length_map.keys()))[0]]
        entrance_point = min_dis_path[-1]
        target_digging_area = str(session_postgre.query(DiggingWorkArea).filter_by(EntranceNodeId=entrance_point).first().Id)
        # target_digging_area_name = str(session_postgre.query(DiggingWorkArea).filter_by(EntranceNodeId=entrance_point).first().Name)

        return min_dis_path, entrance_point, target_digging_area

    """
    update load_G locked_distance
    """
    def update_load_G_locked_distance(self, path, alpha, beta):
        for i in range(len(path) - 1):
            data = dict(self.load_G[path[i]][path[i + 1]]['lane'])
            for u, v in data.items():
                data[u] = [v[0], v[0]]
            self.load_G[path[i]][path[i + 1]]['locked_distance'] = sum(i[-1] for i in list(data.values()))

    # """
    # update unload_G locked_distance
    # """
    # def update_unload_G_locked_distance(self, path):
    #     for i in range(len(path) - 1):
    #         data = dict(self.unload_G[path[i]][path[i + 1]]['lane'])
    #         for u, v in data.items():
    #             data[u] = [v[0], v[0]]
    #         self.unload_G[path[i]][path[i + 1]]['locked_distance'] = sum(i[-1] for i in list(data.values()))
