stale2000 commited on
Commit
78e6f58
1 Parent(s): a982127

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -2
app.py CHANGED
@@ -8,12 +8,17 @@ import gradio
8
  import torch
9
  from diffusers import StableDiffusionPipeline
10
  from torch import autocast
 
 
 
 
11
 
12
 
13
  openai.api_key = os.getenv('openaikey')
14
 
15
  def predict(input, manual_query_repacement, history=[]):
16
 
 
17
  if manual_query_repacement != "":
18
  input = manual_query_repacement
19
 
@@ -30,15 +35,64 @@ def predict(input, manual_query_repacement, history=[]):
30
  responseText = response["choices"][0]["text"]
31
  history.append((input, responseText))
32
 
33
- return history, history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
 
36
 
37
  inputText = gradio.Textbox(value="tmp")
38
  manual_query = gradio.Textbox(placeholder="Input any query here, to replace the image generation query builder entirely.")
39
 
 
 
 
40
  gradio.Interface(fn=predict,
41
  inputs=[inputText,manual_query,'state'],
42
 
43
- outputs=["chatbot",'state']).launch()
 
 
 
 
 
 
 
 
44
 
 
8
  import torch
9
  from diffusers import StableDiffusionPipeline
10
  from torch import autocast
11
+ #from PIL import Image
12
+ #from torchvision import transforms
13
+
14
+ #from diffusers import StableDiffusionImageVariationPipeline
15
 
16
 
17
  openai.api_key = os.getenv('openaikey')
18
 
19
  def predict(input, manual_query_repacement, history=[]):
20
 
21
+ # gpt3
22
  if manual_query_repacement != "":
23
  input = manual_query_repacement
24
 
 
35
  responseText = response["choices"][0]["text"]
36
  history.append((input, responseText))
37
 
38
+
39
+ #img generation
40
+ prompt = "Yoda"
41
+ scale = 10
42
+ n_samples = 4
43
+
44
+ # Sometimes the nsfw checker is confused by the Naruto images, you can disable
45
+ # it at your own risk here
46
+ #disable_safety = False
47
+
48
+ #if disable_safety:
49
+ # def null_safety(images, **kwargs):
50
+ # return images, False
51
+ # pipe.safety_checker = null_safety
52
+
53
+ with autocast("cuda"):
54
+ images = pipe(n_samples*[prompt], guidance_scale=scale).images
55
+
56
+ for idx, im in enumerate(images):
57
+ im.save(f"{idx:06}.png")
58
+
59
+ images_list = pipe(
60
+ inp.tile(n_samples, 1, 1, 1),
61
+ guidance_scale=scale,
62
+ num_inference_steps=steps,
63
+ generator=generator,
64
+ )
65
+
66
+ images = []
67
+ for i, image in enumerate(images_list["images"]):
68
+ if(images_list["nsfw_content_detected"][i]):
69
+ safe_image = Image.open(r"unsafe.png")
70
+ images.append(safe_image)
71
+ else:
72
+ images.append(image)
73
+
74
+
75
+
76
+ return history, history, images
77
 
78
 
79
 
80
  inputText = gradio.Textbox(value="tmp")
81
  manual_query = gradio.Textbox(placeholder="Input any query here, to replace the image generation query builder entirely.")
82
 
83
+ output_img = gr.Gallery(label="Generated image")
84
+ output_img.style(grid=2)
85
+
86
  gradio.Interface(fn=predict,
87
  inputs=[inputText,manual_query,'state'],
88
 
89
+ outputs=["chatbot",'state', output_img]).launch()
90
+
91
+
92
+
93
+ device = "cuda" if torch.cuda.is_available() else "cpu"
94
+ pipe = StableDiffusionPipeline.from_pretrained("stale2000/sd-dnditem", torch_dtype=torch.float16)
95
+ pipe = pipe.to(device)
96
+
97
+
98