wbhu-tc commited on
Commit
7c1a14b
0 Parent(s):
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.gif filter=lfs diff=lfs merge=lfs -text
37
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Python template
2
+ # Byte-compiled / optimized / DLL files
3
+ __pycache__/
4
+ *.py[cod]
5
+ *$py.class
6
+
7
+ # C extensions
8
+ *.so
9
+
10
+ # Distribution / packaging
11
+ .Python
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+ cover/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ .pybuilder/
77
+ target/
78
+
79
+ # Jupyter Notebook
80
+ .ipynb_checkpoints
81
+
82
+ # IPython
83
+ profile_default/
84
+ ipython_config.py
85
+
86
+ # pyenv
87
+ # For a library or package, you might want to ignore these files since the code is
88
+ # intended to run in multiple environments; otherwise, check them in:
89
+ # .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # poetry
99
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
101
+ # commonly ignored for libraries.
102
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103
+ #poetry.lock
104
+
105
+ # pdm
106
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107
+ #pdm.lock
108
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109
+ # in version control.
110
+ # https://pdm.fming.dev/#use-with-ide
111
+ .pdm.toml
112
+
113
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
114
+ __pypackages__/
115
+
116
+ # Celery stuff
117
+ celerybeat-schedule
118
+ celerybeat.pid
119
+
120
+ # SageMath parsed files
121
+ *.sage.py
122
+
123
+ # Environments
124
+ .env
125
+ .venv
126
+ env/
127
+ venv/
128
+ ENV/
129
+ env.bak/
130
+ venv.bak/
131
+
132
+ # Spyder project settings
133
+ .spyderproject
134
+ .spyproject
135
+
136
+ # Rope project settings
137
+ .ropeproject
138
+
139
+ # mkdocs documentation
140
+ /site
141
+
142
+ # mypy
143
+ .mypy_cache/
144
+ .dmypy.json
145
+ dmypy.json
146
+
147
+ # Pyre type checker
148
+ .pyre/
149
+
150
+ # pytype static type analyzer
151
+ .pytype/
152
+
153
+ # Cython debug symbols
154
+ cython_debug/
155
+
156
+ # PyCharm
157
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
160
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161
+ .idea/
162
+
163
+ /logs
164
+ /gin-config
165
+ *.json
166
+ /eval/*csv
167
+ *__pycache__
168
+ scripts/
169
+ eval/
LICENSE ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications").
2
+
3
+ License Terms of the inference code of DepthCrafter:
4
+ --------------------------------------------------------------------
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this Software and associated documentation files, to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and/or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7
+
8
+ - You agree to use the DepthCrafter only for academic, research and education purposes, and refrain from using it for any commercial or production purposes under any circumstances.
9
+
10
+ - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
11
+
12
+ For avoidance of doubts, “Software” means the DepthCrafter model inference code and weights made available under this license excluding any pre-trained data and other AI components.
13
+
14
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
15
+
16
+
17
+ Other dependencies and licenses:
18
+
19
+ Open Source Software Licensed under the MIT License:
20
+ --------------------------------------------------------------------
21
+ 1. Stability AI - Code
22
+ Copyright (c) 2023 Stability AI
23
+
24
+ Terms of the MIT License:
25
+ --------------------------------------------------------------------
26
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
27
+
28
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
29
+
30
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
31
+
32
+ **You may find the code license of Stability AI at the following links: https://github.com/Stability-AI/generative-models/blob/main/LICENSE-CODE
README.md ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## ___***DepthCrafter: Generating Consistent Long Depth Sequences for Open-world Videos***___
2
+ <div align="center">
3
+ <img src='https://depthcrafter.github.io/img/logo.png' style="height:140px"></img>
4
+
5
+
6
+
7
+ <a href='https://arxiv.org/abs/2409.02095'><img src='https://img.shields.io/badge/arXiv-2409.02095-b31b1b.svg'></a> &nbsp;
8
+ <a href='https://depthcrafter.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> &nbsp;
9
+
10
+
11
+ _**[Wenbo Hu<sup>1* &dagger;</sup>](https://wbhu.github.io),
12
+ [Xiangjun Gao<sup>2*</sup>](https://scholar.google.com/citations?user=qgdesEcAAAAJ&hl=en),
13
+ [Xiaoyu Li<sup>1* &dagger;</sup>](https://xiaoyu258.github.io),
14
+ [Sijie Zhao<sup>1</sup>](https://scholar.google.com/citations?user=tZ3dS3MAAAAJ&hl=en),
15
+ [Xiaodong Cun<sup>1</sup>](https://vinthony.github.io/academic), <br>
16
+ [Yong Zhang<sup>1</sup>](https://yzhang2016.github.io),
17
+ [Long Quan<sup>2</sup>](https://home.cse.ust.hk/~quan),
18
+ [Ying Shan<sup>3, 1</sup>](https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en)**_
19
+ <br><br>
20
+ <sup>1</sup>Tencent AI Lab
21
+ <sup>2</sup>The Hong Kong University of Science and Technology
22
+ <sup>3</sup>ARC Lab, Tencent PCG
23
+
24
+ arXiv preprint, 2024
25
+
26
+ </div>
27
+
28
+ ## 🔆 Introduction
29
+ 🤗 DepthCrafter can generate temporally consistent long depth sequences with fine-grained details for open-world videos,
30
+ without requiring additional information such as camera poses or optical flow.
31
+
32
+ ## 🎥 Visualization
33
+ We provide some demos of unprojected point cloud sequences, with reference RGB and estimated depth videos.
34
+ Please refer to our [project page](https://depthcrafter.github.io) for more details.
35
+
36
+
37
+ https://github.com/user-attachments/assets/62141cc8-04d0-458f-9558-fe50bc04cc21
38
+
39
+
40
+
41
+
42
+ ## 🚀 Quick Start
43
+
44
+ ### 🛠️ Installation
45
+ 1. Clone this repo:
46
+ ```bash
47
+ git clone https://github.com/Tencent/DepthCrafter.git
48
+ ```
49
+ 2. Install dependencies (please refer to [requirements.txt](requirements.txt)):
50
+ ```bash
51
+ pip install -r requirements.txt
52
+ ```
53
+
54
+ ## 🤗 Model Zoo
55
+ [DepthCrafter](https://huggingface.co/tencent/DepthCrafter) is available in the Hugging Face Model Hub.
56
+
57
+ ### 🏃‍♂️ Inference
58
+ #### 1. High-resolution inference, requires a GPU with ~26GB memory for 1024x576 resolution:
59
+ - Full inference (~0.6 fps on A100, recommended for high-quality results):
60
+
61
+ ```bash
62
+ python run.py --video-path examples/example_01.mp4
63
+ ```
64
+
65
+
66
+ - Fast inference through 4-step denoising and without classifier-free guidance (~2.3 fps on A100):
67
+
68
+ ```bash
69
+ python run.py --video-path examples/example_01.mp4 --num-inference-steps 4 --guidance-scale 1.0
70
+ ```
71
+
72
+
73
+ #### 2. Low-resolution inference, requires a GPU with ~9GB memory for 512x256 resolution:
74
+
75
+ - Full inference (~2.3 fps on A100):
76
+
77
+ ```bash
78
+ python run.py --video-path examples/example_01.mp4 --max-res 512
79
+ ```
80
+
81
+ - Fast inference through 4-step denoising and without classifier-free guidance (~9.4 fps on A100):
82
+ ```bash
83
+ python run.py --video-path examples/example_01.mp4 --max-res 512 --num-inference-steps 4 --guidance-scale 1.0
84
+ ```
85
+
86
+ ## 🤖 Gradio Demo
87
+ We provide a local Gradio demo for DepthCrafter, which can be launched by running:
88
+ ```bash
89
+ gradio app.py
90
+ ```
91
+
92
+ ## 🤝 Contributing
93
+ - Welcome to open issues and pull requests.
94
+ - Welcome to optimize the inference speed and memory usage, e.g., through model quantization, distillation, or other acceleration techniques.
95
+
96
+ ## 📜 Citation
97
+ If you find this work helpful, please consider citing:
98
+ ```bibtex
99
+ @article{hu2024-DepthCrafter,
100
+ author = {Hu, Wenbo and Gao, Xiangjun and Li, Xiaoyu and Zhao, Sijie and Cun, Xiaodong and Zhang, Yong and Quan, Long and Shan, Ying},
101
+ title = {DepthCrafter: Generating Consistent Long Depth Sequences for Open-world Videos},
102
+ journal = {arXiv preprint arXiv:2409.02095},
103
+ year = {2024}
104
+ }
105
+ ```
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ from copy import deepcopy
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ from diffusers.training_utils import set_seed
9
+
10
+ from depthcrafter.depth_crafter_ppl import DepthCrafterPipeline
11
+ from depthcrafter.unet import DiffusersUNetSpatioTemporalConditionModelDepthCrafter
12
+ from depthcrafter.utils import read_video_frames, vis_sequence_depth, save_video
13
+ from run import DepthCrafterDemo
14
+
15
+ examples = [
16
+ ["examples/example_01.mp4", 25, 1.2, 1024, 195],
17
+ ]
18
+
19
+
20
+ def construct_demo():
21
+ with gr.Blocks(analytics_enabled=False) as depthcrafter_iface:
22
+ gr.Markdown(
23
+ """
24
+ <div align='center'> <h1> DepthCrafter: Generating Consistent Long Depth Sequences for Open-world Videos </span> </h1> \
25
+ <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
26
+ <a href='https://wbhu.github.io'>Wenbo Hu</a>, \
27
+ <a href='https://scholar.google.com/citations?user=qgdesEcAAAAJ&hl=en'>Xiangjun Gao</a>, \
28
+ <a href='https://xiaoyu258.github.io/'>Xiaoyu Li</a>, \
29
+ <a href='https://scholar.google.com/citations?user=tZ3dS3MAAAAJ&hl=en'>Sijie Zhao</a>, \
30
+ <a href='https://vinthony.github.io/academic'> Xiaodong Cun</a>, \
31
+ <a href='https://yzhang2016.github.io'>Yong Zhang</a>, \
32
+ <a href='https://home.cse.ust.hk/~quan'>Long Quan</a>, \
33
+ <a href='https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en'>Ying Shan</a>\
34
+ </h2> \
35
+ <a style='font-size:18px;color: #000000'>If you find DepthCrafter useful, please help star the </a>\
36
+ <a style='font-size:18px;color: #FF5DB0' href='https://github.com/wbhu/DepthCrafter'>[Github Repo]</a>\
37
+ <a style='font-size:18px;color: #000000'>, which is important to Open-Source projects. Thanks!</a>\
38
+ <a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/2409.02095'> [ArXiv] </a>\
39
+ <a style='font-size:18px;color: #000000' href='https://depthcrafter.github.io/'> [Project Page] </a> </div>
40
+ """
41
+ )
42
+ # demo
43
+ depthcrafter_demo = DepthCrafterDemo(
44
+ unet_path="tencent/DepthCrafter",
45
+ pre_train_path="stabilityai/stable-video-diffusion-img2vid-xt",
46
+ )
47
+
48
+ with gr.Row(equal_height=True):
49
+ with gr.Column(scale=1):
50
+ input_video = gr.Video(label="Input Video")
51
+
52
+ # with gr.Tab(label="Output"):
53
+ with gr.Column(scale=2):
54
+ with gr.Row(equal_height=True):
55
+ output_video_1 = gr.Video(
56
+ label="Preprocessed video",
57
+ interactive=False,
58
+ autoplay=True,
59
+ loop=True,
60
+ show_share_button=True,
61
+ scale=5,
62
+ )
63
+ output_video_2 = gr.Video(
64
+ label="Generated Depth Video",
65
+ interactive=False,
66
+ autoplay=True,
67
+ loop=True,
68
+ show_share_button=True,
69
+ scale=5,
70
+ )
71
+
72
+ with gr.Row(equal_height=True):
73
+ with gr.Column(scale=1):
74
+ with gr.Row(equal_height=False):
75
+ with gr.Accordion("Advanced Settings", open=False):
76
+ num_denoising_steps = gr.Slider(
77
+ label="num denoising steps",
78
+ minimum=1,
79
+ maximum=25,
80
+ value=25,
81
+ step=1,
82
+ )
83
+ guidance_scale = gr.Slider(
84
+ label="cfg scale",
85
+ minimum=1.0,
86
+ maximum=1.2,
87
+ value=1.2,
88
+ step=0.1,
89
+ )
90
+ max_res = gr.Slider(
91
+ label="max resolution",
92
+ minimum=512,
93
+ maximum=2048,
94
+ value=1024,
95
+ step=64,
96
+ )
97
+ process_length = gr.Slider(
98
+ label="process length",
99
+ minimum=1,
100
+ maximum=280,
101
+ value=195,
102
+ step=1,
103
+ )
104
+ generate_btn = gr.Button("Generate")
105
+ with gr.Column(scale=2):
106
+ pass
107
+
108
+ gr.Examples(
109
+ examples=examples,
110
+ inputs=[
111
+ input_video,
112
+ num_denoising_steps,
113
+ guidance_scale,
114
+ max_res,
115
+ process_length,
116
+ ],
117
+ outputs=[output_video_1, output_video_2],
118
+ fn=depthcrafter_demo.run,
119
+ cache_examples=False,
120
+ )
121
+
122
+ generate_btn.click(
123
+ fn=depthcrafter_demo.run,
124
+ inputs=[
125
+ input_video,
126
+ num_denoising_steps,
127
+ guidance_scale,
128
+ max_res,
129
+ process_length,
130
+ ],
131
+ outputs=[output_video_1, output_video_2],
132
+ )
133
+
134
+ return depthcrafter_iface
135
+
136
+
137
+ demo = construct_demo()
138
+
139
+ if __name__ == "__main__":
140
+ demo.queue()
141
+ demo.launch(server_name="0.0.0.0", server_port=80, debug=True)
depthcrafter/__init__.py ADDED
File without changes
depthcrafter/depth_crafter_ppl.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Dict, List, Optional, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
7
+ _resize_with_antialiasing,
8
+ StableVideoDiffusionPipelineOutput,
9
+ StableVideoDiffusionPipeline,
10
+ retrieve_timesteps,
11
+ )
12
+ from diffusers.utils import logging
13
+ from diffusers.utils.torch_utils import randn_tensor
14
+
15
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
16
+
17
+
18
+ class DepthCrafterPipeline(StableVideoDiffusionPipeline):
19
+
20
+ @torch.inference_mode()
21
+ def encode_video(
22
+ self,
23
+ video: torch.Tensor,
24
+ chunk_size: int = 14,
25
+ ) -> torch.Tensor:
26
+ """
27
+ :param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
28
+ :param chunk_size: the chunk size to encode video
29
+ :return: image_embeddings in shape of [b, 1024]
30
+ """
31
+
32
+ video_224 = _resize_with_antialiasing(video.float(), (224, 224))
33
+ video_224 = (video_224 + 1.0) / 2.0 # [-1, 1] -> [0, 1]
34
+
35
+ embeddings = []
36
+ for i in range(0, video_224.shape[0], chunk_size):
37
+ tmp = self.feature_extractor(
38
+ images=video_224[i : i + chunk_size],
39
+ do_normalize=True,
40
+ do_center_crop=False,
41
+ do_resize=False,
42
+ do_rescale=False,
43
+ return_tensors="pt",
44
+ ).pixel_values.to(video.device, dtype=video.dtype)
45
+ embeddings.append(self.image_encoder(tmp).image_embeds) # [b, 1024]
46
+
47
+ embeddings = torch.cat(embeddings, dim=0) # [t, 1024]
48
+ return embeddings
49
+
50
+ @torch.inference_mode()
51
+ def encode_vae_video(
52
+ self,
53
+ video: torch.Tensor,
54
+ chunk_size: int = 14,
55
+ ):
56
+ """
57
+ :param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
58
+ :param chunk_size: the chunk size to encode video
59
+ :return: vae latents in shape of [b, c, h, w]
60
+ """
61
+ video_latents = []
62
+ for i in range(0, video.shape[0], chunk_size):
63
+ video_latents.append(
64
+ self.vae.encode(video[i : i + chunk_size]).latent_dist.mode()
65
+ )
66
+ video_latents = torch.cat(video_latents, dim=0)
67
+ return video_latents
68
+
69
+ @staticmethod
70
+ def check_inputs(video, height, width):
71
+ """
72
+ :param video:
73
+ :param height:
74
+ :param width:
75
+ :return:
76
+ """
77
+ if not isinstance(video, torch.Tensor) and not isinstance(video, np.ndarray):
78
+ raise ValueError(
79
+ f"Expected `video` to be a `torch.Tensor` or `VideoReader`, but got a {type(video)}"
80
+ )
81
+
82
+ if height % 8 != 0 or width % 8 != 0:
83
+ raise ValueError(
84
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
85
+ )
86
+
87
+ @torch.no_grad()
88
+ def __call__(
89
+ self,
90
+ video: Union[np.ndarray, torch.Tensor],
91
+ height: int = 576,
92
+ width: int = 1024,
93
+ num_inference_steps: int = 25,
94
+ guidance_scale: float = 1.0,
95
+ window_size: Optional[int] = 110,
96
+ noise_aug_strength: float = 0.02,
97
+ decode_chunk_size: Optional[int] = None,
98
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
99
+ latents: Optional[torch.FloatTensor] = None,
100
+ output_type: Optional[str] = "pil",
101
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
102
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
103
+ return_dict: bool = True,
104
+ overlap: int = 25,
105
+ track_time: bool = False,
106
+ ):
107
+ """
108
+ :param video: in shape [t, h, w, c] if np.ndarray or [t, c, h, w] if torch.Tensor, in range [0, 1]
109
+ :param height:
110
+ :param width:
111
+ :param num_inference_steps:
112
+ :param guidance_scale:
113
+ :param window_size: sliding window processing size
114
+ :param fps:
115
+ :param motion_bucket_id:
116
+ :param noise_aug_strength:
117
+ :param decode_chunk_size:
118
+ :param generator:
119
+ :param latents:
120
+ :param output_type:
121
+ :param callback_on_step_end:
122
+ :param callback_on_step_end_tensor_inputs:
123
+ :param return_dict:
124
+ :return:
125
+ """
126
+ # 0. Default height and width to unet
127
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
128
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
129
+ num_frames = video.shape[0]
130
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 8
131
+ if num_frames <= window_size:
132
+ window_size = num_frames
133
+ overlap = 0
134
+ stride = window_size - overlap
135
+
136
+ # 1. Check inputs. Raise error if not correct
137
+ self.check_inputs(video, height, width)
138
+
139
+ # 2. Define call parameters
140
+ batch_size = 1
141
+ device = self._execution_device
142
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
143
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
144
+ # corresponds to doing no classifier free guidance.
145
+ self._guidance_scale = guidance_scale
146
+
147
+ # 3. Encode input video
148
+ if isinstance(video, np.ndarray):
149
+ video = torch.from_numpy(video.transpose(0, 3, 1, 2))
150
+ else:
151
+ assert isinstance(video, torch.Tensor)
152
+ video = video.to(device=device, dtype=self.dtype)
153
+ video = video * 2.0 - 1.0 # [0,1] -> [-1,1], in [t, c, h, w]
154
+
155
+ if track_time:
156
+ start_event = torch.cuda.Event(enable_timing=True)
157
+ encode_event = torch.cuda.Event(enable_timing=True)
158
+ denoise_event = torch.cuda.Event(enable_timing=True)
159
+ decode_event = torch.cuda.Event(enable_timing=True)
160
+ start_event.record()
161
+
162
+ video_embeddings = self.encode_video(
163
+ video, chunk_size=decode_chunk_size
164
+ ).unsqueeze(
165
+ 0
166
+ ) # [1, t, 1024]
167
+ torch.cuda.empty_cache()
168
+ # 4. Encode input image using VAE
169
+ noise = randn_tensor(
170
+ video.shape, generator=generator, device=device, dtype=video.dtype
171
+ )
172
+ video = video + noise_aug_strength * noise # in [t, c, h, w]
173
+
174
+ # pdb.set_trace()
175
+ needs_upcasting = (
176
+ self.vae.dtype == torch.float16 and self.vae.config.force_upcast
177
+ )
178
+ if needs_upcasting:
179
+ self.vae.to(dtype=torch.float32)
180
+
181
+ video_latents = self.encode_vae_video(
182
+ video.to(self.vae.dtype),
183
+ chunk_size=decode_chunk_size,
184
+ ).unsqueeze(
185
+ 0
186
+ ) # [1, t, c, h, w]
187
+
188
+ if track_time:
189
+ encode_event.record()
190
+ torch.cuda.synchronize()
191
+ elapsed_time_ms = start_event.elapsed_time(encode_event)
192
+ print(f"Elapsed time for encoding video: {elapsed_time_ms} ms")
193
+
194
+ torch.cuda.empty_cache()
195
+
196
+ # cast back to fp16 if needed
197
+ if needs_upcasting:
198
+ self.vae.to(dtype=torch.float16)
199
+
200
+ # 5. Get Added Time IDs
201
+ added_time_ids = self._get_add_time_ids(
202
+ 7,
203
+ 127,
204
+ noise_aug_strength,
205
+ video_embeddings.dtype,
206
+ batch_size,
207
+ 1,
208
+ False,
209
+ ) # [1 or 2, 3]
210
+ added_time_ids = added_time_ids.to(device)
211
+
212
+ # 6. Prepare timesteps
213
+ timesteps, num_inference_steps = retrieve_timesteps(
214
+ self.scheduler, num_inference_steps, device, None, None
215
+ )
216
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
217
+ self._num_timesteps = len(timesteps)
218
+
219
+ # 7. Prepare latent variables
220
+ num_channels_latents = self.unet.config.in_channels
221
+ latents_init = self.prepare_latents(
222
+ batch_size,
223
+ window_size,
224
+ num_channels_latents,
225
+ height,
226
+ width,
227
+ video_embeddings.dtype,
228
+ device,
229
+ generator,
230
+ latents,
231
+ ) # [1, t, c, h, w]
232
+ latents_all = None
233
+
234
+ idx_start = 0
235
+ if overlap > 0:
236
+ weights = torch.linspace(0, 1, overlap, device=device)
237
+ weights = weights.view(1, overlap, 1, 1, 1)
238
+ else:
239
+ weights = None
240
+
241
+ torch.cuda.empty_cache()
242
+
243
+ # inference strategy for long videos
244
+ # two main strategies: 1. noise init from previous frame, 2. segments stitching
245
+ while idx_start < num_frames - overlap:
246
+ idx_end = min(idx_start + window_size, num_frames)
247
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
248
+
249
+ # 9. Denoising loop
250
+ latents = latents_init[:, : idx_end - idx_start].clone()
251
+ latents_init = torch.cat(
252
+ [latents_init[:, -overlap:], latents_init[:, :stride]], dim=1
253
+ )
254
+
255
+ video_latents_current = video_latents[:, idx_start:idx_end]
256
+ video_embeddings_current = video_embeddings[:, idx_start:idx_end]
257
+
258
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
259
+ for i, t in enumerate(timesteps):
260
+ if latents_all is not None and i == 0:
261
+ latents[:, :overlap] = (
262
+ latents_all[:, -overlap:]
263
+ + latents[:, :overlap]
264
+ / self.scheduler.init_noise_sigma
265
+ * self.scheduler.sigmas[i]
266
+ )
267
+
268
+ latent_model_input = latents # [1, t, c, h, w]
269
+ latent_model_input = self.scheduler.scale_model_input(
270
+ latent_model_input, t
271
+ ) # [1, t, c, h, w]
272
+ latent_model_input = torch.cat(
273
+ [latent_model_input, video_latents_current], dim=2
274
+ )
275
+ noise_pred = self.unet(
276
+ latent_model_input,
277
+ t,
278
+ encoder_hidden_states=video_embeddings_current,
279
+ added_time_ids=added_time_ids,
280
+ return_dict=False,
281
+ )[0]
282
+ # perform guidance
283
+ if self.do_classifier_free_guidance:
284
+ latent_model_input = latents
285
+ latent_model_input = self.scheduler.scale_model_input(
286
+ latent_model_input, t
287
+ )
288
+ latent_model_input = torch.cat(
289
+ [latent_model_input, torch.zeros_like(latent_model_input)],
290
+ dim=2,
291
+ )
292
+ noise_pred_uncond = self.unet(
293
+ latent_model_input,
294
+ t,
295
+ encoder_hidden_states=torch.zeros_like(
296
+ video_embeddings_current
297
+ ),
298
+ added_time_ids=added_time_ids,
299
+ return_dict=False,
300
+ )[0]
301
+
302
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
303
+ noise_pred - noise_pred_uncond
304
+ )
305
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
306
+
307
+ if callback_on_step_end is not None:
308
+ callback_kwargs = {}
309
+ for k in callback_on_step_end_tensor_inputs:
310
+ callback_kwargs[k] = locals()[k]
311
+ callback_outputs = callback_on_step_end(
312
+ self, i, t, callback_kwargs
313
+ )
314
+
315
+ latents = callback_outputs.pop("latents", latents)
316
+
317
+ if i == len(timesteps) - 1 or (
318
+ (i + 1) > num_warmup_steps
319
+ and (i + 1) % self.scheduler.order == 0
320
+ ):
321
+ progress_bar.update()
322
+
323
+ if latents_all is None:
324
+ latents_all = latents.clone()
325
+ else:
326
+ assert weights is not None
327
+ # latents_all[:, -overlap:] = (
328
+ # latents[:, :overlap] + latents_all[:, -overlap:]
329
+ # ) / 2.0
330
+ latents_all[:, -overlap:] = latents[
331
+ :, :overlap
332
+ ] * weights + latents_all[:, -overlap:] * (1 - weights)
333
+ latents_all = torch.cat([latents_all, latents[:, overlap:]], dim=1)
334
+
335
+ idx_start += stride
336
+
337
+ if track_time:
338
+ denoise_event.record()
339
+ torch.cuda.synchronize()
340
+ elapsed_time_ms = encode_event.elapsed_time(denoise_event)
341
+ print(f"Elapsed time for denoising video: {elapsed_time_ms} ms")
342
+
343
+ if not output_type == "latent":
344
+ # cast back to fp16 if needed
345
+ if needs_upcasting:
346
+ self.vae.to(dtype=torch.float16)
347
+ frames = self.decode_latents(latents_all, num_frames, decode_chunk_size)
348
+
349
+ if track_time:
350
+ decode_event.record()
351
+ torch.cuda.synchronize()
352
+ elapsed_time_ms = denoise_event.elapsed_time(decode_event)
353
+ print(f"Elapsed time for decoding video: {elapsed_time_ms} ms")
354
+
355
+ frames = self.video_processor.postprocess_video(
356
+ video=frames, output_type=output_type
357
+ )
358
+ else:
359
+ frames = latents_all
360
+
361
+ self.maybe_free_model_hooks()
362
+
363
+ if not return_dict:
364
+ return frames
365
+
366
+ return StableVideoDiffusionPipelineOutput(frames=frames)
depthcrafter/unet.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Tuple
2
+
3
+ import torch
4
+ from diffusers import UNetSpatioTemporalConditionModel
5
+ from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionOutput
6
+
7
+
8
+ class DiffusersUNetSpatioTemporalConditionModelDepthCrafter(
9
+ UNetSpatioTemporalConditionModel
10
+ ):
11
+
12
+ def forward(
13
+ self,
14
+ sample: torch.Tensor,
15
+ timestep: Union[torch.Tensor, float, int],
16
+ encoder_hidden_states: torch.Tensor,
17
+ added_time_ids: torch.Tensor,
18
+ return_dict: bool = True,
19
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
20
+
21
+ # 1. time
22
+ timesteps = timestep
23
+ if not torch.is_tensor(timesteps):
24
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
25
+ # This would be a good case for the `match` statement (Python 3.10+)
26
+ is_mps = sample.device.type == "mps"
27
+ if isinstance(timestep, float):
28
+ dtype = torch.float32 if is_mps else torch.float64
29
+ else:
30
+ dtype = torch.int32 if is_mps else torch.int64
31
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
32
+ elif len(timesteps.shape) == 0:
33
+ timesteps = timesteps[None].to(sample.device)
34
+
35
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
36
+ batch_size, num_frames = sample.shape[:2]
37
+ timesteps = timesteps.expand(batch_size)
38
+
39
+ t_emb = self.time_proj(timesteps)
40
+
41
+ # `Timesteps` does not contain any weights and will always return f32 tensors
42
+ # but time_embedding might actually be running in fp16. so we need to cast here.
43
+ # there might be better ways to encapsulate this.
44
+ t_emb = t_emb.to(dtype=self.conv_in.weight.dtype)
45
+
46
+ emb = self.time_embedding(t_emb) # [batch_size * num_frames, channels]
47
+
48
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
49
+ time_embeds = time_embeds.reshape((batch_size, -1))
50
+ time_embeds = time_embeds.to(emb.dtype)
51
+ aug_emb = self.add_embedding(time_embeds)
52
+ emb = emb + aug_emb
53
+
54
+ # Flatten the batch and frames dimensions
55
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
56
+ sample = sample.flatten(0, 1)
57
+ # Repeat the embeddings num_video_frames times
58
+ # emb: [batch, channels] -> [batch * frames, channels]
59
+ emb = emb.repeat_interleave(num_frames, dim=0)
60
+ # encoder_hidden_states: [batch, frames, channels] -> [batch * frames, 1, channels]
61
+ encoder_hidden_states = encoder_hidden_states.flatten(0, 1).unsqueeze(1)
62
+
63
+ # 2. pre-process
64
+ sample = sample.to(dtype=self.conv_in.weight.dtype)
65
+ assert sample.dtype == self.conv_in.weight.dtype, (
66
+ f"sample.dtype: {sample.dtype}, "
67
+ f"self.conv_in.weight.dtype: {self.conv_in.weight.dtype}"
68
+ )
69
+ sample = self.conv_in(sample)
70
+
71
+ image_only_indicator = torch.zeros(
72
+ batch_size, num_frames, dtype=sample.dtype, device=sample.device
73
+ )
74
+
75
+ down_block_res_samples = (sample,)
76
+ for downsample_block in self.down_blocks:
77
+ if (
78
+ hasattr(downsample_block, "has_cross_attention")
79
+ and downsample_block.has_cross_attention
80
+ ):
81
+ sample, res_samples = downsample_block(
82
+ hidden_states=sample,
83
+ temb=emb,
84
+ encoder_hidden_states=encoder_hidden_states,
85
+ image_only_indicator=image_only_indicator,
86
+ )
87
+
88
+ else:
89
+ sample, res_samples = downsample_block(
90
+ hidden_states=sample,
91
+ temb=emb,
92
+ image_only_indicator=image_only_indicator,
93
+ )
94
+
95
+ down_block_res_samples += res_samples
96
+
97
+ # 4. mid
98
+ sample = self.mid_block(
99
+ hidden_states=sample,
100
+ temb=emb,
101
+ encoder_hidden_states=encoder_hidden_states,
102
+ image_only_indicator=image_only_indicator,
103
+ )
104
+
105
+ # 5. up
106
+ for i, upsample_block in enumerate(self.up_blocks):
107
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
108
+ down_block_res_samples = down_block_res_samples[
109
+ : -len(upsample_block.resnets)
110
+ ]
111
+
112
+ if (
113
+ hasattr(upsample_block, "has_cross_attention")
114
+ and upsample_block.has_cross_attention
115
+ ):
116
+ sample = upsample_block(
117
+ hidden_states=sample,
118
+ res_hidden_states_tuple=res_samples,
119
+ temb=emb,
120
+ encoder_hidden_states=encoder_hidden_states,
121
+ image_only_indicator=image_only_indicator,
122
+ )
123
+ else:
124
+ sample = upsample_block(
125
+ hidden_states=sample,
126
+ res_hidden_states_tuple=res_samples,
127
+ temb=emb,
128
+ image_only_indicator=image_only_indicator,
129
+ )
130
+
131
+ # 6. post-process
132
+ sample = self.conv_norm_out(sample)
133
+ sample = self.conv_act(sample)
134
+ sample = self.conv_out(sample)
135
+
136
+ # 7. Reshape back to original shape
137
+ sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
138
+
139
+ if not return_dict:
140
+ return (sample,)
141
+
142
+ return UNetSpatioTemporalConditionOutput(sample=sample)
depthcrafter/utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import matplotlib.cm as cm
4
+ import torch
5
+
6
+
7
+ def read_video_frames(video_path, process_length, target_fps, max_res):
8
+ # a simple function to read video frames
9
+ cap = cv2.VideoCapture(video_path)
10
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
11
+ original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
12
+ original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
13
+ # round the height and width to the nearest multiple of 64
14
+ height = round(original_height / 64) * 64
15
+ width = round(original_width / 64) * 64
16
+
17
+ # resize the video if the height or width is larger than max_res
18
+ if max(height, width) > max_res:
19
+ scale = max_res / max(original_height, original_width)
20
+ height = round(original_height * scale / 64) * 64
21
+ width = round(original_width * scale / 64) * 64
22
+
23
+ if target_fps < 0:
24
+ target_fps = original_fps
25
+
26
+ stride = max(round(original_fps / target_fps), 1)
27
+
28
+ frames = []
29
+ frame_count = 0
30
+ while cap.isOpened():
31
+ ret, frame = cap.read()
32
+ if not ret or (process_length > 0 and frame_count >= process_length):
33
+ break
34
+ if frame_count % stride == 0:
35
+ frame = cv2.resize(frame, (width, height))
36
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
37
+ frames.append(frame.astype("float32") / 255.0)
38
+ frame_count += 1
39
+ cap.release()
40
+
41
+ frames = np.array(frames)
42
+ return frames, target_fps
43
+
44
+
45
+ def save_video(
46
+ video_frames,
47
+ output_video_path,
48
+ fps: int = 15,
49
+ ) -> str:
50
+ # a simple function to save video frames
51
+ height, width = video_frames[0].shape[:2]
52
+ is_color = video_frames[0].ndim == 3
53
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
54
+ video_writer = cv2.VideoWriter(
55
+ output_video_path, fourcc, fps, (width, height), isColor=is_color
56
+ )
57
+
58
+ for frame in video_frames:
59
+ frame = (frame * 255).astype(np.uint8)
60
+ if is_color:
61
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
62
+ video_writer.write(frame)
63
+
64
+ video_writer.release()
65
+ return output_video_path
66
+
67
+
68
+ class ColorMapper:
69
+ # a color mapper to map depth values to a certain colormap
70
+ def __init__(self, colormap: str = "inferno"):
71
+ self.colormap = torch.tensor(cm.get_cmap(colormap).colors)
72
+
73
+ def apply(self, image: torch.Tensor, v_min=None, v_max=None):
74
+ # assert len(image.shape) == 2
75
+ if v_min is None:
76
+ v_min = image.min()
77
+ if v_max is None:
78
+ v_max = image.max()
79
+ image = (image - v_min) / (v_max - v_min)
80
+ image = (image * 255).long()
81
+ image = self.colormap[image]
82
+ return image
83
+
84
+
85
+ def vis_sequence_depth(depths: np.ndarray, v_min=None, v_max=None):
86
+ visualizer = ColorMapper()
87
+ if v_min is None:
88
+ v_min = depths.min()
89
+ if v_max is None:
90
+ v_max = depths.max()
91
+ res = visualizer.apply(torch.tensor(depths), v_min=v_min, v_max=v_max).numpy()
92
+ return res
examples/example_01.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:afb78decc210225793b20d5bca5b13da07c97233e6fabea44bf02eba8a52bdaf
3
+ size 14393250
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==2.3.0+cu117
2
+ diffusers==0.29.1
3
+ numpy==1.26.4
4
+ matplotlib==3.8.4
5
+ opencv-python==4.8.1.78
run.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import numpy as np
4
+ import torch
5
+ import argparse
6
+ from diffusers.training_utils import set_seed
7
+
8
+ from depthcrafter.depth_crafter_ppl import DepthCrafterPipeline
9
+ from depthcrafter.unet import DiffusersUNetSpatioTemporalConditionModelDepthCrafter
10
+ from depthcrafter.utils import vis_sequence_depth, save_video, read_video_frames
11
+
12
+
13
+ class DepthCrafterDemo:
14
+ def __init__(
15
+ self,
16
+ unet_path: str,
17
+ pre_train_path: str,
18
+ cpu_offload: str = "model",
19
+ ):
20
+ unet = DiffusersUNetSpatioTemporalConditionModelDepthCrafter.from_pretrained(
21
+ unet_path,
22
+ subfolder="unet",
23
+ low_cpu_mem_usage=True,
24
+ torch_dtype=torch.float16,
25
+ )
26
+ # load weights of other components from the provided checkpoint
27
+ self.pipe = DepthCrafterPipeline.from_pretrained(
28
+ pre_train_path,
29
+ unet=unet,
30
+ torch_dtype=torch.float16,
31
+ variant="fp16",
32
+ )
33
+
34
+ # for saving memory, we can offload the model to CPU, or even run the model sequentially to save more memory
35
+ if cpu_offload is not None:
36
+ if cpu_offload == "sequential":
37
+ # This will slow, but save more memory
38
+ self.pipe.enable_sequential_cpu_offload()
39
+ elif cpu_offload == "model":
40
+ self.pipe.enable_model_cpu_offload()
41
+ else:
42
+ raise ValueError(f"Unknown cpu offload option: {cpu_offload}")
43
+ else:
44
+ self.pipe.to("cuda")
45
+ # enable attention slicing and xformers memory efficient attention
46
+ try:
47
+ self.pipe.enable_xformers_memory_efficient_attention()
48
+ except Exception as e:
49
+ print(e)
50
+ print("Xformers is not enabled")
51
+ self.pipe.enable_attention_slicing()
52
+
53
+ def infer(
54
+ self,
55
+ video: str,
56
+ num_denoising_steps: int,
57
+ guidance_scale: float,
58
+ save_folder: str = "./demo_output",
59
+ window_size: int = 110,
60
+ process_length: int = 195,
61
+ overlap: int = 25,
62
+ max_res: int = 1024,
63
+ target_fps: int = 15,
64
+ seed: int = 42,
65
+ track_time: bool = True,
66
+ save_npz: bool = False,
67
+ ):
68
+ set_seed(seed)
69
+
70
+ frames, target_fps = read_video_frames(
71
+ video, process_length, target_fps, max_res
72
+ )
73
+ print(f"==> video name: {video}, frames shape: {frames.shape}")
74
+
75
+ # inference the depth map using the DepthCrafter pipeline
76
+ with torch.inference_mode():
77
+ res = self.pipe(
78
+ frames,
79
+ height=frames.shape[1],
80
+ width=frames.shape[2],
81
+ output_type="np",
82
+ guidance_scale=guidance_scale,
83
+ num_inference_steps=num_denoising_steps,
84
+ window_size=window_size,
85
+ overlap=overlap,
86
+ track_time=track_time,
87
+ ).frames[0]
88
+ # convert the three-channel output to a single channel depth map
89
+ res = res.sum(-1) / res.shape[-1]
90
+ # normalize the depth map to [0, 1] across the whole video
91
+ res = (res - res.min()) / (res.max() - res.min())
92
+ # visualize the depth map and save the results
93
+ vis = vis_sequence_depth(res)
94
+ # save the depth map and visualization with the target FPS
95
+ save_path = os.path.join(
96
+ save_folder, os.path.splitext(os.path.basename(video))[0]
97
+ )
98
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
99
+ if save_npz:
100
+ np.savez_compressed(save_path + ".npz", depth=res)
101
+ save_video(res, save_path + "_depth.mp4", fps=target_fps)
102
+ save_video(vis, save_path + "_vis.mp4", fps=target_fps)
103
+ save_video(frames, save_path + "_input.mp4", fps=target_fps)
104
+ return [
105
+ save_path + "_input.mp4",
106
+ save_path + "_vis.mp4",
107
+ save_path + "_depth.mp4",
108
+ ]
109
+
110
+ def run(
111
+ self,
112
+ input_video,
113
+ num_denoising_steps,
114
+ guidance_scale,
115
+ max_res=1024,
116
+ process_length=195,
117
+ ):
118
+ res_path = self.infer(
119
+ input_video,
120
+ num_denoising_steps,
121
+ guidance_scale,
122
+ max_res=max_res,
123
+ process_length=process_length,
124
+ )
125
+ # clear the cache for the next video
126
+ gc.collect()
127
+ torch.cuda.empty_cache()
128
+ return res_path[:2]
129
+
130
+
131
+ if __name__ == "__main__":
132
+ # running configs
133
+ # the most important arguments for memory saving are `cpu_offload`, `enable_xformers`, `max_res`, and `window_size`
134
+ # the most important arguments for trade-off between quality and speed are
135
+ # `num_inference_steps`, `guidance_scale`, and `max_res`
136
+ parser = argparse.ArgumentParser(description="DepthCrafter")
137
+ parser.add_argument(
138
+ "--video-path", type=str, required=True, help="Path to the input video file(s)"
139
+ )
140
+ parser.add_argument(
141
+ "--save-folder",
142
+ type=str,
143
+ default="./demo_output",
144
+ help="Folder to save the output",
145
+ )
146
+ parser.add_argument(
147
+ "--unet-path",
148
+ type=str,
149
+ default="tencent/DepthCrafter",
150
+ help="Path to the UNet model",
151
+ )
152
+ parser.add_argument(
153
+ "--pre-train-path",
154
+ type=str,
155
+ default="stabilityai/stable-video-diffusion-img2vid-xt",
156
+ help="Path to the pre-trained model",
157
+ )
158
+ parser.add_argument(
159
+ "--process-length", type=int, default=195, help="Number of frames to process"
160
+ )
161
+ parser.add_argument(
162
+ "--cpu-offload",
163
+ type=str,
164
+ default="model",
165
+ choices=["model", "sequential", None],
166
+ help="CPU offload option",
167
+ )
168
+ parser.add_argument(
169
+ "--target-fps", type=int, default=15, help="Target FPS for the output video"
170
+ ) # -1 for original fps
171
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
172
+ parser.add_argument(
173
+ "--num-inference-steps", type=int, default=25, help="Number of inference steps"
174
+ )
175
+ parser.add_argument(
176
+ "--guidance-scale", type=float, default=1.2, help="Guidance scale"
177
+ )
178
+ parser.add_argument("--window-size", type=int, default=110, help="Window size")
179
+ parser.add_argument("--overlap", type=int, default=25, help="Overlap size")
180
+ parser.add_argument("--max-res", type=int, default=1024, help="Maximum resolution")
181
+ parser.add_argument("--save_npz", type=bool, default=True, help="Save npz file")
182
+ parser.add_argument("--track_time", type=bool, default=False, help="Track time")
183
+
184
+ args = parser.parse_args()
185
+
186
+ depthcrafter_demo = DepthCrafterDemo(
187
+ unet_path=args.unet_path,
188
+ pre_train_path=args.pre_train_path,
189
+ cpu_offload=args.cpu_offload,
190
+ )
191
+ # process the videos, the video paths are separated by comma
192
+ video_paths = args.video_path.split(",")
193
+ for video in video_paths:
194
+ depthcrafter_demo.infer(
195
+ video,
196
+ args.num_inference_steps,
197
+ args.guidance_scale,
198
+ save_folder=args.save_folder,
199
+ window_size=args.window_size,
200
+ process_length=args.process_length,
201
+ overlap=args.overlap,
202
+ max_res=args.max_res,
203
+ target_fps=args.target_fps,
204
+ seed=args.seed,
205
+ track_time=args.track_time,
206
+ save_npz=args.save_npz,
207
+ )
208
+ # clear the cache for the next video
209
+ gc.collect()
210
+ torch.cuda.empty_cache()