diff --git a/interactive_CA/interactive.py b/interactive_CA/interactive.py
index fb458baf0777b2459402313302f225a52b5cfc66..cf41ddfc6d7cb05a089ebe53b1a3d14e6c52fe1c 100644
--- a/interactive_CA/interactive.py
+++ b/interactive_CA/interactive.py
@@ -15,6 +15,16 @@ from model import CAModel, CellularAutomataModel
 from utils import load_emoji, to_rgb, adv_attack
 
 
+def make_seed(size, n_channels):
+    x = torch.zeros((1, n_channels, size, size), dtype=torch.float32)
+    x[:, 3:, size // 2, size // 2] = 1
+    return x
+
+
+def to_rgb_ad(img_rgba):
+    rgb, a = img_rgba[:, :3, ...], torch.clamp(img_rgba[:, 3:, ...], 0, 1)
+    return torch.clamp(1.0 - a + rgb, 0, 1)
+
 
 class Interactive:
     def __init__(self, args):
@@ -27,8 +37,11 @@ class Interactive:
         self.eps = args.eps
         self.emoji_size = args.emoji_size
 
-        if self.es: self.target_img = load_emoji(args.img, self.emoji_size)
-        else: self.target_img = torch.from_numpy(load_emoji(args.img, self.emoji_size)).permute(2, 0, 1)[None, ...]
+        if self.es:
+            self.target_img = load_emoji(args.img, self.emoji_size)
+        else:
+            self.target_img = torch.from_numpy(load_emoji(
+                args.img, self.emoji_size)).permute(2, 0, 1)[None, ...]
         self.writer = SummaryWriter(self.logdir)
         # auto calculate padding
         p = (self.size-self.emoji_size)//2
@@ -36,14 +49,17 @@ class Interactive:
 
         if self.es:
             self.pad_target = F.pad(tt(self.target_img), (0, 0, p, p, p, p))
-            self.net = CellularAutomataModel(n_channels=self.n_channels, fire_rate=self.fire_rate, hidden_channels=self.hidden_size)
+            self.net = CellularAutomataModel(
+                n_channels=self.n_channels, fire_rate=self.fire_rate, hidden_channels=self.hidden_size)
             h, w = self.pad_target.shape[:2]
             self.seed = np.zeros([h, w, self.n_channels], np.float64)
             self.seed[h // 2, w // 2, 3:] = 1.0
         else:
-            self.net = CAModel(n_channels=args.n_channels, hidden_channels=args.hidden_size)
+            self.net = CAModel(n_channels=args.n_channels,
+                               hidden_channels=args.hidden_size)
             self.seed = make_seed(self.size, args.n_channels)
-            self.pad_target = F.pad(self.target_img, (p, p, p, p), "constant", 0)
+            self.pad_target = F.pad(
+                self.target_img, (p, p, p, p), "constant", 0)
             self.pad_target = self.pad_target.repeat(1, 1, 1, 1)
 
         # whidden = torch.concat((self.pad_target, torch.zeros((self.size,self.size,12))), axis=2)
@@ -55,23 +71,40 @@ class Interactive:
     def load_model(self, path):
         """Load a PyTorch model from path."""
         self.net.load_state_dict(torch.load(path))
-        if self.es: self.net.double()
+        if self.es:
+            self.net.double()
 
     def seedclone(self):
-        if self.es: return tt(np.repeat(self.seed[None, ...], 1, 0))
-        else: return self.seed.clone()
+        if self.es:
+            return tt(np.repeat(self.seed[None, ...], 1, 0))
+        else:
+            return self.seed.clone()
+
+    def game_update(self, surface, cur_img, sz):
+        nxt = np.zeros((cur_img.shape[0], cur_img.shape[1]))
 
+        for r, c, _ in np.ndindex(cur_img.shape):
+            pygame.draw.rect(surface, cur_img[r, c], (c*sz, r*sz, sz, sz))
+
+        return nxt
+
+    def save_cell(self, x, path):
+        if self.es:
+            image = to_rgb(x).permute(0, 3, 1, 2)
+        else:
+            image = to_rgb_ad(x[:, :4].detach().cpu())
+        save_image(image, path, nrow=1, padding=0)
 
     # Do damage on model using pygame, cannot run through ssh
     def interactive(self):
-
         x_eval = self.seedclone()
 
         cellsize = 20
         imgpath = '%s/one.png' % (self.logdir)
-        
+
         pygame.init()
-        surface = pygame.display.set_mode((self.size * cellsize, self.size * cellsize))
+        surface = pygame.display.set_mode(
+            (self.size * cellsize, self.size * cellsize))
         pygame.display.set_caption("Interactive CA-ES")
 
         damaged = 100
@@ -91,10 +124,14 @@ class Interactive:
                         # mpos_y = (self.size // 2) + 1
                         # mpos_x = 0
                         dmg_size = self.size
-                        if self.es: x_eval[:, mpos_y:mpos_y + dmg_size, mpos_x:mpos_x + dmg_size, :] = 0
-                        else:       x_eval[:, :, mpos_y:mpos_y + dmg_size, mpos_x:mpos_x + dmg_size] = 0
+                        if self.es:
+                            x_eval[:, mpos_y:mpos_y + dmg_size,
+                                   mpos_x:mpos_x + dmg_size, :] = 0
+                        else:
+                            x_eval[:, :, mpos_y:mpos_y + dmg_size,
+                                   mpos_x:mpos_x + dmg_size] = 0
                         # damaged = 0 # number of steps to record loss after damage has occurred
-                        
+
                         # # For noise:
                         # l_func = torch.nn.MSELoss()
                         # e = x_eval.detach().cpu()
@@ -105,30 +142,34 @@ class Interactive:
                         # x_eval = adv_attack(x_eval, self.eps, e.grad.data)
                         pygame.display.set_caption("Saving loss...")
                 elif event.type == pygame.KEYDOWN:
-                    if event.key == pygame.K_r: # reset
+                    if event.key == pygame.K_r:  # reset when pressing r
                         x_eval = self.seedclone()
-                        counter = 0
-
 
             x_eval = self.net(x_eval)
-            if self.es: image = to_rgb(x_eval).permute(0, 3, 1, 2)
-            else: image = to_rgb_ad(x_eval[:, :4].detach().cpu())
-            
-            save_image(image, imgpath, nrow=1, padding=0)
-            
+            self.save_cell(x_eval, imgpath)
+
             # Damage at 51:
-            # if counter == 51:
-            #     # For lower half:
-            #     mpos_y = (self.size // 2) + 1
-            #     mpos_x = 0
-            #     dmg_size = self.size
-            #     if self.es: x_eval[:, mpos_y:mpos_y + dmg_size, mpos_x:mpos_x + dmg_size, :] = 0
-            #     else:       x_eval[:, :, mpos_y:mpos_y + dmg_size, mpos_x:mpos_x + dmg_size] = 0
-            
+            if counter == 40:
+                self.save_cell(x_eval, f'{self.logdir}/{counter}.png')
+            elif counter == 51:
+                # For lower half:
+                mpos_y = (self.size // 2) + 1
+                mpos_x = 0
+                dmg_size = self.size
+                # damage then save image
+                if self.es:
+                    x_eval[:, mpos_y:mpos_y + dmg_size,
+                           mpos_x:mpos_x + dmg_size, :] = 0
+                else:
+                    x_eval[:, :, mpos_y:mpos_y + dmg_size,
+                           mpos_x:mpos_x + dmg_size] = 0
+                self.save_cell(x_eval, f'{self.logdir}/{counter}.png')
+            elif counter == 60:
+                self.save_cell(x_eval, f'{self.logdir}/{counter}.png')
 
             loss = self.net.loss(x_eval, self.pad_target)
             self.writer.add_scalar("train/fit", loss, counter)
-            
+
             # # For manual damage:
             # if damaged < 100:
             #     loss = self.net.loss(x_eval, self.pad_target)
@@ -142,28 +183,10 @@ class Interactive:
             image = np.asarray(Image.open(imgpath))
 
             self.game_update(surface, image, cellsize)
-            time.sleep(0.05) # update delay
-            if counter == 400: 
+            time.sleep(0.00)  # update delay
+            if counter == 400:
                 print('Reached 400 iterations. Shutting down...')
                 pygame.quit()
                 sys.exit()
             counter += 1
             pygame.display.update()
-
-    def game_update(self, surface, cur_img, sz):
-        nxt = np.zeros((cur_img.shape[0], cur_img.shape[1]))
-
-        for r, c, _ in np.ndindex(cur_img.shape):
-            pygame.draw.rect(surface, cur_img[r,c], (c*sz, r*sz, sz, sz))
-
-        return nxt
-
-def make_seed(size, n_channels):
-    x = torch.zeros((1, n_channels, size, size), dtype=torch.float32)
-    x[:, 3:, size // 2, size // 2] = 1
-    return x
-
-def to_rgb_ad(img_rgba):
-    rgb, a = img_rgba[:, :3, ...], torch.clamp(img_rgba[:, 3:, ...], 0, 1)
-    return torch.clamp(1.0 - a + rgb, 0, 1)
-        
\ No newline at end of file