waveydaveygravy commited on
Commit
6bc67da
1 Parent(s): 8e0bd83

Upload 2 files

Browse files
Files changed (2) hide show
  1. apphf.py +448 -0
  2. apphfupscaletest.py +610 -0
apphf.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import shutil
4
+ import ffmpeg
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ import numpy as np
8
+ import cv2
9
+ import torch
10
+ import spaces
11
+
12
+ from diffusers import AutoencoderKL, DDIMScheduler
13
+ from einops import repeat
14
+ from omegaconf import OmegaConf
15
+ from PIL import Image
16
+ from torchvision import transforms
17
+ from transformers import CLIPVisionModelWithProjection
18
+
19
+ from src.models.pose_guider import PoseGuider
20
+ from src.models.unet_2d_condition import UNet2DConditionModel
21
+ from src.models.unet_3d import UNet3DConditionModel
22
+ from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
23
+ from src.utils.util import get_fps, read_frames, save_videos_grid, save_pil_imgs
24
+
25
+ from src.audio_models.model import Audio2MeshModel
26
+ from src.utils.audio_util import prepare_audio_feature
27
+ from src.utils.mp_utils import LMKExtractor
28
+ from src.utils.draw_util import FaceMeshVisualizer
29
+ from src.utils.pose_util import project_points, project_points_with_trans, matrix_to_euler_and_translation, euler_and_translation_to_matrix
30
+ from src.utils.crop_face_single import crop_face
31
+ from src.audio2vid import get_headpose_temp, smooth_pose_seq
32
+ from src.utils.frame_interpolation import init_frame_interpolation_model, batch_images_interpolation_tool
33
+
34
+
35
+ config = OmegaConf.load('./configs/prompts/animation_audio.yaml')
36
+ if config.weight_dtype == "fp16":
37
+ weight_dtype = torch.float16
38
+ else:
39
+ weight_dtype = torch.float32
40
+
41
+ audio_infer_config = OmegaConf.load(config.audio_inference_config)
42
+ # prepare model
43
+ a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
44
+ a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt'], map_location="cpu"), strict=False)
45
+ a2m_model.cuda().eval()
46
+
47
+ vae = AutoencoderKL.from_pretrained(
48
+ config.pretrained_vae_path,
49
+ ).to("cuda", dtype=weight_dtype)
50
+
51
+ reference_unet = UNet2DConditionModel.from_pretrained(
52
+ config.pretrained_base_model_path,
53
+ subfolder="unet",
54
+ ).to(dtype=weight_dtype, device="cuda")
55
+
56
+ inference_config_path = config.inference_config
57
+ infer_config = OmegaConf.load(inference_config_path)
58
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
59
+ config.pretrained_base_model_path,
60
+ config.motion_module_path,
61
+ subfolder="unet",
62
+ unet_additional_kwargs=infer_config.unet_additional_kwargs,
63
+ ).to(dtype=weight_dtype, device="cuda")
64
+
65
+ pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
66
+
67
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(
68
+ config.image_encoder_path
69
+ ).to(dtype=weight_dtype, device="cuda")
70
+
71
+ sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
72
+ scheduler = DDIMScheduler(**sched_kwargs)
73
+
74
+ # load pretrained weights
75
+ denoising_unet.load_state_dict(
76
+ torch.load(config.denoising_unet_path, map_location="cpu"),
77
+ strict=False,
78
+ )
79
+ reference_unet.load_state_dict(
80
+ torch.load(config.reference_unet_path, map_location="cpu"),
81
+ )
82
+ pose_guider.load_state_dict(
83
+ torch.load(config.pose_guider_path, map_location="cpu"),
84
+ )
85
+
86
+ pipe = Pose2VideoPipeline(
87
+ vae=vae,
88
+ image_encoder=image_enc,
89
+ reference_unet=reference_unet,
90
+ denoising_unet=denoising_unet,
91
+ pose_guider=pose_guider,
92
+ scheduler=scheduler,
93
+ )
94
+ pipe = pipe.to("cuda", dtype=weight_dtype)
95
+
96
+ # lmk_extractor = LMKExtractor()
97
+ # vis = FaceMeshVisualizer()
98
+
99
+ frame_inter_model = init_frame_interpolation_model()
100
+
101
+ @spaces.GPU
102
+ def audio2video(input_audio, ref_img, headpose_video=None, size=512, steps=25, length=60, seed=42):
103
+ fps = 30
104
+ cfg = 3.5
105
+ fi_step = 3
106
+
107
+ generator = torch.manual_seed(seed)
108
+
109
+ lmk_extractor = LMKExtractor()
110
+ vis = FaceMeshVisualizer()
111
+
112
+ width, height = size, size
113
+
114
+ date_str = datetime.now().strftime("%Y%m%d")
115
+ time_str = datetime.now().strftime("%H%M")
116
+ save_dir_name = f"{time_str}--seed_{seed}-{size}x{size}"
117
+
118
+ save_dir = Path(f"a2v_output/{date_str}/{save_dir_name}")
119
+ while os.path.exists(save_dir):
120
+ save_dir = Path(f"a2v_output/{date_str}/{save_dir_name}_{np.random.randint(10000):04d}")
121
+ save_dir.mkdir(exist_ok=True, parents=True)
122
+
123
+ ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
124
+ ref_image_np = crop_face(ref_image_np, lmk_extractor)
125
+ if ref_image_np is None:
126
+ return None, Image.fromarray(ref_img)
127
+
128
+ ref_image_np = cv2.resize(ref_image_np, (size, size))
129
+ ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
130
+
131
+ face_result = lmk_extractor(ref_image_np)
132
+ if face_result is None:
133
+ return None, ref_image_pil
134
+
135
+ lmks = face_result['lmks'].astype(np.float32)
136
+ ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
137
+
138
+ sample = prepare_audio_feature(input_audio, wav2vec_model_path=audio_infer_config['a2m_model']['model_path'])
139
+ sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda()
140
+ sample['audio_feature'] = sample['audio_feature'].unsqueeze(0)
141
+
142
+ # inference
143
+ pred = a2m_model.infer(sample['audio_feature'], sample['seq_len'])
144
+ pred = pred.squeeze().detach().cpu().numpy()
145
+ pred = pred.reshape(pred.shape[0], -1, 3)
146
+ pred = pred + face_result['lmks3d']
147
+
148
+ if headpose_video is not None:
149
+ pose_seq = get_headpose_temp(headpose_video)
150
+ else:
151
+ pose_seq = np.load(config['pose_temp'])
152
+ mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0)
153
+ cycled_pose_seq = np.tile(mirrored_pose_seq, (sample['seq_len'] // len(mirrored_pose_seq) + 1, 1))[:sample['seq_len']]
154
+
155
+ # project 3D mesh to 2D landmark
156
+ projected_vertices = project_points(pred, face_result['trans_mat'], cycled_pose_seq, [height, width])
157
+
158
+ pose_images = []
159
+ for i, verts in enumerate(projected_vertices):
160
+ lmk_img = vis.draw_landmarks((width, height), verts, normed=False)
161
+ pose_images.append(lmk_img)
162
+
163
+ pose_list = []
164
+ # pose_tensor_list = []
165
+
166
+ # pose_transform = transforms.Compose(
167
+ # [transforms.Resize((height, width)), transforms.ToTensor()]
168
+ # )
169
+ args_L = len(pose_images) if length==0 or length > len(pose_images) else length
170
+ #args_L = min(args_L, 9999)
171
+ for pose_image_np in pose_images[: args_L : fi_step]:
172
+ # pose_image_pil = Image.fromarray(cv2.cvtColor(pose_image_np, cv2.COLOR_BGR2RGB))
173
+ # pose_tensor_list.append(pose_transform(pose_image_pil))
174
+ pose_image_np = cv2.resize(pose_image_np, (width, height))
175
+ pose_list.append(pose_image_np)
176
+
177
+ pose_list = np.array(pose_list)
178
+
179
+ video_length = len(pose_list)
180
+
181
+ video = pipe(
182
+ ref_image_pil,
183
+ pose_list,
184
+ ref_pose,
185
+ width,
186
+ height,
187
+ video_length,
188
+ steps,
189
+ cfg,
190
+ generator=generator,
191
+ ).videos
192
+
193
+ video = batch_images_interpolation_tool(video, frame_inter_model, inter_frames=fi_step-1)
194
+
195
+ save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4"
196
+ save_videos_grid(
197
+ video,
198
+ save_path,
199
+ n_rows=1,
200
+ fps=fps,
201
+ )
202
+
203
+ # save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio"
204
+ # save_pil_imgs(video, save_path)
205
+
206
+ # save_path = batch_images_interpolation_tool(save_path, frame_inter_model, int(fps))
207
+
208
+ stream = ffmpeg.input(save_path)
209
+ audio = ffmpeg.input(input_audio)
210
+ ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run()
211
+ os.remove(save_path)
212
+
213
+ return save_path.replace('_noaudio.mp4', '.mp4'), ref_image_pil
214
+
215
+ @spaces.GPU
216
+ def video2video(ref_img, source_video, size=512, steps=25, length=60, seed=42):
217
+ cfg = 3.5
218
+ fi_step = 3
219
+
220
+ generator = torch.manual_seed(seed)
221
+
222
+ lmk_extractor = LMKExtractor()
223
+ vis = FaceMeshVisualizer()
224
+
225
+ width, height = size, size
226
+
227
+ date_str = datetime.now().strftime("%Y%m%d")
228
+ time_str = datetime.now().strftime("%H%M")
229
+ save_dir_name = f"{time_str}--seed_{seed}-{size}x{size}"
230
+
231
+ save_dir = Path(f"v2v_output/{date_str}/{save_dir_name}")
232
+ while os.path.exists(save_dir):
233
+ save_dir = Path(f"v2v_output/{date_str}/{save_dir_name}_{np.random.randint(10000):04d}")
234
+ save_dir.mkdir(exist_ok=True, parents=True)
235
+
236
+ ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
237
+ ref_image_np = crop_face(ref_image_np, lmk_extractor)
238
+ if ref_image_np is None:
239
+ return None, Image.fromarray(ref_img)
240
+
241
+ ref_image_np = cv2.resize(ref_image_np, (size, size))
242
+ ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
243
+
244
+ face_result = lmk_extractor(ref_image_np)
245
+ if face_result is None:
246
+ return None, ref_image_pil
247
+
248
+ lmks = face_result['lmks'].astype(np.float32)
249
+ ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)
250
+
251
+ source_images = read_frames(source_video)
252
+ src_fps = get_fps(source_video)
253
+ pose_transform = transforms.Compose(
254
+ [transforms.Resize((height, width)), transforms.ToTensor()]
255
+ )
256
+
257
+ step = 1
258
+ if src_fps == 60:
259
+ src_fps = 30
260
+ step = 2
261
+
262
+ pose_trans_list = []
263
+ verts_list = []
264
+ bs_list = []
265
+ args_L = len(source_images) if length==0 or length*step > len(source_images) else length*step
266
+ #args_L = min(args_L, 90*step)
267
+ for src_image_pil in source_images[: args_L : step*fi_step]:
268
+ src_img_np = cv2.cvtColor(np.array(src_image_pil), cv2.COLOR_RGB2BGR)
269
+ frame_height, frame_width, _ = src_img_np.shape
270
+ src_img_result = lmk_extractor(src_img_np)
271
+ if src_img_result is None:
272
+ break
273
+ pose_trans_list.append(src_img_result['trans_mat'])
274
+ verts_list.append(src_img_result['lmks3d'])
275
+ bs_list.append(src_img_result['bs'])
276
+
277
+ trans_mat_arr = np.array(pose_trans_list)
278
+ verts_arr = np.array(verts_list)
279
+ bs_arr = np.array(bs_list)
280
+ min_bs_idx = np.argmin(bs_arr.sum(1))
281
+
282
+ # compute delta pose
283
+ pose_arr = np.zeros([trans_mat_arr.shape[0], 6])
284
+
285
+ for i in range(pose_arr.shape[0]):
286
+ euler_angles, translation_vector = matrix_to_euler_and_translation(trans_mat_arr[i]) # real pose of source
287
+ pose_arr[i, :3] = euler_angles
288
+ pose_arr[i, 3:6] = translation_vector
289
+
290
+ init_tran_vec = face_result['trans_mat'][:3, 3] # init translation of tgt
291
+ pose_arr[:, 3:6] = pose_arr[:, 3:6] - pose_arr[0, 3:6] + init_tran_vec # (relative translation of source) + (init translation of tgt)
292
+
293
+ pose_arr_smooth = smooth_pose_seq(pose_arr, window_size=3)
294
+ pose_mat_smooth = [euler_and_translation_to_matrix(pose_arr_smooth[i][:3], pose_arr_smooth[i][3:6]) for i in range(pose_arr_smooth.shape[0])]
295
+ pose_mat_smooth = np.array(pose_mat_smooth)
296
+
297
+ # face retarget
298
+ verts_arr = verts_arr - verts_arr[min_bs_idx] + face_result['lmks3d']
299
+ # project 3D mesh to 2D landmark
300
+ projected_vertices = project_points_with_trans(verts_arr, pose_mat_smooth, [frame_height, frame_width])
301
+
302
+ pose_list = []
303
+ for i, verts in enumerate(projected_vertices):
304
+ lmk_img = vis.draw_landmarks((frame_width, frame_height), verts, normed=False)
305
+ pose_image_np = cv2.resize(lmk_img, (width, height))
306
+ pose_list.append(pose_image_np)
307
+
308
+ pose_list = np.array(pose_list)
309
+
310
+ video_length = len(pose_list)
311
+
312
+ video = pipe(
313
+ ref_image_pil,
314
+ pose_list,
315
+ ref_pose,
316
+ width,
317
+ height,
318
+ video_length,
319
+ steps,
320
+ cfg,
321
+ generator=generator,
322
+ ).videos
323
+
324
+ video = batch_images_interpolation_tool(video, frame_inter_model, inter_frames=fi_step-1)
325
+
326
+ save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4"
327
+ save_videos_grid(
328
+ video,
329
+ save_path,
330
+ n_rows=1,
331
+ fps=src_fps,
332
+ )
333
+
334
+ # save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio"
335
+ # save_pil_imgs(video, save_path)
336
+
337
+ # save_path = batch_images_interpolation_tool(save_path, frame_inter_model, int(src_fps))
338
+
339
+ audio_output = f'{save_dir}/audio_from_video.aac'
340
+ # extract audio
341
+ try:
342
+ ffmpeg.input(source_video).output(audio_output, acodec='copy').run()
343
+ # merge audio and video
344
+ stream = ffmpeg.input(save_path)
345
+ audio = ffmpeg.input(audio_output)
346
+ ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac', shortest=None).run()
347
+
348
+ os.remove(save_path)
349
+ os.remove(audio_output)
350
+ except:
351
+ shutil.move(
352
+ save_path,
353
+ save_path.replace('_noaudio.mp4', '.mp4')
354
+ )
355
+
356
+ return save_path.replace('_noaudio.mp4', '.mp4'), ref_image_pil
357
+
358
+
359
+ ################# GUI ################
360
+
361
+ title = r"""
362
+ <h1>AniPortrait</h1>
363
+ """
364
+
365
+ description = r"""
366
+ <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/Zejun-Yang/AniPortrait' target='_blank'><b>AniPortrait: Audio-Driven Synthesis of Photorealistic Portrait Animations</b></a>.<br>
367
+ """
368
+
369
+ tips = r"""
370
+ Here is an accelerated version of AniPortrait. Due to limitations in computing power, the wait time will be quite long. Please utilize the source code to experience the full performance.
371
+ """
372
+
373
+ with gr.Blocks() as demo:
374
+
375
+ gr.Markdown(title)
376
+ gr.Markdown(description)
377
+ gr.Markdown(tips)
378
+
379
+ with gr.Tab("Audio2video"):
380
+ with gr.Row():
381
+ with gr.Column():
382
+ with gr.Row():
383
+ a2v_input_audio = gr.Audio(sources=["upload", "microphone"], type="filepath", editable=True, label="Input audio", interactive=True)
384
+ a2v_ref_img = gr.Image(label="Upload reference image", sources="upload")
385
+ a2v_headpose_video = gr.Video(label="Option: upload head pose reference video", sources="upload")
386
+
387
+ with gr.Row():
388
+ a2v_size_slider = gr.Slider(minimum=256, maximum=512, step=8, value=384, label="Video size (-W & -H)")
389
+ a2v_step_slider = gr.Slider(minimum=5, maximum=20, step=1, value=15, label="Steps (--steps)")
390
+
391
+ with gr.Row():
392
+ a2v_length = gr.Slider(minimum=0, maximum=9999, step=1, value=30, label="Length (-L) (Set to 0 to automatically calculate length)")
393
+ a2v_seed = gr.Number(value=42, label="Seed (--seed)")
394
+
395
+ a2v_botton = gr.Button("Generate", variant="primary")
396
+ a2v_output_video = gr.PlayableVideo(label="Result", interactive=False)
397
+
398
+ gr.Examples(
399
+ examples=[
400
+ ["configs/inference/audio/lyl.wav", "configs/inference/ref_images/Aragaki.png", None],
401
+ ["configs/inference/audio/lyl.wav", "configs/inference/ref_images/solo.png", None],
402
+ ["configs/inference/audio/lyl.wav", "configs/inference/ref_images/lyl.png", "configs/inference/head_pose_temp/pose_ref_video.mp4"],
403
+ ],
404
+ inputs=[a2v_input_audio, a2v_ref_img, a2v_headpose_video],
405
+ )
406
+
407
+
408
+ with gr.Tab("Video2video"):
409
+ with gr.Row():
410
+ with gr.Column():
411
+ with gr.Row():
412
+ v2v_ref_img = gr.Image(label="Upload reference image", sources="upload")
413
+ v2v_source_video = gr.Video(label="Upload source video", sources="upload")
414
+
415
+ with gr.Row():
416
+ v2v_size_slider = gr.Slider(minimum=256, maximum=512, step=8, value=384, label="Video size (-W & -H)")
417
+ v2v_step_slider = gr.Slider(minimum=5, maximum=20, step=1, value=15, label="Steps (--steps)")
418
+
419
+ with gr.Row():
420
+ v2v_length = gr.Slider(minimum=0, maximum=999, step=1, value=30, label="Length (-L) (Set to 0 to automatically calculate length)")
421
+ v2v_seed = gr.Number(value=42, label="Seed (--seed)")
422
+
423
+ v2v_botton = gr.Button("Generate", variant="primary")
424
+ v2v_output_video = gr.PlayableVideo(label="Result", interactive=False)
425
+
426
+ gr.Examples(
427
+ examples=[
428
+ ["configs/inference/ref_images/Aragaki.png", "configs/inference/video/Aragaki_song.mp4"],
429
+ ["configs/inference/ref_images/solo.png", "configs/inference/video/Aragaki_song.mp4"],
430
+ ["configs/inference/ref_images/lyl.png", "configs/inference/head_pose_temp/pose_ref_video.mp4"],
431
+ ],
432
+ inputs=[v2v_ref_img, v2v_source_video, a2v_headpose_video],
433
+ )
434
+
435
+ a2v_botton.click(
436
+ fn=audio2video,
437
+ inputs=[a2v_input_audio, a2v_ref_img, a2v_headpose_video,
438
+ a2v_size_slider, a2v_step_slider, a2v_length, a2v_seed],
439
+ outputs=[a2v_output_video, a2v_ref_img]
440
+ )
441
+ v2v_botton.click(
442
+ fn=video2video,
443
+ inputs=[v2v_ref_img, v2v_source_video,
444
+ v2v_size_slider, v2v_step_slider, v2v_length, v2v_seed],
445
+ outputs=[v2v_output_video, v2v_ref_img]
446
+ )
447
+
448
+ demo.launch(share=True)
apphfupscaletest.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import shutil
4
+ import ffmpeg
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ import numpy as np
8
+ import cv2
9
+ import torch
10
+ import spaces
11
+
12
+ from diffusers import AutoencoderKL, DDIMScheduler
13
+ from einops import repeat
14
+ from omegaconf import OmegaConf
15
+ from PIL import Image
16
+ from torchvision import transforms
17
+ from transformers import CLIPVisionModelWithProjection
18
+ from face_enhancer import (
19
+ get_available_enhancer_names,
20
+ load_face_enhancer_model,
21
+ cv2_interpolations,
22
+ )
23
+
24
+
25
+ from src.models.pose_guider import PoseGuider
26
+ from src.models.unet_2d_condition import UNet2DConditionModel
27
+ from src.models.unet_3d import UNet3DConditionModel
28
+ from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
29
+ from src.utils.util import get_fps, read_frames, save_videos_grid, save_pil_imgs
30
+
31
+ from src.audio_models.model import Audio2MeshModel
32
+ from src.utils.audio_util import prepare_audio_feature
33
+ from src.utils.mp_utils import LMKExtractor
34
+ from src.utils.draw_util import FaceMeshVisualizer
35
+ from src.utils.pose_util import (
36
+ project_points,
37
+ project_points_with_trans,
38
+ matrix_to_euler_and_translation,
39
+ euler_and_translation_to_matrix,
40
+ )
41
+ from src.utils.crop_face_single import crop_face
42
+ from src.audio2vid import get_headpose_temp, smooth_pose_seq
43
+ from src.utils.frame_interpolation import (
44
+ init_frame_interpolation_model,
45
+ batch_images_interpolation_tool,
46
+ )
47
+
48
+
49
+ config = OmegaConf.load("./configs/prompts/animation_audio.yaml")
50
+ if config.weight_dtype == "fp16":
51
+ weight_dtype = torch.float16
52
+ else:
53
+ weight_dtype = torch.float32
54
+
55
+ audio_infer_config = OmegaConf.load(config.audio_inference_config)
56
+ # prepare model
57
+ a2m_model = Audio2MeshModel(audio_infer_config["a2m_model"])
58
+ a2m_model.load_state_dict(
59
+ torch.load(audio_infer_config["pretrained_model"]["a2m_ckpt"], map_location="cpu"),
60
+ strict=False,
61
+ )
62
+ a2m_model.cuda().eval()
63
+
64
+ vae = AutoencoderKL.from_pretrained(
65
+ config.pretrained_vae_path,
66
+ ).to("cuda", dtype=weight_dtype)
67
+
68
+ reference_unet = UNet2DConditionModel.from_pretrained(
69
+ config.pretrained_base_model_path,
70
+ subfolder="unet",
71
+ ).to(dtype=weight_dtype, device="cuda")
72
+
73
+ inference_config_path = config.inference_config
74
+ infer_config = OmegaConf.load(inference_config_path)
75
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
76
+ config.pretrained_base_model_path,
77
+ config.motion_module_path,
78
+ subfolder="unet",
79
+ unet_additional_kwargs=infer_config.unet_additional_kwargs,
80
+ ).to(dtype=weight_dtype, device="cuda")
81
+
82
+ pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(
83
+ device="cuda", dtype=weight_dtype
84
+ ) # not use cross attention
85
+
86
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(config.image_encoder_path).to(
87
+ dtype=weight_dtype, device="cuda"
88
+ )
89
+
90
+ sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
91
+ scheduler = DDIMScheduler(**sched_kwargs)
92
+
93
+ # load pretrained weights
94
+ denoising_unet.load_state_dict(
95
+ torch.load(config.denoising_unet_path, map_location="cpu"),
96
+ strict=False,
97
+ )
98
+ reference_unet.load_state_dict(
99
+ torch.load(config.reference_unet_path, map_location="cpu"),
100
+ )
101
+ pose_guider.load_state_dict(
102
+ torch.load(config.pose_guider_path, map_location="cpu"),
103
+ )
104
+
105
+ pipe = Pose2VideoPipeline(
106
+ vae=vae,
107
+ image_encoder=image_enc,
108
+ reference_unet=reference_unet,
109
+ denoising_unet=denoising_unet,
110
+ pose_guider=pose_guider,
111
+ scheduler=scheduler,
112
+ )
113
+ pipe = pipe.to("cuda", dtype=weight_dtype)
114
+
115
+ # lmk_extractor = LMKExtractor()
116
+ # vis = FaceMeshVisualizer()
117
+
118
+ frame_inter_model = init_frame_interpolation_model()
119
+
120
+
121
+ @spaces.GPU
122
+ def audio2video(
123
+ input_audio, ref_img, headpose_video=None, size=512, steps=25, length=60, seed=42
124
+ ):
125
+ fps = 30
126
+ cfg = 3.5
127
+ fi_step = 3
128
+
129
+ generator = torch.manual_seed(seed)
130
+
131
+ lmk_extractor = LMKExtractor()
132
+ vis = FaceMeshVisualizer()
133
+
134
+ width, height = size, size
135
+
136
+ date_str = datetime.now().strftime("%Y%m%d")
137
+ time_str = datetime.now().strftime("%H%M")
138
+ save_dir_name = f"{time_str}--seed_{seed}-{size}x{size}"
139
+
140
+ save_dir = Path(f"a2v_output/{date_str}/{save_dir_name}")
141
+ while os.path.exists(save_dir):
142
+ save_dir = Path(
143
+ f"a2v_output/{date_str}/{save_dir_name}_{np.random.randint(10000):04d}"
144
+ )
145
+ save_dir.mkdir(exist_ok=True, parents=True)
146
+
147
+ ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
148
+ ref_image_np = crop_face(ref_image_np, lmk_extractor)
149
+ if ref_image_np is None:
150
+ return None, Image.fromarray(ref_img)
151
+
152
+ ref_image_np = cv2.resize(ref_image_np, (size, size))
153
+ ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
154
+
155
+ face_result = lmk_extractor(ref_image_np)
156
+ if face_result is None:
157
+ return None, ref_image_pil
158
+
159
+ lmks = face_result["lmks"].astype(np.float32)
160
+ ref_pose = vis.draw_landmarks(
161
+ (ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True
162
+ )
163
+
164
+ sample = prepare_audio_feature(
165
+ input_audio, wav2vec_model_path=audio_infer_config["a2m_model"]["model_path"]
166
+ )
167
+ sample["audio_feature"] = torch.from_numpy(sample["audio_feature"]).float().cuda()
168
+ sample["audio_feature"] = sample["audio_feature"].unsqueeze(0)
169
+
170
+ # inference
171
+ pred = a2m_model.infer(sample["audio_feature"], sample["seq_len"])
172
+ pred = pred.squeeze().detach().cpu().numpy()
173
+ pred = pred.reshape(pred.shape[0], -1, 3)
174
+ pred = pred + face_result["lmks3d"]
175
+
176
+ if headpose_video is not None:
177
+ pose_seq = get_headpose_temp(headpose_video)
178
+ else:
179
+ pose_seq = np.load(config["pose_temp"])
180
+ mirrored_pose_seq = np.concatenate((pose_seq, pose_seq[-2:0:-1]), axis=0)
181
+ cycled_pose_seq = np.tile(
182
+ mirrored_pose_seq, (sample["seq_len"] // len(mirrored_pose_seq) + 1, 1)
183
+ )[: sample["seq_len"]]
184
+
185
+ # project 3D mesh to 2D landmark
186
+ projected_vertices = project_points(
187
+ pred, face_result["trans_mat"], cycled_pose_seq, [height, width]
188
+ )
189
+
190
+ pose_images = []
191
+ for i, verts in enumerate(projected_vertices):
192
+ lmk_img = vis.draw_landmarks((width, height), verts, normed=False)
193
+ pose_images.append(lmk_img)
194
+
195
+ pose_list = []
196
+ # pose_tensor_list = []
197
+
198
+ # pose_transform = transforms.Compose(
199
+ # [transforms.Resize((height, width)), transforms.ToTensor()]
200
+ # )
201
+ args_L = len(pose_images) if length == 0 or length > len(pose_images) else length
202
+ # args_L = min(args_L, 9999)
203
+ for pose_image_np in pose_images[:args_L:fi_step]:
204
+ # pose_image_pil = Image.fromarray(cv2.cvtColor(pose_image_np, cv2.COLOR_BGR2RGB))
205
+ # pose_tensor_list.append(pose_transform(pose_image_pil))
206
+ pose_image_np = cv2.resize(pose_image_np, (width, height))
207
+ pose_list.append(pose_image_np)
208
+
209
+ pose_list = np.array(pose_list)
210
+
211
+ video_length = len(pose_list)
212
+
213
+ video = pipe(
214
+ ref_image_pil,
215
+ pose_list,
216
+ ref_pose,
217
+ width,
218
+ height,
219
+ video_length,
220
+ steps,
221
+ cfg,
222
+ generator=generator,
223
+ ).videos
224
+
225
+ video = batch_images_interpolation_tool(
226
+ video, frame_inter_model, inter_frames=fi_step - 1
227
+ )
228
+
229
+ save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4"
230
+ save_videos_grid(
231
+ video,
232
+ save_path,
233
+ n_rows=1,
234
+ fps=fps,
235
+ )
236
+
237
+ # save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio"
238
+ # save_pil_imgs(video, save_path)
239
+
240
+ # save_path = batch_images_interpolation_tool(save_path, frame_inter_model, int(fps))
241
+
242
+ stream = ffmpeg.input(save_path)
243
+ audio = ffmpeg.input(input_audio)
244
+ ffmpeg.output(
245
+ stream.video,
246
+ audio.audio,
247
+ save_path.replace("_noaudio.mp4", ".mp4"),
248
+ vcodec="copy",
249
+ acodec="aac",
250
+ shortest=None,
251
+ ).run()
252
+ os.remove(save_path)
253
+
254
+ return save_path.replace("_noaudio.mp4", ".mp4"), ref_image_pil
255
+
256
+
257
+ @spaces.GPU
258
+ def video2video(ref_img, source_video, size=512, steps=25, length=60, seed=42):
259
+ cfg = 3.5
260
+ fi_step = 3
261
+
262
+ generator = torch.manual_seed(seed)
263
+
264
+ lmk_extractor = LMKExtractor()
265
+ vis = FaceMeshVisualizer()
266
+
267
+ width, height = size, size
268
+
269
+ date_str = datetime.now().strftime("%Y%m%d")
270
+ time_str = datetime.now().strftime("%H%M")
271
+ save_dir_name = f"{time_str}--seed_{seed}-{size}x{size}"
272
+
273
+ save_dir = Path(f"v2v_output/{date_str}/{save_dir_name}")
274
+ while os.path.exists(save_dir):
275
+ save_dir = Path(
276
+ f"v2v_output/{date_str}/{save_dir_name}_{np.random.randint(10000):04d}"
277
+ )
278
+ save_dir.mkdir(exist_ok=True, parents=True)
279
+
280
+ ref_image_np = cv2.cvtColor(ref_img, cv2.COLOR_RGB2BGR)
281
+ ref_image_np = crop_face(ref_image_np, lmk_extractor)
282
+ if ref_image_np is None:
283
+ return None, Image.fromarray(ref_img)
284
+
285
+ ref_image_np = cv2.resize(ref_image_np, (size, size))
286
+ ref_image_pil = Image.fromarray(cv2.cvtColor(ref_image_np, cv2.COLOR_BGR2RGB))
287
+
288
+ face_result = lmk_extractor(ref_image_np)
289
+ if face_result is None:
290
+ return None, ref_image_pil
291
+
292
+ lmks = face_result["lmks"].astype(np.float32)
293
+ ref_pose = vis.draw_landmarks(
294
+ (ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True
295
+ )
296
+
297
+ source_images = read_frames(source_video)
298
+ src_fps = get_fps(source_video)
299
+ pose_transform = transforms.Compose(
300
+ [transforms.Resize((height, width)), transforms.ToTensor()]
301
+ )
302
+
303
+ step = 1
304
+ if src_fps == 60:
305
+ src_fps = 30
306
+ step = 2
307
+
308
+ pose_trans_list = []
309
+ verts_list = []
310
+ bs_list = []
311
+ args_L = (
312
+ len(source_images)
313
+ if length == 0 or length * step > len(source_images)
314
+ else length * step
315
+ )
316
+ # args_L = min(args_L, 90*step)
317
+ for src_image_pil in source_images[: args_L : step * fi_step]:
318
+ src_img_np = cv2.cvtColor(np.array(src_image_pil), cv2.COLOR_RGB2BGR)
319
+ frame_height, frame_width, _ = src_img_np.shape
320
+ src_img_result = lmk_extractor(src_img_np)
321
+ if src_img_result is None:
322
+ break
323
+ pose_trans_list.append(src_img_result["trans_mat"])
324
+ verts_list.append(src_img_result["lmks3d"])
325
+ bs_list.append(src_img_result["bs"])
326
+
327
+ trans_mat_arr = np.array(pose_trans_list)
328
+ verts_arr = np.array(verts_list)
329
+ bs_arr = np.array(bs_list)
330
+ min_bs_idx = np.argmin(bs_arr.sum(1))
331
+
332
+ # compute delta pose
333
+ pose_arr = np.zeros([trans_mat_arr.shape[0], 6])
334
+
335
+ for i in range(pose_arr.shape[0]):
336
+ euler_angles, translation_vector = matrix_to_euler_and_translation(
337
+ trans_mat_arr[i]
338
+ ) # real pose of source
339
+ pose_arr[i, :3] = euler_angles
340
+ pose_arr[i, 3:6] = translation_vector
341
+
342
+ init_tran_vec = face_result["trans_mat"][:3, 3] # init translation of tgt
343
+ pose_arr[:, 3:6] = (
344
+ pose_arr[:, 3:6] - pose_arr[0, 3:6] + init_tran_vec
345
+ ) # (relative translation of source) + (init translation of tgt)
346
+
347
+ pose_arr_smooth = smooth_pose_seq(pose_arr, window_size=3)
348
+ pose_mat_smooth = [
349
+ euler_and_translation_to_matrix(pose_arr_smooth[i][:3], pose_arr_smooth[i][3:6])
350
+ for i in range(pose_arr_smooth.shape[0])
351
+ ]
352
+ pose_mat_smooth = np.array(pose_mat_smooth)
353
+
354
+ # face retarget
355
+ verts_arr = verts_arr - verts_arr[min_bs_idx] + face_result["lmks3d"]
356
+ # project 3D mesh to 2D landmark
357
+ projected_vertices = project_points_with_trans(
358
+ verts_arr, pose_mat_smooth, [frame_height, frame_width]
359
+ )
360
+
361
+ pose_list = []
362
+ for i, verts in enumerate(projected_vertices):
363
+ lmk_img = vis.draw_landmarks((frame_width, frame_height), verts, normed=False)
364
+ pose_image_np = cv2.resize(lmk_img, (width, height))
365
+ pose_list.append(pose_image_np)
366
+
367
+ pose_list = np.array(pose_list)
368
+
369
+ video_length = len(pose_list)
370
+
371
+ video = pipe(
372
+ ref_image_pil,
373
+ pose_list,
374
+ ref_pose,
375
+ width,
376
+ height,
377
+ video_length,
378
+ steps,
379
+ cfg,
380
+ generator=generator,
381
+ ).videos
382
+
383
+ video = batch_images_interpolation_tool(
384
+ video, frame_inter_model, inter_frames=fi_step - 1
385
+ )
386
+
387
+ save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio.mp4"
388
+ save_videos_grid(
389
+ video,
390
+ save_path,
391
+ n_rows=1,
392
+ fps=src_fps,
393
+ )
394
+
395
+ # save_path = f"{save_dir}/{size}x{size}_{time_str}_noaudio"
396
+ # save_pil_imgs(video, save_path)
397
+
398
+ # save_path = batch_images_interpolation_tool(save_path, frame_inter_model, int(src_fps))
399
+
400
+ audio_output = f"{save_dir}/audio_from_video.aac"
401
+ # extract audio
402
+ try:
403
+ ffmpeg.input(source_video).output(audio_output, acodec="copy").run()
404
+ # merge audio and video
405
+ stream = ffmpeg.input(save_path)
406
+ audio = ffmpeg.input(audio_output)
407
+ ffmpeg.output(
408
+ stream.video,
409
+ audio.audio,
410
+ save_path.replace("_noaudio.mp4", ".mp4"),
411
+ vcodec="copy",
412
+ acodec="aac",
413
+ shortest=None,
414
+ ).run()
415
+
416
+ os.remove(save_path)
417
+ os.remove(audio_output)
418
+ except:
419
+ shutil.move(save_path, save_path.replace("_noaudio.mp4", ".mp4"))
420
+
421
+ return save_path.replace("_noaudio.mp4", ".mp4"), ref_image_pil
422
+
423
+
424
+ ################# GUI ################
425
+
426
+ title = r"""
427
+ <h1>AniPortrait</h1>
428
+ """
429
+
430
+ description = r"""
431
+ <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/Zejun-Yang/AniPortrait' target='_blank'><b>AniPortrait: Audio-Driven Synthesis of Photorealistic Portrait Animations</b></a>.<br>
432
+ """
433
+
434
+ tips = r"""
435
+ Here is an accelerated version of AniPortrait. Due to limitations in computing power, the wait time will be quite long. Please utilize the source code to experience the full performance.
436
+ """
437
+
438
+ with gr.Blocks() as demo:
439
+
440
+ gr.Markdown(title)
441
+ gr.Markdown(description)
442
+ gr.Markdown(tips)
443
+
444
+ with gr.Tab("Audio2video"):
445
+ with gr.Row():
446
+ with gr.Column():
447
+ with gr.Row():
448
+ a2v_input_audio = gr.Audio(
449
+ sources=["upload", "microphone"],
450
+ type="filepath",
451
+ editable=True,
452
+ label="Input audio",
453
+ interactive=True,
454
+ )
455
+ a2v_ref_img = gr.Image(
456
+ label="Upload reference image", sources="upload"
457
+ )
458
+ a2v_headpose_video = gr.Video(
459
+ label="Option: upload head pose reference video",
460
+ sources="upload",
461
+ )
462
+
463
+ with gr.Row():
464
+ a2v_size_slider = gr.Slider(
465
+ minimum=256,
466
+ maximum=512,
467
+ step=8,
468
+ value=384,
469
+ label="Video size (-W & -H)",
470
+ )
471
+ a2v_step_slider = gr.Slider(
472
+ minimum=5, maximum=20, step=1, value=15, label="Steps (--steps)"
473
+ )
474
+
475
+ with gr.Row():
476
+ a2v_length = gr.Slider(
477
+ minimum=0,
478
+ maximum=9999,
479
+ step=1,
480
+ value=30,
481
+ label="Length (-L) (Set to 0 to automatically calculate length)",
482
+ )
483
+ a2v_seed = gr.Number(value=42, label="Seed (--seed)")
484
+
485
+ a2v_botton = gr.Button("Generate", variant="primary")
486
+ a2v_output_video = gr.PlayableVideo(label="Result", interactive=False)
487
+
488
+ gr.Examples(
489
+ examples=[
490
+ [
491
+ "configs/inference/audio/lyl.wav",
492
+ "configs/inference/ref_images/Aragaki.png",
493
+ None,
494
+ ],
495
+ [
496
+ "configs/inference/audio/lyl.wav",
497
+ "configs/inference/ref_images/solo.png",
498
+ None,
499
+ ],
500
+ [
501
+ "configs/inference/audio/lyl.wav",
502
+ "configs/inference/ref_images/lyl.png",
503
+ "configs/inference/head_pose_temp/pose_ref_video.mp4",
504
+ ],
505
+ ],
506
+ inputs=[a2v_input_audio, a2v_ref_img, a2v_headpose_video],
507
+ )
508
+
509
+ with gr.Tab("Video2video"):
510
+ with gr.Row():
511
+ with gr.Column():
512
+ with gr.Row():
513
+ v2v_ref_img = gr.Image(
514
+ label="Upload reference image", sources="upload"
515
+ )
516
+ v2v_source_video = gr.Video(
517
+ label="Upload source video", sources="upload"
518
+ )
519
+
520
+ with gr.Row():
521
+ v2v_size_slider = gr.Slider(
522
+ minimum=256,
523
+ maximum=512,
524
+ step=8,
525
+ value=384,
526
+ label="Video size (-W & -H)",
527
+ )
528
+ v2v_step_slider = gr.Slider(
529
+ minimum=5, maximum=20, step=1, value=15, label="Steps (--steps)"
530
+ )
531
+
532
+ with gr.Row():
533
+ v2v_length = gr.Slider(
534
+ minimum=0,
535
+ maximum=999,
536
+ step=1,
537
+ value=30,
538
+ label="Length (-L) (Set to 0 to automatically calculate length)",
539
+ )
540
+ v2v_seed = gr.Number(value=42, label="Seed (--seed)")
541
+
542
+ v2v_botton = gr.Button("Generate", variant="primary")
543
+ v2v_output_video = gr.PlayableVideo(label="Result", interactive=False)
544
+
545
+ gr.Examples(
546
+ examples=[
547
+ [
548
+ "configs/inference/ref_images/Aragaki.png",
549
+ "configs/inference/video/Aragaki_song.mp4",
550
+ ],
551
+ [
552
+ "configs/inference/ref_images/solo.png",
553
+ "configs/inference/video/Aragaki_song.mp4",
554
+ ],
555
+ [
556
+ "configs/inference/ref_images/lyl.png",
557
+ "configs/inference/head_pose_temp/pose_ref_video.mp4",
558
+ ],
559
+ ],
560
+ inputs=[v2v_ref_img, v2v_source_video, a2v_headpose_video],
561
+ )
562
+
563
+ with gr.Tab("Video Upscale"):
564
+ with gr.Row():
565
+ with gr.Column():
566
+ with gr.Row():
567
+ upscale_video = gr.Video(label="Upload video", sources="upload")
568
+ upscale_method = gr.Dropdown(
569
+ get_available_enhancer_names(),
570
+ label="Upscale method",
571
+ value="REAL-ESRGAN 4x",
572
+ )
573
+ upscale_botton = gr.Button("Upscale", variant="primary")
574
+ upscale_output_video = gr.PlayableVideo(
575
+ label="Upscaled video", interactive=False
576
+ )
577
+
578
+ upscale_botton.click(
579
+ fn=lambda video, method: upscale_video_with_face_enhancer(video, method),
580
+ inputs=[upscale_video, upscale_method],
581
+ outputs=[upscale_output_video],
582
+ )
583
+
584
+ a2v_botton.click(
585
+ fn=audio2video,
586
+ inputs=[
587
+ a2v_input_audio,
588
+ a2v_ref_img,
589
+ a2v_headpose_video,
590
+ a2v_size_slider,
591
+ a2v_step_slider,
592
+ a2v_length,
593
+ a2v_seed,
594
+ ],
595
+ outputs=[a2v_output_video, a2v_ref_img],
596
+ )
597
+ v2v_botton.click(
598
+ fn=video2video,
599
+ inputs=[
600
+ v2v_ref_img,
601
+ v2v_source_video,
602
+ v2v_size_slider,
603
+ v2v_step_slider,
604
+ v2v_length,
605
+ v2v_seed,
606
+ ],
607
+ outputs=[v2v_output_video, v2v_ref_img],
608
+ )
609
+
610
+ demo.launch(share=True)