Commit b6736647 authored by Bård Sørensen Hestmark's avatar Bård Sørensen Hestmark
Browse files

paths interactive

parent 90e1b245
......@@ -12,19 +12,7 @@ from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image
from model import CAModel, CellularAutomataModel
from utils import load_emoji, to_rgb, adv_attack
def make_seed(size, n_channels):
x = torch.zeros((1, n_channels, size, size), dtype=torch.float32)
x[:, 3:, size // 2, size // 2] = 1
return x
def to_rgb_ad(img_rgba):
rgb, a = img_rgba[:, :3, ...], torch.clamp(img_rgba[:, 3:, ...], 0, 1)
return torch.clamp(1.0 - a + rgb, 0, 1)
from utils import load_emoji, to_rgb, adv_attack, make_seed, to_rgb_ad
class Interactive:
def __init__(self, args):
......@@ -94,9 +82,9 @@ class Interactive:
image = to_rgb_ad(x[:, :4].detach().cpu())
save_image(image, path, nrow=1, padding=0)
# Do damage on model using pygame, cannot run through ssh
def interactive(self):
"""Do damage on model using pygame"""
x_eval = self.seedclone()
cellsize = 20
......
......@@ -29,6 +29,9 @@ adam_sample_models = [
es_nonsample_models = [
['final_models\\ES\\NonSamplePools\\9-CARROT-train_05-05-2022_09-14-06\\models\\model_2212000', '🥕', 9, True]
['final_models\\ES\NonSamplePools\\9-RABBIT FACE-train_06-05-2022_12-18-56\models\\model_1998000', '🐰', 9, True]
['final_models\\ES\\NonSamplePools\\15-CARROT-train_06-05-2022_11-06-58\models\model_1998000', '🥕', 15, True]
['rip_rabbit_15_evolution doth not favour the meek', '🐰', 0, None]
]
es_sample_models = [
......
......@@ -40,37 +40,13 @@ def to_rgba(x):
"""Return the four first channels (RGBA) of an image."""
return x[..., :4]
def save_model(ca, base_fn):
"""Save a PyTorch model to a specific path."""
torch.save(ca.state_dict(), base_fn)
def visualize(xs, step_i, nrow=1):
"""Save a batch of multiple x's to file"""
for i in range(len(xs)):
xs[i] = to_rgb(xs[i]).permute(0, 3, 1, 2)
save_image(torch.cat(xs, dim=0), './logg/pic/p%04d.png' % step_i, nrow=nrow, padding=0)
class Pool:
"""Class for storing and providing samples of different stages of growth."""
def __init__(self, seed, size):
self.size = size
self.slots = np.repeat([seed], size, 0)
self.seed = seed
def commit(self, batch):
"""Replace existing slots with a batch."""
indices = batch["indices"]
for i, x in enumerate(batch["x"]):
if (x[:, :, 3] > 0.1).any(): # Avoid committing dead image
self.slots[indices[i]] = x.copy()
def sample(self, c):
"""Retrieve a batch from the pool."""
indices = np.random.choice(self.size, c, False)
batch = {
"indices": indices,
"x": self.slots[indices]
}
return batch
def make_seed(size, n_channels):
x = torch.zeros((1, n_channels, size, size), dtype=torch.float32)
x[:, 3:, size // 2, size // 2] = 1
return x
def to_rgb_ad(img_rgba):
rgb, a = img_rgba[:, :3, ...], torch.clamp(img_rgba[:, 3:, ...], 0, 1)
return torch.clamp(1.0 - a + rgb, 0, 1)
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment