File size: 11,715 Bytes
c4e6a63
 
710e5f8
 
 
c4e6a63
 
 
 
710e5f8
c4e6a63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
710e5f8
c4e6a63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
from typing import Optional, List

import numpy as np
import torch
from cv2 import dilate
from diffusers import DDIMScheduler, StableDiffusionPipeline
from tqdm import tqdm

from src.attention_based_segmentation import Segmentor
from src.attention_utils import show_cross_attention
from src.prompt_to_prompt_controllers import DummyController, AttentionStore


def get_stable_diffusion_model(args):
    device = torch.device(f'cuda:{args.gpu_id}') if torch.cuda.is_available() else torch.device('cpu')
    if args.real_image_path != "":
        scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
        ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=args.auth_token, scheduler=scheduler).to(device)
    else:
        ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=args.auth_token).to(device)

    return ldm_stable

def get_stable_diffusion_config(args):
    return {
        "low_resource": args.low_resource,
        "num_diffusion_steps": args.num_diffusion_steps,
        "guidance_scale": args.guidance_scale,
        "max_num_words": args.max_num_words
    }


def generate_original_image(args, ldm_stable, ldm_stable_config, prompts, latent, uncond_embeddings):
    g_cpu = torch.Generator(device=ldm_stable.device).manual_seed(args.seed)
    controller = AttentionStore(ldm_stable_config["low_resource"])
    diffusion_model_wrapper = DiffusionModelWrapper(args, ldm_stable, ldm_stable_config, controller, generator=g_cpu)
    image, x_t, orig_all_latents, _ = diffusion_model_wrapper.forward(prompts,
                                                                      latent=latent,
                                                                      uncond_embeddings=uncond_embeddings)
    orig_mask = Segmentor(controller, prompts, args.num_segments, args.background_segment_threshold, background_nouns=args.background_nouns)\
        .get_background_mask(args.prompt.split(' ').index("{word}") + 1)
    average_attention = controller.get_average_attention()
    return image, x_t, orig_all_latents, orig_mask, average_attention


class DiffusionModelWrapper:
    def __init__(self, args, model, model_config, controller=None, prompt_mixing=None, generator=None):
        self.args = args
        self.model = model
        self.model_config = model_config
        self.controller = controller
        if self.controller is None:
            self.controller = DummyController()
        self.prompt_mixing = prompt_mixing
        self.device = model.device
        self.generator = generator

        self.height = 512
        self.width = 512

        self.diff_step = 0
        self.register_attention_control()


    def diffusion_step(self, latents, context, t, other_context=None):
        if self.model_config["low_resource"]:
            self.uncond_pred = True
            noise_pred_uncond = self.model.unet(latents, t, encoder_hidden_states=(context[0], None))["sample"]
            self.uncond_pred = False
            noise_prediction_text = self.model.unet(latents, t, encoder_hidden_states=(context[1], other_context))["sample"]
        else:
            latents_input = torch.cat([latents] * 2)
            noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=(context, other_context))["sample"]
            noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + self.model_config["guidance_scale"] * (noise_prediction_text - noise_pred_uncond)
        latents = self.model.scheduler.step(noise_pred, t, latents)["prev_sample"]
        latents = self.controller.step_callback(latents)
        return latents


    def latent2image(self, latents):
        latents = 1 / 0.18215 * latents
        image = self.model.vae.decode(latents)['sample']
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).numpy()
        image = (image * 255).astype(np.uint8)
        return image


    def init_latent(self, latent, batch_size):
        if latent is None:
            latent = torch.randn(
                (1, self.model.unet.in_channels, self.height // 8, self.width // 8),
                generator=self.generator, device=self.model.device
            )
        latents = latent.expand(batch_size,  self.model.unet.in_channels, self.height // 8, self.width // 8).to(self.device)
        return latent, latents


    def register_attention_control(self):
        def ca_forward(model_self, place_in_unet):
            to_out = model_self.to_out
            if type(to_out) is torch.nn.modules.container.ModuleList:
                to_out = model_self.to_out[0]
            else:
                to_out = model_self.to_out

            def forward(x, context=None, mask=None):
                batch_size, sequence_length, dim = x.shape
                h = model_self.heads
                q = model_self.to_q(x)
                is_cross = context is not None
                context = context if is_cross else (x, None)

                k = model_self.to_k(context[0])
                if is_cross and self.prompt_mixing is not None:
                    v_context = self.prompt_mixing.get_context_for_v(self.diff_step, context[0], context[1])
                    v = model_self.to_v(v_context)
                else:
                    v = model_self.to_v(context[0])

                q = model_self.reshape_heads_to_batch_dim(q)
                k = model_self.reshape_heads_to_batch_dim(k)
                v = model_self.reshape_heads_to_batch_dim(v)

                sim = torch.einsum("b i d, b j d -> b i j", q, k) * model_self.scale

                if mask is not None:
                    mask = mask.reshape(batch_size, -1)
                    max_neg_value = -torch.finfo(sim.dtype).max
                    mask = mask[:, None, :].repeat(h, 1, 1)
                    sim.masked_fill_(~mask, max_neg_value)

                # attention, what we cannot get enough of
                attn = sim.softmax(dim=-1)
                if self.enbale_attn_controller_changes:
                    attn = self.controller(attn, is_cross, place_in_unet)
                
                if is_cross and self.prompt_mixing is not None and context[1] is not None:
                    attn = self.prompt_mixing.get_cross_attn(self, self.diff_step, attn, place_in_unet, batch_size)

                if not is_cross and (not self.model_config["low_resource"] or not self.uncond_pred) and self.prompt_mixing is not None:
                    attn = self.prompt_mixing.get_self_attn(self, self.diff_step, attn, place_in_unet, batch_size)

                out = torch.einsum("b i j, b j d -> b i d", attn, v)
                out = model_self.reshape_batch_dim_to_heads(out)
                return to_out(out)

            return forward

        def register_recr(net_, count, place_in_unet):
            if net_.__class__.__name__ == 'CrossAttention':
                net_.forward = ca_forward(net_, place_in_unet)
                return count + 1
            elif hasattr(net_, 'children'):
                for net__ in net_.children():
                    count = register_recr(net__, count, place_in_unet)
            return count

        cross_att_count = 0
        sub_nets = self.model.unet.named_children()
        for net in sub_nets:
            if "down" in net[0]:
                cross_att_count += register_recr(net[1], 0, "down")
            elif "up" in net[0]:
                cross_att_count += register_recr(net[1], 0, "up")
            elif "mid" in net[0]:
                cross_att_count += register_recr(net[1], 0, "mid")
        self.controller.num_att_layers = cross_att_count


    def get_text_embedding(self, prompt: List[str], max_length=None, truncation=True):
        text_input = self.model.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.model.tokenizer.model_max_length if max_length is None else max_length,
            truncation=truncation,
            return_tensors="pt",
        )
        text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.device))[0]
        max_length = text_input.input_ids.shape[-1]
        return text_embeddings, max_length


    @torch.no_grad()
    def forward(self, prompt: List[str], latent: Optional[torch.FloatTensor] = None,
                other_prompt: List[str] = None, post_background = False, orig_all_latents = None, orig_mask = None,
                uncond_embeddings=None, start_time=51, return_type='image'):
        self.enbale_attn_controller_changes = True
        batch_size = len(prompt)

        text_embeddings, max_length = self.get_text_embedding(prompt)
        if uncond_embeddings is None:
            uncond_embeddings_, _ = self.get_text_embedding([""] * batch_size, max_length=max_length, truncation=False)
        else:
            uncond_embeddings_ = None

        other_context = None
        if other_prompt is not None:
            other_text_embeddings, _ = self.get_text_embedding(other_prompt)
            other_context = other_text_embeddings

        latent, latents = self.init_latent(latent, batch_size)
        
        # set timesteps
        self.model.scheduler.set_timesteps(self.model_config["num_diffusion_steps"])
        all_latents = []

        object_mask = None
        self.diff_step = 0
        for i, t in enumerate(tqdm(self.model.scheduler.timesteps[-start_time:])):
            if uncond_embeddings_ is None:
                context = [uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings]
            else:
                context = [uncond_embeddings_, text_embeddings]
            if not self.model_config["low_resource"]:
                context = torch.cat(context)

            self.down_cross_index = 0
            self.mid_cross_index = 0
            self.up_cross_index = 0
            latents = self.diffusion_step(latents, context, t, other_context)

            if post_background and self.diff_step == self.args.background_blend_timestep:
                object_mask = Segmentor(self.controller,
                                        prompt,
                                        self.args.num_segments,
                                        self.args.background_segment_threshold,
                                        background_nouns=self.args.background_nouns)\
                    .get_background_mask(self.args.prompt.split(' ').index("{word}") + 1)
                self.enbale_attn_controller_changes = False
                mask = object_mask.astype(np.bool8) + orig_mask.astype(np.bool8)
                mask = torch.from_numpy(mask).float().cuda()
                shape = (1, 1, mask.shape[0], mask.shape[1])
                mask = torch.nn.Upsample(size=(64, 64), mode='nearest')(mask.view(shape))
                mask_eroded = dilate(mask.cpu().numpy()[0, 0], np.ones((3, 3), np.uint8), iterations=1)
                mask = torch.from_numpy(mask_eroded).float().cuda().view(1, 1, 64, 64)
                latents = mask * latents + (1 - mask) * orig_all_latents[self.diff_step]

            all_latents.append(latents)
            self.diff_step += 1

        if return_type == 'image':
            image = self.latent2image(latents)
        else:
            image = latents
        
        return image, latent, all_latents, object_mask
    
    
    def show_last_cross_attention(self, res: int, from_where: List[str], prompts, select: int = 0):
        show_cross_attention(self.controller, res, from_where, prompts, tokenizer=self.model.tokenizer, select=select)