Skip to content
Snippets Groups Projects
Commit 10b78689 authored by Tommy Duc Luu's avatar Tommy Duc Luu
Browse files

interactive changes

parent e9f52f24
No related branches found
No related tags found
No related merge requests found
...@@ -52,12 +52,15 @@ class Interactive: ...@@ -52,12 +52,15 @@ class Interactive:
self.net.load_state_dict(torch.load(path)) self.net.load_state_dict(torch.load(path))
if self.es: self.net.double() if self.es: self.net.double()
def seedclone(self):
if self.es: return tt(np.repeat(self.seed[None, ...], 1, 0))
else: return self.seed.clone()
# Do damage on model using pygame, cannot run through ssh # Do damage on model using pygame, cannot run through ssh
def interactive(self): def interactive(self):
if self.es: x_eval = tt(np.repeat(self.seed[None, ...], 1, 0)) x_eval = self.seedclone()
else: x_eval = self.seed.clone()
cellsize = 20 cellsize = 20
imgpath = '%s/one.png' % (self.logdir) imgpath = '%s/one.png' % (self.logdir)
...@@ -96,6 +99,10 @@ class Interactive: ...@@ -96,6 +99,10 @@ class Interactive:
# l.backward() # l.backward()
# x_eval = adv_attack(x_eval, self.eps, e.grad.data) # x_eval = adv_attack(x_eval, self.eps, e.grad.data)
pygame.display.set_caption("Saving loss...") pygame.display.set_caption("Saving loss...")
elif event.type == pygame.KEYDOWN:
if event.key == pygame.K_r:
x_eval = self.seedclone()
x_eval = self.net(x_eval) x_eval = self.net(x_eval)
if self.es: image = to_rgb(x_eval).permute(0, 3, 1, 2) if self.es: image = to_rgb(x_eval).permute(0, 3, 1, 2)
...@@ -111,7 +118,7 @@ class Interactive: ...@@ -111,7 +118,7 @@ class Interactive:
# dmg_size = self.size # dmg_size = self.size
# if self.es: x_eval[:, mpos_y:mpos_y + dmg_size, mpos_x:mpos_x + dmg_size, :] = 0 # if self.es: x_eval[:, mpos_y:mpos_y + dmg_size, mpos_x:mpos_x + dmg_size, :] = 0
# else: x_eval[:, :, mpos_y:mpos_y + dmg_size, mpos_x:mpos_x + dmg_size] = 0 # else: x_eval[:, :, mpos_y:mpos_y + dmg_size, mpos_x:mpos_x + dmg_size] = 0
if counter == 400: pygame.quit() if counter == 4000000: pygame.quit()
counter += 1 counter += 1
loss = self.net.loss(x_eval, self.pad_target) loss = self.net.loss(x_eval, self.pad_target)
...@@ -130,7 +137,7 @@ class Interactive: ...@@ -130,7 +137,7 @@ class Interactive:
image = np.asarray(Image.open(imgpath)) image = np.asarray(Image.open(imgpath))
self.game_update(surface, image, cellsize) self.game_update(surface, image, cellsize)
time.sleep(0.05) time.sleep(0.05) # update delay
pygame.display.update() pygame.display.update()
def game_update(self, surface, cur_img, sz): def game_update(self, surface, cur_img, sz):
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment