Commit 9cbfdff2 authored by Bård Sørensen Hestmark's avatar Bård Sørensen Hestmark
Browse files

Add CA-ES code from server and restructure


Former-commit-id: 2a89cc3b
parent ac0f29be
......@@ -2,4 +2,5 @@
/workspace.xml
/.idea
/venv
/CA
\ No newline at end of file
/CA
**/__pycache__/
\ No newline at end of file
import copy
import math
from tqdm import trange
import multiprocessing as mp
import numpy as np
import torch
import torch.multiprocessing as tmp
import torch.nn.functional as F
from torch import tensor as tt
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image
from model import CellularAutomataModel
from utils import load_emoji, save_model, to_rgb, Pool
import pygame
import time
from PIL import Image
import logging
class ES:
def __init__(self, args):
self.population_size = args.population_size
self.n_iterations = args.n_iterations
self.pool_size = args.pool_size
self.batch_size = args.batch_size
self.eval_freq = args.eval_freq
self.n_channels = args.n_channels
self.hidden_size = args.hidden_size
self.fire_rate = args.fire_rate
self.lr = args.lr
self.sigma = args.sigma
self.size = args.size + 2 * args.padding
self.target_img = load_emoji(args.img, self.size)
self.padding = args.padding
self.logdir = args.logdir
self.decay_state = 0
p = self.padding
self.pad_target = F.pad(tt(self.target_img), (0, 0, p, p, p, p))
h, w = self.pad_target.shape[:2]
self.seed = np.zeros([h, w, self.n_channels], np.float64)
self.seed[h // 2, w // 2, 3:] = 1.0
self.net = CellularAutomataModel(n_channels=self.n_channels, fire_rate=self.fire_rate, hidden_channels=self.hidden_size)
self.param_shape = [tuple(p.shape) for p in self.net.parameters()]
self.pool = Pool(self.seed, self.pool_size)
if args.load_model_path != "":
self.load_model(args.load_model_path)
self.lr = 0.00075 # test
self.decay_state = 2
t_rgb = to_rgb(self.pad_target).permute(2, 0, 1)
if args.mode == "train":
save_image(t_rgb, "%s/target_image.png" % self.logdir)
self.writer = SummaryWriter(self.logdir)
def load_model(self, path):
"""Load a PyTorch model from path."""
self.net.load_state_dict(torch.load(path))
self.net.double()
def fitness_shape(self, x):
"""Sort x and and map x to linear values between -0.5 and 0.5
Return standard score of x
"""
shaped = np.zeros(len(x))
shaped[x.argsort()] = np.arange(len(x), dtype=np.float64)
shaped /= (len(x) - 1)
shaped -= 0.5
shaped = (shaped - shaped.mean()) / shaped.std()
return shaped
def update_parameters(self, fitnesses, epsilons):
"""Update parent network weights using evaluated mutants and fitness."""
fitnesses = self.fitness_shape(fitnesses)
for i, e in enumerate(epsilons):
for j, w in enumerate(self.net.parameters()):
w.data += self.lr * 1 / (self.population_size * self.sigma) * fitnesses[i] * e[j]
def get_population(self):
"""Return an array with values sampled from N(0, sigma)"""
epsilons = []
for _ in range(int(self.population_size / 2)):
e = []
e2 = []
for w in self.param_shape:
j = np.random.randn(*w) * self.sigma
e.append(j)
e2.append(-j)
epsilons.append(e)
epsilons.append(e2)
return np.array(epsilons, dtype=np.object)
def step(self, model_try, x):
"""Perform a generation of CA using trained net.
Return output x and loss
"""
torch.seed()
iter_n = torch.randint(30, 40, (1,)).item()
for _ in range(iter_n): x = model_try(x)
loss = self.net.loss(x, self.pad_target)
loss = torch.mean(loss)
return x, loss.item()
def fitness(self, epsilon, x0, pid, q=None):
"""Method that start a generation of ES.
Return output from generation x and its fitness
"""
model_try = copy.deepcopy(self.net)
if epsilon is not None:
for i, w in enumerate(model_try.parameters()):
w.data += torch.tensor(epsilon[i])
x, loss = self.step(model_try, x0)
fitness = -loss
if not math.isfinite(fitness):
raise ValueError('Encountered non-number value in loss. Fitness ' + str(fitness) + '. Loss: ' + str(loss))
q.put((x, fitness, pid))
return
def decay_lr(self, fitness):
# Fitness treshholds for adjusting learning rate
# fit_t1 = -0.06 # testing for size ~20
# fit_t2 = -0.03
fit_t1 = -0.05 # works well for size ~15
fit_t2 = -0.02
# fit_t1 = -0.03 # used for size 9
# fit_t2 = -0.01
if not self.decay_state == 2:
if fitness >= fit_t1 and self.decay_state == 0:
reduce = 0.3
self.lr *= reduce
self.decay_state += 1
logging.info("Fitness higher than than %.3f, lr set to %.5f (*%.2f)" % (fit_t1, self.lr, reduce))
elif fitness >= fit_t2 and self.decay_state == 1:
reduce = 0.5
self.lr *= reduce
self.decay_state += 1
logging.info("Fitness higher than %.3f, lr set to %.5f (*%.2f)" % (fit_t2, self.lr, reduce))
def evaluate_main(self, x0):
"""Return output and fitness from a generation using unperturbed weights/coeffs"""
x_main, loss_main = self.step(self.net, x0.clone())
fit_main = - loss_main
return x_main, fit_main
def train(self):
"""main training loop"""
logging.info("Starting training")
x0 = tt(np.repeat(self.seed[None, ...], self.batch_size, 0)) #seed
_, _ = self.step(self.net, x0.clone())
processes = []
q = tmp.Manager().Queue()
t = trange(self.n_iterations, desc='Mean reward:', leave=True)
for iteration in t:
batch = self.pool.sample(self.batch_size)
x0 = batch["x"]
loss_rank = self.net.loss(tt(x0), self.pad_target).numpy().argsort()[::-1]
x0 = x0[loss_rank]
x0[:1] = self.seed
x0 = tt(x0)
epsilons = self.get_population()
fitnesses = np.zeros(self.population_size, dtype=np.float64)
xs = torch.zeros(self.population_size, *x0.shape, dtype=torch.float64)
for i in range(self.population_size):
p = tmp.Process(target=self.fitness, args=(epsilons[i], x0.clone(), i, q))
p.start()
processes.append(p)
for p in processes:
p.join()
x, fit, pid = q.get()
fitnesses[pid] = fit
xs[pid] = x
processes = []
idx = np.argmax(fitnesses)
batch["x"][:] = xs[idx]
self.pool.commit(batch)
fitnesses = np.array(fitnesses).astype(np.float64)
self.update_parameters(fitnesses, epsilons)
# Logging
mean_fit = np.mean(fitnesses)
self.writer.add_scalar("train/fit", mean_fit, iteration)
if iteration % 10 == 0:
t.set_description("Mean reward: %.4f " % mean_fit, refresh=True)
if (iteration+1) % self.eval_freq == 0:
self.decay_lr(mean_fit)
# Save picture of model
x_eval = x0.clone()
model = self.net
pics = []
for eval in range(36):
x_eval = model(x_eval)
if eval % 5 == 0:
pics.append(to_rgb(x_eval).permute(0, 3, 1, 2))
save_image(torch.cat(pics, dim=0), '%s/pic/big%04d.png' % (self.logdir, iteration), nrow=1, padding=0)
save_model(self.net, self.logdir + "/models/model_" + str(iteration))
# Do damage on model using pygame, cannot run through ssh
def interactive(self):
model = self.net
x_eval = tt(np.repeat(self.seed[None, ...], self.batch_size, 0))
cellsize = 50
imgpath = '%s/one.png' % (self.logdir)
pygame.init()
surface = pygame.display.set_mode((self.size * cellsize, self.size * cellsize))
pygame.display.set_caption("Interactive CA-ES")
damaged = 0
losses = []
while True:
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
return
elif event.type == pygame.MOUSEBUTTONDOWN:
# damage
if damaged == 0:
dmg_size = 20
mpos_x, mpos_y = event.pos
mpos_x, mpos_y = mpos_x // cellsize, mpos_y // cellsize
x_eval[:, mpos_y:mpos_y + dmg_size, mpos_x:mpos_x + dmg_size, :] = 0
damaged = 100 # the number of steps to record loss after damage hass occurred
x_eval = model(x_eval)
image = to_rgb(x_eval).permute(0, 3, 1, 2)
# Very quick hack to get rid of the batch dimension in tensor
save_image(image, imgpath, nrow=1, padding=0)
image = np.asarray(Image.open(imgpath))
if damaged > 0:
loss= self.net.loss(x_eval, self.pad_target)
losses.append(loss)
if damaged == 1:
lossfile = open('%s/losses.txt' % self.logdir, "w")
for l in range(len(losses)):
lossfile.write('%d,%.06f\n' % (l, losses[l]))
lossfile.close()
losses.clear()
damaged -= 1
self.game_update(surface, image, cellsize)
time.sleep(0.05)
pygame.display.update()
def game_update(self, surface, cur_img, sz):
nxt = np.zeros((cur_img.shape[0], cur_img.shape[1]))
for r, c, _ in np.ndindex(cur_img.shape):
pygame.draw.rect(surface, cur_img[r,c], (c*sz, r*sz, sz, sz))
return nxt
def generate_graphic(self):
model = self.net
x_eval = tt(np.repeat(self.seed[None, ...], self.batch_size, 0))
pics = []
pics.append(to_rgb(x_eval).permute(0, 3, 1, 2))
for eval in range(40):
x_eval = model(x_eval)
if eval in [10, 20, 30, 39]: # frames to save img of
pics.append(to_rgb(x_eval).permute(0, 3, 1, 2))
save_image(torch.cat(pics, dim=0), '%s/graphic.png' % (self.logdir), nrow=len(pics), padding=0)
\ No newline at end of file
INFO:root:
Arguments:
mode: 'graphic'
population_size: 16
n_iterations: 100
pool_size: 1024
batch_size: 1
eval_freq: 500
n_channels: 16
hidden_size: 32
fire_rate: 0.5
lr: 0.005
sigma: 0.01
img: '🥕'
size: 15
padding: 0
logdir: 'logs/CARROT-graphic_12-04-2022_10-15-16'
load_model_path: 'model_999500'
import torch
import argparse
import time
import os
import unicodedata
import logging
from es import ES
# rabbit 🐰
# carrot 🥕
# watermelon 🍉
if __name__ == '__main__':
emoji = '🥕'
parser = argparse.ArgumentParser()
parser.add_argument("--mode", type=str, default="graphic", metavar="train/interactive/graphic", help="Decides mode to run e.g. train or damage")
parser.add_argument("--population_size", type=int, default=16, metavar=128, help="Population size")
parser.add_argument("--n_iterations", type=int, default=100, help="Number of iterations to train for.")
parser.add_argument("--pool_size", type=int, default=1024, help="Size of the training pool, zero if training without pool")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size.")
parser.add_argument("--eval_freq", type=int, default=500, help="Frequency for various saving/evaluating/logging",)
parser.add_argument("--n_channels", type=int, default=16, help="Number of channels of the input tensor")
parser.add_argument("--hidden_size", type=int, default=32, help="Number of hidden channels")
parser.add_argument("--fire_rate", type=float, default=0.5, metavar=0.5, help="Cell fire rate")
parser.add_argument("--lr", type=float, default=0.005, metavar=0.005, help="Learning rate")
parser.add_argument("--sigma", type=float, default=0.01, metavar=0.01, help="Sigma")
parser.add_argument("--img", type=str, default=emoji, metavar="🐰", help="The emoji to train on")
parser.add_argument("--size", type=int, default=15, help="Image size")
parser.add_argument("--padding", type=int, default=0, help="Padding. The shape after padding is (h + 2 * p, w + 2 * p).")
parser.add_argument("--logdir", type=str, default="logs", help="Logging folder for new model")
parser.add_argument("--load_model_path", type=str, default="model_999500", help="Path to pre trained model")
args = parser.parse_args()
if not os.path.isdir(args.logdir):
raise Exception("Logging directory '%s' not found in base folder" % args.logdir)
args.logdir = "%s/%s-%s_%s" % (args.logdir, unicodedata.name(args.img), args.mode, time.strftime("%d-%m-%Y_%H-%M-%S"))
os.mkdir(args.logdir)
logging.basicConfig(filename='%s/logfile.log' % args.logdir, encoding='utf-8', level=logging.INFO)
argprint = "\nArguments:\n"
for arg, value in vars(args).items():
argprint += ("%s: %r\n" % (arg, value))
logging.info(argprint)
es = ES(args)
torch.set_num_threads(1) # disable pytorch's built in parallelization
match args.mode:
case "train":
os.mkdir(args.logdir + "/models")
os.mkdir(args.logdir + "/pic")
es.train()
case "interactive": es.interactive()
case "graphic": es.generate_graphic()
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class CellularAutomataModel(nn.Module):
def __init__(self, n_channels, hidden_channels, fire_rate):
super().__init__()
self.n_channels = n_channels
self.hidden_channels = hidden_channels
self.fire_rate = fire_rate
self.fc0 = nn.Linear(self.n_channels * 3, self.hidden_channels, bias=False)
self.fc1 = nn.Linear(self.hidden_channels, self.n_channels, bias=False)
with torch.no_grad(): self.fc1.weight.zero_()
identity = np.float64([0, 1, 0])
identity = torch.from_numpy(np.outer(identity, identity))
sobel_x = torch.from_numpy(np.outer([1, 2, 1], [-1, 0, 1]) / 8.0) # sobel filter
sobel_y = sobel_x.T
self.kernel = torch.cat([
identity[None, None, ...],
sobel_x[None, None, ...],
sobel_y[None, None, ...]],
dim=0).repeat(self.n_channels, 1, 1, 1)
for param in self.parameters(): param.requires_grad = False
self.double()
def perceive(self, x):
"""Percieve neighboors with two sobel filters and one single-entry filter"""
y = F.conv2d(x.permute(0, 3, 1, 2), self.kernel, groups=16, padding=1)
y = y.permute(0, 2, 3, 1)
return y
def loss(self, x, y):
"""mean squared error"""
return torch.mean(torch.square(x[..., :4] - y), [-2, -3, -1])
def forward(self, x, fire_rate=None, step_size=1.0):
"""Forward a cell grid through the network and return the cell grid with changes applied."""
y = self.perceive(x)
pre_life_mask = get_living_mask(x)
dx1 = self.fc0(y)
dx1 = F.relu(dx1)
dx2 = self.fc1(dx1)
dx = dx2 * step_size
if fire_rate is None:
fire_rate = self.fire_rate
update_mask_rand = torch.rand(*x[:, :, :, :1].shape)
update_mask = update_mask_rand <= fire_rate
x += dx * update_mask.double()
post_life_mask = get_living_mask(x)
life_mask = pre_life_mask.bool() & post_life_mask.bool()
res = x * life_mask.double()
return res
def get_living_mask(x):
"""returns boolean vector of the same shape as x, except for the last dimension.
The last dimension is a single value, true/false, that determines if alpha > 0.1"""
alpha = x[:, :, :, 3:4]
m = F.max_pool3d(alpha, kernel_size=3, stride=1, padding=1) > 0.1
return m
import requests
import torch
from torchvision.utils import save_image
import numpy as np
import PIL.Image, PIL.ImageDraw
import io
def load_emoji(emoji_code, img_size):
"""Loads image of emoji with code 'emoji' from google's emojirepository"""
emoji_code = hex(ord(emoji_code))[2:].lower()
url = 'https://raw.githubusercontent.com/googlefonts/noto-emoji/main/png/128/emoji_u%s.png' % emoji_code
req = requests.get(url)
img = PIL.Image.open(io.BytesIO(req.content))
img.thumbnail((img_size, img_size), PIL.Image.ANTIALIAS)
img = np.float64(img) / 255.0
img[..., :3] *= img[..., 3:]
return img
def to_alpha(x):
"""Return the alpha channel of an image."""
return torch.clamp(x[..., 3:4], 0.0, 1.0)
def to_rgb(x):
"""Return the three first channels (RGB) with alpha deducted."""
rgb, a = x[..., :3], to_alpha(x)
return 1.0 - a + rgb
def to_rgba(x):
"""Return the four first channels (RGBA) of an image."""
return x[..., :4]
def save_model(ca, base_fn):
"""Save a PyTorch model to a specific path."""
torch.save(ca.state_dict(), base_fn)
def visualize(xs, step_i, nrow=1):
"""Save a batch of multiple x's to file"""
for i in range(len(xs)):
xs[i] = to_rgb(xs[i]).permute(0, 3, 1, 2)
save_image(torch.cat(xs, dim=0), './logg/pic/p%04d.png' % step_i, nrow=nrow, padding=0)
class Pool:
"""Class for storing and providing samples of different stages of growth."""
def __init__(self, seed, size):
self.size = size
self.slots = np.repeat([seed], size, 0)
self.seed = seed
def commit(self, batch):
"""Replace existing slots with a batch."""
indices = batch["indices"]
for i, x in enumerate(batch["x"]):
if (x[:, :, 3] > 0.1).any(): # Avoid committing dead image
self.slots[indices[i]] = x.copy()
def sample(self, c):
"""Retrieve a batch from the pool."""
indices = np.random.choice(self.size, c, False)
batch = {
"indices": indices,
"x": self.slots[indices]
}
return batch
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment