fffiloni commited on
Commit
f4a11e0
1 Parent(s): 7343da4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -6
app.py CHANGED
@@ -1,17 +1,31 @@
1
  import gradio as gr
2
-
3
  import torch
4
 
5
- from spectro import wav_bytes_from_spectrogram_image
 
 
 
 
 
6
  from diffusers import StableDiffusionPipeline
 
7
 
8
  from share_btn import community_icon_html, loading_icon_html, share_js
9
 
10
- model_id = "riffusion/riffusion-model-v1"
11
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
12
  pipe = pipe.to("cuda")
 
 
 
 
 
 
 
 
 
13
 
14
- def predict(prompt, duration):
15
  if duration == 5:
16
  width_duration=512
17
  else :
@@ -23,6 +37,42 @@ def predict(prompt, duration):
23
  f.write(wav[0].getbuffer())
24
  return spec, 'output.wav', gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  title = """
27
  <div style="text-align: center; max-width: 500px; margin: 0 auto;">
28
  <div
@@ -139,6 +189,7 @@ with gr.Blocks(css=css) as demo:
139
  gr.HTML(title)
140
 
141
  prompt_input = gr.Textbox(placeholder="a cat diva singing in a New York jazz club", label="Musical prompt", elem_id="prompt-in")
 
142
  duration_input = gr.Slider(label="Duration in seconds", minimum=5, maximum=10, step=1, value=8, elem_id="duration-slider")
143
  send_btn = gr.Button(value="Get a new spectrogram ! ", elem_id="submit-btn")
144
 
@@ -154,7 +205,7 @@ with gr.Blocks(css=css) as demo:
154
 
155
  gr.HTML(article)
156
 
157
- send_btn.click(predict, inputs=[prompt_input, duration_input], outputs=[spectrogram_output, sound_output, share_button, community_icon, loading_icon])
158
  share_button.click(None, [], [], _js=share_js)
159
 
160
  demo.queue(max_size=250).launch(debug=True)
 
1
  import gradio as gr
 
2
  import torch
3
 
4
+ from scipy.io import wavfile
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ from spectro import wav_bytes_from_spectrogram_image, spectrogram_from_waveform, image_from_spectrogram
9
+
10
  from diffusers import StableDiffusionPipeline
11
+ from diffusers import StableDiffusionImg2ImgPipeline
12
 
13
  from share_btn import community_icon_html, loading_icon_html, share_js
14
 
15
+ MODEL_ID = "riffusion/riffusion-model-v1"
16
+ pipe = StableDiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.float16)
17
  pipe = pipe.to("cuda")
18
+ pipe2 = StableDiffusionImg2ImgPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.float16)
19
+ pipe2 = pipe2.to("cuda")
20
+
21
+ def predict(prompt, audio_input, duration):
22
+ if audio_input == None:
23
+ return classic(prompt, duration)
24
+ else:
25
+ return audio_transfer(prompt, audio_input)
26
+
27
 
28
+ def classic(prompt, duration):
29
  if duration == 5:
30
  width_duration=512
31
  else :
 
37
  f.write(wav[0].getbuffer())
38
  return spec, 'output.wav', gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
39
 
40
+ def audio_transfer(prompt, audio):
41
+ # read uploaded file to wav
42
+ rate, data = wavfile.read(audio)
43
+
44
+ # convert to mono
45
+ data = np.mean(data, axis=1)
46
+
47
+ # convert to float32
48
+ data = data.astype(np.float32)
49
+
50
+ # take a random 7 second slice of the audio
51
+ data = data[rate*7:rate*14]
52
+
53
+ spectrogram = spectrogram_from_waveform(
54
+ waveform=data,
55
+ sample_rate=rate,
56
+ # width=768,
57
+ n_fft=8192,
58
+ hop_length=512,
59
+ win_length=8192,
60
+ )
61
+
62
+ spec = image_from_spectrogram(spectrogram)
63
+
64
+ images = pipe2(
65
+ prompt=prompt,
66
+ image=spec,
67
+ strength=0.5,
68
+ guidance_scale=7
69
+ ).images
70
+
71
+ wav = wav_bytes_from_spectrogram_image(images[0])
72
+ with open("output.wav", "wb") as f:
73
+ f.write(wav[0].getbuffer())
74
+ return images[0], 'output.wav', gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
75
+
76
  title = """
77
  <div style="text-align: center; max-width: 500px; margin: 0 auto;">
78
  <div
 
189
  gr.HTML(title)
190
 
191
  prompt_input = gr.Textbox(placeholder="a cat diva singing in a New York jazz club", label="Musical prompt", elem_id="prompt-in")
192
+ audio_input = gr.Audio(label="audio input", type="filepath", source="upload")
193
  duration_input = gr.Slider(label="Duration in seconds", minimum=5, maximum=10, step=1, value=8, elem_id="duration-slider")
194
  send_btn = gr.Button(value="Get a new spectrogram ! ", elem_id="submit-btn")
195
 
 
205
 
206
  gr.HTML(article)
207
 
208
+ send_btn.click(predict, inputs=[prompt_input, audio_input, duration_input], outputs=[spectrogram_output, sound_output, share_button, community_icon, loading_icon])
209
  share_button.click(None, [], [], _js=share_js)
210
 
211
  demo.queue(max_size=250).launch(debug=True)