import socket
import select
import pickle
import time


from MarioConfig import *
from threading import *
import pickle
import numpy as np
from MarioConfig import MarioConfig

def reward_function(total_x, total_y, score, ticks):
    return total_x**1.9 + total_y**2.3 + score**1.5 + ticks**1.1
    
def NES(npop, learning_rate, sigma, layers, W, b, Nw, Nb, R):
    #print("NES INPUT:")
    #print("npop: ", npop)
    #print("learning rate: ", learning_rate)
    #print("len(W): ", len(W))
    #print("len(b) ", len(b))
    #print("W[0].shape: ", W[0].shape)
    #print("b[0].shape: ", b[0].shape)
    #print("len(R): ", len(R))
    #print("W: ", W)
    #print("b: ", b)
    #print("Nw: ", Nw)
    #print("Nb: ", Nb)
    #print("R: ", R)
    A = None
    Rstd = np.std(R)
    if Rstd == 0:
        A = R-np.mean(R)
    else:
        A = (R - np.mean(R)) / np.std(R)
    for i in range(len(layers)):
        W[i] = W[i] + (learning_rate/(npop*sigma)) * np.dot(Nw[i].transpose(1,2,0), A)
        b[i] = b[i] + (learning_rate/(npop*sigma)) * np.dot(Nb[i].T, A)

client_server_packets = {
    "client_hello": 0,
    "client_cpu_cores": 1,
    "client_ready": 2,
    "client_complete": 3,
    "client_goodbye": 4,
    "client_ping_response": 5
}

server_client_packets = {
    "server_hello": 0,
    "server_step": 2,
    "server_stop": 3,
    "server_update": 4, # sends when all rewards has been recieved, clients reacts by updating
    "server_client_dc": 5,
    "server_config_self": 6,
    "server_config_others": 8,
    "server_update_config": 7,
    "server_ping": 9,
}

def utility_upgrade_configs(configs : [MarioConfig], npop : int, results : [float], Nw, Nb):
    for config in configs:
        #print("Updating config: ", config.nodeid)
        config.update(npop, results, Nw, Nb)

###################################
# Packet types for easy pickling
###################################
##################################
# Client packets
##################################
class ClientConfigPacket():
    def __init__(self, cpu_cores):
        self.cpu_cores = cpu_cores
class ClientReadyPacket():
    def __init__(self, ready, nodeids):
        self.ready = ready
        self.nodeids = nodeids
class ClientCompletePacket():
    def __init__(self, rewards, stages):
        self.rewards = rewards
        self.stages = stages

##################################
# Server packets
##################################
class ServerClientCompletePacket():
    def __init__(self, rewards, nodeid):
        self.rewards = rewards
        self.nodeid = nodeid
class ServerClientDCPacket():
    def __init__(self, nodeid):
        self.nodeid = nodeid
class ServerClientConfigOthers():
    def __init__(self, configs):
        self.configs = configs
class ServerClientUpdateConfig():
    def __init__(self, rewards, nodeids, population, npop_total):
        self.rewards = rewards
        self.nodeids = nodeids
        self.population = population
        self.npop_total = npop_total