diff --git a/interactive_CA/interactive.py b/interactive_CA/interactive.py index fb458baf0777b2459402313302f225a52b5cfc66..cf41ddfc6d7cb05a089ebe53b1a3d14e6c52fe1c 100644 --- a/interactive_CA/interactive.py +++ b/interactive_CA/interactive.py @@ -15,6 +15,16 @@ from model import CAModel, CellularAutomataModel from utils import load_emoji, to_rgb, adv_attack +def make_seed(size, n_channels): + x = torch.zeros((1, n_channels, size, size), dtype=torch.float32) + x[:, 3:, size // 2, size // 2] = 1 + return x + + +def to_rgb_ad(img_rgba): + rgb, a = img_rgba[:, :3, ...], torch.clamp(img_rgba[:, 3:, ...], 0, 1) + return torch.clamp(1.0 - a + rgb, 0, 1) + class Interactive: def __init__(self, args): @@ -27,8 +37,11 @@ class Interactive: self.eps = args.eps self.emoji_size = args.emoji_size - if self.es: self.target_img = load_emoji(args.img, self.emoji_size) - else: self.target_img = torch.from_numpy(load_emoji(args.img, self.emoji_size)).permute(2, 0, 1)[None, ...] + if self.es: + self.target_img = load_emoji(args.img, self.emoji_size) + else: + self.target_img = torch.from_numpy(load_emoji( + args.img, self.emoji_size)).permute(2, 0, 1)[None, ...] self.writer = SummaryWriter(self.logdir) # auto calculate padding p = (self.size-self.emoji_size)//2 @@ -36,14 +49,17 @@ class Interactive: if self.es: self.pad_target = F.pad(tt(self.target_img), (0, 0, p, p, p, p)) - self.net = CellularAutomataModel(n_channels=self.n_channels, fire_rate=self.fire_rate, hidden_channels=self.hidden_size) + self.net = CellularAutomataModel( + n_channels=self.n_channels, fire_rate=self.fire_rate, hidden_channels=self.hidden_size) h, w = self.pad_target.shape[:2] self.seed = np.zeros([h, w, self.n_channels], np.float64) self.seed[h // 2, w // 2, 3:] = 1.0 else: - self.net = CAModel(n_channels=args.n_channels, hidden_channels=args.hidden_size) + self.net = CAModel(n_channels=args.n_channels, + hidden_channels=args.hidden_size) self.seed = make_seed(self.size, args.n_channels) - self.pad_target = F.pad(self.target_img, (p, p, p, p), "constant", 0) + self.pad_target = F.pad( + self.target_img, (p, p, p, p), "constant", 0) self.pad_target = self.pad_target.repeat(1, 1, 1, 1) # whidden = torch.concat((self.pad_target, torch.zeros((self.size,self.size,12))), axis=2) @@ -55,23 +71,40 @@ class Interactive: def load_model(self, path): """Load a PyTorch model from path.""" self.net.load_state_dict(torch.load(path)) - if self.es: self.net.double() + if self.es: + self.net.double() def seedclone(self): - if self.es: return tt(np.repeat(self.seed[None, ...], 1, 0)) - else: return self.seed.clone() + if self.es: + return tt(np.repeat(self.seed[None, ...], 1, 0)) + else: + return self.seed.clone() + + def game_update(self, surface, cur_img, sz): + nxt = np.zeros((cur_img.shape[0], cur_img.shape[1])) + for r, c, _ in np.ndindex(cur_img.shape): + pygame.draw.rect(surface, cur_img[r, c], (c*sz, r*sz, sz, sz)) + + return nxt + + def save_cell(self, x, path): + if self.es: + image = to_rgb(x).permute(0, 3, 1, 2) + else: + image = to_rgb_ad(x[:, :4].detach().cpu()) + save_image(image, path, nrow=1, padding=0) # Do damage on model using pygame, cannot run through ssh def interactive(self): - x_eval = self.seedclone() cellsize = 20 imgpath = '%s/one.png' % (self.logdir) - + pygame.init() - surface = pygame.display.set_mode((self.size * cellsize, self.size * cellsize)) + surface = pygame.display.set_mode( + (self.size * cellsize, self.size * cellsize)) pygame.display.set_caption("Interactive CA-ES") damaged = 100 @@ -91,10 +124,14 @@ class Interactive: # mpos_y = (self.size // 2) + 1 # mpos_x = 0 dmg_size = self.size - if self.es: x_eval[:, mpos_y:mpos_y + dmg_size, mpos_x:mpos_x + dmg_size, :] = 0 - else: x_eval[:, :, mpos_y:mpos_y + dmg_size, mpos_x:mpos_x + dmg_size] = 0 + if self.es: + x_eval[:, mpos_y:mpos_y + dmg_size, + mpos_x:mpos_x + dmg_size, :] = 0 + else: + x_eval[:, :, mpos_y:mpos_y + dmg_size, + mpos_x:mpos_x + dmg_size] = 0 # damaged = 0 # number of steps to record loss after damage has occurred - + # # For noise: # l_func = torch.nn.MSELoss() # e = x_eval.detach().cpu() @@ -105,30 +142,34 @@ class Interactive: # x_eval = adv_attack(x_eval, self.eps, e.grad.data) pygame.display.set_caption("Saving loss...") elif event.type == pygame.KEYDOWN: - if event.key == pygame.K_r: # reset + if event.key == pygame.K_r: # reset when pressing r x_eval = self.seedclone() - counter = 0 - x_eval = self.net(x_eval) - if self.es: image = to_rgb(x_eval).permute(0, 3, 1, 2) - else: image = to_rgb_ad(x_eval[:, :4].detach().cpu()) - - save_image(image, imgpath, nrow=1, padding=0) - + self.save_cell(x_eval, imgpath) + # Damage at 51: - # if counter == 51: - # # For lower half: - # mpos_y = (self.size // 2) + 1 - # mpos_x = 0 - # dmg_size = self.size - # if self.es: x_eval[:, mpos_y:mpos_y + dmg_size, mpos_x:mpos_x + dmg_size, :] = 0 - # else: x_eval[:, :, mpos_y:mpos_y + dmg_size, mpos_x:mpos_x + dmg_size] = 0 - + if counter == 40: + self.save_cell(x_eval, f'{self.logdir}/{counter}.png') + elif counter == 51: + # For lower half: + mpos_y = (self.size // 2) + 1 + mpos_x = 0 + dmg_size = self.size + # damage then save image + if self.es: + x_eval[:, mpos_y:mpos_y + dmg_size, + mpos_x:mpos_x + dmg_size, :] = 0 + else: + x_eval[:, :, mpos_y:mpos_y + dmg_size, + mpos_x:mpos_x + dmg_size] = 0 + self.save_cell(x_eval, f'{self.logdir}/{counter}.png') + elif counter == 60: + self.save_cell(x_eval, f'{self.logdir}/{counter}.png') loss = self.net.loss(x_eval, self.pad_target) self.writer.add_scalar("train/fit", loss, counter) - + # # For manual damage: # if damaged < 100: # loss = self.net.loss(x_eval, self.pad_target) @@ -142,28 +183,10 @@ class Interactive: image = np.asarray(Image.open(imgpath)) self.game_update(surface, image, cellsize) - time.sleep(0.05) # update delay - if counter == 400: + time.sleep(0.00) # update delay + if counter == 400: print('Reached 400 iterations. Shutting down...') pygame.quit() sys.exit() counter += 1 pygame.display.update() - - def game_update(self, surface, cur_img, sz): - nxt = np.zeros((cur_img.shape[0], cur_img.shape[1])) - - for r, c, _ in np.ndindex(cur_img.shape): - pygame.draw.rect(surface, cur_img[r,c], (c*sz, r*sz, sz, sz)) - - return nxt - -def make_seed(size, n_channels): - x = torch.zeros((1, n_channels, size, size), dtype=torch.float32) - x[:, 3:, size // 2, size // 2] = 1 - return x - -def to_rgb_ad(img_rgba): - rgb, a = img_rgba[:, :3, ...], torch.clamp(img_rgba[:, 3:, ...], 0, 1) - return torch.clamp(1.0 - a + rgb, 0, 1) - \ No newline at end of file