interactive.py 7.24 KB
Newer Older
1
2
3
4

import time

import numpy as np
5
import sys
6
7
8
9
10
11
12
13
14
15
16
17
import pygame
import torch
import torch.nn.functional as F
from PIL import Image
from torch import tensor as tt
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image

from model import CAModel, CellularAutomataModel
from utils import load_emoji, to_rgb, adv_attack


Tommy Duc Luu's avatar
Tommy Duc Luu committed
18
19
20
21
22
23
24
25
26
27
def make_seed(size, n_channels):
    x = torch.zeros((1, n_channels, size, size), dtype=torch.float32)
    x[:, 3:, size // 2, size // 2] = 1
    return x


def to_rgb_ad(img_rgba):
    rgb, a = img_rgba[:, :3, ...], torch.clamp(img_rgba[:, 3:, ...], 0, 1)
    return torch.clamp(1.0 - a + rgb, 0, 1)

28

29
30
31
32
33
34
class Interactive:
    def __init__(self, args):
        self.n_channels = args.n_channels
        self.hidden_size = args.hidden_size
        self.fire_rate = args.fire_rate
        self.logdir = args.logdir
35
        self.size = args.size
36
37
        self.es = args.es
        self.eps = args.eps
38
        self.emoji_size = args.emoji_size
Tommy Duc Luu's avatar
Tommy Duc Luu committed
39
        self.imgpath = '%s/one.png' % (self.logdir)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
40
        self.isRepaired = False
41

Tommy Duc Luu's avatar
Tommy Duc Luu committed
42
43
44
45
46
        if self.es:
            self.target_img = load_emoji(args.img, self.emoji_size)
        else:
            self.target_img = torch.from_numpy(load_emoji(
                args.img, self.emoji_size)).permute(2, 0, 1)[None, ...]
47
        self.writer = SummaryWriter(self.logdir)
48
49
50
        # auto calculate padding
        p = (self.size-self.emoji_size)//2
        self.size = self.emoji_size+2*p
51
52
53

        if self.es:
            self.pad_target = F.pad(tt(self.target_img), (0, 0, p, p, p, p))
Tommy Duc Luu's avatar
Tommy Duc Luu committed
54
55
            self.net = CellularAutomataModel(
                n_channels=self.n_channels, fire_rate=self.fire_rate, hidden_channels=self.hidden_size)
56
57
58
59
            h, w = self.pad_target.shape[:2]
            self.seed = np.zeros([h, w, self.n_channels], np.float64)
            self.seed[h // 2, w // 2, 3:] = 1.0
        else:
Tommy Duc Luu's avatar
Tommy Duc Luu committed
60
61
            self.net = CAModel(n_channels=args.n_channels,
                               hidden_channels=args.hidden_size)
62
            self.seed = make_seed(self.size, args.n_channels)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
63
64
            self.pad_target = F.pad(
                self.target_img, (p, p, p, p), "constant", 0)
65
66
67
68
69
70
71
72
            self.pad_target = self.pad_target.repeat(1, 1, 1, 1)

        if args.load_model_path != "":
            self.load_model(args.load_model_path)

    def load_model(self, path):
        """Load a PyTorch model from path."""
        self.net.load_state_dict(torch.load(path))
Tommy Duc Luu's avatar
Tommy Duc Luu committed
73
74
        if self.es:
            self.net.double()
75

Tommy Duc Luu's avatar
Tommy Duc Luu committed
76
    def seedclone(self):
Tommy Duc Luu's avatar
Tommy Duc Luu committed
77
78
79
80
81
82
83
        if self.es:
            return tt(np.repeat(self.seed[None, ...], 1, 0))
        else:
            return self.seed.clone()

    def game_update(self, surface, cur_img, sz):
        nxt = np.zeros((cur_img.shape[0], cur_img.shape[1]))
Tommy Duc Luu's avatar
Tommy Duc Luu committed
84

Tommy Duc Luu's avatar
Tommy Duc Luu committed
85
86
87
88
89
90
91
92
93
94
95
        for r, c, _ in np.ndindex(cur_img.shape):
            pygame.draw.rect(surface, cur_img[r, c], (c*sz, r*sz, sz, sz))

        return nxt

    def save_cell(self, x, path):
        if self.es:
            image = to_rgb(x).permute(0, 3, 1, 2)
        else:
            image = to_rgb_ad(x[:, :4].detach().cpu())
        save_image(image, path, nrow=1, padding=0)
96
97

    # Do damage on model using pygame, cannot run through ssh
Tommy Duc Luu's avatar
Tommy Duc Luu committed
98

99
    def interactive(self):
Tommy Duc Luu's avatar
Tommy Duc Luu committed
100
        x_eval = self.seedclone()
101
102

        cellsize = 20
Tommy Duc Luu's avatar
Tommy Duc Luu committed
103

104
        pygame.init()
Tommy Duc Luu's avatar
Tommy Duc Luu committed
105
106
        surface = pygame.display.set_mode(
            (self.size * cellsize, self.size * cellsize))
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        pygame.display.set_caption("Interactive CA-ES")

        damaged = 100
        counter = 0

        while True:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    pygame.quit()
                    return
                elif event.type == pygame.MOUSEBUTTONDOWN:
                    # damage
                    if damaged == 100:
                        dmg_size = 20
                        mpos_x, mpos_y = event.pos
                        mpos_x, mpos_y = mpos_x // cellsize, mpos_y // cellsize
                        # mpos_y = (self.size // 2) + 1
                        # mpos_x = 0
                        dmg_size = self.size
Tommy Duc Luu's avatar
Tommy Duc Luu committed
126
127
128
129
130
131
                        if self.es:
                            x_eval[:, mpos_y:mpos_y + dmg_size,
                                   mpos_x:mpos_x + dmg_size, :] = 0
                        else:
                            x_eval[:, :, mpos_y:mpos_y + dmg_size,
                                   mpos_x:mpos_x + dmg_size] = 0
132
                        # damaged = 0 # number of steps to record loss after damage has occurred
Tommy Duc Luu's avatar
Tommy Duc Luu committed
133

134
135
136
137
138
139
140
141
142
                        # # For noise:
                        # l_func = torch.nn.MSELoss()
                        # e = x_eval.detach().cpu()
                        # e.requires_grad = True
                        # l = l_func(e, self.batch_target)
                        # self.net.zero_grad()
                        # l.backward()
                        # x_eval = adv_attack(x_eval, self.eps, e.grad.data)
                        pygame.display.set_caption("Saving loss...")
Tommy Duc Luu's avatar
Tommy Duc Luu committed
143
                elif event.type == pygame.KEYDOWN:
Tommy Duc Luu's avatar
Tommy Duc Luu committed
144
                    if event.key == pygame.K_r:  # reset when pressing r
Tommy Duc Luu's avatar
Tommy Duc Luu committed
145
                        x_eval = self.seedclone()
146
147

            x_eval = self.net(x_eval)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
148
            self.save_cell(x_eval, self.imgpath)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
149
            cur_path = f'{self.logdir}/{counter}.png'
Tommy Duc Luu's avatar
Tommy Duc Luu committed
150

Tommy Duc Luu's avatar
Tommy Duc Luu committed
151
            if counter in [40, 60, 100, 150, 200, 400]:
Tommy Duc Luu's avatar
Tommy Duc Luu committed
152
                self.save_cell(x_eval, cur_path)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
153
            # Quadratic erasing at 51
Tommy Duc Luu's avatar
Tommy Duc Luu committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
            # elif counter == 51:
            #     # record loss before dmg
            #     before_loss = self.net.loss(x_eval, self.pad_target)
            #     # For lower half:
            #     mpos_y = (self.size // 2) + 1
            #     mpos_x = 0
            #     dmg_size = self.size
            #     # damage then save image
            #     if self.es:
            #         x_eval[:, mpos_y:mpos_y + dmg_size,
            #                mpos_x:mpos_x + dmg_size, :] = 0
            #     else:
            #         x_eval[:, :, mpos_y:mpos_y + dmg_size,
            #                mpos_x:mpos_x + dmg_size] = 0
            #     self.save_cell(x_eval, cur_path)
169
170
171

            loss = self.net.loss(x_eval, self.pad_target)
            self.writer.add_scalar("train/fit", loss, counter)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
172

Tommy Duc Luu's avatar
Tommy Duc Luu committed
173
174
175
176
177
            # find the it that the image was repaired with tol=7 difference FITS together with quadratic erasing
            # if counter > 51 and loss/before_loss <= 7 and not self.isRepaired:
            #     self.save_cell(x_eval, cur_path)
            #     self.isRepaired = True

178
179
180
181
182
183
184
185
186
187
            # # For manual damage:
            # if damaged < 100:
            #     loss = self.net.loss(x_eval, self.pad_target)
            #     self.writer.add_scalar("train/fit", loss, damaged)

            #     if damaged == 99:
            #         pygame.display.set_caption("Interactive CA-ES")
            #     damaged += 1

            # Saving and loading each image as a quick hack to get rid of the batch dimension in tensor
Tommy Duc Luu's avatar
Tommy Duc Luu committed
188
            image = np.asarray(Image.open(self.imgpath))
189
            self.game_update(surface, image, cellsize)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
190
            time.sleep(0.00)  # update delay
Tommy Duc Luu's avatar
Tommy Duc Luu committed
191
            if counter == 400:
192
193
194
195
                print('Reached 400 iterations. Shutting down...')
                pygame.quit()
                sys.exit()
            counter += 1
196
            pygame.display.update()