Commit 53165ddd authored by Tommy Duc Luu's avatar Tommy Duc Luu
Browse files

adverarial test

parent 7985b3c0
......@@ -26,7 +26,7 @@ class Interactive:
self.emoji_size = args.emoji_size
self.imgpath = '%s/one.png' % (self.logdir)
self.isRepaired = False
self.l_func = torch.nn.MSELoss()
if self.es:
self.target_img = load_emoji(args.img, self.emoji_size)
else:
......@@ -44,6 +44,8 @@ class Interactive:
h, w = self.pad_target.shape[:2]
self.seed = np.zeros([h, w, self.n_channels], np.float64)
self.seed[h // 2, w // 2, 3:] = 1.0
whidden = torch.concat((self.pad_target.detach(), torch.zeros((self.size,self.size,12))), axis=2)
self.batch_target = np.repeat(whidden.clone().detach()[None, ...], 1, 0)
else:
self.net = CAModel(n_channels=args.n_channels,
hidden_channels=args.hidden_size)
......@@ -51,6 +53,10 @@ class Interactive:
self.pad_target = F.pad(
self.target_img, (p, p, p, p), "constant", 0)
self.pad_target = self.pad_target.repeat(1, 1, 1, 1)
whidden = torch.concat((self.pad_target[0].detach(), torch.zeros((12,self.size,self.size))), axis=0)
self.batch_target = np.repeat(whidden.clone().detach()[None, ...], 1, 0).float()
if args.load_model_path != "":
self.load_model(args.load_model_path)
......@@ -136,6 +142,14 @@ class Interactive:
self.save_cell(x_eval, self.imgpath)
cur_path = f'{self.logdir}/{counter}.png'
if counter < 40:
# For noise:
e = x_eval.clone().detach().cpu()
e.requires_grad = True
l = self.l_func(e, self.batch_target)
self.net.zero_grad()
l.backward()
x_eval = adv_attack(x_eval, self.eps, e.grad.data)
if counter in [40, 60, 100, 150, 200, 400]:
self.save_cell(x_eval, cur_path)
# Quadratic erasing at 51
......
......@@ -46,7 +46,7 @@ models = [adam_nonsample_models, adam_sample_models,
if __name__ == '__main__':
# pick model [Index of model types][index of 9x9 or 15x15 rabbit or carrot]
model = models[1][2] # change only this or size
model = models[0][0] # change only this or size
# Auto determined
load_model = model[0]
......@@ -74,7 +74,7 @@ if __name__ == '__main__':
metavar=0.5, help="Cell fire rate")
parser.add_argument("--es", type=bool, default=es,
metavar=True, help="ES or adam")
parser.add_argument("--eps", type=float, default=0.007,
parser.add_argument("--eps", type=float, default=0.00007,
help="Epsilon scales the amount of damage done from adversarial attacks")
args = parser.parse_args()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment