Spaces:
Runtime error
Runtime error
najoungkim
commited on
Commit
•
87468ed
1
Parent(s):
58cacd2
Initial commit
Browse files- LiberationMono-Bold.ttf +0 -0
- README.md +4 -4
- app.py +204 -0
- requirements.txt +13 -0
LiberationMono-Bold.ttf
ADDED
Binary file (302 kB). View file
|
|
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
title: Round Trip Dalle Mini
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.0.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
|
|
1 |
---
|
2 |
title: Round Trip Dalle Mini
|
3 |
+
emoji: 🔁
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.0.14
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
app.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
import os
|
5 |
+
# Uncomment to run on cpu
|
6 |
+
# os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
7 |
+
os.environ["WANDB_DISABLED"] = "true"
|
8 |
+
os.environ['WANDB_SILENT']="true"
|
9 |
+
|
10 |
+
import random
|
11 |
+
import re
|
12 |
+
import torch
|
13 |
+
|
14 |
+
import gradio as gr
|
15 |
+
import jax
|
16 |
+
import jax.numpy as jnp
|
17 |
+
import numpy as np
|
18 |
+
from flax.jax_utils import replicate
|
19 |
+
from flax.training.common_utils import shard, shard_prng_key
|
20 |
+
from PIL import Image, ImageDraw, ImageFont
|
21 |
+
|
22 |
+
from functools import partial
|
23 |
+
|
24 |
+
from transformers import CLIPProcessor, FlaxCLIPModel, AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
|
25 |
+
from dalle_mini import DalleBart, DalleBartProcessor
|
26 |
+
from vqgan_jax.modeling_flax_vqgan import VQModel
|
27 |
+
|
28 |
+
|
29 |
+
DALLE_REPO = "dalle-mini/dalle-mini/mini-1:v0"
|
30 |
+
DALLE_COMMIT_ID = None
|
31 |
+
|
32 |
+
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
|
33 |
+
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
|
34 |
+
|
35 |
+
model, params = DalleBart.from_pretrained(
|
36 |
+
DALLE_REPO, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False
|
37 |
+
)
|
38 |
+
vqgan, vqgan_params = VQModel.from_pretrained(
|
39 |
+
VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False
|
40 |
+
)
|
41 |
+
|
42 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
43 |
+
|
44 |
+
encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
|
45 |
+
decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
|
46 |
+
model_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
|
47 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
|
48 |
+
tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
|
49 |
+
viz_model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
|
50 |
+
|
51 |
+
|
52 |
+
def captioned_strip(images, caption=None, rows=1):
|
53 |
+
increased_h = 0 if caption is None else 24
|
54 |
+
w, h = images[0].size[0], images[0].size[1]
|
55 |
+
img = Image.new("RGB", (len(images) * w // rows, h * rows + increased_h))
|
56 |
+
for i, img_ in enumerate(images):
|
57 |
+
img.paste(img_, (i // rows * w, increased_h + (i % rows) * h))
|
58 |
+
|
59 |
+
if caption is not None:
|
60 |
+
draw = ImageDraw.Draw(img)
|
61 |
+
font = ImageFont.truetype(
|
62 |
+
"LiberationMono-Bold.ttf", 7
|
63 |
+
)
|
64 |
+
draw.text((20, 3), caption, (255, 255, 255), font=font)
|
65 |
+
return img
|
66 |
+
|
67 |
+
|
68 |
+
def get_images(indices, params):
|
69 |
+
return vqgan.decode_code(indices, params=params)
|
70 |
+
|
71 |
+
|
72 |
+
def predict_caption(image, max_length=128, num_beams=4):
|
73 |
+
image = image.convert('RGB')
|
74 |
+
image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
|
75 |
+
clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0]
|
76 |
+
caption_ids = viz_model.generate(image, max_length = max_length)[0]
|
77 |
+
caption_text = clean_text(tokenizer.decode(caption_ids))
|
78 |
+
return caption_text
|
79 |
+
|
80 |
+
|
81 |
+
# model inference
|
82 |
+
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
|
83 |
+
def p_generate(
|
84 |
+
tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
|
85 |
+
):
|
86 |
+
return model.generate(
|
87 |
+
**tokenized_prompt,
|
88 |
+
prng_key=key,
|
89 |
+
params=params,
|
90 |
+
top_k=top_k,
|
91 |
+
top_p=top_p,
|
92 |
+
temperature=temperature,
|
93 |
+
condition_scale=condition_scale,
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
# decode image
|
98 |
+
@partial(jax.pmap, axis_name="batch")
|
99 |
+
def p_decode(indices, params):
|
100 |
+
return vqgan.decode_code(indices, params=params)
|
101 |
+
|
102 |
+
p_get_images = jax.pmap(get_images, "batch")
|
103 |
+
|
104 |
+
params = replicate(params)
|
105 |
+
vqgan_params = replicate(vqgan_params)
|
106 |
+
|
107 |
+
processor = DalleBartProcessor.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
|
108 |
+
print("Initialized DalleBartProcessor")
|
109 |
+
clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
110 |
+
print("Initialized FlaxCLIPModel")
|
111 |
+
|
112 |
+
|
113 |
+
def hallucinate(prompt, num_images=8):
|
114 |
+
gen_top_k = None
|
115 |
+
gen_top_p = None
|
116 |
+
temperature = None
|
117 |
+
cond_scale = 10.0
|
118 |
+
|
119 |
+
print(f"Prompts: {prompt}")
|
120 |
+
prompt = [prompt] * jax.device_count()
|
121 |
+
inputs = processor(prompt)
|
122 |
+
inputs = replicate(inputs)
|
123 |
+
|
124 |
+
# create a random key
|
125 |
+
seed = random.randint(0, 2**32 - 1)
|
126 |
+
key = jax.random.PRNGKey(seed)
|
127 |
+
|
128 |
+
images = []
|
129 |
+
for i in range(max(num_images // jax.device_count(), 1)):
|
130 |
+
key, subkey = jax.random.split(key)
|
131 |
+
encoded_images = p_generate(
|
132 |
+
inputs,
|
133 |
+
shard_prng_key(subkey),
|
134 |
+
params,
|
135 |
+
gen_top_k,
|
136 |
+
gen_top_p,
|
137 |
+
temperature,
|
138 |
+
cond_scale,
|
139 |
+
)
|
140 |
+
print(f"Encoded image {i}")
|
141 |
+
# remove BOS
|
142 |
+
encoded_images = encoded_images.sequences[..., 1:]
|
143 |
+
# decode images
|
144 |
+
decoded_images = p_decode(encoded_images, vqgan_params)
|
145 |
+
decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
|
146 |
+
for decoded_img in decoded_images:
|
147 |
+
img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
|
148 |
+
images.append(img)
|
149 |
+
|
150 |
+
print(f"Finished decoding image {i}")
|
151 |
+
return images
|
152 |
+
|
153 |
+
|
154 |
+
def run_inference(prompt, num_roundtrips=3, num_images=1):
|
155 |
+
outputs = []
|
156 |
+
for i in range(int(num_roundtrips)):
|
157 |
+
images = hallucinate(prompt, num_images=num_images)
|
158 |
+
image = images[0]
|
159 |
+
print("Generated image")
|
160 |
+
caption = predict_caption(image)
|
161 |
+
print(f"Predicted caption: {caption}")
|
162 |
+
|
163 |
+
output_title = f"""
|
164 |
+
<font size="+3">
|
165 |
+
<b>[Roundtrip {i}]</b><br>
|
166 |
+
Prompt: {prompt}<br>
|
167 |
+
🥑 :<br></font>"""
|
168 |
+
output_caption = f"""
|
169 |
+
<font size="+3">
|
170 |
+
🤖💬 : {caption}<br>
|
171 |
+
</font>
|
172 |
+
"""
|
173 |
+
outputs.append(output_title)
|
174 |
+
outputs.append(image)
|
175 |
+
outputs.append(output_caption)
|
176 |
+
prompt = caption
|
177 |
+
|
178 |
+
return outputs
|
179 |
+
|
180 |
+
|
181 |
+
inputs = gr.inputs.Textbox(label="What prompt do you want to start with?", default="a poster of cookie monster live action")
|
182 |
+
# num_roundtrips = gr.inputs.Number(default=2, label="How many roundtrips?")
|
183 |
+
num_roundtrips = 3
|
184 |
+
outputs = []
|
185 |
+
for _ in range(int(num_roundtrips)):
|
186 |
+
outputs.append(gr.outputs.HTML(label=""))
|
187 |
+
outputs.append(gr.Image(label=""))
|
188 |
+
outputs.append(gr.outputs.HTML(label=""))
|
189 |
+
|
190 |
+
description = """
|
191 |
+
Round trip DALL·E-mini iterates between DALL·E generation and image captioning, inspired by round trip translation!
|
192 |
+
"""
|
193 |
+
article = "<p style='text-align: center'>Put together by: Najoung Kim | Dall-E Mini code from flax-community/dalle-mini | Caption code from SRDdev/Image-Caption</p>"
|
194 |
+
|
195 |
+
gr.Interface(
|
196 |
+
fn=run_inference,
|
197 |
+
inputs=[inputs],
|
198 |
+
outputs=outputs,
|
199 |
+
title="Round Trip DALL·E mini 🥑🔁🤖💬",
|
200 |
+
description=description,
|
201 |
+
article=article,
|
202 |
+
theme="default",
|
203 |
+
css = ".output-image, .input-image, .image-preview {height: 256px !important} "
|
204 |
+
).launch(share=True)
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio>=2.2.3
|
2 |
+
flax
|
3 |
+
transformers
|
4 |
+
einops
|
5 |
+
unidecode
|
6 |
+
ftfy
|
7 |
+
emoji
|
8 |
+
pillow
|
9 |
+
jax
|
10 |
+
flax
|
11 |
+
torch
|
12 |
+
git+https://github.com/patil-suraj/vqgan-jax.git
|
13 |
+
git+https://github.com/borisdayma/dalle-mini.git
|