Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
from typing import Tuple, List | |
from cv2 import putText, getTextSize, FONT_HERSHEY_SIMPLEX | |
# import matplotlib.pyplot as plt | |
from PIL import Image | |
from src.prompt_to_prompt_controllers import AttentionStore | |
def aggregate_attention(attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int, prompts): | |
out = [] | |
attention_maps = attention_store.get_average_attention() | |
num_pixels = res ** 2 | |
for location in from_where: | |
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: | |
if item.shape[1] == num_pixels: | |
cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select] | |
out.append(cross_maps) | |
out = torch.cat(out, dim=0) | |
out = out.sum(0) / out.shape[0] | |
return out.cpu() | |
def show_cross_attention(attention_store: AttentionStore, res: int, from_where: List[str], prompts, tokenizer, select: int = 0): | |
tokens = tokenizer.encode(prompts[select]) | |
decoder = tokenizer.decode | |
attention_maps = aggregate_attention(attention_store, res, from_where, True, select, prompts) | |
images = [] | |
for i in range(len(tokens)): | |
image = attention_maps[:, :, i] | |
image = 255 * image / image.max() | |
image = image.unsqueeze(-1).expand(*image.shape, 3) | |
image = image.numpy().astype(np.uint8) | |
image = np.array(Image.fromarray(image).resize((256, 256))) | |
image = text_under_image(image, decoder(int(tokens[i]))) | |
images.append(image) | |
view_images(np.stack(images, axis=0)) | |
def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str], | |
max_com=10, select: int = 0): | |
attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape( | |
(res ** 2, res ** 2)) | |
u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True)) | |
images = [] | |
for i in range(max_com): | |
image = vh[i].reshape(res, res) | |
image = image - image.min() | |
image = 255 * image / image.max() | |
image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8) | |
image = Image.fromarray(image).resize((256, 256)) | |
image = np.array(image) | |
images.append(image) | |
view_images(np.concatenate(images, axis=1)) | |
def view_images(images, num_rows=1, offset_ratio=0.02): | |
if type(images) is list: | |
num_empty = len(images) % num_rows | |
elif images.ndim == 4: | |
num_empty = images.shape[0] % num_rows | |
else: | |
images = [images] | |
num_empty = 0 | |
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 | |
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty | |
num_items = len(images) | |
h, w, c = images[0].shape | |
offset = int(h * offset_ratio) | |
num_cols = num_items // num_rows | |
image_ = np.ones((h * num_rows + offset * (num_rows - 1), | |
w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 | |
for i in range(num_rows): | |
for j in range(num_cols): | |
image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ | |
i * num_cols + j] | |
pil_img = Image.fromarray(image_) | |
display(pil_img) | |
def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): | |
h, w, c = image.shape | |
offset = int(h * .2) | |
img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 | |
font = FONT_HERSHEY_SIMPLEX | |
img[:h] = image | |
textsize = getTextSize(text, font, 1, 2)[0] | |
text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 | |
putText(img, text, (text_x, text_y ), font, 1, text_color, 2) | |
return img | |
def display(image): | |
global display_index | |
plt.imshow(image) | |
plt.show() | |