Skip to content
Snippets Groups Projects
Commit ad9f026b authored by Bård Sørensen Hestmark's avatar Bård Sørensen Hestmark
Browse files

imggen

Former-commit-id: 550854c5f5a25bb7054a39680e210893cb1039c3 [formerly d1d867ca]
Former-commit-id: 08ad51819c48c86c4725965fc4257ef7a5f759e2
parent 2573e430
No related branches found
No related tags found
No related merge requests found
import copy
import math
from tqdm import trange
import multiprocessing as mp
import numpy as np
import torch
......@@ -226,21 +225,4 @@ class ES:
self.writer.add_scalar("growth_loss/200", growth_loss[1], iteration)
save_image(torch.cat(pics, dim=0), '%s/pic/big%04d.png' % (self.logdir, iteration), nrow=1, padding=0)
save_model(self.net, self.logdir + "/models/model_" + str(iteration))
# if mean_fit > -0.003:
# logging.info("Training goal reached, exiting")
# break
def generate_graphic(self):
model = self.net
x_eval = tt(np.repeat(self.seed[None, ...], self.batch_size, 0))
pics = []
pics.append(to_rgb(x_eval).permute(0, 3, 1, 2))
for eval in range(40):
x_eval = model(x_eval)
if eval in [10, 20, 30, 39]: # frames to save img of
pics.append(to_rgb(x_eval).permute(0, 3, 1, 2))
save_image(torch.cat(pics, dim=0), '%s/graphic.png' % (self.logdir), nrow=len(pics), padding=0)
\ No newline at end of file
......@@ -212,3 +212,22 @@ class Interactive:
print('Reached 400 iterations. Shutting down...')
pygame.quit()
sys.exit()
def generate_graphic(self):
model = self.net
x_eval = self.seedclone()
pics = []
pics.append(to_rgb(x_eval).permute(0, 3, 1, 2))
for eval in range(40):
x_eval = model(x_eval)
if eval in [4, 9, 20, 39]: # frames to save img of
if self.es:
image = to_rgb(x_eval).permute(0, 3, 1, 2)
else:
image = to_rgb_ad(x_eval[:, :4].detach().cpu())
pics.append(image)
save_image(torch.cat(pics, dim=0), '%s/graphic.png' % (self.logdir), nrow=len(pics), padding=0)
......@@ -54,7 +54,7 @@ if __name__ == '__main__':
"Logging directory '%s' not found in base folder" % args.logdir)
match args.es:
case 'True':
case 'True': #heh
method = 'ES'
args.es = True
case 'False':
......@@ -76,3 +76,4 @@ if __name__ == '__main__':
Interactive = Interactive(args)
Interactive.interactive()
# Interactive.generate_graphic_es()
......@@ -33,20 +33,20 @@ models = [adam_nonsample_models, adam_sample_models,
if __name__ == '__main__':
# Run all models:
for i in models:
for model in i:
load_model = model[0]
emoji = model[1]
size = model[2]
es = model[3]
command = "python .\interactive_CA\main.py -i %s -s %i -l %s -e %r" % (emoji, size, load_model, es)
subprocess.run(command)
# for i in models:
# for model in i:
# load_model = model[0]
# emoji = model[1]
# size = model[2]
# es = model[3]
# command = "python .\interactive_CA\main.py -i %s -s %i -l %s -e %r" % (emoji, size, load_model, es)
# subprocess.run(command)
# # Run single model:
# model = models[3][2]
# load_model = model[0]
# emoji = model[1]
# size = model[2]
# es = model[3]
# command = "python .\interactive_CA\main.py -i %s -s %i -l %s -e %r" % (emoji, size, load_model, es)
# subprocess.run(command)
\ No newline at end of file
model = models[3][2]
load_model = model[0]
emoji = model[1]
size = model[2]
es = model[3]
command = "python .\interactive_CA\main.py -i %s -s %i -l %s -e %r" % (emoji, size, load_model, es)
subprocess.run(command)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment