import numpy as np

from traffic_flow.traffic_flow_planner import *
from para_config import *
from equipment.truck import TruckInfo
from equipment.excavator import ExcavatorInfo
from equipment.dump import DumpInfo
import sched
import time
from dispatcher import Dispatcher, PreSchedule
# load_area_list = []
# unload_area_list = []
area_pair_list = []

import networkx as nx
import matplotlib.pyplot as plt


def get_area():
    for item in session_mysql.query(Dispatch).filter_by( isauto=1, isdeleted=0).all():
        area_pair_list.append([item.load_area_id, item.unload_area_id])
        # if item.load_area_id not in load_area_list:
        #     load_area_list.append([item.load_area_id, item.unload_area_id])
        # if item.unload_area_id not in unload_area_list:
        #     unload_area_list.append(item.unload_area_id)
lanes = []
def get_load_lane():
    for i in area_pair_list:
        for item in session_postgre.query(WalkTime).filter_by(load_area_id = i[0], unload_area_id = i[1]).all():
            lanes.append(item.to_load_lanes)

lane_endnodeid = []
def get_lane_endpoint(laneid):
    lane_info = session_postgre.query(Lane).filter_by(Id=laneid).first()
    lane_endnodeid.append(str(lane_info.EndNodeId))

# trip = []


# unload_G_nodes = []
# unload_G_edges = []
# def create_unload_graph():
#
#     unload_G = nx.Graph()
#
#     for item in trip:
#         if str(session_postgre.query(DiggingWorkArea).filter_by(Id =item[0][0]).first().ExitNodeId) not in unload_G_nodes:
#             unload_G_nodes.append(str(session_postgre.query(DiggingWorkArea).filter_by(Id =item[0][0]).first().ExitNodeId))
#         if str(session_postgre.query(DumpArea).filter_by(Id=item[0][1]).first().EntranceNodeId) not in unload_G_nodes:
#             unload_G_nodes.append(str(session_postgre.query(DumpArea).filter_by(Id=item[0][1]).first().EntranceNodeId))
#
#         saved_lane = {}
#         num_of_endpoint = 0
#
#         for i in item[1]:
#             i_startpoint = session_postgre.query(Lane).filter_by(Id=i).first().StartNodeId
#             i_endpoint = session_postgre.query(Lane).filter_by(Id=i).first().EndNodeId
#
#             if session_postgre.query(Lane).filter_by(Id=i).first().EndNodeId:
#                 num_of_endpoint += 1
#             if session_postgre.query(Lane).filter_by(Id=i).first().StartNodeId:
#                 num_of_endpoint += 1
#             if num_of_endpoint < 3:
#                 saved_lane[i] = [i_startpoint, i_endpoint, session_postgre.query(Lane).filter_by(Id=i).first().Length]
#             else:
#                 num_of_endpoint = 0
#                 unload_G.add_edge(saved_lane[list(saved_lane.keys())[0]][0], saved_lane[list(saved_lane.keys())[0]][1],
#                                   weight = (sum(n[2] for n in list(saved_lane.values()))))
#                 if i_endpoint not in unload_G_nodes:
#                     unload_G_nodes.append(i_endpoint)
#     print(unload_G.edges())
            # unload_G_nodes.append(session_postgre.query(DiggingWorkArea).filter_by(Id = i[0]).first().ExitNodeId)
            # unload_G_nodes.append(session_postgre.query(DumpArea).filter_by(Id = i[1]).first().EntranceNodeId)
def get_truck_location(lane_id):
   for (u, v, wt) in unload_G.edges.data('lane'):
        # print(u,v,wt)
        if lane_id in wt:
            print(u)
            return u

def get_dijstarpath(lane_id, dj_target):
    dj_source = get_truck_location(lane_id)
    minWPath = nx.dijkstra_path(unload_G, source= dj_source, target= dj_target)  # 顶点 0 到 顶点 3 的最短加权路径
    print(f"{dj_source}到{dj_target}的最短加权路径: ", minWPath)
    return minWPath

def get_all_trip(trip):
    for item in session_postgre.query(WalkTime).all():
        trip.append([[str(item.load_area_id), str(item.unload_area_id)], item.to_unload_lanes, item.to_load_lanes])
        # trip[item.load_area_name, item.unload_area_name][1] = item.to_load_lanes
        # trip[item.load_area_name, item.unload_area_name][2] = item.to_unload_lanes


def generate_unloading_graph(G,trip):

    unload_G_nodes = []
    unload_G_digging_nodes = []
    unload_G_dump_nodes=[]
    unload_G_edges = []


    # print(trip)
    # print(len(trip))
    for item in trip:
        # print(item[1])

        Exitnode_for_digging = str(session_postgre.query(DiggingWorkArea).filter_by(Id =item[0][0]).first().ExitNodeId)
        Entrancenode_for_dump = str(session_postgre.query(DumpArea).filter_by(Id=item[0][1]).first().EntranceNodeId)

        # find the exitnodeid for diggingarea, find the entrancenodeid for dumparea
        if Exitnode_for_digging not in unload_G_nodes:
            unload_G_nodes.append(Exitnode_for_digging)
            unload_G_digging_nodes.append(Exitnode_for_digging)
        if Entrancenode_for_dump not in unload_G_nodes:
            unload_G_nodes.append(Entrancenode_for_dump)
            unload_G_dump_nodes.append(Entrancenode_for_dump)

        unload_G.add_node(Exitnode_for_digging, name='digging')
        unload_G.add_node(Entrancenode_for_dump, name='dump')

        # print(Exitnode_for_digging)
        # print(Entrancenode_for_dump)

        saved_lane = []
        for i in item[1]:
            i_startpoint = str(session_postgre.query(Lane).filter_by(Id=i).first().StartNodeId)
            i_endpoint = str(session_postgre.query(Lane).filter_by(Id=i).first().EndNodeId)


            saved_lane.append([str(i_startpoint), str(i_endpoint), float(session_postgre.query(Lane).filter_by(Id=i).first().Length),str(i)])
            # son_lane_num = sum(1 for i in session_postgre.query(Lane).filter_by(StartNodeId = i_endpoint).all())
            son_lane_num = len(session_postgre.query(Lane).filter_by(StartNodeId=i_endpoint).all())

            # 可以添加的节点：分叉口或终点
            if son_lane_num > 1  or i_endpoint in unload_G_dump_nodes:
                # print("item",item[0])
                # print(saved_lane)
                # print("\n")
                lanes = {}
                for lane_info in saved_lane:
                    lanes[lane_info[3]] = lane_info[2]

                unload_G.add_edge(saved_lane[0][0], saved_lane[-1][1],  real_weight = sum(lanes.values()), blocked_weight = 0, lane = lanes)
                unload_G.add_node(saved_lane[0][0])
                unload_G.add_node(saved_lane[-1][1])
                # unload_G_land_edges_map[lanes] = [saved_lane[0][0], saved_lane[-1][1]]
                # if [saved_lane[0][0], saved_lane[-1][1]] not in unload_G_edges:
                #     unload_G_edges.append([saved_lane[0][0], saved_lane[-1][1]])
                saved_lane = []
                if i_startpoint not in unload_G_nodes:
                    unload_G_nodes.append(i_startpoint)
                if i_endpoint not in unload_G_nodes:
                    unload_G_nodes.append(i_endpoint)



def get_source_node (G,truck_location_lane):
    unload_G = G
    for (u,v,wt) in unload_G.edges.data('lane'):
        if lane_id in wt:
            print(v)


def get_target_node(G,source_node, pre_target):
    unload_G = G
    target_list = []
    for (u, wt) in unload_G.nodes.data('name'):
        # select next reachable target
        if wt == 'dump' and u != pre_target :
            target_list.append(u)

    lane_id = source_node
    path_length_map = {}
    # build the path_length_map from the source node
    for i in target_list:
        try:
            distance, path = nx.single_source_dijkstra(unload_G, source=lane_id, target=  ai, weight="weight")
            path_length_map[distance] = path
            # print(path)
        except Exception as es:
            print(es)
    # return the target area's entrance point and target area
    entrance_point = path_length_map[sorted(list(path_length_map.keys()))[0]][-1]
    target_dump_area = str(session_postgre.query(DumpArea).filter_by(EntranceNodeId=entrance_point).first().Id)
    target_dump_area_name = str(session_postgre.query(DumpArea).filter_by(EntranceNodeId=entrance_point).first().Name)

    print(target_dump_area,target_dump_area_name)


def generate_loading_graph(G,trip):

    load_G_all_nodes = []
    load_G_digging_nodes = []
    load_G_dump_nodes=[]
    load_G_edges = []

    # print(trip)
    # print(len(trip))
    for item in trip:

        Entrancenode_for_digging = str(session_postgre.query(DiggingWorkArea).filter_by(Id=item[0][0]).first().EntranceNodeId)
        Exitnode_for_dump = str(session_postgre.query(DumpArea).filter_by(Id=item[0][1]).first().ExitNodeId)

        # find the exitnodeid for dumparea, entrancenodeid for diggingarea,
        if Exitnode_for_dump not in load_G_all_nodes:
            load_G_all_nodes.append(Exitnode_for_dump)
            load_G_dump_nodes.append(Exitnode_for_dump)
        if Entrancenode_for_digging not in load_G_all_nodes:
            load_G_all_nodes.append(Entrancenode_for_digging)
            load_G_digging_nodes.append(Entrancenode_for_digging)

        saved_lane = []
        for i in item[2]:
            i_startpoint = str(session_postgre.query(Lane).filter_by(Id=i).first().StartNodeId)
            i_endpoint = str(session_postgre.query(Lane).filter_by(Id=i).first().EndNodeId)

            saved_lane.append([str(i_startpoint), str(i_endpoint), float(session_postgre.query(Lane).filter_by(Id=i).first().Length)])
            # son_lane_num = sum(1 for i in session_postgre.query(Lane).filter_by(StartNodeId = i_endpoint).all())
            son_lane_num = len(session_postgre.query(Lane).filter_by(StartNodeId=i_endpoint).all())

            # 可以添加的节点：分叉口或终点
            if son_lane_num > 1  or i_endpoint in load_G_digging_nodes:
                # print("item",item[0])
                # print(saved_lane)
                # print("\n")
                lanes = []
                for i in saved_lane:
                    lanes.append(i[0])
                    lanes.append(i[1])
                load_G.add_edge(saved_lane[0][0], saved_lane[-1][1],  real_weight = sum(n[2] for n in saved_lane), blocked_weight = 0, lane = lanes)
                load_G.add_node(saved_lane[0][0])
                load_G.add_node(saved_lane[-1][1])
                # load_G_land_edges_map[lanes] = [saved_lane[0][0], saved_lane[-1][1]]
                # if [saved_lane[0][0], saved_lane[-1][1]] not in load_G_edges:
                #     load_G_edges.append([saved_lane[0][0], saved_lane[-1][1]])
                saved_lane = []
                if i_startpoint not in load_G_all_nodes:
                    load_G_all_nodes.append(i_startpoint)
                if i_endpoint not in load_G_all_nodes:
                    load_G_all_nodes.append(i_endpoint)

if __name__ == '__main__':

    trip = []
    unload_G = nx.Graph()
    load_G = nx.Graph()

    get_all_trip(trip)
    # generate_unloading_graph(unload_G,trip)
    generate_unloading_graph(unload_G, trip)
    print(unload_G.edges(data=True))



    # print(trip)
    # id = "2f2d10d0-0134-3ef1-fa54-f3133db99ae7"
    # son_lane_num = len(session_postgre.query(Lane).filter_by(StartNodeId=id).all())
    # print(son_lane_num)

    # print(len(trip))

    # unload_G.add_nodes_from(unload_G_digging_nodes, name = "digging")
    # unload_G.add_nodes_from(unload_G_dump_nodes, name='dump')

    # lane_id = "21c30b97-b134-627d-a816-14340c5cd7e3"
    # get_target_node(G = unload_G,source_node=lane_id,pre_target=None)
    # lane_id = "fa7f0363-d134-627e-e8b5-a0906a000dfd"
    # get_source_node(unload_G,lane_id)

    # get_truck_location(lane_id)
    # target = "ec91d7cd-7134-3ef3-11a7-3911025a05f0"
    # get_dijstarpath(lane_id,target)

    # get_truck_location(lane_id)
    # print(unload_G.edges(data="lane"))
    #
    # print(unload_G.edges(data='lane'))

    # print(unload_G_land_edges_map)
    # print(unload_G.number_of_edges())
    # print(unload_G.nodes())
    # # print("type of G.nodes",type(unload_G.nodes()))
    # print(unload_G.number_of_nodes())
    #
    #
    # print("unload_edges",len(unload_G_edges))
    # print(unload_G_edges)
    # print("unlod_nodes",len(unload_G_nodes))
    # print(unload_G_nodes)
    # pos = nx.shell_layout(load_G)
    # nx.draw(load_G, pos, with_labels=True, node_color='red', edge_color='blue', font_size=18, width=5, node_size=600,
    #         alpha=0.5)
    # plt.show()
    # minWPath = nx.dijkstra_path(unload_G, source='5e86128e-7134-627d-a7e4-022cf2f86a3e', target='2f2d10d0-0134-3ef1-fa54-f3133db99ae7')  # 顶点 0 到 顶点 3 的最短加权路径
    # print("5e86128e-7134-627d-a7e4-022cf2f86a3e 到 4卸 的最短加权路径: ", minWPath)

    # get_area()
    # print(area_pair_list)
    # # print(load_area_list)
    # # print(unload_area_list)
    # get_load_lane()
    # print(lanes)
    #
    # get_lane_endpoint(lanes[0][0])
    # print(lane_endnodeid)
    # print(type(lane_endnodeid[0]))
    # print(lane_endnodeid)
    # print(type(lanes[0][0]))
    # for i in area_pair_list:
    #     for item in session_postgre.query(WalkTime).filter_by(load_area_id = i[0], unload_area_id = i[1]).all():
    #         lanes.append(item.to_unload_lanes)
