|
from src.utils import * |
|
from src.flow_utils import warp_tensor |
|
import torch |
|
import torchvision |
|
import gc |
|
|
|
""" |
|
========================================================================== |
|
* step(): one DDPM step with background smoothing |
|
* inference(): translate one batch with FRESCO and background smoothing |
|
========================================================================== |
|
""" |
|
|
|
def step(pipe, model_output, timestep, sample, generator, repeat_noise=False, |
|
visualize_pipeline=False, flows=None, occs=None, saliency=None): |
|
""" |
|
DDPM step with background smoothing |
|
* background smoothing: warp the background region of the previous frame to the current frame |
|
""" |
|
scheduler = pipe.scheduler |
|
|
|
prev_timestep = scheduler.previous_timestep(timestep) |
|
|
|
|
|
alpha_prod_t = scheduler.alphas_cumprod[timestep] |
|
alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.one |
|
|
|
beta_prod_t = 1 - alpha_prod_t |
|
beta_prod_t_prev = 1 - alpha_prod_t_prev |
|
current_alpha_t = alpha_prod_t / alpha_prod_t_prev |
|
current_beta_t = 1 - current_alpha_t |
|
|
|
|
|
|
|
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) |
|
|
|
""" |
|
[HACK] add background smoothing |
|
decode the feature |
|
warp the feature of f_{i-1} |
|
fuse the warped f_{i-1} with f_{i} in the non-salient region (i.e., background) |
|
encode the fused feature |
|
""" |
|
if saliency is not None and flows is not None and occs is not None: |
|
image = pipe.vae.decode(pred_original_sample / pipe.vae.config.scaling_factor).sample |
|
image = warp_tensor(image, flows, occs, saliency, unet_chunk_size=1) |
|
pred_original_sample = pipe.vae.config.scaling_factor * pipe.vae.encode(image).latent_dist.sample() |
|
|
|
|
|
|
|
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t |
|
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t |
|
|
|
|
|
|
|
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample |
|
|
|
|
|
variance = beta_prod_t_prev / beta_prod_t * current_beta_t |
|
variance = torch.clamp(variance, min=1e-20) |
|
variance = (variance ** 0.5) * torch.randn(model_output.shape, generator=generator, |
|
device=model_output.device, dtype=model_output.dtype) |
|
""" |
|
[HACK] background smoothing |
|
applying the same noise could be good for static background |
|
""" |
|
if repeat_noise: |
|
variance = variance[0:1].repeat(model_output.shape[0],1,1,1) |
|
|
|
if visualize_pipeline: |
|
image = pipe.vae.decode(pred_original_sample / pipe.vae.config.scaling_factor).sample |
|
viz = torchvision.utils.make_grid(torch.clamp(image, -1, 1), image.shape[0], 1) |
|
visualize(viz.cpu(), 90) |
|
|
|
pred_prev_sample = pred_prev_sample + variance |
|
|
|
return (pred_prev_sample, pred_original_sample) |
|
|
|
|
|
@torch.no_grad() |
|
def inference(pipe, controlnet, frescoProc, |
|
imgs, prompt_embeds, edges, timesteps, |
|
cond_scale=[0.7]*20, num_inference_steps=20, num_warmup_steps=6, |
|
do_classifier_free_guidance=True, seed=0, guidance_scale=7.5, use_controlnet=True, |
|
record_latents=[], propagation_mode=False, visualize_pipeline=False, |
|
flows = None, occs = None, saliency=None, repeat_noise=False, |
|
num_intraattn_steps = 1, step_interattn_end = 350, bg_smoothing_steps = [16,17]): |
|
""" |
|
video-to-video translation inference pipeline with FRESCO |
|
* add controlnet and SDEdit |
|
* add FRESCO-guided attention |
|
* add FRESCO-guided optimization |
|
* add background smoothing |
|
* add support for inter-batch long video translation |
|
|
|
[input of the original pipe] |
|
pipe: base diffusion model |
|
imgs: a batch of the input frames |
|
prompt_embeds: prompts |
|
num_inference_steps: number of DDPM steps |
|
timesteps: generated by pipe.scheduler.set_timesteps(num_inference_steps) |
|
do_classifier_free_guidance: cfg, should be always true |
|
guidance_scale: cfg scale |
|
seed |
|
|
|
[input of SDEdit] |
|
num_warmup_steps: skip the first num_warmup_steps DDPM steps |
|
|
|
[input of controlnet] |
|
use_controlnet: bool, whether using controlnet |
|
controlnet: controlnet model |
|
edges: input for controlnet (edge/stroke/depth, etc.) |
|
cond_scale: controlnet scale |
|
|
|
[input of FRESCO] |
|
frescoProc: FRESCO attention controller |
|
flows: optical flows |
|
occs: occlusion mask |
|
num_intraattn_steps: apply num_interattn_steps steps of spatial-guided attention |
|
step_interattn_end: apply temporal-guided attention in [step_interattn_end, 1000] steps |
|
|
|
[input for background smoothing] |
|
saliency: saliency mask |
|
repeat_noise: bool, use the same noise for all frames |
|
bg_smoothing_steps: apply background smoothing in bg_smoothing_steps |
|
|
|
[input for long video translation] |
|
record_latents: recorded latents in the last batch |
|
propagation_mode: bool, whether this is not the first batch |
|
|
|
[output] |
|
latents: a batch of latents of the translated frames |
|
""" |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
device = pipe._execution_device |
|
noise_scheduler = pipe.scheduler |
|
generator = torch.Generator(device=device).manual_seed(seed) |
|
B, C, H, W = imgs.shape |
|
latents = pipe.prepare_latents( |
|
B, |
|
pipe.unet.config.in_channels, |
|
H, |
|
W, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
latents = None, |
|
) |
|
|
|
if repeat_noise: |
|
latents = latents[0:1].repeat(B,1,1,1).detach() |
|
|
|
if num_warmup_steps < 0: |
|
latents_init = latents.detach() |
|
num_warmup_steps = 0 |
|
else: |
|
|
|
latent_x0 = pipe.vae.config.scaling_factor * pipe.vae.encode(imgs.to(pipe.unet.dtype)).latent_dist.sample() |
|
latents_init = noise_scheduler.add_noise(latent_x0, latents, timesteps[num_warmup_steps]).detach() |
|
|
|
|
|
with pipe.progress_bar(total=num_inference_steps-num_warmup_steps) as progress_bar: |
|
latents = latents_init |
|
for i, t in enumerate(timesteps[num_warmup_steps:]): |
|
""" |
|
[HACK] control the steps to apply spatial/temporal-guided attention |
|
[HACK] record and restore latents from previous batch |
|
""" |
|
if i >= num_intraattn_steps: |
|
frescoProc.controller.disable_intraattn() |
|
if t < step_interattn_end: |
|
frescoProc.controller.disable_interattn() |
|
if propagation_mode: |
|
latents[0:2] = record_latents[i].detach().clone() |
|
record_latents[i] = latents[[0,len(latents)-1]].detach().clone() |
|
else: |
|
record_latents += [latents[[0,len(latents)-1]].detach().clone()] |
|
|
|
|
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
|
|
|
if use_controlnet: |
|
control_model_input = latent_model_input |
|
controlnet_prompt_embeds = prompt_embeds |
|
|
|
down_block_res_samples, mid_block_res_sample = controlnet( |
|
control_model_input, |
|
t, |
|
encoder_hidden_states=controlnet_prompt_embeds, |
|
controlnet_cond=edges, |
|
conditioning_scale=cond_scale[i+num_warmup_steps], |
|
guess_mode=False, |
|
return_dict=False, |
|
) |
|
else: |
|
down_block_res_samples, mid_block_res_sample = None, None |
|
|
|
|
|
noise_pred = pipe.unet( |
|
latent_model_input, |
|
t, |
|
encoder_hidden_states=prompt_embeds, |
|
cross_attention_kwargs=None, |
|
down_block_additional_residuals=down_block_res_samples, |
|
mid_block_additional_residual=mid_block_res_sample, |
|
return_dict=False, |
|
)[0] |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
""" |
|
[HACK] background smoothing |
|
Note: bg_smoothing_steps should be rescaled based on num_inference_steps |
|
current [16,17] is based on num_inference_steps=20 |
|
""" |
|
if i + num_warmup_steps in bg_smoothing_steps: |
|
latents = step(pipe, noise_pred, t, latents, generator, |
|
visualize_pipeline=visualize_pipeline, |
|
flows = flows, occs = occs, saliency=saliency)[0] |
|
else: |
|
latents = step(pipe, noise_pred, t, latents, generator, |
|
visualize_pipeline=visualize_pipeline)[0] |
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > 0 and (i + 1) % pipe.scheduler.order == 0): |
|
progress_bar.update() |
|
|
|
return latents |