from flask import Flask
from http.server import HTTPServer, BaseHTTPRequestHandler
from pymongo import MongoClient
from pymongo.server_api import ServerApi
from consts import DB_NAME, COLLECTION, MONGO_URI, MONGO_CERT_PATH, SSL_CERT_PATH, SSL_KEY_PATH, HOST, PORT
from map.get_markers import get_markers
from data_structs import Measurement, MarkerTemplate, DateAndTime, Sensor
import ssl
import keyboard
from flask import json

app = Flask(__name__)
terminate_server = 0


# Initialise MongoDB connection
def initDatabase():
    try:
        client = MongoClient(MONGO_URI,
                            tls=True,
                            tlsCertificateKeyFile=MONGO_CERT_PATH,
                            server_api=ServerApi('1'))

        db = client[DB_NAME]
        collection = db[COLLECTION]
        print("Connected to MongoDB")
        return client
    except Exception as e:
        print(f"Failed to connect to MongoDB: {e}")

    
# Define HTTP class
class IceHTTP(BaseHTTPRequestHandler):
    def do_GET(self):
        if self.path == '/':
            self.send_response(200)
            self.send_header("Content-type", "text/plain")
            self.end_headers()

            self.wfile.write(b"Root path hit!")

        # Update_map endpoint
        elif self.path == '/update_map': # NB: should be POST?
            # Fetch marker data
            markers_data, resp_code = get_markers()

            # Set headers
            self.send_response(resp_code)
            self.send_header("Content-type", "application/json")
            self.end_headers()

            # Write the JSON data to response object
            self.wfile.write(str(markers_data).encode('utf-8'))

# Listen for pressing of q key to terminate server
def on_key_press(server, event):
    if event.name == 'q':
        print('Terminating server...')
        server.server_close()
        keyboard.unhook_all()
        quit()


# Start a server on port 8443 using self defined HTTP class
if __name__ == "__main__":
    # Initialise database connection
    client = initDatabase()

    db = client[DB_NAME]
    testCol = db["TestCollection"]

    # NB: temporary test data
    sensor1 = Sensor(ID=1, type="Type1", active=True)
    sensor2 = Sensor(ID=2, type="Type2", active=False)

    datetime1 = DateAndTime(2023, 12, 31, 15, 43)
    datetime2 = DateAndTime(2024, 1, 15, 12, 2)
    datetime3 = DateAndTime(2024, 1, 31, 18, 10)

    measurement1 = Measurement(longitude=10.9771, latitude=60.7066, datetime=datetime1, sensor=sensor1, 
                                   precipitation=0.0, thickness=0.0, max_weight=0.0, safety_level=0.0, accuracy=2.5)
    measurement2 = Measurement(longitude=10.8171, latitude=60.6366, datetime=datetime2, sensor=sensor2, 
                                   precipitation=0.0, thickness=0.0, max_weight=0.0, safety_level=0.0, accuracy=1.5)
    measurement3 = Measurement(longitude=10.8471, latitude=60.7366, datetime=datetime3, sensor=sensor1, 
                                   precipitation=0.0, thickness=0.0, max_weight=0.0, safety_level=0.0, accuracy=4.0)
    
    testData = [
            MarkerTemplate(measurement1, 30.0-measurement1.accuracy, "Green"),
            MarkerTemplate(measurement2, 10.0-measurement2.accuracy, "Red"),
            MarkerTemplate(measurement3, 20.0-measurement3.accuracy, "Yellow"),
        ]
    
    data_to_insert = [marker_template.to_dict() for marker_template in testData]

    for element in data_to_insert:
        insert_result = testCol.insert_one(element)

    try:
        # Load SSL certificate and private key
        ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
        ssl_context.load_cert_chain(SSL_CERT_PATH, SSL_KEY_PATH)

        # Create HTTP server with SSL support
        server = HTTPServer((HOST, PORT), IceHTTP)
        server.socket = ssl_context.wrap_socket(server.socket, server_side=True)

        print("Server running on port ", PORT)

        # Register key press event handler
        keyboard.on_press(lambda event: on_key_press(server, event))

        # Run server indefinitely
        server.serve_forever()

    except Exception as e:
        print(f"Failed to start server on port {PORT}: {e}")