Skip to content
Snippets Groups Projects
Commit b6736647 authored by Bård Sørensen Hestmark's avatar Bård Sørensen Hestmark
Browse files

paths interactive

parent 90e1b245
No related branches found
No related tags found
No related merge requests found
......@@ -12,19 +12,7 @@ from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image
from model import CAModel, CellularAutomataModel
from utils import load_emoji, to_rgb, adv_attack
def make_seed(size, n_channels):
x = torch.zeros((1, n_channels, size, size), dtype=torch.float32)
x[:, 3:, size // 2, size // 2] = 1
return x
def to_rgb_ad(img_rgba):
rgb, a = img_rgba[:, :3, ...], torch.clamp(img_rgba[:, 3:, ...], 0, 1)
return torch.clamp(1.0 - a + rgb, 0, 1)
from utils import load_emoji, to_rgb, adv_attack, make_seed, to_rgb_ad
class Interactive:
def __init__(self, args):
......@@ -94,9 +82,9 @@ class Interactive:
image = to_rgb_ad(x[:, :4].detach().cpu())
save_image(image, path, nrow=1, padding=0)
# Do damage on model using pygame, cannot run through ssh
def interactive(self):
"""Do damage on model using pygame"""
x_eval = self.seedclone()
cellsize = 20
......
......@@ -29,6 +29,9 @@ adam_sample_models = [
es_nonsample_models = [
['final_models\\ES\\NonSamplePools\\9-CARROT-train_05-05-2022_09-14-06\\models\\model_2212000', '🥕', 9, True]
['final_models\\ES\NonSamplePools\\9-RABBIT FACE-train_06-05-2022_12-18-56\models\\model_1998000', '🐰', 9, True]
['final_models\\ES\\NonSamplePools\\15-CARROT-train_06-05-2022_11-06-58\models\model_1998000', '🥕', 15, True]
['rip_rabbit_15_evolution doth not favour the meek', '🐰', 0, None]
]
es_sample_models = [
......
......@@ -40,37 +40,13 @@ def to_rgba(x):
"""Return the four first channels (RGBA) of an image."""
return x[..., :4]
def save_model(ca, base_fn):
"""Save a PyTorch model to a specific path."""
torch.save(ca.state_dict(), base_fn)
def visualize(xs, step_i, nrow=1):
"""Save a batch of multiple x's to file"""
for i in range(len(xs)):
xs[i] = to_rgb(xs[i]).permute(0, 3, 1, 2)
save_image(torch.cat(xs, dim=0), './logg/pic/p%04d.png' % step_i, nrow=nrow, padding=0)
class Pool:
"""Class for storing and providing samples of different stages of growth."""
def __init__(self, seed, size):
self.size = size
self.slots = np.repeat([seed], size, 0)
self.seed = seed
def commit(self, batch):
"""Replace existing slots with a batch."""
indices = batch["indices"]
for i, x in enumerate(batch["x"]):
if (x[:, :, 3] > 0.1).any(): # Avoid committing dead image
self.slots[indices[i]] = x.copy()
def sample(self, c):
"""Retrieve a batch from the pool."""
indices = np.random.choice(self.size, c, False)
batch = {
"indices": indices,
"x": self.slots[indices]
}
return batch
def make_seed(size, n_channels):
x = torch.zeros((1, n_channels, size, size), dtype=torch.float32)
x[:, 3:, size // 2, size // 2] = 1
return x
def to_rgb_ad(img_rgba):
rgb, a = img_rgba[:, :3, ...], torch.clamp(img_rgba[:, 3:, ...], 0, 1)
return torch.clamp(1.0 - a + rgb, 0, 1)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment