import random import time import gym import numpy as np from typing import Tuple, List, Dict, Callable, Optional from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.vec_env import SubprocVecEnv from stable_baselines3.common.vec_env.base_vec_env import VecEnvStepReturn, VecEnvObs from stable_baselines3.common.vec_env.subproc_vec_env import _flatten_obs from src.env.logger import InfoCollector from src.env.rfunc import RewardFunc from src.smb.asyncsimlt import AsycSimltPool from src.smb.level import lvlhcat from src.gan.gans import SAGenerator from src.gan.gankits import * from src.smb.proxy import MarioProxy, MarioJavaAgents from src.utils.datastruct import RingQueue def get_padded_obs(vecs, histlen, add_batch_dim=False): if len(vecs) < histlen: lack = histlen - len(vecs) pad = [np.zeros([nz], np.float32) for _ in range(lack)] obs = np.concatenate(pad + vecs) else: obs = np.concatenate(vecs) if add_batch_dim: obs = np.reshape(obs, [1, -1]) return obs class SingleProcessOLGenEnv(gym.Env): def __init__(self, rfunc, decoder: SAGenerator, eplen: int=50, device='cuda:0'): self.rfunc = rfunc self.initvec_set = np.load(getpath('smb/init_latvecs.npy')) self.decoder = decoder self.decoder.to(device) self.hist_len = self.rfunc.get_n() self.eplen = eplen self.device = device self.action_space = gym.spaces.Box(-1, 1, (nz,)) self.observation_space = gym.spaces.Box(-1, 1, (self.hist_len * nz,)) self.lat_vecs = [] self.simulator = MarioProxy() pass def step(self, action): self.lat_vecs.append(action) done = len(self.lat_vecs) == (self.eplen + 1) info = {} if done: rewss = self.__evalute() info['rewss'] = rewss rewsums = [sum(items) for items in zip(*rewss.values())] info['transitions'] = self.__process_traj(rewsums[-self.eplen:]) self.reset() return self.getobs(), 0, done, info def __evalute(self): z = torch.tensor(np.stack(self.lat_vecs).reshape([-1, nz, 1, 1]), device=self.device, dtype=torch.float) segs = process_onehot(self.decoder(z)) lvl = lvlhcat(segs) simlt_res = MarioProxy.get_seg_infos(self.simulator.simulate_complete(lvl)) rewardss = self.rfunc.get_rewards(segs=segs, simlt_res=simlt_res) return rewardss def __process_traj(self, rewards): obs = [] for i in range(1, len(self.lat_vecs) + 1): ob = get_padded_obs(self.lat_vecs[max(0, i - self.hist_len): i], self.hist_len) obs.append(ob) traj = [(obs[i], self.lat_vecs[i+1], rewards[i], obs[i+1]) for i in range(len(self.lat_vecs) - 1)] return traj def reset(self): self.lat_vecs.clear() z0 = self.initvec_set[random.randrange(0, len(self.initvec_set))] self.lat_vecs.append(z0) return self.getobs() def getobs(self): s = max(0, len(self.lat_vecs) - self.hist_len) return get_padded_obs(self.lat_vecs[s:], self.hist_len, True) def render(self, mode="human"): pass def generate_levels(self, agent, n=1, max_parallel=None): if max_parallel is None: max_parallel = min(n, 512) levels = [] latvecs = [] obs_queues = [RingQueue(self.hist_len) for _ in range(max_parallel)] while len(levels) < n: veclists = [[] for _ in range(min(max_parallel, n - len(levels)))] for queue, veclist in zip(obs_queues, veclists): queue.clear() init_latvec = self.initvec_set[random.randrange(0, len(self.initvec_set))] queue.push(init_latvec) veclist.append(init_latvec) for _ in range(self.eplen): obs = np.stack([get_padded_obs(queue.to_list(), self.hist_len) for queue in obs_queues]) actions = agent.make_decision(obs) for queue, veclist, action in zip(obs_queues, veclists, actions): queue.push(action) veclist.append(action) for veclist in veclists: latvecs.append(np.stack(veclist)) z = torch.tensor(latvecs[-1], device=self.device).view(-1, nz, 1, 1) lvl = lvlhcat(process_onehot(self.decoder(z))) levels.append(lvl) return levels, latvecs class AsyncOlGenEnv: def __init__(self, histlen, decoder: SAGenerator, eval_pool: AsycSimltPool, eplen: int=50, device='cuda:0'): self.initvec_set = np.load(getpath('smb/init_latvecs.npy')) self.decoder = decoder self.decoder.to(device) self.device = device # mario simulator 在eval_pool里面 self.eval_pool = eval_pool self.eplen = eplen self.tid = 0 self.histlen = histlen self.cur_vectraj = [] self.buffer = {} def reset(self): if len(self.cur_vectraj) > 0: self.buffer[self.tid] = self.cur_vectraj self.cur_vectraj = [] self.tid += 1 z0 = self.initvec_set[random.randrange(0, len(self.initvec_set))] self.cur_vectraj.append(z0) return self.getobs() def step(self, action): self.cur_vectraj.append(action) done = len(self.cur_vectraj) == (self.eplen + 1) if done: self.__submit_eval_task() self.reset() return self.getobs(), done def getobs(self): s = max(0, len(self.cur_vectraj) - self.histlen) return get_padded_obs(self.cur_vectraj[s:], self.histlen, True) def __submit_eval_task(self): z = torch.tensor(np.stack(self.cur_vectraj).reshape([-1, nz, 1, 1]), device=self.device) segs = process_onehot(self.decoder(torch.clamp(z, -1, 1))) lvl = lvlhcat(segs) args = (self.tid, str(lvl)) self.eval_pool.put('evaluate', args) def refresh(self): if self.eval_pool is not None: self.eval_pool.refresh() def rollout(self, close=False, wait=False) -> Tuple[List[Tuple], List[Dict[str, List]]]: transitions, rewss = [], [] if close: eval_res = self.eval_pool.close() else: eval_res = self.eval_pool.get(wait) for tid, rewards in eval_res: rewss.append(rewards) rewsums = [sum(items) for items in zip(*rewards.values())] vectraj = self.buffer.pop(tid) transitions += self.__process_traj(vectraj, rewsums[-self.eplen:]) return transitions, rewss def __process_traj(self, vectraj, rewards): obs = [] for i in range(1, len(vectraj) + 1): ob = get_padded_obs(vectraj[max(0, i - self.histlen): i], self.histlen) obs.append(ob) traj = [(obs[i], vectraj[i+1], rewards[i], obs[i+1]) for i in range(len(vectraj) - 1)] return traj def close(self): res = self.rollout(True) self.eval_pool = None return res def generate_levels(self, agent, n=1, max_parallel=None): if max_parallel is None: max_parallel = min(n, 512) levels = [] latvecs = [] obs_queues = [RingQueue(self.histlen) for _ in range(max_parallel)] while len(levels) < n: veclists = [[] for _ in range(min(max_parallel, n - len(levels)))] for queue, veclist in zip(obs_queues, veclists): queue.clear() init_latvec = self.initvec_set[random.randrange(0, len(self.initvec_set))] queue.push(init_latvec) veclist.append(init_latvec) for _ in range(self.eplen): obs = np.stack([get_padded_obs(queue.to_list(), self.histlen) for queue in obs_queues]) actions = agent.make_decision(obs) for queue, veclist, action in zip(obs_queues, veclists, actions): queue.push(action) veclist.append(action) for veclist in veclists: latvecs.append(np.stack(veclist)) z = torch.tensor(latvecs[-1], device=self.device).view(-1, nz, 1, 1) lvl = lvlhcat(process_onehot(self.decoder(z))) levels.append(lvl) return levels, latvecs ######### Adopt from https://github.com/SUSTechGameAI/MFEDRL ######### class SyncOLGenWorkerEnv(gym.Env): def __init__(self, rfunc=None, hist_len=5, eplen=25, return_lvl=False, init_one=False, play_style='Runner'): self.rfunc = RewardFunc() if rfunc is None else rfunc self.mario_proxy = MarioProxy() if self.rfunc.require_simlt else None self.action_space = gym.spaces.Box(-1, 1, (nz,)) self.hist_len = hist_len self.observation_space = gym.spaces.Box(-1, 1, (hist_len * nz,)) self.segs = [] self.latvec_archive = RingQueue(hist_len) self.eplen = eplen self.counter = 0 # self.repairer = DivideConquerRepairer() self.init_one = init_one self.backup_latvecs = None self.backup_strsegs = None self.return_lvl = return_lvl self.jagent = MarioJavaAgents.__getitem__(play_style) self.simlt_k = 80 if play_style == 'Runner' else 320 def receive(self, **kwargs): for key in kwargs.keys(): setattr(self, key, kwargs[key]) def step(self, data): action, strseg = data seg = MarioLevel(strseg) self.latvec_archive.push(action) self.counter += 1 self.segs.append(seg) done = self.counter >= self.eplen if done: full_level = lvlhcat(self.segs) w = MarioLevel.seg_width segs = [full_level[:, s: s + w] for s in range(0, full_level.w, w)] if self.mario_proxy: raw_simlt_res = self.mario_proxy.simulate_complete(lvlhcat(segs), self.jagent, self.simlt_k) simlt_res = MarioProxy.get_seg_infos(raw_simlt_res) else: simlt_res = None rewards = self.rfunc.get_rewards(segs=segs, simlt_res=simlt_res) info = {} total_score = 0 if self.return_lvl: info['LevelStr'] = str(full_level) for key in rewards: info[f'{key}_reward_list'] = rewards[key][-self.eplen:] info[f'{key}'] = sum(rewards[key][-self.eplen:]) total_score += info[f'{key}'] info['TotalScore'] = total_score info['EpLength'] = self.counter else: info = {} return self.__get_obs(), 0, done, info def reset(self): self.segs.clear() self.latvec_archive.clear() for latvec, strseg in zip(self.backup_latvecs, self.backup_strsegs): self.latvec_archive.push(latvec) self.segs.append(MarioLevel(strseg)) self.backup_latvecs, self.backup_strsegs = None, None self.counter = 0 return self.__get_obs() def __get_obs(self): lack = self.hist_len - len(self.latvec_archive) pad = [np.zeros([nz], np.float32) for _ in range(lack)] return np.concatenate([*pad, *self.latvec_archive.to_list()]) def render(self, mode='human'): pass class VecOLGenEnv(SubprocVecEnv): def __init__( self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[str] = None, hist_len=5, eplen=50, init_one=True, log_path=None, log_itv=-1, log_targets=None, device='cuda:0' ): super(VecOLGenEnv, self).__init__(env_fns, start_method) self.decoder = get_decoder(device=device) if log_path: self.logger = InfoCollector(log_path, log_itv, log_targets) else: self.logger = None self.hist_len = hist_len self.total_steps = 0 self.start_time = time.time() self.eplen = eplen self.device = device self.init_one = init_one self.latvec_set = np.load(getpath('smb/init_latvecs.npy')) def step_async(self, actions: np.ndarray) -> None: with torch.no_grad(): z = torch.tensor(actions.astype(np.float32), device=self.device).view(-1, nz, 1, 1) segs = process_onehot(self.decoder(z)) for remote, action, seg in zip(self.remotes, actions, segs): remote.send(("step", (action, str(seg)))) self.waiting = True def step_wait(self) -> VecEnvStepReturn: self.total_steps += self.num_envs results = [remote.recv() for remote in self.remotes] self.waiting = False obs, rews, dones, infos = zip(*results) envs_to_send = [i for i in range(self.num_envs) if dones[i]] self.send_reset_data(envs_to_send) if self.logger is not None: for i in range(self.num_envs): if infos[i]: infos[i]['TotalSteps'] = self.total_steps infos[i]['TimePassed'] = time.time() - self.start_time self.logger.on_step(dones, infos) return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos def reset(self) -> VecEnvObs: self.send_reset_data() for remote in self.remotes: remote.send(("reset", None)) obs = [remote.recv() for remote in self.remotes] self.send_reset_data() return _flatten_obs(obs, self.observation_space) def send_reset_data(self, env_ids=None): if env_ids is None: env_ids = [*range(self.num_envs)] target_remotes = self._get_target_remotes(env_ids) n_inits = 1 if self.init_one else self.hist_len latvecs = [self.latvec_set[random.sample(range(len(self.latvec_set)), n_inits)] for _ in range(len(env_ids))] with torch.no_grad(): segss = [[] for _ in range(len(env_ids))] for i in range(len(env_ids)): z = torch.tensor(latvecs[i]).view(-1, nz, 1, 1).to(self.device) segss[i] = [process_onehot(self.decoder(z))] if self.init_one else process_onehot(self.decoder(z)) for remote, latvec, segs in zip(target_remotes, latvecs, segss): kwargs = {'backup_latvecs': latvec, 'backup_strsegs': [str(seg) for seg in segs]} remote.send(("env_method", ('receive', [], kwargs))) for remote in target_remotes: remote.recv() def close(self) -> None: super().close() if self.logger is not None: self.logger.close() def make_vec_offrew_env( num_envs, rfunc=None, log_path=None, eplen=25, log_itv=-1, hist_len=5, init_one=True, play_style='Runner', device='cuda:0', log_targets=None, return_lvl=False ): return make_vec_env( SyncOLGenWorkerEnv, n_envs=num_envs, vec_env_cls=VecOLGenEnv, vec_env_kwargs={ 'log_path': log_path, 'log_itv': log_itv, 'log_targets': log_targets, 'device': device, 'eplen': eplen, 'hist_len': hist_len, 'init_one': init_one }, env_kwargs={ 'rfunc': rfunc, 'eplen': eplen, 'return_lvl': return_lvl, 'play_style': play_style, 'hist_len': hist_len, 'init_one': init_one } )