interactive.py 6.95 KB
Newer Older
1
2
3
4

import time

import numpy as np
5
import sys
6
7
8
9
10
11
12
13
14
import pygame
import torch
import torch.nn.functional as F
from PIL import Image
from torch import tensor as tt
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image

from model import CAModel, CellularAutomataModel
Bård Sørensen Hestmark's avatar
Bård Sørensen Hestmark committed
15
from utils import load_emoji, to_rgb, adv_attack, make_seed, to_rgb_ad
16

17
18
19
20
21
22
class Interactive:
    def __init__(self, args):
        self.n_channels = args.n_channels
        self.hidden_size = args.hidden_size
        self.fire_rate = args.fire_rate
        self.logdir = args.logdir
23
        self.size = args.size
24
25
        self.es = args.es
        self.eps = args.eps
26
        self.emoji_size = args.emoji_size
Tommy Duc Luu's avatar
Tommy Duc Luu committed
27
        self.imgpath = '%s/one.png' % (self.logdir)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
28
        self.isRepaired = False
29

Tommy Duc Luu's avatar
Tommy Duc Luu committed
30
31
32
33
34
        if self.es:
            self.target_img = load_emoji(args.img, self.emoji_size)
        else:
            self.target_img = torch.from_numpy(load_emoji(
                args.img, self.emoji_size)).permute(2, 0, 1)[None, ...]
35
        self.writer = SummaryWriter(self.logdir)
36
37
38
        # auto calculate padding
        p = (self.size-self.emoji_size)//2
        self.size = self.emoji_size+2*p
39
40
41

        if self.es:
            self.pad_target = F.pad(tt(self.target_img), (0, 0, p, p, p, p))
Tommy Duc Luu's avatar
Tommy Duc Luu committed
42
43
            self.net = CellularAutomataModel(
                n_channels=self.n_channels, fire_rate=self.fire_rate, hidden_channels=self.hidden_size)
44
45
46
47
            h, w = self.pad_target.shape[:2]
            self.seed = np.zeros([h, w, self.n_channels], np.float64)
            self.seed[h // 2, w // 2, 3:] = 1.0
        else:
Tommy Duc Luu's avatar
Tommy Duc Luu committed
48
49
            self.net = CAModel(n_channels=args.n_channels,
                               hidden_channels=args.hidden_size)
50
            self.seed = make_seed(self.size, args.n_channels)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
51
52
            self.pad_target = F.pad(
                self.target_img, (p, p, p, p), "constant", 0)
53
54
55
56
57
58
59
60
            self.pad_target = self.pad_target.repeat(1, 1, 1, 1)

        if args.load_model_path != "":
            self.load_model(args.load_model_path)

    def load_model(self, path):
        """Load a PyTorch model from path."""
        self.net.load_state_dict(torch.load(path))
Tommy Duc Luu's avatar
Tommy Duc Luu committed
61
62
        if self.es:
            self.net.double()
63

Tommy Duc Luu's avatar
Tommy Duc Luu committed
64
    def seedclone(self):
Tommy Duc Luu's avatar
Tommy Duc Luu committed
65
66
67
68
69
70
71
        if self.es:
            return tt(np.repeat(self.seed[None, ...], 1, 0))
        else:
            return self.seed.clone()

    def game_update(self, surface, cur_img, sz):
        nxt = np.zeros((cur_img.shape[0], cur_img.shape[1]))
Tommy Duc Luu's avatar
Tommy Duc Luu committed
72

Tommy Duc Luu's avatar
Tommy Duc Luu committed
73
74
75
76
77
78
79
80
81
82
83
        for r, c, _ in np.ndindex(cur_img.shape):
            pygame.draw.rect(surface, cur_img[r, c], (c*sz, r*sz, sz, sz))

        return nxt

    def save_cell(self, x, path):
        if self.es:
            image = to_rgb(x).permute(0, 3, 1, 2)
        else:
            image = to_rgb_ad(x[:, :4].detach().cpu())
        save_image(image, path, nrow=1, padding=0)
84

Tommy Duc Luu's avatar
Tommy Duc Luu committed
85

86
    def interactive(self):
Bård Sørensen Hestmark's avatar
Bård Sørensen Hestmark committed
87
        """Do damage on model using pygame"""
Tommy Duc Luu's avatar
Tommy Duc Luu committed
88
        x_eval = self.seedclone()
89
90

        cellsize = 20
Tommy Duc Luu's avatar
Tommy Duc Luu committed
91

92
        pygame.init()
Tommy Duc Luu's avatar
Tommy Duc Luu committed
93
94
        surface = pygame.display.set_mode(
            (self.size * cellsize, self.size * cellsize))
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        pygame.display.set_caption("Interactive CA-ES")

        damaged = 100
        counter = 0

        while True:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    pygame.quit()
                    return
                elif event.type == pygame.MOUSEBUTTONDOWN:
                    # damage
                    if damaged == 100:
                        dmg_size = 20
                        mpos_x, mpos_y = event.pos
                        mpos_x, mpos_y = mpos_x // cellsize, mpos_y // cellsize
                        # mpos_y = (self.size // 2) + 1
                        # mpos_x = 0
                        dmg_size = self.size
Tommy Duc Luu's avatar
Tommy Duc Luu committed
114
115
116
117
118
119
                        if self.es:
                            x_eval[:, mpos_y:mpos_y + dmg_size,
                                   mpos_x:mpos_x + dmg_size, :] = 0
                        else:
                            x_eval[:, :, mpos_y:mpos_y + dmg_size,
                                   mpos_x:mpos_x + dmg_size] = 0
120
                        # damaged = 0 # number of steps to record loss after damage has occurred
Tommy Duc Luu's avatar
Tommy Duc Luu committed
121

122
123
124
125
126
127
128
129
130
                        # # For noise:
                        # l_func = torch.nn.MSELoss()
                        # e = x_eval.detach().cpu()
                        # e.requires_grad = True
                        # l = l_func(e, self.batch_target)
                        # self.net.zero_grad()
                        # l.backward()
                        # x_eval = adv_attack(x_eval, self.eps, e.grad.data)
                        pygame.display.set_caption("Saving loss...")
Tommy Duc Luu's avatar
Tommy Duc Luu committed
131
                elif event.type == pygame.KEYDOWN:
Tommy Duc Luu's avatar
Tommy Duc Luu committed
132
                    if event.key == pygame.K_r:  # reset when pressing r
Tommy Duc Luu's avatar
Tommy Duc Luu committed
133
                        x_eval = self.seedclone()
134
135

            x_eval = self.net(x_eval)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
136
            self.save_cell(x_eval, self.imgpath)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
137
            cur_path = f'{self.logdir}/{counter}.png'
Tommy Duc Luu's avatar
Tommy Duc Luu committed
138

Tommy Duc Luu's avatar
Tommy Duc Luu committed
139
            if counter in [40, 60, 100, 150, 200, 400]:
Tommy Duc Luu's avatar
Tommy Duc Luu committed
140
                self.save_cell(x_eval, cur_path)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
141
            # Quadratic erasing at 51
Tommy Duc Luu's avatar
Tommy Duc Luu committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
            # elif counter == 51:
            #     # record loss before dmg
            #     before_loss = self.net.loss(x_eval, self.pad_target)
            #     # For lower half:
            #     mpos_y = (self.size // 2) + 1
            #     mpos_x = 0
            #     dmg_size = self.size
            #     # damage then save image
            #     if self.es:
            #         x_eval[:, mpos_y:mpos_y + dmg_size,
            #                mpos_x:mpos_x + dmg_size, :] = 0
            #     else:
            #         x_eval[:, :, mpos_y:mpos_y + dmg_size,
            #                mpos_x:mpos_x + dmg_size] = 0
            #     self.save_cell(x_eval, cur_path)
157
158
159

            loss = self.net.loss(x_eval, self.pad_target)
            self.writer.add_scalar("train/fit", loss, counter)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
160

Tommy Duc Luu's avatar
Tommy Duc Luu committed
161
162
163
164
165
            # find the it that the image was repaired with tol=7 difference FITS together with quadratic erasing
            # if counter > 51 and loss/before_loss <= 7 and not self.isRepaired:
            #     self.save_cell(x_eval, cur_path)
            #     self.isRepaired = True

166
167
168
169
170
171
172
173
174
175
            # # For manual damage:
            # if damaged < 100:
            #     loss = self.net.loss(x_eval, self.pad_target)
            #     self.writer.add_scalar("train/fit", loss, damaged)

            #     if damaged == 99:
            #         pygame.display.set_caption("Interactive CA-ES")
            #     damaged += 1

            # Saving and loading each image as a quick hack to get rid of the batch dimension in tensor
Tommy Duc Luu's avatar
Tommy Duc Luu committed
176
            image = np.asarray(Image.open(self.imgpath))
177
            self.game_update(surface, image, cellsize)
Tommy Duc Luu's avatar
Tommy Duc Luu committed
178
            time.sleep(0.00)  # update delay
Tommy Duc Luu's avatar
Tommy Duc Luu committed
179
            if counter == 400:
180
181
182
183
                print('Reached 400 iterations. Shutting down...')
                pygame.quit()
                sys.exit()
            counter += 1
184
            pygame.display.update()