Skip to content
Snippets Groups Projects
Commit 709c0d30 authored by Bård Sørensen Hestmark's avatar Bård Sørensen Hestmark
Browse files

interactive run all models

parent a6c3fe7a
No related branches found
No related tags found
No related merge requests found
...@@ -139,7 +139,7 @@ class Interactive: ...@@ -139,7 +139,7 @@ class Interactive:
if counter in [40, 60, 100, 150, 200, 400]: if counter in [40, 60, 100, 150, 200, 400]:
self.save_cell(x_eval, cur_path) self.save_cell(x_eval, cur_path)
# Quadratic erasing at 51 # Quadratic erasing at 51
# elif counter == 51: # if counter == 51:
# # record loss before dmg # # record loss before dmg
# before_loss = self.net.loss(x_eval, self.pad_target) # before_loss = self.net.loss(x_eval, self.pad_target)
# # For lower half: # # For lower half:
...@@ -176,9 +176,9 @@ class Interactive: ...@@ -176,9 +176,9 @@ class Interactive:
image = np.asarray(Image.open(self.imgpath)) image = np.asarray(Image.open(self.imgpath))
self.game_update(surface, image, cellsize) self.game_update(surface, image, cellsize)
time.sleep(0.00) # update delay time.sleep(0.00) # update delay
if counter == 2000: counter += 1
pygame.display.update()
if counter == 400:
print('Reached 400 iterations. Shutting down...') print('Reached 400 iterations. Shutting down...')
pygame.quit() pygame.quit()
sys.exit() sys.exit()
counter += 1
pygame.display.update()
...@@ -46,49 +46,54 @@ models = [adam_nonsample_models, adam_sample_models, ...@@ -46,49 +46,54 @@ models = [adam_nonsample_models, adam_sample_models,
if __name__ == '__main__': if __name__ == '__main__':
# pick model [Index of model types][index of 9x9 or 15x15 rabbit or carrot] # pick model [Index of model types][index of 9x9 or 15x15 rabbit or carrot]
model = models[1][2] # change only this or size # model = models[1][2] # change only this or size
# Auto determined # Auto determined
load_model = model[0] # load_model = model[0]
emoji = model[1] # emoji = model[1]
emoji_size = model[2] # size of training image usually 9 or 15 # emoji_size = model[2] # size of training image usually 9 or 15
es = model[3] # es = model[3]
# canvas size # canvas size
size = emoji_size # size = emoji_size
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-i", "--img", type=str, default=emoji, parser.add_argument("-i", "--img", type=str, default='',
metavar="🐰", help="The emoji to train on") metavar="🐰", help="The emoji to train on")
parser.add_argument("-s", "--size", type=int, parser.add_argument("-s", "--size", type=int,
default=size, help="Image size") default=9, help="Image size")
parser.add_argument("--logdir", type=str, default="interactive_CA/logs", parser.add_argument("--logdir", type=str, default="interactive_CA/logs",
help="Logging folder for new model") help="Logging folder for new model")
parser.add_argument("-l", "--load_model_path", type=str, parser.add_argument("-l", "--load_model_path", type=str,
default=load_model, help="Path to pre trained model") default='', help="Path to pre trained model")
parser.add_argument("--n_channels", type=int, default=16, parser.add_argument("--n_channels", type=int, default=16,
help="Number of channels of the input tensor") help="Number of channels of the input tensor")
parser.add_argument("--hidden_size", type=int, default=32, parser.add_argument("--hidden_size", type=int, default=32,
help="Number of hidden channels") help="Number of hidden channels")
parser.add_argument("--fire_rate", type=float, default=0.5, parser.add_argument("--fire_rate", type=float, default=0.5,
metavar=0.5, help="Cell fire rate") metavar=0.5, help="Cell fire rate")
parser.add_argument("--es", type=bool, default=es, parser.add_argument("-e", "--es", type=str, default='',
metavar=True, help="ES or adam") metavar=True, help="ES or adam")
parser.add_argument("--eps", type=float, default=0.007, parser.add_argument("--eps", type=float, default=0.007,
help="Epsilon scales the amount of damage done from adversarial attacks") help="Epsilon scales the amount of damage done from adversarial attacks")
args = parser.parse_args() args = parser.parse_args()
args.emoji_size = emoji_size args.emoji_size = args.size
if not os.path.isdir(args.logdir): if not os.path.isdir(args.logdir):
raise Exception( raise Exception(
"Logging directory '%s' not found in base folder" % args.logdir) "Logging directory '%s' not found in base folder" % args.logdir)
match es: match args.es:
case True: method = 'ES' case 'True':
case False: method = 'ADAM' method = 'ES'
args.es = True
case 'False':
method = 'ADAM'
args.es = False
args.logdir = "%s/%s-%s-%s_%s" % (args.logdir, emoji_size, method,
args.logdir = "%s/%s-%s-%s_%s" % (args.logdir, args.emoji_size, method,
unicodedata.name(args.img), time.strftime("%d-%m-%Y_%H-%M-%S")) unicodedata.name(args.img), time.strftime("%d-%m-%Y_%H-%M-%S"))
os.mkdir(args.logdir) os.mkdir(args.logdir)
......
import subprocess
adam_nonsample_models = [
['final_models\\Adam\\NonSamplePools\\9-CARROT-train_05-05-2022_17-55-17\\models\\model_19500.pt', '🥕', 9, False],
['final_models\\Adam\\NonSamplePools\\9-RABBIT-FACE-train_05-05-2022_16-07-21\\models\\model_19500.pt', '🐰', 9, False],
['final_models\\Adam\\NonSamplePools\\15-CARROT-train_05-05-2022_17-31-03\\models\\model_19500.pt', '🥕', 15, False],
['final_models\\Adam\\NonSamplePools\\15-RABBIT-FACE-train_05-05-2022_14-39-24\\models\\model_19500.pt', '🐰', 15, False]
]
adam_sample_models = [
['final_models\\Adam\\SamplePools\\9-CARROT-train_05-05-2022_17-04-48\\models\\model_19500.pt', '🥕', 9, False],
['final_models\\Adam\\SamplePools\\9-RABBIT-FACE-train_05-05-2022_11-43-14\\models\\model_19500.pt', '🐰', 9, False],
['final_models\\Adam\\SamplePools\\15-CARROT-train_05-05-2022_17-00-31\\models\\model_19500.pt', '🥕', 15, False],
['final_models\\Adam\\SamplePools\\15-RABBIT-FACE-train_05-05-2022_13-44-43\\models\\model_19500.pt', '🐰', 15, False]
]
es_nonsample_models = [
['final_models\\ES\\NonSamplePools\\9-CARROT-train_05-05-2022_09-14-06\\models\\model_2212000', '🥕', 9, True],
['final_models\\ES\\NonSamplePools\\9-RABBIT-FACE-train_06-05-2022_12-18-56\\models\\model_1999000', '🐰', 9, True ],
['final_models\\ES\\NonSamplePools\\15-CARROT-train_06-05-2022_11-06-58\\models\\model_1999000', '🥕', 15, True],
['final_models\\ES\\NonSamplePools\\15-RABBIT-FACE-train_06-05-2022_12-21-00\\models\\model_1999000', '🐰', 15, True]
]
es_sample_models = [
['final_models\\ES\\SamplePools\\9-CARROT-train_29-04-2022_11-33-16\\models\\model_1036000', '🥕', 9, True],
['final_models\\ES\\SamplePools\\9-RABBIT-FACE-train_01-05-2022_10-32-24\\models\\model_1157000', '🐰', 9, True],
['final_models\\ES\\SamplePools\\15-CARROT-train_29-04-2022_11-18-06\\models\\model_1105000', '🥕', 15, True],
['final_models\\ES\\SamplePools\\15-RABBIT-FACE-train_01-05-2022_10-32-40\\models\\model_1126000', '🐰', 15, True]
]
models = [adam_nonsample_models, adam_sample_models,
es_nonsample_models, es_sample_models]
if __name__ == '__main__':
for i in models:
for model in i:
load_model = model[0]
emoji = model[1]
size = model[2]
es = model[3]
command = "python .\interactive_CA\main.py -i %s -s %i -l %s -e %r" % (emoji, size, load_model, es)
subprocess.run(command)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment