diff --git a/CA-ES/es.py b/CA-ES/es.py index 91ef55959337f87e4a7391f2045c47765a8066ce..acf8a456fa8e893a475104936758db08b0b1bc6e 100644 --- a/CA-ES/es.py +++ b/CA-ES/es.py @@ -1,7 +1,6 @@ import copy import math from tqdm import trange -import multiprocessing as mp import numpy as np import torch @@ -226,21 +225,4 @@ class ES: self.writer.add_scalar("growth_loss/200", growth_loss[1], iteration) save_image(torch.cat(pics, dim=0), '%s/pic/big%04d.png' % (self.logdir, iteration), nrow=1, padding=0) save_model(self.net, self.logdir + "/models/model_" + str(iteration)) - - # if mean_fit > -0.003: - # logging.info("Training goal reached, exiting") - # break - - def generate_graphic(self): - model = self.net - x_eval = tt(np.repeat(self.seed[None, ...], self.batch_size, 0)) - pics = [] - pics.append(to_rgb(x_eval).permute(0, 3, 1, 2)) - - for eval in range(40): - x_eval = model(x_eval) - if eval in [10, 20, 30, 39]: # frames to save img of - pics.append(to_rgb(x_eval).permute(0, 3, 1, 2)) - - save_image(torch.cat(pics, dim=0), '%s/graphic.png' % (self.logdir), nrow=len(pics), padding=0) - \ No newline at end of file + diff --git a/interactive_CA/interactive.py b/interactive_CA/interactive.py index 72a284f130981957240d52b054c8e10aa8d4e503..e8d49e0183b57dd7238bb5b444dffdd5c63f9054 100644 --- a/interactive_CA/interactive.py +++ b/interactive_CA/interactive.py @@ -212,3 +212,22 @@ class Interactive: print('Reached 400 iterations. Shutting down...') pygame.quit() sys.exit() + + + def generate_graphic(self): + model = self.net + x_eval = self.seedclone() + pics = [] + pics.append(to_rgb(x_eval).permute(0, 3, 1, 2)) + + for eval in range(40): + x_eval = model(x_eval) + if eval in [4, 9, 20, 39]: # frames to save img of + if self.es: + image = to_rgb(x_eval).permute(0, 3, 1, 2) + else: + image = to_rgb_ad(x_eval[:, :4].detach().cpu()) + pics.append(image) + + save_image(torch.cat(pics, dim=0), '%s/graphic.png' % (self.logdir), nrow=len(pics), padding=0) + diff --git a/interactive_CA/main.py b/interactive_CA/main.py index 076915514fb3fabc9d0b4e7a62c7a025c71f1bc3..f3fa3c8a77140cbd589c46dcf62477a25f8e0f8d 100644 --- a/interactive_CA/main.py +++ b/interactive_CA/main.py @@ -54,7 +54,7 @@ if __name__ == '__main__': "Logging directory '%s' not found in base folder" % args.logdir) match args.es: - case 'True': + case 'True': #heh method = 'ES' args.es = True case 'False': @@ -76,3 +76,4 @@ if __name__ == '__main__': Interactive = Interactive(args) Interactive.interactive() + # Interactive.generate_graphic_es() diff --git a/interactive_CA/run_all.py b/interactive_CA/run_all.py index 49d1776da11cd32be57a54b5f29d4fa74533de81..f6243fbba0d77531c72f5eff965cc0c795ba6bd3 100644 --- a/interactive_CA/run_all.py +++ b/interactive_CA/run_all.py @@ -33,20 +33,20 @@ models = [adam_nonsample_models, adam_sample_models, if __name__ == '__main__': # Run all models: - for i in models: - for model in i: - load_model = model[0] - emoji = model[1] - size = model[2] - es = model[3] - command = "python .\interactive_CA\main.py -i %s -s %i -l %s -e %r" % (emoji, size, load_model, es) - subprocess.run(command) + # for i in models: + # for model in i: + # load_model = model[0] + # emoji = model[1] + # size = model[2] + # es = model[3] + # command = "python .\interactive_CA\main.py -i %s -s %i -l %s -e %r" % (emoji, size, load_model, es) + # subprocess.run(command) # # Run single model: - # model = models[3][2] - # load_model = model[0] - # emoji = model[1] - # size = model[2] - # es = model[3] - # command = "python .\interactive_CA\main.py -i %s -s %i -l %s -e %r" % (emoji, size, load_model, es) - # subprocess.run(command) \ No newline at end of file + model = models[3][2] + load_model = model[0] + emoji = model[1] + size = model[2] + es = model[3] + command = "python .\interactive_CA\main.py -i %s -s %i -l %s -e %r" % (emoji, size, load_model, es) + subprocess.run(command) \ No newline at end of file