diff --git a/interactive_CA/interactive.py b/interactive_CA/interactive.py index 08d2019af2367db8d58281e3334335ccf78db813..fb458baf0777b2459402313302f225a52b5cfc66 100644 --- a/interactive_CA/interactive.py +++ b/interactive_CA/interactive.py @@ -2,6 +2,7 @@ import time import numpy as np +import sys import pygame import torch import torch.nn.functional as F @@ -14,20 +15,24 @@ from model import CAModel, CellularAutomataModel from utils import load_emoji, to_rgb, adv_attack + class Interactive: def __init__(self, args): self.n_channels = args.n_channels self.hidden_size = args.hidden_size self.fire_rate = args.fire_rate - self.size = args.size + 2 * args.padding self.logdir = args.logdir + self.size = args.size self.es = args.es self.eps = args.eps + self.emoji_size = args.emoji_size - if self.es: self.target_img = load_emoji(args.img, self.size) - else: self.target_img = torch.from_numpy(load_emoji(args.img, self.size)).permute(2, 0, 1)[None, ...] + if self.es: self.target_img = load_emoji(args.img, self.emoji_size) + else: self.target_img = torch.from_numpy(load_emoji(args.img, self.emoji_size)).permute(2, 0, 1)[None, ...] self.writer = SummaryWriter(self.logdir) - p = args.padding + # auto calculate padding + p = (self.size-self.emoji_size)//2 + self.size = self.emoji_size+2*p if self.es: self.pad_target = F.pad(tt(self.target_img), (0, 0, p, p, p, p)) @@ -37,8 +42,8 @@ class Interactive: self.seed[h // 2, w // 2, 3:] = 1.0 else: self.net = CAModel(n_channels=args.n_channels, hidden_channels=args.hidden_size) - self.seed = torch.nn.functional.pad(make_seed(args.size, args.n_channels), (p, p, p, p), "constant", 0) - self.pad_target = torch.nn.functional.pad(self.target_img, (p, p, p, p), "constant", 0) + self.seed = make_seed(self.size, args.n_channels) + self.pad_target = F.pad(self.target_img, (p, p, p, p), "constant", 0) self.pad_target = self.pad_target.repeat(1, 1, 1, 1) # whidden = torch.concat((self.pad_target, torch.zeros((self.size,self.size,12))), axis=2) @@ -100,8 +105,9 @@ class Interactive: # x_eval = adv_attack(x_eval, self.eps, e.grad.data) pygame.display.set_caption("Saving loss...") elif event.type == pygame.KEYDOWN: - if event.key == pygame.K_r: + if event.key == pygame.K_r: # reset x_eval = self.seedclone() + counter = 0 x_eval = self.net(x_eval) @@ -118,8 +124,7 @@ class Interactive: # dmg_size = self.size # if self.es: x_eval[:, mpos_y:mpos_y + dmg_size, mpos_x:mpos_x + dmg_size, :] = 0 # else: x_eval[:, :, mpos_y:mpos_y + dmg_size, mpos_x:mpos_x + dmg_size] = 0 - if counter == 400: pygame.quit() - counter += 1 + loss = self.net.loss(x_eval, self.pad_target) self.writer.add_scalar("train/fit", loss, counter) @@ -138,6 +143,11 @@ class Interactive: self.game_update(surface, image, cellsize) time.sleep(0.05) # update delay + if counter == 400: + print('Reached 400 iterations. Shutting down...') + pygame.quit() + sys.exit() + counter += 1 pygame.display.update() def game_update(self, surface, cur_img, sz): diff --git a/interactive_CA/main.py b/interactive_CA/main.py index 772c2523a6b70710e8bda691d34a8e97d2b62941..e1bd14745db3e64cd33483f07764f6c3f90f95b8 100644 --- a/interactive_CA/main.py +++ b/interactive_CA/main.py @@ -19,19 +19,19 @@ models = [ 'automata_ex\logs\CARROT-train_01-05-2022_14-19-47\models\model_16000.pt', 'automata_ex\logs\CARROT-train_01-05-2022_15-08-31\models\model_79000.pt', 'CA-ES\saved_models\\20_lizard', - 'final_models\Adam\SamplePools\\15-CARROT-train_05-05-2022_17-00-31\models\model_19500.pt' + 'final_models\Adam\SamplePools\\15-RABBIT-FACE-train_05-05-2022_13-44-43\models\model_19500.pt' ] if __name__ == '__main__': - emoji = '🥕' + emoji = '🐰' load_model = models[5] - size = 40 + size = 15 # canvas size + emoji_size = 15 # size of training image usually 9 or 15 es = False parser = argparse.ArgumentParser() parser.add_argument("-i", "--img", type=str, default=emoji, metavar="🐰", help="The emoji to train on") parser.add_argument("-s", "--size", type=int, default=size, help="Image size") - parser.add_argument("--padding", type=int, default=0, help="Padding. The shape after padding is (h + 2 * p, w + 2 * p).") parser.add_argument("--logdir", type=str, default="interactive_CA/logs", help="Logging folder for new model") parser.add_argument("-l", "--load_model_path", type=str, default=load_model, help="Path to pre trained model") parser.add_argument("--n_channels", type=int, default=16, help="Number of channels of the input tensor") @@ -41,11 +41,14 @@ if __name__ == '__main__': parser.add_argument("--eps", type=float, default=0.007, help="Epsilon scales the amount of damage done from adversarial attacks") args = parser.parse_args() + args.emoji_size = emoji_size if not os.path.isdir(args.logdir): raise Exception("Logging directory '%s' not found in base folder" % args.logdir) - args.logdir = "%s/ES-%s_%s" % (args.logdir, unicodedata.name(args.img), time.strftime("%d-%m-%Y_%H-%M-%S")) + method = 'ADAM' + if es: method='ES' + args.logdir = "%s/%s-%s_%s" % (args.logdir, method,unicodedata.name(args.img), time.strftime("%d-%m-%Y_%H-%M-%S")) os.mkdir(args.logdir) logging.basicConfig(filename='%s/logfile.log' % args.logdir, encoding='utf-8', level=logging.INFO)