miittnnss's picture
Update app.py
adebd8e
raw
history blame
No virus
3.86 kB
import gradio as gr
import requests
import json
import io
import random
import os
import torch
from transformers import pipeline
from PIL import Image
API_BASE_URL = "https://api-inference.huggingface.co/models/"
MODEL_LIST = [
"openskyml/dalle-3-xl",
"Linaqruf/animagine-xl-2.0",
"Lykon/dreamshaper-7",
"Linaqruf/animagine-xl",
"runwayml/stable-diffusion-v1-5",
"stabilityai/stable-diffusion-xl-base-1.0",
"prompthero/openjourney-v4",
"nerijs/pixel-art-xl",
"Linaqruf/anything-v3.0",
"playgroundai/playground-v2-1024px-aesthetic",
"ilovecutiee/fantastical-art-lora",
"segmind/SSD-1B",
"segmind/Segmind-Vega",
"stablediffusionapi/anything-v5",
"stablediffusionapi/realistic-vision-v51",
"hakurei/waifu-diffusion"
]
API_TOKEN = os.getenv("HF_READ_TOKEN") # Make sure to set your Hugging Face token
HEADERS = {"Authorization": f"Bearer {API_TOKEN}"}
pipe = pipeline("text-generation", model="isek-ai/SDPrompt-RetNet-300M", trust_remote_code=True)
def select_model(model_name):
if model_name in MODEL_LIST:
return f"{API_BASE_URL}{model_name}"
def extend_prompt(input_text):
if not input_text.strip():
gr.Warning("Input text is empty!")
return None
seed = random.randint(1, 1000000)
torch.manual_seed(seed)
output = pipe(input_text, max_length=(len(input_text) + random.randint(60, 90)), num_return_sequences=4)
return output
def generate_image(prompt, selected_model, is_negative=False, steps=1, cfg_scale=6, seed=None):
if not prompt.strip():
raise gr.Error("Cannot generate image: Input text is empty!")
model_url = select_model(selected_model)
API_URL = f"{model_url}"
payload = {
"inputs": prompt,
"is_negative": is_negative,
"steps": steps,
"cfg_scale": cfg_scale,
"seed": seed if seed is not None else random.randint(-1, 2147483647)
}
try:
response = requests.post(API_URL, headers=HEADERS, json=payload)
response.raise_for_status()
image_bytes = io.BytesIO(response.content)
image = Image.open(image_bytes)
return image
except requests.exceptions.RequestException as e:
raise gr.Error(e)
with gr.Blocks(theme="soft") as playground:
gr.HTML(
"""
<div style="text-align: center; margin: 0 auto;">
<div style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;">
<h1 style="font-weight: 900; margin-bottom: 7px; margin-top: 5px;">🎨🤖 Play with SD Models</h1>
</div>
<p style="margin-bottom: 10px; font-size: 94%; line-height: 23px;">
Explore and create your AI art with Stable Diffusion models!
</p>
</div>
"""
)
with gr.Row():
image_output = gr.Image(type="pil", label="Output Image", elem_id="gallery")
with gr.Column(elem_id="prompt-container"):
text_prompt = gr.Textbox(label="Prompt", placeholder="a cute cat", lines=1, elem_id="prompt-text-input")
model_dropdown = gr.Dropdown(label="Model", choices=MODEL_LIST, elem_id="model-dropdown", value="runwayml/stable-diffusion-v1-5")
gen_button = gr.Button("Generate", variant='primary', elem_id="gen-button")
extend_button = gr.Button("Extend Prompt", variant='primary', elem_id="extend-button")
with gr.Accordion("Advanced settings", open=False):
negative_prompt = gr.Textbox(label="Negative Prompt", value="text, blurry, fuzziness", lines=1, elem_id="negative-prompt-text-input")
gen_button.click(generate_image, inputs=[text_prompt, model_dropdown, negative_prompt], outputs=image_output)
extend_button.click(extend_prompt, inputs=text_prompt, outputs=text_prompt)
playground.launch(show_api=False)