baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
No virus
2.8 kB
import os
import random
import numpy as np
from src.gan.gankits import *
from src.utils.filesys import getpath
from src.utils.img import make_img_sheet
from src.utils.datastruct import RingQueue
from src.olgen.olg_policy import RLGenPolicy, RandGenPolicy
from src.smb.level import lvlhcat, save_batch
def rand_gen_levels(n=100, h=50, dest_path=''):
levels = []
latvecs = []
decoder = get_decoder('models/decoder.pth', 'cuda:0')
init_arxv = np.load(getpath('smb/init_latvecs.npy'))
for _ in range(n):
z0 = init_arxv[random.randrange(0, len(init_arxv))]
z0 = torch.tensor(z0, device='cuda:0', dtype=torch.float)
z = torch.cat([z0, sample_latvec(h, 'cuda:0')], dim=0)
lvl = lvlhcat(process_onehot(decoder(z)))
levels.append(lvl)
latvecs.append(z.cpu().numpy())
if dest_path:
save_batch(levels, dest_path)
np.save(getpath(dest_path), np.stack(latvecs))
return levels, np.stack(latvecs)
def generate_levels(policy, dest_folder='', batch_name='samples.lvls', n=200, h=50, parallel=64, save_img=False):
levels = []
latvecs = []
obs_queues = [RingQueue(policy.n) for _ in range(parallel)]
init_arxv = np.load(getpath('smb/init_latvecs.npy'))
decoder = get_decoder('models/decoder.pth', 'cuda:0')
while len(levels) < n:
veclists = [[] for _ in range(parallel)]
for queue, veclist in zip(obs_queues, veclists):
queue.clear()
init_latvec = init_arxv[random.randrange(0, len(init_arxv))]
queue.push(init_latvec)
veclist.append(init_latvec)
for _ in range(h):
obs = np.stack([np.concatenate(queue.to_list()) for queue in obs_queues])
actions = policy.step(obs)
for queue, veclist, action in zip(obs_queues, veclists, actions):
queue.push(action)
veclist.append(action)
for veclist in veclists:
latvecs.append(np.stack(veclist))
z = torch.tensor(latvecs[-1], device='cuda:0').view(-1, nz, 1, 1)
lvl = lvlhcat(process_onehot(decoder(z)))
levels.append(lvl)
# print(f'{len(levels)}/{n} generated')
if dest_folder:
os.makedirs(getpath(dest_folder), exist_ok=True)
save_batch(levels[:n], getpath(dest_folder, batch_name))
if save_img:
for i, lvl in enumerate(levels[:n]):
lvl.to_img(f'{dest_folder}/lvl-{i}.png')
return levels[:n]
def make_samples(path, n=12, h=20, space=12):
plc = RLGenPolicy.from_path(path)
levels = generate_levels(plc, n=n, h=h)
imgs = [lvl.to_img() for lvl in levels]
make_img_sheet(imgs, ncols=1, y_margin=space, save_path=f'{path}/samples.png')
pass
if __name__ == '__main__':
pass