Commit cc53fe02 authored by Bård Sørensen Hestmark's avatar Bård Sørensen Hestmark
Browse files

Merge branch 'main' of gitlab.stud.idi.ntnu.no:baardshe/IDATT2900-45b


Former-commit-id: 3ae2f258ca7b0ea78f46f10f9c69c19bf9312f05 [formerly 611cadc4]
Former-commit-id: 791529c2da569a11d958c0773dc4e98d83ceb5f1
parents 802e3fb7 d30b74f0
......@@ -26,7 +26,7 @@ class Interactive:
self.emoji_size = args.emoji_size
self.imgpath = '%s/one.png' % (self.logdir)
self.isRepaired = False
self.l_func = torch.nn.MSELoss()
if self.es:
self.target_img = load_emoji(args.img, self.emoji_size)
else:
......@@ -44,6 +44,8 @@ class Interactive:
h, w = self.pad_target.shape[:2]
self.seed = np.zeros([h, w, self.n_channels], np.float64)
self.seed[h // 2, w // 2, 3:] = 1.0
whidden = torch.concat((self.pad_target.detach(), torch.zeros((self.size,self.size,12))), axis=2)
self.batch_target = np.repeat(whidden.clone().detach()[None, ...], 1, 0)
else:
self.net = CAModel(n_channels=args.n_channels,
hidden_channels=args.hidden_size)
......@@ -51,6 +53,10 @@ class Interactive:
self.pad_target = F.pad(
self.target_img, (p, p, p, p), "constant", 0)
self.pad_target = self.pad_target.repeat(1, 1, 1, 1)
whidden = torch.concat((self.pad_target[0].detach(), torch.zeros((12,self.size,self.size))), axis=0)
self.batch_target = np.repeat(whidden.clone().detach()[None, ...], 1, 0).float()
if args.load_model_path != "":
self.load_model(args.load_model_path)
......@@ -136,6 +142,14 @@ class Interactive:
self.save_cell(x_eval, self.imgpath)
cur_path = f'{self.logdir}/{counter}.png'
if counter < 40:
# For noise:
e = x_eval.clone().detach().cpu()
e.requires_grad = True
l = self.l_func(e, self.batch_target)
self.net.zero_grad()
l.backward()
x_eval = adv_attack(x_eval, self.eps, e.grad.data)
if counter in [40, 60, 100, 150, 200, 400]:
self.save_cell(x_eval, cur_path)
# Quadratic erasing at 51
......
......@@ -74,7 +74,7 @@ if __name__ == '__main__':
metavar=0.5, help="Cell fire rate")
parser.add_argument("-e", "--es", type=str, default='',
metavar=True, help="ES or adam")
parser.add_argument("--eps", type=float, default=0.007,
parser.add_argument("--eps", type=float, default=0.00007,
help="Epsilon scales the amount of damage done from adversarial attacks")
args = parser.parse_args()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment