diff --git a/interactive_CA/interactive.py b/interactive_CA/interactive.py index dcffc4083793d8342269e8d3a46422a27e418240..b95fdfb93657a56e7b0d8b2a1d41f5c94354f3f8 100644 --- a/interactive_CA/interactive.py +++ b/interactive_CA/interactive.py @@ -12,19 +12,7 @@ from torch.utils.tensorboard import SummaryWriter from torchvision.utils import save_image from model import CAModel, CellularAutomataModel -from utils import load_emoji, to_rgb, adv_attack - - -def make_seed(size, n_channels): - x = torch.zeros((1, n_channels, size, size), dtype=torch.float32) - x[:, 3:, size // 2, size // 2] = 1 - return x - - -def to_rgb_ad(img_rgba): - rgb, a = img_rgba[:, :3, ...], torch.clamp(img_rgba[:, 3:, ...], 0, 1) - return torch.clamp(1.0 - a + rgb, 0, 1) - +from utils import load_emoji, to_rgb, adv_attack, make_seed, to_rgb_ad class Interactive: def __init__(self, args): @@ -94,9 +82,9 @@ class Interactive: image = to_rgb_ad(x[:, :4].detach().cpu()) save_image(image, path, nrow=1, padding=0) - # Do damage on model using pygame, cannot run through ssh def interactive(self): + """Do damage on model using pygame""" x_eval = self.seedclone() cellsize = 20 diff --git a/interactive_CA/main.py b/interactive_CA/main.py index 1ffcc2247147ac8e2d4188b67c1c26fac35ab606..60817dfb85f09ed7873b5904ea1402e8abaf9db2 100644 --- a/interactive_CA/main.py +++ b/interactive_CA/main.py @@ -29,6 +29,9 @@ adam_sample_models = [ es_nonsample_models = [ ['final_models\\ES\\NonSamplePools\\9-CARROT-train_05-05-2022_09-14-06\\models\\model_2212000', '🥕', 9, True] + ['final_models\\ES\NonSamplePools\\9-RABBIT FACE-train_06-05-2022_12-18-56\models\\model_1998000', '🐰', 9, True] + ['final_models\\ES\\NonSamplePools\\15-CARROT-train_06-05-2022_11-06-58\models\model_1998000', '🥕', 15, True] + ['rip_rabbit_15_evolution doth not favour the meek', '🐰', 0, None] ] es_sample_models = [ diff --git a/interactive_CA/utils.py b/interactive_CA/utils.py index 0488d1c81a432b1c5cf9fbd4fee6567bf0f94714..8f190761e7bc91d3942340f064ac29befa81f006 100644 --- a/interactive_CA/utils.py +++ b/interactive_CA/utils.py @@ -40,37 +40,13 @@ def to_rgba(x): """Return the four first channels (RGBA) of an image.""" return x[..., :4] -def save_model(ca, base_fn): - """Save a PyTorch model to a specific path.""" - torch.save(ca.state_dict(), base_fn) - -def visualize(xs, step_i, nrow=1): - """Save a batch of multiple x's to file""" - for i in range(len(xs)): - xs[i] = to_rgb(xs[i]).permute(0, 3, 1, 2) - save_image(torch.cat(xs, dim=0), './logg/pic/p%04d.png' % step_i, nrow=nrow, padding=0) - -class Pool: - """Class for storing and providing samples of different stages of growth.""" - def __init__(self, seed, size): - self.size = size - self.slots = np.repeat([seed], size, 0) - self.seed = seed - - def commit(self, batch): - """Replace existing slots with a batch.""" - indices = batch["indices"] - for i, x in enumerate(batch["x"]): - if (x[:, :, 3] > 0.1).any(): # Avoid committing dead image - self.slots[indices[i]] = x.copy() - - def sample(self, c): - """Retrieve a batch from the pool.""" - indices = np.random.choice(self.size, c, False) - batch = { - "indices": indices, - "x": self.slots[indices] - } - return batch +def make_seed(size, n_channels): + x = torch.zeros((1, n_channels, size, size), dtype=torch.float32) + x[:, 3:, size // 2, size // 2] = 1 + return x + +def to_rgb_ad(img_rgba): + rgb, a = img_rgba[:, :3, ...], torch.clamp(img_rgba[:, 3:, ...], 0, 1) + return torch.clamp(1.0 - a + rgb, 0, 1) \ No newline at end of file