Commit e7fa6040 authored by Bård Sørensen Hestmark's avatar Bård Sørensen Hestmark
Browse files

Merge branch 'nonmodel' into 'main'

Nonmodel

See merge request !1
parents ba5c6d72 251bf98c
# Code from https://github.com/jankrepl/mildlyoverfitted/tree/master/github_adventures/automata, project had no licence
import torch
import torch.nn as nn
class CAModel(nn.Module):
"""Cell automata model.
Parameters
----------
n_channels : int
Number of channels of the grid.
hidden_channels : int
Hidden channels that are related to the pixelwise 1x1 convolution.
fire_rate : float
Number between 0 and 1. The lower it is the more likely it is for
cells to be set to zero during the `stochastic_update` process.
device : torch.device
Determines on what device we perfrom all the computations.
Attributes
----------
update_module : nn.Sequential
The only part of the network containing trainable parameters. Composed
of 1x1 convolution, ReLu and 1x1 convolution.
filters : torch.Tensor
Constant tensor of shape `(3 * n_channels, 1, 3, 3)`.
"""
def __init__(self, n_channels=16, hidden_channels=128, fire_rate=0.5, device=None):
super().__init__()
self.fire_rate = 0.5
self.n_channels = n_channels
self.device = device or torch.device("cpu")
# Perceive step
sobel_filter_ = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]])
scalar = 8.0
sobel_filter_x = sobel_filter_ / scalar
sobel_filter_y = sobel_filter_.t() / scalar
identity_filter = torch.tensor(
[
[0, 0, 0],
[0, 1, 0],
[0, 0, 0],
],
dtype=torch.float32,
)
filters = torch.stack(
[identity_filter, sobel_filter_x, sobel_filter_y]
) # (3, 3, 3)
filters = filters.repeat((n_channels, 1, 1)) # (3 * n_channels, 3, 3)
self.filters = filters[:, None, ...].to(
self.device
) # (3 * n_channels, 1, 3, 3)
# Update step
self.update_module = nn.Sequential(
nn.Conv2d(
3 * n_channels,
hidden_channels,
kernel_size=1, # (1, 1)
),
nn.ReLU(),
nn.Conv2d(
hidden_channels,
n_channels,
kernel_size=1,
bias=False,
),
)
with torch.no_grad():
self.update_module[2].weight.zero_()
self.to(self.device)
def perceive(self, x):
"""Approximate channelwise gradient and combine with the input.
This is the only place where we include information on the
neighboring cells. However, we are not using any learnable
parameters here.
Parameters
----------
x : torch.Tensor
Shape `(n_samples, n_channels, grid_size, grid_size)`.
Returns
-------
torch.Tensor
Shape `(n_samples, 3 * n_channels, grid_size, grid_size)`.
"""
return nn.functional.conv2d(x, self.filters, padding=1, groups=self.n_channels)
def update(self, x):
"""Perform update.
Note that this is the only part of the forward pass that uses
trainable parameters
Paramters
---------
x : torch.Tensor
Shape `(n_samples, 3 * n_channels, grid_size, grid_size)`.
Returns
-------
torch.Tensor
Shape `(n_samples, n_channels, grid_size, grid_size)`.
"""
return self.update_module(x)
@staticmethod
def stochastic_update(x, fire_rate):
"""Run pixel-wise dropout.
Unlike dropout there is no scaling taking place.
Parameters
----------
x : torch.Tensor
Shape `(n_samples, n_channels, grid_size, grid_size)`.
fire_rate : float
Number between 0 and 1. The higher the more likely a given cell
updates.
Returns
-------
torch.Tensor
Shape `(n_samples, n_channels, grid_size, grid_size)`.
"""
device = x.device
mask = (torch.rand(x[:, :1, :, :].shape) <= fire_rate).to(device, torch.float32)
return x * mask # broadcasted over all channels
@staticmethod
def get_living_mask(x):
"""Identify living cells.
Parameters
----------
x : torch.Tensor
Shape `(n_samples, n_channels, grid_size, grid_size)`.
Returns
-------
torch.Tensor
Shape `(n_samples, 1, grid_size, grid_size)` and the
dtype is bool.
"""
return (
nn.functional.max_pool2d(
x[:, 3:4, :, :], kernel_size=3, stride=1, padding=1
)
> 0.1
)
def forward(self, x):
"""Run the forward pass.
Parameters
----------
x : torch.Tensor
Shape `(n_samples, n_channels, grid_size, grid_size)`.
Returns
-------
torch.Tensor
Shape `(n_sample, n_channels, grid_size, grid_size)`.
"""
pre_life_mask = self.get_living_mask(x)
y = self.perceive(x)
dx = self.update(y)
dx = self.stochastic_update(dx, fire_rate=self.fire_rate)
x = x + dx
post_life_mask = self.get_living_mask(x)
life_mask = (pre_life_mask & post_life_mask).to(torch.float32)
return x * life_mask
\ No newline at end of file
Based on video 'Growing neural cellular automata in PyTorch': https://www.youtube.com/watch?v=21ACbWoF2Oo
Everything in this folder is from https://github.com/jankrepl/mildlyoverfitted/tree/master/github_adventures/automata
The project had no license
The log folder contains a trained example for tensorboard that ran for 10000 batches, that showed a decent result.
Run tensorboard with:
`tensorboard --logdir=logs`
Get parameters:
`python train.py --help`
Run args suggestion:
`python train.py -p 0 -n 10000 -s 9 -hch 32 -pool=true -th 3.33e-3`
Remove '-d cuda' to train using cpu
In pytorch edit configuration and add parameters:
`-d cuda -n 10000 -b 4`
import argparse
import pathlib
import time
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from model import CAModel
import io
import os
import requests
import unicodedata
import json
def load_image(path, size=40):
"""Load an image.
Parameters
----------
path : pathlib.Path
Path to where the image is located. Note that the image needs to be
RGBA.
size : int
The image will be resized to a square wit ha side length of `size`.
Returns
-------
torch.Tensor
4D float image of shape `(1, 4, size, size)`. The RGB channels
are premultiplied by the alpha channel.
"""
img = Image.open(path)
img = img.resize((size, size), Image.ANTIALIAS)
img = np.float32(img) / 255.0
img[..., :3] *= img[..., 3:]
return torch.from_numpy(img).permute(2, 0, 1)[None, ...]
def load_emoji(emoji_code, img_size):
"""Loads image of emoji with code 'emoji' from google's emojirepository"""
emoji_code = hex(ord(emoji_code))[2:].lower()
url = 'https://raw.githubusercontent.com/googlefonts/noto-emoji/main/png/128/emoji_u%s.png' % emoji_code
req = requests.get(url)
img = Image.open(io.BytesIO(req.content))
img.thumbnail((img_size, img_size), Image.ANTIALIAS)
img = np.float64(img) / 255.0
img[..., :3] *= img[..., 3:]
return torch.from_numpy(img).permute(2, 0, 1)[None, ...]
def to_rgb(img_rgba):
"""Convert RGBA image to RGB image.
Parameters
----------
img_rgba : torch.Tensor
4D tensor of shape `(1, 4, size, size)` where the RGB channels
were already multiplied by the alpha.
Returns
-------
img_rgb : torch.Tensor
4D tensor of shape `(1, 3, size, size)`.
"""
rgb, a = img_rgba[:, :3, ...], torch.clamp(img_rgba[:, 3:, ...], 0, 1)
return torch.clamp(1.0 - a + rgb, 0, 1)
def save_model(PATH, model):
torch.save(model.state_dict(), PATH)
def make_seed(size, n_channels):
"""Create a starting tensor for training.
The only active pixels are going to be in the middle.
Parameters
----------
size : int
The height and the width of the tensor.
n_channels : int
Overall number of channels. Note that it needs to be higher than 4
since the first 4 channels represent RGBA.
Returns
-------
torch.Tensor
4D float tensor of shape `(1, n_chanels, size, size)`.
"""
x = torch.zeros((1, n_channels, size, size), dtype=torch.float32)
x[:, 3:, size // 2, size // 2] = 1
return x
def main(argv=None):
parser = argparse.ArgumentParser(
description="Training script for the Celluar Automata"
)
parser.add_argument("-img", "--img", type=str, default="rabbit.png",
help="Path to the image we want to reproduce")
parser.add_argument(
"-b",
"--batch-size",
type=int,
default=8,
help="Batch size. Samples will always be taken randomly from the pool."
)
parser.add_argument(
"-d",
"--device",
type=str,
default="cpu",
help="Device to use",
choices=("cpu", "cuda"),
)
parser.add_argument(
"-e",
"--eval-frequency",
type=int,
default=500,
help="Evaluation frequency.",
)
parser.add_argument(
"-i",
"--eval-iterations",
type=int,
default=300,
help="Number of iterations when evaluating.",
)
parser.add_argument(
"-n",
"--n-batches",
type=int,
default=10000,
help="Number of batches to train for.",
)
parser.add_argument(
"-c",
"--n-channels",
type=int,
default=16,
help="Number of channels of the input tensor",
)
parser.add_argument(
"-l",
"--logdir",
type=str,
default="logs",
help="Folder where all the logs and outputs are saved.",
)
parser.add_argument(
"-p",
"--padding",
type=int,
default=0,
help="Padding. The shape after padding is (h + 2 * p, w + 2 * p).",
)
parser.add_argument(
"--pool-size",
type=int,
default=1024,
help="Size of the training pool",
)
parser.add_argument(
"-s",
"--size",
type=int,
default=9,
help="Image size",
)
parser.add_argument(
"-hch",
"--hidden-channels",
type=int,
default=32,
help="Number of hidden channels"
)
parser.add_argument(
"-pool",
"--pool",
type=str,
default="true",
help="True to train with pools, false to train without them"
)
parser.add_argument(
"-threshhold",
"--th",
type=float,
default=3.3e-3,
help="Stop training at certain loss threshhold"
)
# Parse arguments
args = parser.parse_args()
print(vars(args))
# pools on/off
if args.pool.lower() == "true":
print("Training with sample pools")
args.pool = True
else:
args.pool = False
args.img = "🥕"#"🐰" # switch emoji here
args.mode = "train"
if not os.path.isdir(args.logdir):
raise Exception(
"Logging directory '%s' not found in base folder" % args.logdir)
# make log dir
args.logdir = "%s/%s-%s-%s_%s" % (args.logdir, args.size, unicodedata.name(
args.img), args.mode, time.strftime("%d-%m-%Y_%H-%M-%S"))
os.mkdir(args.logdir)
os.mkdir(args.logdir + "/models")
os.mkdir(args.logdir + "/pic")
print(f'logs saved to dir: {args.logdir}')
with open(f'{args.logdir}/args.json', 'w') as f:
f.write(json.dumps(vars(args), indent=4))
# Misc
device = torch.device(args.device)
log_path = pathlib.Path(args.logdir)
log_path.mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(log_path)
# Target image
#target_img_ = load_image(args.img, size=args.size)
target_img_ = load_emoji(args.img, args.size)
p = args.padding
target_img_ = nn.functional.pad(target_img_, (p, p, p, p), "constant", 0)
target_img = target_img_.to(device)
target_img = target_img.repeat(args.batch_size, 1, 1, 1)
writer.add_image("ground truth", to_rgb(target_img_)[0])
# Model and optimizer
model = CAModel(n_channels=args.n_channels,
hidden_channels=args.hidden_channels, device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)
# Pool initialization
seed = make_seed(args.size, args.n_channels).to(device)
seed = nn.functional.pad(seed, (p, p, p, p), "constant", 0)
pool = seed.clone().repeat(args.pool_size, 1, 1, 1)
for it in tqdm(range(args.n_batches)):
batch_ixs = np.random.choice(
args.pool_size, args.batch_size, replace=False
).tolist()
x = pool[batch_ixs]
for i in range(np.random.randint(30, 40)):
x = model(x)
loss_batch = ((target_img - x[:, :4, ...]) ** 2).mean(dim=[1, 2, 3])
loss = loss_batch.mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
writer.add_scalar("train/loss", -loss, it)
if args.pool:
argmax_batch = loss_batch.argmax().item()
argmax_pool = batch_ixs[argmax_batch]
remaining_batch = [i for i in range(
args.batch_size) if i != argmax_batch]
remaining_pool = [i for i in batch_ixs if i != argmax_pool]
pool[argmax_pool] = seed.clone()
pool[remaining_pool] = x[remaining_batch].detach()
if it % args.eval_frequency == 0 or loss < args.th:
save_model(f'{args.logdir}/models/model_{it}.pt', model)
x_eval = seed.clone() # (1, n_channels, size, size)
eval_video = torch.empty(
1, args.eval_iterations, 3, *x_eval.shape[2:])
for it_eval in range(args.eval_iterations):
x_eval = model(x_eval)
x_eval_out = to_rgb(x_eval[:, :4].detach().cpu())
eval_video[0, it_eval] = x_eval_out
save_image(x_eval_out, f'{args.logdir}/pic/im_{it}.png')
writer.add_video("eval", eval_video, it, fps=60)
if loss < args.th:
break
if __name__ == "__main__":
main()
python train.py -p 0 -n 10000 -s 9 -hch 32 -pool=true -th 0
\ No newline at end of file
{
"img": "\ud83e\udd55",
"batch_size": 8,
"device": "cpu",
"eval_frequency": 500,
"eval_iterations": 300,
"n_batches": 20000,
"n_channels": 16,
"logdir": "logs/15-CARROT-train_05-05-2022_17-00-31",
"padding": 0,
"pool_size": 1024,
"size": 15,
"hidden_channels": 32,
"pool": true,
"th": 0.0,
"mode": "train"
}
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment