From 709c0d30e46fd94e17f44c2007f2022fbf543094 Mon Sep 17 00:00:00 2001 From: baardshe <baardshe@stud.ntnu.no> Date: Sun, 8 May 2022 16:22:33 +0200 Subject: [PATCH] interactive run all models --- interactive_CA/interactive.py | 8 +++---- interactive_CA/main.py | 37 +++++++++++++++++------------- interactive_CA/run_all.py | 42 +++++++++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 20 deletions(-) create mode 100644 interactive_CA/run_all.py diff --git a/interactive_CA/interactive.py b/interactive_CA/interactive.py index 91cda6e..922982c 100644 --- a/interactive_CA/interactive.py +++ b/interactive_CA/interactive.py @@ -139,7 +139,7 @@ class Interactive: if counter in [40, 60, 100, 150, 200, 400]: self.save_cell(x_eval, cur_path) # Quadratic erasing at 51 - # elif counter == 51: + # if counter == 51: # # record loss before dmg # before_loss = self.net.loss(x_eval, self.pad_target) # # For lower half: @@ -176,9 +176,9 @@ class Interactive: image = np.asarray(Image.open(self.imgpath)) self.game_update(surface, image, cellsize) time.sleep(0.00) # update delay - if counter == 2000: + counter += 1 + pygame.display.update() + if counter == 400: print('Reached 400 iterations. Shutting down...') pygame.quit() sys.exit() - counter += 1 - pygame.display.update() diff --git a/interactive_CA/main.py b/interactive_CA/main.py index fac062f..a953bd0 100644 --- a/interactive_CA/main.py +++ b/interactive_CA/main.py @@ -29,7 +29,7 @@ adam_sample_models = [ es_nonsample_models = [ ['final_models\\ES\\NonSamplePools\\9-CARROT-train_05-05-2022_09-14-06\\models\\model_2212000', '🥕', 9, True], - ['final_models\\ES\\NonSamplePools\\9-RABBIT FACE-train_06-05-2022_12-18-56\\models\\model_1999000', '🐰', 9, True ], + ['final_models\\ES\\NonSamplePools\\9-RABBIT FACE-train_06-05-2022_12-18-56\\models\\model_1999000', '🐰', 9, True], ['final_models\\ES\\NonSamplePools\\15-CARROT-train_06-05-2022_11-06-58\\models\\model_1999000', '🥕', 15, True], ['final_models\\ES\\NonSamplePools\\15-RABBIT FACE-train_06-05-2022_12-21-00\\models\\model_1999000', '🐰', 15, True] ] @@ -46,49 +46,54 @@ models = [adam_nonsample_models, adam_sample_models, if __name__ == '__main__': # pick model [Index of model types][index of 9x9 or 15x15 rabbit or carrot] - model = models[1][2] # change only this or size + # model = models[1][2] # change only this or size # Auto determined - load_model = model[0] - emoji = model[1] - emoji_size = model[2] # size of training image usually 9 or 15 - es = model[3] + # load_model = model[0] + # emoji = model[1] + # emoji_size = model[2] # size of training image usually 9 or 15 + # es = model[3] # canvas size - size = emoji_size + # size = emoji_size parser = argparse.ArgumentParser() - parser.add_argument("-i", "--img", type=str, default=emoji, + parser.add_argument("-i", "--img", type=str, default='', metavar="🐰", help="The emoji to train on") parser.add_argument("-s", "--size", type=int, - default=size, help="Image size") + default=9, help="Image size") parser.add_argument("--logdir", type=str, default="interactive_CA/logs", help="Logging folder for new model") parser.add_argument("-l", "--load_model_path", type=str, - default=load_model, help="Path to pre trained model") + default='', help="Path to pre trained model") parser.add_argument("--n_channels", type=int, default=16, help="Number of channels of the input tensor") parser.add_argument("--hidden_size", type=int, default=32, help="Number of hidden channels") parser.add_argument("--fire_rate", type=float, default=0.5, metavar=0.5, help="Cell fire rate") - parser.add_argument("--es", type=bool, default=es, + parser.add_argument("-e", "--es", type=str, default='', metavar=True, help="ES or adam") parser.add_argument("--eps", type=float, default=0.007, help="Epsilon scales the amount of damage done from adversarial attacks") args = parser.parse_args() - args.emoji_size = emoji_size + args.emoji_size = args.size if not os.path.isdir(args.logdir): raise Exception( "Logging directory '%s' not found in base folder" % args.logdir) - match es: - case True: method = 'ES' - case False: method = 'ADAM' + match args.es: + case 'True': + method = 'ES' + args.es = True + case 'False': + method = 'ADAM' + args.es = False + - args.logdir = "%s/%s-%s-%s_%s" % (args.logdir, emoji_size, method, + args.logdir = "%s/%s-%s-%s_%s" % (args.logdir, args.emoji_size, method, unicodedata.name(args.img), time.strftime("%d-%m-%Y_%H-%M-%S")) os.mkdir(args.logdir) diff --git a/interactive_CA/run_all.py b/interactive_CA/run_all.py new file mode 100644 index 0000000..bef124f --- /dev/null +++ b/interactive_CA/run_all.py @@ -0,0 +1,42 @@ +import subprocess + +adam_nonsample_models = [ + ['final_models\\Adam\\NonSamplePools\\9-CARROT-train_05-05-2022_17-55-17\\models\\model_19500.pt', '🥕', 9, False], + ['final_models\\Adam\\NonSamplePools\\9-RABBIT-FACE-train_05-05-2022_16-07-21\\models\\model_19500.pt', '🐰', 9, False], + ['final_models\\Adam\\NonSamplePools\\15-CARROT-train_05-05-2022_17-31-03\\models\\model_19500.pt', '🥕', 15, False], + ['final_models\\Adam\\NonSamplePools\\15-RABBIT-FACE-train_05-05-2022_14-39-24\\models\\model_19500.pt', '🐰', 15, False] +] + +adam_sample_models = [ + ['final_models\\Adam\\SamplePools\\9-CARROT-train_05-05-2022_17-04-48\\models\\model_19500.pt', '🥕', 9, False], + ['final_models\\Adam\\SamplePools\\9-RABBIT-FACE-train_05-05-2022_11-43-14\\models\\model_19500.pt', '🐰', 9, False], + ['final_models\\Adam\\SamplePools\\15-CARROT-train_05-05-2022_17-00-31\\models\\model_19500.pt', '🥕', 15, False], + ['final_models\\Adam\\SamplePools\\15-RABBIT-FACE-train_05-05-2022_13-44-43\\models\\model_19500.pt', '🐰', 15, False] +] + +es_nonsample_models = [ + ['final_models\\ES\\NonSamplePools\\9-CARROT-train_05-05-2022_09-14-06\\models\\model_2212000', '🥕', 9, True], + ['final_models\\ES\\NonSamplePools\\9-RABBIT-FACE-train_06-05-2022_12-18-56\\models\\model_1999000', '🐰', 9, True ], + ['final_models\\ES\\NonSamplePools\\15-CARROT-train_06-05-2022_11-06-58\\models\\model_1999000', '🥕', 15, True], + ['final_models\\ES\\NonSamplePools\\15-RABBIT-FACE-train_06-05-2022_12-21-00\\models\\model_1999000', '🐰', 15, True] +] + +es_sample_models = [ + ['final_models\\ES\\SamplePools\\9-CARROT-train_29-04-2022_11-33-16\\models\\model_1036000', '🥕', 9, True], + ['final_models\\ES\\SamplePools\\9-RABBIT-FACE-train_01-05-2022_10-32-24\\models\\model_1157000', '🐰', 9, True], + ['final_models\\ES\\SamplePools\\15-CARROT-train_29-04-2022_11-18-06\\models\\model_1105000', '🥕', 15, True], + ['final_models\\ES\\SamplePools\\15-RABBIT-FACE-train_01-05-2022_10-32-40\\models\\model_1126000', '🐰', 15, True] +] + +models = [adam_nonsample_models, adam_sample_models, + es_nonsample_models, es_sample_models] + +if __name__ == '__main__': + for i in models: + for model in i: + load_model = model[0] + emoji = model[1] + size = model[2] + es = model[3] + command = "python .\interactive_CA\main.py -i %s -s %i -l %s -e %r" % (emoji, size, load_model, es) + subprocess.run(command) \ No newline at end of file -- GitLab