NCERL-Diverse-PCG / test_ddpm.py
baiyanlali-zhao's picture
添加注释
3582c8a
raw
history blame
No virus
4.51 kB
import os
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import logging
from src.ddpm.diffusion import Diffusion
from src.ddpm.modules import UNet
from src.ddpm.dataset import create_dataloader
from pathlib import Path
import argparse
import datetime
from src.gan.gankits import process_onehot, get_decoder
from src.smb.level import MarioLevel, lvlhcat, save_batch
from src.utils.filesys import getpath
from src.utils.img import make_img_sheet
sprite_counts = np.power(np.array([
74977, 15252, 572591, 5826, 1216, 7302, 237, 237, 2852, 1074, 235, 304, 48, 96, 160, 1871, 936, 186, 428, 80, 428
]), 1/4
)
min_count = np.min(sprite_counts)
def setup_logging(run_name, beta_schedule):
model_path = os.path.join("models", beta_schedule, run_name)
result_path = os.path.join("results", beta_schedule, run_name)
os.makedirs(model_path, exist_ok=True)
os.makedirs(result_path, exist_ok=True)
return model_path, result_path
# 测试DDPM的模型训练
def train(args):
path = getpath(args.res_path)
os.makedirs(path, exist_ok=True)
dataloader = create_dataloader(batch_size=args.batch_size, shuffle=True, num_workers=0)
device = 'cpu' if args.gpuid < 0 else f'cuda:{args.gpuid}'
model = UNet().to(device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr)
mse = nn.MSELoss()
diffusion = Diffusion(device=device, schedule=args.beta_schedule)
temperatures = torch.tensor(min_count / sprite_counts, dtype=torch.float32).to(device)
l = len(dataloader)
for epoch in range(args.resume_from+1, args.resume_from+args.epochs+1):
logging.info(f"Starting epoch {epoch}:")
epoch_loss = {'rec_loss': 0, 'mse': 0, 'loss': 0}
for i, images in enumerate(dataloader):
images = images.to(device)
t = diffusion.sample_timesteps(images.shape[0]).to(device) # random int from 1~1000
x_t, noise = diffusion.noise_images(images, t) # x_t: image with noise at t, noise: gaussian noise
predicted_noise = model(x_t.float(), t.float()) # returns predicted noise eps_theta
original_img = images.argmax(dim=1) # batch x 14 x 14
reconstructed_img = diffusion.sample_only_final(x_t, t, predicted_noise, temperatures)
rec_loss = -reconstructed_img.log_prob(original_img).sum(dim=(1,2)).mean() # batch
mse_loss = mse(noise.float(), predicted_noise.float())
loss = 0.001 * rec_loss + mse_loss
epoch_loss['rec_loss'] += rec_loss.item()
epoch_loss['mse'] += mse_loss.item()
epoch_loss['loss'] += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(
'\nIteration: %d' % epoch,
'rec_loss: %.5g' % (epoch_loss['rec_loss']/l),
'mse: %.5g' % (epoch_loss['mse']/l)
)
if epoch % 1000 == 0:
itpath = getpath(path, f'it{epoch}')
os.makedirs(itpath, exist_ok=True)
model.save(getpath(path, itpath, 'ddpm.pth'))
lvls = []
init_lateves = torch.tensor(np.load(getpath('analysis/initial_seg.npy')))
gan = get_decoder()
init_seg_onhots = gan(torch.tensor(init_lateves).view(*init_lateves.shape, 1, 1))
i = 0
for init_seg_onehot in init_seg_onhots:
seg_onehots = diffusion.sample(model, n=25)[-1]
a = init_seg_onehot.view(1, *init_seg_onehot.shape)
b = seg_onehots.detach().cpu()
print(a.shape, b.shape)
segs = process_onehot(torch.cat([a, b], dim=0))
level = lvlhcat(segs)
lvls.append(level)
save_batch(lvls, getpath(path, 'samples.lvls'))
model.save(getpath(path, 'ddpm.pth'))
def launch():
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=10000)
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--res_path", type=str, default='exp_data/DDPM')
parser.add_argument("--gpuid", type=int, default=0)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--beta_schedule", type=str, default="quadratic", choices=['linear', 'quadratic', 'sigmoid'])
parser.add_argument("--run_name", type=str, default=f"{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}")
parser.add_argument("--resume_from", type=int, default=0)
args = parser.parse_args()
train(args)
if __name__ == "__main__":
launch()