SEG-SDXL / app.py
nyanko7's picture
Update app.py
93f33bc verified
raw
history blame contribute delete
No virus
3.74 kB
import spaces
import gradio as gr
from diffusers import StableDiffusionXLPipeline
import numpy as np
import math
import torch
import random
from gradio_imageslider import ImageSlider
theme = gr.themes.Base(
font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
)
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
custom_pipeline="nyanko7/sdxl_smoothed_energy_guidance",
torch_dtype=torch.float16
)
device="cuda"
pipe = pipe.to(device)
@spaces.GPU
def run(prompt, negative_prompt=None, guidance_scale=7.0, seg_scale=3.0, seg_layers=["mid"], randomize_seed=True, seed=42, progress=gr.Progress(track_tqdm=True)):
prompt = prompt.strip()
negative_prompt = negative_prompt.strip() if negative_prompt and negative_prompt.strip() else None
print(f"Initial seed for prompt `{prompt}`", seed)
if(randomize_seed):
seed = random.randint(0, 9007199254740991)
if not prompt and not negative_prompt:
guidance_scale = 0.0
print(f"Seed before sending to generator for prompt: `{prompt}`", seed)
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, seg_scale=seg_scale, seg_applied_layers=seg_layers, generator=generator, num_inference_steps=25).images[0]
generator = torch.Generator(device="cuda").manual_seed(seed)
image_normal = pipe(prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, seg_scale=0.0, generator=generator, num_inference_steps=25).images[0]
print(f"Seed at the end of generation for prompt: `{prompt}`", seed)
return (image, image_normal), seed
css = '''
.gradio-container{
max-width: 768px !important;
margin: 0 auto;
}
'''
with gr.Blocks(css=css, theme=theme) as demo:
gr.Markdown('''# Smoothed Energy Guidance SDXL
SDXL [diffusers implementation](https://huggingface.co./nyanko7/sdxl_smoothed_energy_guidance) of [Smoothed Energy Guidance](https://arxiv.org/abs/2408.00760)
''')
with gr.Group():
with gr.Row():
prompt = gr.Textbox(show_label=False, scale=4, placeholder="Your prompt", info="Leave blank to test unconditional generation")
button = gr.Button("Generate", min_width=120)
output = ImageSlider(label="Left: SEG, Right: No SEG", interactive=False)
with gr.Accordion("Advanced Settings", open=False):
guidance_scale = gr.Number(label="CFG Guidance Scale", info="The guidance scale for CFG, ignored if no prompt is entered (unconditional generation)", value=7.0)
negative_prompt = gr.Textbox(label="Negative prompt", info="Is only applied for the CFG part, leave blank for unconditional generation")
seg_scale = gr.Number(label="Seg Scale", value=3.0)
seg_layers = gr.Dropdown(label="Model layers to apply Seg to", info="mid is the one used on the paper, up and down blocks seem unstable", choices=["up", "mid", "down"], multiselect=True, value="mid")
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
seed = gr.Slider(minimum=1, maximum=9007199254740991, step=1, randomize=True)
gr.Examples(fn=run, examples=[" ", "an insect robot preparing a delicious meal, anime style", "a photo of a group of friends at an amusement park"], inputs=prompt, outputs=[output, seed], cache_examples="lazy")
gr.on(
triggers=[
button.click,
prompt.submit
],
fn=run,
inputs=[prompt, negative_prompt, guidance_scale, seg_scale, seg_layers, randomize_seed, seed],
outputs=[output, seed],
)
if __name__ == "__main__":
demo.launch(share=True)