Skip to content
Snippets Groups Projects
Commit 7adbac8d authored by Tommy Duc Luu's avatar Tommy Duc Luu
Browse files

saving function added

Former-commit-id: 4b34f28970de0a7805ad5bf3b47f26bd89bef904 [formerly da337725]
Former-commit-id: 345a68665218f4dcd63a102dee997d858c357f16
parent 2cfaee43
No related branches found
No related tags found
No related merge requests found
......@@ -15,6 +15,16 @@ from model import CAModel, CellularAutomataModel
from utils import load_emoji, to_rgb, adv_attack
def make_seed(size, n_channels):
x = torch.zeros((1, n_channels, size, size), dtype=torch.float32)
x[:, 3:, size // 2, size // 2] = 1
return x
def to_rgb_ad(img_rgba):
rgb, a = img_rgba[:, :3, ...], torch.clamp(img_rgba[:, 3:, ...], 0, 1)
return torch.clamp(1.0 - a + rgb, 0, 1)
class Interactive:
def __init__(self, args):
......@@ -27,8 +37,11 @@ class Interactive:
self.eps = args.eps
self.emoji_size = args.emoji_size
if self.es: self.target_img = load_emoji(args.img, self.emoji_size)
else: self.target_img = torch.from_numpy(load_emoji(args.img, self.emoji_size)).permute(2, 0, 1)[None, ...]
if self.es:
self.target_img = load_emoji(args.img, self.emoji_size)
else:
self.target_img = torch.from_numpy(load_emoji(
args.img, self.emoji_size)).permute(2, 0, 1)[None, ...]
self.writer = SummaryWriter(self.logdir)
# auto calculate padding
p = (self.size-self.emoji_size)//2
......@@ -36,14 +49,17 @@ class Interactive:
if self.es:
self.pad_target = F.pad(tt(self.target_img), (0, 0, p, p, p, p))
self.net = CellularAutomataModel(n_channels=self.n_channels, fire_rate=self.fire_rate, hidden_channels=self.hidden_size)
self.net = CellularAutomataModel(
n_channels=self.n_channels, fire_rate=self.fire_rate, hidden_channels=self.hidden_size)
h, w = self.pad_target.shape[:2]
self.seed = np.zeros([h, w, self.n_channels], np.float64)
self.seed[h // 2, w // 2, 3:] = 1.0
else:
self.net = CAModel(n_channels=args.n_channels, hidden_channels=args.hidden_size)
self.net = CAModel(n_channels=args.n_channels,
hidden_channels=args.hidden_size)
self.seed = make_seed(self.size, args.n_channels)
self.pad_target = F.pad(self.target_img, (p, p, p, p), "constant", 0)
self.pad_target = F.pad(
self.target_img, (p, p, p, p), "constant", 0)
self.pad_target = self.pad_target.repeat(1, 1, 1, 1)
# whidden = torch.concat((self.pad_target, torch.zeros((self.size,self.size,12))), axis=2)
......@@ -55,23 +71,40 @@ class Interactive:
def load_model(self, path):
"""Load a PyTorch model from path."""
self.net.load_state_dict(torch.load(path))
if self.es: self.net.double()
if self.es:
self.net.double()
def seedclone(self):
if self.es: return tt(np.repeat(self.seed[None, ...], 1, 0))
else: return self.seed.clone()
if self.es:
return tt(np.repeat(self.seed[None, ...], 1, 0))
else:
return self.seed.clone()
def game_update(self, surface, cur_img, sz):
nxt = np.zeros((cur_img.shape[0], cur_img.shape[1]))
for r, c, _ in np.ndindex(cur_img.shape):
pygame.draw.rect(surface, cur_img[r, c], (c*sz, r*sz, sz, sz))
return nxt
def save_cell(self, x, path):
if self.es:
image = to_rgb(x).permute(0, 3, 1, 2)
else:
image = to_rgb_ad(x[:, :4].detach().cpu())
save_image(image, path, nrow=1, padding=0)
# Do damage on model using pygame, cannot run through ssh
def interactive(self):
x_eval = self.seedclone()
cellsize = 20
imgpath = '%s/one.png' % (self.logdir)
pygame.init()
surface = pygame.display.set_mode((self.size * cellsize, self.size * cellsize))
surface = pygame.display.set_mode(
(self.size * cellsize, self.size * cellsize))
pygame.display.set_caption("Interactive CA-ES")
damaged = 100
......@@ -91,10 +124,14 @@ class Interactive:
# mpos_y = (self.size // 2) + 1
# mpos_x = 0
dmg_size = self.size
if self.es: x_eval[:, mpos_y:mpos_y + dmg_size, mpos_x:mpos_x + dmg_size, :] = 0
else: x_eval[:, :, mpos_y:mpos_y + dmg_size, mpos_x:mpos_x + dmg_size] = 0
if self.es:
x_eval[:, mpos_y:mpos_y + dmg_size,
mpos_x:mpos_x + dmg_size, :] = 0
else:
x_eval[:, :, mpos_y:mpos_y + dmg_size,
mpos_x:mpos_x + dmg_size] = 0
# damaged = 0 # number of steps to record loss after damage has occurred
# # For noise:
# l_func = torch.nn.MSELoss()
# e = x_eval.detach().cpu()
......@@ -105,30 +142,34 @@ class Interactive:
# x_eval = adv_attack(x_eval, self.eps, e.grad.data)
pygame.display.set_caption("Saving loss...")
elif event.type == pygame.KEYDOWN:
if event.key == pygame.K_r: # reset
if event.key == pygame.K_r: # reset when pressing r
x_eval = self.seedclone()
counter = 0
x_eval = self.net(x_eval)
if self.es: image = to_rgb(x_eval).permute(0, 3, 1, 2)
else: image = to_rgb_ad(x_eval[:, :4].detach().cpu())
save_image(image, imgpath, nrow=1, padding=0)
self.save_cell(x_eval, imgpath)
# Damage at 51:
# if counter == 51:
# # For lower half:
# mpos_y = (self.size // 2) + 1
# mpos_x = 0
# dmg_size = self.size
# if self.es: x_eval[:, mpos_y:mpos_y + dmg_size, mpos_x:mpos_x + dmg_size, :] = 0
# else: x_eval[:, :, mpos_y:mpos_y + dmg_size, mpos_x:mpos_x + dmg_size] = 0
if counter == 40:
self.save_cell(x_eval, f'{self.logdir}/{counter}.png')
elif counter == 51:
# For lower half:
mpos_y = (self.size // 2) + 1
mpos_x = 0
dmg_size = self.size
# damage then save image
if self.es:
x_eval[:, mpos_y:mpos_y + dmg_size,
mpos_x:mpos_x + dmg_size, :] = 0
else:
x_eval[:, :, mpos_y:mpos_y + dmg_size,
mpos_x:mpos_x + dmg_size] = 0
self.save_cell(x_eval, f'{self.logdir}/{counter}.png')
elif counter == 60:
self.save_cell(x_eval, f'{self.logdir}/{counter}.png')
loss = self.net.loss(x_eval, self.pad_target)
self.writer.add_scalar("train/fit", loss, counter)
# # For manual damage:
# if damaged < 100:
# loss = self.net.loss(x_eval, self.pad_target)
......@@ -142,28 +183,10 @@ class Interactive:
image = np.asarray(Image.open(imgpath))
self.game_update(surface, image, cellsize)
time.sleep(0.05) # update delay
if counter == 400:
time.sleep(0.00) # update delay
if counter == 400:
print('Reached 400 iterations. Shutting down...')
pygame.quit()
sys.exit()
counter += 1
pygame.display.update()
def game_update(self, surface, cur_img, sz):
nxt = np.zeros((cur_img.shape[0], cur_img.shape[1]))
for r, c, _ in np.ndindex(cur_img.shape):
pygame.draw.rect(surface, cur_img[r,c], (c*sz, r*sz, sz, sz))
return nxt
def make_seed(size, n_channels):
x = torch.zeros((1, n_channels, size, size), dtype=torch.float32)
x[:, 3:, size // 2, size // 2] = 1
return x
def to_rgb_ad(img_rgba):
rgb, a = img_rgba[:, :3, ...], torch.clamp(img_rgba[:, 3:, ...], 0, 1)
return torch.clamp(1.0 - a + rgb, 0, 1)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment