interactive.py 7.72 KB
Newer Older
1
2
3
4

import time

import numpy as np
5
import sys
6
7
8
9
10
11
12
13
14
import pygame
import torch
import torch.nn.functional as F
from PIL import Image
from torch import tensor as tt
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image

from model import CAModel, CellularAutomataModel
Bård Sørensen Hestmark's avatar
Bård Sørensen Hestmark committed
15
from utils import load_emoji, to_rgb, adv_attack, make_seed, to_rgb_ad
16

17
18
19
20
21
22
class Interactive:
    def __init__(self, args):
        self.n_channels = args.n_channels
        self.hidden_size = args.hidden_size
        self.fire_rate = args.fire_rate
        self.logdir = args.logdir
23
        self.size = args.size
24
25
        self.es = args.es
        self.eps = args.eps
26
        self.emoji_size = args.emoji_size
Tommy Duc Luu's avatar
Tommy Duc Luu committed
27
        self.imgpath = '%s/one.png' % (self.logdir)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
28
        self.isRepaired = False
Tommy Duc Luu's avatar
Tommy Duc Luu committed
29
        self.l_func = torch.nn.MSELoss()
Tommy Duc Luu's avatar
Tommy Duc Luu committed
30
31
32
33
34
        if self.es:
            self.target_img = load_emoji(args.img, self.emoji_size)
        else:
            self.target_img = torch.from_numpy(load_emoji(
                args.img, self.emoji_size)).permute(2, 0, 1)[None, ...]
35
        self.writer = SummaryWriter(self.logdir)
36
37
38
        # auto calculate padding
        p = (self.size-self.emoji_size)//2
        self.size = self.emoji_size+2*p
39
40
41

        if self.es:
            self.pad_target = F.pad(tt(self.target_img), (0, 0, p, p, p, p))
Tommy Duc Luu's avatar
Tommy Duc Luu committed
42
43
            self.net = CellularAutomataModel(
                n_channels=self.n_channels, fire_rate=self.fire_rate, hidden_channels=self.hidden_size)
44
45
46
            h, w = self.pad_target.shape[:2]
            self.seed = np.zeros([h, w, self.n_channels], np.float64)
            self.seed[h // 2, w // 2, 3:] = 1.0
Tommy Duc Luu's avatar
Tommy Duc Luu committed
47
48
            whidden = torch.concat((self.pad_target.detach(), torch.zeros((self.size,self.size,12))), axis=2)
            self.batch_target = np.repeat(whidden.clone().detach()[None, ...], 1, 0)
49
        else:
Tommy Duc Luu's avatar
Tommy Duc Luu committed
50
51
            self.net = CAModel(n_channels=args.n_channels,
                               hidden_channels=args.hidden_size)
52
            self.seed = make_seed(self.size, args.n_channels)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
53
54
            self.pad_target = F.pad(
                self.target_img, (p, p, p, p), "constant", 0)
55
            self.pad_target = self.pad_target.repeat(1, 1, 1, 1)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
56
57
58
59
            whidden = torch.concat((self.pad_target[0].detach(), torch.zeros((12,self.size,self.size))), axis=0)
            self.batch_target = np.repeat(whidden.clone().detach()[None, ...], 1, 0).float()

        
60
61
62
63
64
65
66

        if args.load_model_path != "":
            self.load_model(args.load_model_path)

    def load_model(self, path):
        """Load a PyTorch model from path."""
        self.net.load_state_dict(torch.load(path))
Tommy Duc Luu's avatar
Tommy Duc Luu committed
67
68
        if self.es:
            self.net.double()
69

Tommy Duc Luu's avatar
Tommy Duc Luu committed
70
    def seedclone(self):
Tommy Duc Luu's avatar
Tommy Duc Luu committed
71
72
73
74
75
76
77
        if self.es:
            return tt(np.repeat(self.seed[None, ...], 1, 0))
        else:
            return self.seed.clone()

    def game_update(self, surface, cur_img, sz):
        nxt = np.zeros((cur_img.shape[0], cur_img.shape[1]))
Tommy Duc Luu's avatar
Tommy Duc Luu committed
78

Tommy Duc Luu's avatar
Tommy Duc Luu committed
79
80
81
82
83
84
85
86
87
88
89
        for r, c, _ in np.ndindex(cur_img.shape):
            pygame.draw.rect(surface, cur_img[r, c], (c*sz, r*sz, sz, sz))

        return nxt

    def save_cell(self, x, path):
        if self.es:
            image = to_rgb(x).permute(0, 3, 1, 2)
        else:
            image = to_rgb_ad(x[:, :4].detach().cpu())
        save_image(image, path, nrow=1, padding=0)
90

Tommy Duc Luu's avatar
Tommy Duc Luu committed
91

92
    def interactive(self):
Bård Sørensen Hestmark's avatar
Bård Sørensen Hestmark committed
93
        """Do damage on model using pygame"""
Tommy Duc Luu's avatar
Tommy Duc Luu committed
94
        x_eval = self.seedclone()
95
96

        cellsize = 20
Tommy Duc Luu's avatar
Tommy Duc Luu committed
97

98
        pygame.init()
Tommy Duc Luu's avatar
Tommy Duc Luu committed
99
100
        surface = pygame.display.set_mode(
            (self.size * cellsize, self.size * cellsize))
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        pygame.display.set_caption("Interactive CA-ES")

        damaged = 100
        counter = 0

        while True:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    pygame.quit()
                    return
                elif event.type == pygame.MOUSEBUTTONDOWN:
                    # damage
                    if damaged == 100:
                        dmg_size = 20
                        mpos_x, mpos_y = event.pos
                        mpos_x, mpos_y = mpos_x // cellsize, mpos_y // cellsize
                        # mpos_y = (self.size // 2) + 1
                        # mpos_x = 0
                        dmg_size = self.size
Tommy Duc Luu's avatar
Tommy Duc Luu committed
120
121
122
123
124
125
                        if self.es:
                            x_eval[:, mpos_y:mpos_y + dmg_size,
                                   mpos_x:mpos_x + dmg_size, :] = 0
                        else:
                            x_eval[:, :, mpos_y:mpos_y + dmg_size,
                                   mpos_x:mpos_x + dmg_size] = 0
126
                        # damaged = 0 # number of steps to record loss after damage has occurred
Tommy Duc Luu's avatar
Tommy Duc Luu committed
127

128
129
130
131
132
133
134
135
136
                        # # For noise:
                        # l_func = torch.nn.MSELoss()
                        # e = x_eval.detach().cpu()
                        # e.requires_grad = True
                        # l = l_func(e, self.batch_target)
                        # self.net.zero_grad()
                        # l.backward()
                        # x_eval = adv_attack(x_eval, self.eps, e.grad.data)
                        pygame.display.set_caption("Saving loss...")
Tommy Duc Luu's avatar
Tommy Duc Luu committed
137
                elif event.type == pygame.KEYDOWN:
Tommy Duc Luu's avatar
Tommy Duc Luu committed
138
                    if event.key == pygame.K_r:  # reset when pressing r
Tommy Duc Luu's avatar
Tommy Duc Luu committed
139
                        x_eval = self.seedclone()
140
141

            x_eval = self.net(x_eval)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
142
            self.save_cell(x_eval, self.imgpath)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
143
            cur_path = f'{self.logdir}/{counter}.png'
Tommy Duc Luu's avatar
Tommy Duc Luu committed
144

Tommy Duc Luu's avatar
Tommy Duc Luu committed
145
146
147
148
149
150
151
152
            if counter < 40:
                # For noise:
                e = x_eval.clone().detach().cpu()
                e.requires_grad = True
                l = self.l_func(e, self.batch_target)
                self.net.zero_grad()
                l.backward()
                x_eval = adv_attack(x_eval, self.eps, e.grad.data)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
153
            if counter in [40, 60, 100, 150, 200, 400]:
Tommy Duc Luu's avatar
Tommy Duc Luu committed
154
                self.save_cell(x_eval, cur_path)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
155
            # Quadratic erasing at 51
Tommy Duc Luu's avatar
Tommy Duc Luu committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
            # elif counter == 51:
            #     # record loss before dmg
            #     before_loss = self.net.loss(x_eval, self.pad_target)
            #     # For lower half:
            #     mpos_y = (self.size // 2) + 1
            #     mpos_x = 0
            #     dmg_size = self.size
            #     # damage then save image
            #     if self.es:
            #         x_eval[:, mpos_y:mpos_y + dmg_size,
            #                mpos_x:mpos_x + dmg_size, :] = 0
            #     else:
            #         x_eval[:, :, mpos_y:mpos_y + dmg_size,
            #                mpos_x:mpos_x + dmg_size] = 0
            #     self.save_cell(x_eval, cur_path)
171
172
173

            loss = self.net.loss(x_eval, self.pad_target)
            self.writer.add_scalar("train/fit", loss, counter)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
174

Tommy Duc Luu's avatar
Tommy Duc Luu committed
175
176
177
178
179
            # find the it that the image was repaired with tol=7 difference FITS together with quadratic erasing
            # if counter > 51 and loss/before_loss <= 7 and not self.isRepaired:
            #     self.save_cell(x_eval, cur_path)
            #     self.isRepaired = True

180
181
182
183
184
185
186
187
188
189
            # # For manual damage:
            # if damaged < 100:
            #     loss = self.net.loss(x_eval, self.pad_target)
            #     self.writer.add_scalar("train/fit", loss, damaged)

            #     if damaged == 99:
            #         pygame.display.set_caption("Interactive CA-ES")
            #     damaged += 1

            # Saving and loading each image as a quick hack to get rid of the batch dimension in tensor
Tommy Duc Luu's avatar
Tommy Duc Luu committed
190
            image = np.asarray(Image.open(self.imgpath))
191
            self.game_update(surface, image, cellsize)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
192
            time.sleep(0.00)  # update delay
Tommy Duc Luu's avatar
tests    
Tommy Duc Luu committed
193
            if counter == 2000:
194
195
196
197
                print('Reached 400 iterations. Shutting down...')
                pygame.quit()
                sys.exit()
            counter += 1
198
            pygame.display.update()