from typing import Optional, List import numpy as np import torch from cv2 import dilate from diffusers import DDIMScheduler, StableDiffusionPipeline from tqdm import tqdm from src.attention_based_segmentation import Segmentor from src.attention_utils import show_cross_attention from src.prompt_to_prompt_controllers import DummyController, AttentionStore def get_stable_diffusion_model(args): device = torch.device(f'cuda:{args.gpu_id}') if torch.cuda.is_available() else torch.device('cpu') if args.real_image_path != "": scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=args.auth_token, scheduler=scheduler).to(device) else: ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=args.auth_token).to(device) return ldm_stable def get_stable_diffusion_config(args): return { "low_resource": args.low_resource, "num_diffusion_steps": args.num_diffusion_steps, "guidance_scale": args.guidance_scale, "max_num_words": args.max_num_words } def generate_original_image(args, ldm_stable, ldm_stable_config, prompts, latent, uncond_embeddings): g_cpu = torch.Generator(device=ldm_stable.device).manual_seed(args.seed) controller = AttentionStore(ldm_stable_config["low_resource"]) diffusion_model_wrapper = DiffusionModelWrapper(args, ldm_stable, ldm_stable_config, controller, generator=g_cpu) image, x_t, orig_all_latents, _ = diffusion_model_wrapper.forward(prompts, latent=latent, uncond_embeddings=uncond_embeddings) orig_mask = Segmentor(controller, prompts, args.num_segments, args.background_segment_threshold, background_nouns=args.background_nouns)\ .get_background_mask(args.prompt.split(' ').index("{word}") + 1) average_attention = controller.get_average_attention() return image, x_t, orig_all_latents, orig_mask, average_attention class DiffusionModelWrapper: def __init__(self, args, model, model_config, controller=None, prompt_mixing=None, generator=None): self.args = args self.model = model self.model_config = model_config self.controller = controller if self.controller is None: self.controller = DummyController() self.prompt_mixing = prompt_mixing self.device = model.device self.generator = generator self.height = 512 self.width = 512 self.diff_step = 0 self.register_attention_control() def diffusion_step(self, latents, context, t, other_context=None): if self.model_config["low_resource"]: self.uncond_pred = True noise_pred_uncond = self.model.unet(latents, t, encoder_hidden_states=(context[0], None))["sample"] self.uncond_pred = False noise_prediction_text = self.model.unet(latents, t, encoder_hidden_states=(context[1], other_context))["sample"] else: latents_input = torch.cat([latents] * 2) noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=(context, other_context))["sample"] noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.model_config["guidance_scale"] * (noise_prediction_text - noise_pred_uncond) latents = self.model.scheduler.step(noise_pred, t, latents)["prev_sample"] latents = self.controller.step_callback(latents) return latents def latent2image(self, latents): latents = 1 / 0.18215 * latents image = self.model.vae.decode(latents)['sample'] image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() image = (image * 255).astype(np.uint8) return image def init_latent(self, latent, batch_size): if latent is None: latent = torch.randn( (1, self.model.unet.in_channels, self.height // 8, self.width // 8), generator=self.generator, device=self.model.device ) latents = latent.expand(batch_size, self.model.unet.in_channels, self.height // 8, self.width // 8).to(self.device) return latent, latents def register_attention_control(self): def ca_forward(model_self, place_in_unet): to_out = model_self.to_out if type(to_out) is torch.nn.modules.container.ModuleList: to_out = model_self.to_out[0] else: to_out = model_self.to_out def forward(x, context=None, mask=None): batch_size, sequence_length, dim = x.shape h = model_self.heads q = model_self.to_q(x) is_cross = context is not None context = context if is_cross else (x, None) k = model_self.to_k(context[0]) if is_cross and self.prompt_mixing is not None: v_context = self.prompt_mixing.get_context_for_v(self.diff_step, context[0], context[1]) v = model_self.to_v(v_context) else: v = model_self.to_v(context[0]) q = model_self.reshape_heads_to_batch_dim(q) k = model_self.reshape_heads_to_batch_dim(k) v = model_self.reshape_heads_to_batch_dim(v) sim = torch.einsum("b i d, b j d -> b i j", q, k) * model_self.scale if mask is not None: mask = mask.reshape(batch_size, -1) max_neg_value = -torch.finfo(sim.dtype).max mask = mask[:, None, :].repeat(h, 1, 1) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of attn = sim.softmax(dim=-1) if self.enbale_attn_controller_changes: attn = self.controller(attn, is_cross, place_in_unet) if is_cross and self.prompt_mixing is not None and context[1] is not None: attn = self.prompt_mixing.get_cross_attn(self, self.diff_step, attn, place_in_unet, batch_size) if not is_cross and (not self.model_config["low_resource"] or not self.uncond_pred) and self.prompt_mixing is not None: attn = self.prompt_mixing.get_self_attn(self, self.diff_step, attn, place_in_unet, batch_size) out = torch.einsum("b i j, b j d -> b i d", attn, v) out = model_self.reshape_batch_dim_to_heads(out) return to_out(out) return forward def register_recr(net_, count, place_in_unet): if net_.__class__.__name__ == 'CrossAttention': net_.forward = ca_forward(net_, place_in_unet) return count + 1 elif hasattr(net_, 'children'): for net__ in net_.children(): count = register_recr(net__, count, place_in_unet) return count cross_att_count = 0 sub_nets = self.model.unet.named_children() for net in sub_nets: if "down" in net[0]: cross_att_count += register_recr(net[1], 0, "down") elif "up" in net[0]: cross_att_count += register_recr(net[1], 0, "up") elif "mid" in net[0]: cross_att_count += register_recr(net[1], 0, "mid") self.controller.num_att_layers = cross_att_count def get_text_embedding(self, prompt: List[str], max_length=None, truncation=True): text_input = self.model.tokenizer( prompt, padding="max_length", max_length=self.model.tokenizer.model_max_length if max_length is None else max_length, truncation=truncation, return_tensors="pt", ) text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.device))[0] max_length = text_input.input_ids.shape[-1] return text_embeddings, max_length @torch.no_grad() def forward(self, prompt: List[str], latent: Optional[torch.FloatTensor] = None, other_prompt: List[str] = None, post_background = False, orig_all_latents = None, orig_mask = None, uncond_embeddings=None, start_time=51, return_type='image'): self.enbale_attn_controller_changes = True batch_size = len(prompt) text_embeddings, max_length = self.get_text_embedding(prompt) if uncond_embeddings is None: uncond_embeddings_, _ = self.get_text_embedding([""] * batch_size, max_length=max_length, truncation=False) else: uncond_embeddings_ = None other_context = None if other_prompt is not None: other_text_embeddings, _ = self.get_text_embedding(other_prompt) other_context = other_text_embeddings latent, latents = self.init_latent(latent, batch_size) # set timesteps self.model.scheduler.set_timesteps(self.model_config["num_diffusion_steps"]) all_latents = [] object_mask = None self.diff_step = 0 for i, t in enumerate(tqdm(self.model.scheduler.timesteps[-start_time:])): if uncond_embeddings_ is None: context = [uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings] else: context = [uncond_embeddings_, text_embeddings] if not self.model_config["low_resource"]: context = torch.cat(context) self.down_cross_index = 0 self.mid_cross_index = 0 self.up_cross_index = 0 latents = self.diffusion_step(latents, context, t, other_context) if post_background and self.diff_step == self.args.background_blend_timestep: object_mask = Segmentor(self.controller, prompt, self.args.num_segments, self.args.background_segment_threshold, background_nouns=self.args.background_nouns)\ .get_background_mask(self.args.prompt.split(' ').index("{word}") + 1) self.enbale_attn_controller_changes = False mask = object_mask.astype(np.bool8) + orig_mask.astype(np.bool8) mask = torch.from_numpy(mask).float().cuda() shape = (1, 1, mask.shape[0], mask.shape[1]) mask = torch.nn.Upsample(size=(64, 64), mode='nearest')(mask.view(shape)) mask_eroded = dilate(mask.cpu().numpy()[0, 0], np.ones((3, 3), np.uint8), iterations=1) mask = torch.from_numpy(mask_eroded).float().cuda().view(1, 1, 64, 64) latents = mask * latents + (1 - mask) * orig_all_latents[self.diff_step] all_latents.append(latents) self.diff_step += 1 if return_type == 'image': image = self.latent2image(latents) else: image = latents return image, latent, all_latents, object_mask def show_last_cross_attention(self, res: int, from_where: List[str], prompts, select: int = 0): show_cross_attention(self.controller, res, from_where, prompts, tokenizer=self.model.tokenizer, select=select)