import os
import json
import random
from math import cos
import geopandas as gpd
from matplotlib import pyplot as plt
from shapely.ops import linemerge, unary_union, polygonize
from shapely.geometry import Polygon, LineString

from server.consts import LAKE_RELATIONS_PATH


def cut_map_handler(self, cursor, lake_name: str, cell_size_in_km: float = 0.5):
    status_code, map_data = cut_map(cursor, lake_name, cell_size_in_km)

    # Set headers
    self.send_response(status_code)
    self.send_header("Content-type", "application/json")
    self.end_headers()

    # Write the map data to the response
    self.wfile.write(json.dumps(map_data).encode('utf-8'))


def cut_map(cursor, lake_name: str, cell_size_in_km: float = 0.5) -> (int, str):
    """
    Cuts a map into a grid based on a selected cell size

            Parameters:
                    cursor (cursor): An Sqlite3 cursor object that points to the database
                    lake_name (str): The name of the lake to be cut
                    cell_size_in_km (float): The selected cell size in kilometers
    """
    try:
        # Read relation from GeoJson file and extract all geometry of type Polygon
        geo_data = gpd.read_file(LAKE_RELATIONS_PATH + lake_name + ".geojson")
        polygon_data = geo_data[geo_data['geometry'].geom_type == 'Polygon']
        polygons = [Polygon(polygon.exterior) for polygon in polygon_data['geometry']]

        if len(polygons) <= 1:
            raise Exception("Failed to convert JSON object to Shapely Polygons")

        # Select an arbitrary x and y value from within the polygon
        bounds = polygons[0].bounds
        start_x, start_y, _, _ = bounds

        # Calculate the cell width and height in degrees of latitude and longitude
        cell_width = cell_size_in_km / 111.3200
        cell_height = cell_width / cos(start_x * 0.01745)

        # List to store new GeoJSON feature objects
        features = []  # List to store new GeoJSON feature objects
        sub_div_id = 0  # Tracker to create unique subdivision ids
        divided_map = []  # Object for plotting the tiles

        # Process all polygons
        for polygon in polygons:
            # Generate a grid based on the calculated cell size
            lines = create_grid(polygon, cell_width * 2, cell_height)
            lines.append(polygon.boundary)
            # Merge the grid lines into a single polygonized object
            lines = unary_union(lines)
            lines = linemerge(lines)
            lines = list(polygonize(lines))

            # Combine the polygon and the grid to form the subdivisions
            for line in lines:
                if line.intersects(polygon):  # Add the geometry which intersects the gird
                    cell = (line.intersection(polygon))
                    divided_map.append(cell)

                    # Calculate cell center based on bounds, and round down to two decimals
                    min_x, min_y, max_x, max_y = cell.bounds
                    center = round(max_y - (max_y - min_y), 6), round(max_x - (max_x - min_x), 6)

                    rounded_coordinates = []
                    if isinstance(cell, Polygon):
                        for coords in cell.exterior.coords:
                            rounded_coords = (round(coords[0], 4), round(coords[1], 4))
                            rounded_coordinates.append(rounded_coords)

                    rounded_tile = Polygon(rounded_coordinates)
                    geometry = rounded_tile.__geo_interface__

                    if not geometry['coordinates']:
                        continue  # Skip empty tiles

                    # Create new feature object
                    tile_feature = {
                        'type': 'Feature',
                        'properties': {
                            'sub_div_id': str(sub_div_id),
                            'sub_div_center': center
                        },
                        'geometry': geometry
                    }
                    # Append new feature object to list, and increment sub_div_id for next iteration
                    features.append(tile_feature)
                    sub_div_id += 1
            break  # NB test break

        # Create new GeoJSON object containing all the new feature objects
        feature_collection = {
            'type': 'FeatureCollection',
            'cell_count': sub_div_id,  # Add the last subdivision ID as number of tiles
            'cell_width': cell_width,
            'cell_height': cell_height,
            'cell_size_in_km': cell_size_in_km,
            'features': features
        }

        # Check if the name exists in the database
        cursor.execute('''
            SELECT Name FROM BodyOfWater WHERE Name = ?;
        ''', (lake_name,))
        existing_lake = cursor.fetchone()

        # If lake_name doesn't exist, insert it into the database
        if existing_lake is None:
            cursor.execute('''
                INSERT INTO BodyOfWater(Name) VALUES (?);
            ''', (lake_name,))

        # Plot the newly created map and save it to a new file
        plot_map(divided_map)
        write_json_to_file(lake_name, feature_collection)

        # Return OK and the newly divided map
        return 200, feature_collection

    except FileNotFoundError as e:
        print(f"Failed to find the map file: {e}")
        return 404, f"Failed to find the map file: {e}"
    except Exception as e:
        print(f"Error in adding new map: {e}")
        return 500, f"Error in adding new map: {e}"


def create_grid(poly: Polygon, cell_width: float, cell_height: float):
    """
    Returns a list of vertical and horizontal LineStrings that create a grid.

            Parameters:
                    poly (Polygon): A Shapely Polygon representing a map or part of a map
                    cell_width (float): The width of the grid cells in degrees
                    cell_height (float): The height of the grid cells in degrees

            Returns:
                    grid_lines (list): List of LineString objects defining the grid
    """
    # Retrieve bounds of the entire polygon
    bounds = poly.bounds
    min_x, min_y, max_x, max_y = bounds

    # List to store all created lines
    grid_lines = []

    # Create vertical lines while within bounds
    x = min_x
    while x <= max_x:
        line = LineString([(x, min_y), (x, max_y)])
        grid_lines.append(line)
        x += cell_width

    # Create horizontal lines while within bounds
    y = min_y
    while y <= max_y:
        line = LineString([(min_x, y), (max_x, y)])
        grid_lines.append(line)
        y += cell_height

    return grid_lines


def write_json_to_file(lake_name: str, map_data: dict):
    """
    Writes a divided map to a JSON file and updates all_lake_names.json

            Parameters:
                    lake_name (str): Name of the lake and file to write to
                    map_data (dict): List of map polygons converted to a JSON dictionary
    """
    # Create and write divided map to new file
    print("Writing to file...")
    if not os.path.exists(LAKE_RELATIONS_PATH):
        raise Exception("Directory from path does not exist")

    with open(LAKE_RELATIONS_PATH + '/' + lake_name + '_div.json', 'w') as f:
        json.dump(map_data, f)

    # Read all_system_lakes.json
    with open(LAKE_RELATIONS_PATH + 'all_lake_names.json', 'r', encoding='utf-8') as f:
        data = json.load(f)

    # Check if the lake name exists in the list
    if lake_name not in data:
        data.append(lake_name)  # Only append to list if it does not already exist

        # Update all_lake_names.json with new lake name
        with open(LAKE_RELATIONS_PATH + 'all_lake_names.json', 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=2)


# Plotting the map can take a considerable amount of time, especially when creating maps with many
# subdivisions. Removing calls to plot_map will speed up the process, but it is highly recommended
# to plot the map after each division to ensure that the map was divided as intended.
def plot_map(divided_map):
    """
    Plots the divisions of a  map using matplotlib.

        Parameters:
            divided_map (list): List of Shapely Polygons to be plotted
    """
    print("Plotting... This may take some time...")

    # Convert Polygon objects to GeoDataFrames
    tiles = [gpd.GeoDataFrame(geometry=[tile]) for tile in divided_map]

    # Configure plot settings
    fig, ax = plt.subplots()

    # Plot each tile
    for tile in tiles:
        # Give each tile a random color to clearly visualize the grid
        random_color = "#{:06x}".format(random.randint(0, 0xFFFFFF))
        gpd.GeoSeries(tile.geometry).plot(ax=ax, facecolor=random_color, edgecolor='none')

    ax.set_aspect(1.8)

    # Display plot
    plt.show()