liuyuan-pal commited on
Commit
ab287b7
1 Parent(s): 0fa63ef
Files changed (1) hide show
  1. app.py +28 -10
app.py CHANGED
@@ -17,12 +17,19 @@ _DESCRIPTION = '''
17
  <a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2309.03453"><img src="https://img.shields.io/badge/2309.03453-f9f7f7?logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADcAAABMCAYAAADJPi9EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAuIwAALiMBeKU/dgAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAAa2SURBVHja3Zt7bBRFGMAXUCDGF4rY7m7bAwuhlggKStFgLBgFEkCIIRJEEoOBYHwRFYKilUgEReVNJEGCJJpehHI3M9vZvd3bUP1DjNhEIRQQsQgSHiJgQZ5dv7krWEvvdmZ7d7vHJN+ft/f99pv5XvOtJMFCqvoCUpTdIEeRLC+L9Ox5i3Q9LACaCeK0kXoSChVcD3C/tQPHpAEsquQ73IkUcEz2kcLCknyGW5MGjkljRFVL8xJOKyi4CwCOuQAeAkfTP1+tNxLkogvgEbDgffkJqKqvuMA5ifOpqg/5qWecRstNg7xoUTI1Fovdxg8oy2s5AP8CGeYHmGngeZaOL4I4LXLcpHg4149/GDz4xqgsb+UAbMKKUpkrqHA43MUyyJpWUK0EHeG2YKRXr7tB+QMcgGewLD+ebTDbtrtbBt7UPlhS4rV4IvcDI7J8P1OeA/AcAI7LHljN7aB8XTowJmZt9EFRD/o0SDMH4HlwMhMyDWZZSAHFf3YDs3RS49WDLuaAY3IJq+qzmQKLxXAZKN7oDoYbdV3v5elPqiSpMyiOuAEVZVqHXb1OhloUH+MA+ztO0cAO/RkrfyBE7OAEbAZvO8vzVtTRWFD6DAfY5biBM3PWiaL0a4lvXICwnV8WjmE6ntYmhqX2jjp5LbMZjCw/wbYeN6CizOa2GMVzQOlmHjB4Ceuyk6LJ8huccEmR5Xddg7OOV/NAtchW+E3XbOag60QA4Qwuarca0bRuEJyr+cFQwzcY98huxhAKdQelt4kAQpj4qJ3gvFXAYn+aJumXk1yPlpQUgtIHhbYoFMUstNRRWgjnpl4A7IKlayNymqFHFaWCpV9CFry3LGxR1CgA5kB5M8OX2goApwpaz6mdOMGxtAgXWJySxb4WuQD4qTDgU+N5AAnzpr7ChSWpCyisiQJqY0Y7FtmSKpbV23b45kC0KHBxcQ9QeI8w4KgnHRPVtIU7rOtbioLVg5Hl/qDwSVFAMqLSMSObroCdZYlzIJtMRFVHCaRo/wFWPgaAXzdbBpkc2A4aKzCNd97+URQuESYGDDhIVfWOQIKZJu4D2+oXlgDTV1865gUQZDts756BArMNMoR1oa46BYqbyPixZz1ZUFV3sgwoGBajuBKATl3btIn8QYYMuezRgrsiRUWyr2BxA40EkPMpA/Hm6gbUu7fjEXA3azP6AsbKD9bxdUuhjM9W7fII52BF+daRpE4+WA3P501+jbfmHvQKyFqMuXf7Ot4mkN2fr50y+bRH61X7AXdUpHSxaPQ4GVbR5AGw3g+434XgQGKfr72I+vQRhfsu92dOx7WicInzt3CBg1RVpMm0NveWo2SqFzgmdNZMbriILD+S+zoueWf2vSdAipzacWN5nMl6XxNlUHa/J8DoJodUDE0HR8Ll5V0lPxcrLEHZPV4AzS83OLis7FowVa3RSku7BSNxJqQAlN3hBTC2apmDSkpaw22wJemGQFUG7J4MlP3JC6A+f96V7vRyX9It3nzT/GrjIU8edM7rMSnIi10f476lzbE1K7yEiEuWro0OJBguLCwDuFOJc1Na6sRWL/cCeMIwUN9ggSVbe3v/5/EgzTKWLvEAiBrYRUkgwNI2ZaFQNT75UDxEUEx97zYnzpmiLEmbaYCbNxYtFAb0/Z4AztgUrhyxuNgxPnhfHFDHz/vTgFWUQZxTRkkJhQ6YNdVUEPAfO6ZV5BRss6LcCVb7VaAma9giy0XJZBt9IQh42NY0NSdgbLIPlLUF6rEdrdt0CUCK1wsCbkcI3ZSLc7ZSwGLbmJXbPsNxnE5xilYKAobZ77LpGZ8TAIun+/iCKQoF71IxQDI3K2CCd+ARNvXg9sykBcnHAoCZG4u66hlDoQLe6QV4CRtFSxZQ+D0BwNO2jgdkzoGoah1nj3FVlSR19taTSYxI8QLut23U8dsgzqHulJNCQpcqBnpTALCuQ6NSYLHpmR5i42gZzuIdcrMMvMJbQlxe3jXxyZnLACl7ARm/FjPIDOY8ODtpM71sxwfcZpvBeUzKWmfNINM5AS+wO0Khh7dMqKccu4+qatarZjYAwDlgetzStHtEt+XedsBOQtU9XMrRgjg4KTnc5nr+dmqadit/4C4uLm8DuA9koJTj1TL7fI5nDL+qqoo/FLGAzL7dYT17PzvAcQONYSUQRxW/QMrHZVIyik0ZuQA2mzp+Ji8BW4YM3Mbzm9inaHkJCGfrUZZjujiYailfFwA8DHIy3acwUj4v9vUVa+SmgNsl5fuyDTKovW9/IAmfLV0Pi2UncA515kjYdrwC9i9rpuHiq3JwtAAAAABJRU5ErkJggg=="></a>
18
  <a style="display:inline-block; margin-left: .5em" href='https://github.com/liuyuan-pal/SyncDreamer'><img src='https://img.shields.io/github/stars/liuyuan-pal/SyncDreamer?style=social' /></a>
19
  </div>
20
- Given a single-view image, SyncDreamer is able to generate multiview-consistent images, which enables direct 3D reconstruction with NeuS or NeRF without SDS loss'''
 
 
 
 
 
 
21
  _USER_GUIDE0 = "Step0: Please upload an image in the block above (or choose an example above). We use alpha values as object masks if given."
22
  _USER_GUIDE1 = "Step1: Please select a crop size using the glider."
23
  _USER_GUIDE2 = "Step2: Please choose a suitable elevation angle and then click the Generate button."
24
  _USER_GUIDE3 = "Generated multiview images are shown below!"
25
 
 
26
 
27
  def mask_prediction(mask_predictor, image_in: Image.Image):
28
  if image_in.mode=='RGBA':
@@ -56,11 +63,16 @@ def generate(model, batch_view_num, sample_num, cfg_scale, seed, image_input, el
56
  elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
57
  data = {"input_image": image_input, "input_elevation": elevation_input}
58
  for k, v in data.items():
59
- data[k] = v.unsqueeze(0)#.cuda()
 
 
 
60
  data[k] = torch.repeat_interleave(data[k], sample_num, dim=0)
61
 
62
- x_sample = model.sample(data, cfg_scale, batch_view_num)
63
- # x_sample = torch.zeros(sample_num, 16, 3, 256, 256)
 
 
64
 
65
  B, N, _, H, W = x_sample.shape
66
  x_sample = (torch.clamp(x_sample,max=1.0,min=-1.0) + 1) * 0.5
@@ -80,12 +92,15 @@ def run_demo():
80
  ckpt = 'ckpt/syncdreamer-pretrain.ckpt'
81
  config = OmegaConf.load(cfg)
82
  # model = None
83
- model = instantiate_from_config(config.model)
84
- print(f'loading model from {ckpt} ...')
85
- ckpt = torch.load(ckpt,map_location='cpu')
86
- model.load_state_dict(ckpt['state_dict'], strict=True)
87
- model = model.cuda().eval()
88
- del ckpt
 
 
 
89
 
90
  # init sam model
91
  mask_predictor = None # sam_init(device_idx)
@@ -121,10 +136,12 @@ def run_demo():
121
  examples_per_page=40
122
  )
123
 
 
124
  with gr.Column(scale=1):
125
  sam_block = gr.Image(type='pil', image_mode='RGBA', label="SAM output", height=256, interactive=False)
126
  crop_size_slider = gr.Slider(120, 240, 200, step=10, label='Crop size', interactive=True)
127
  crop_btn = gr.Button('Crop the image', variant='primary', interactive=True)
 
128
 
129
  with gr.Column(scale=1):
130
  input_block = gr.Image(type='pil', image_mode='RGB', label="Input to SyncDreamer", height=256, interactive=False)
@@ -134,6 +151,7 @@ def run_demo():
134
  # batch_view_num = gr.Slider(1, 16, 8, step=1, label='', interactive=True)
135
  seed = gr.Number(6033, label='Random seed', interactive=True)
136
  run_btn = gr.Button('Run Generation', variant='primary', interactive=True)
 
137
 
138
  output_block = gr.Image(type='pil', image_mode='RGB', label="Outputs of SyncDreamer", height=256, interactive=False)
139
 
 
17
  <a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2309.03453"><img src="https://img.shields.io/badge/2309.03453-f9f7f7?logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADcAAABMCAYAAADJPi9EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAuIwAALiMBeKU/dgAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAAa2SURBVHja3Zt7bBRFGMAXUCDGF4rY7m7bAwuhlggKStFgLBgFEkCIIRJEEoOBYHwRFYKilUgEReVNJEGCJJpehHI3M9vZvd3bUP1DjNhEIRQQsQgSHiJgQZ5dv7krWEvvdmZ7d7vHJN+ft/f99pv5XvOtJMFCqvoCUpTdIEeRLC+L9Ox5i3Q9LACaCeK0kXoSChVcD3C/tQPHpAEsquQ73IkUcEz2kcLCknyGW5MGjkljRFVL8xJOKyi4CwCOuQAeAkfTP1+tNxLkogvgEbDgffkJqKqvuMA5ifOpqg/5qWecRstNg7xoUTI1Fovdxg8oy2s5AP8CGeYHmGngeZaOL4I4LXLcpHg4149/GDz4xqgsb+UAbMKKUpkrqHA43MUyyJpWUK0EHeG2YKRXr7tB+QMcgGewLD+ebTDbtrtbBt7UPlhS4rV4IvcDI7J8P1OeA/AcAI7LHljN7aB8XTowJmZt9EFRD/o0SDMH4HlwMhMyDWZZSAHFf3YDs3RS49WDLuaAY3IJq+qzmQKLxXAZKN7oDoYbdV3v5elPqiSpMyiOuAEVZVqHXb1OhloUH+MA+ztO0cAO/RkrfyBE7OAEbAZvO8vzVtTRWFD6DAfY5biBM3PWiaL0a4lvXICwnV8WjmE6ntYmhqX2jjp5LbMZjCw/wbYeN6CizOa2GMVzQOlmHjB4Ceuyk6LJ8huccEmR5Xddg7OOV/NAtchW+E3XbOag60QA4Qwuarca0bRuEJyr+cFQwzcY98huxhAKdQelt4kAQpj4qJ3gvFXAYn+aJumXk1yPlpQUgtIHhbYoFMUstNRRWgjnpl4A7IKlayNymqFHFaWCpV9CFry3LGxR1CgA5kB5M8OX2goApwpaz6mdOMGxtAgXWJySxb4WuQD4qTDgU+N5AAnzpr7ChSWpCyisiQJqY0Y7FtmSKpbV23b45kC0KHBxcQ9QeI8w4KgnHRPVtIU7rOtbioLVg5Hl/qDwSVFAMqLSMSObroCdZYlzIJtMRFVHCaRo/wFWPgaAXzdbBpkc2A4aKzCNd97+URQuESYGDDhIVfWOQIKZJu4D2+oXlgDTV1865gUQZDts756BArMNMoR1oa46BYqbyPixZz1ZUFV3sgwoGBajuBKATl3btIn8QYYMuezRgrsiRUWyr2BxA40EkPMpA/Hm6gbUu7fjEXA3azP6AsbKD9bxdUuhjM9W7fII52BF+daRpE4+WA3P501+jbfmHvQKyFqMuXf7Ot4mkN2fr50y+bRH61X7AXdUpHSxaPQ4GVbR5AGw3g+434XgQGKfr72I+vQRhfsu92dOx7WicInzt3CBg1RVpMm0NveWo2SqFzgmdNZMbriILD+S+zoueWf2vSdAipzacWN5nMl6XxNlUHa/J8DoJodUDE0HR8Ll5V0lPxcrLEHZPV4AzS83OLis7FowVa3RSku7BSNxJqQAlN3hBTC2apmDSkpaw22wJemGQFUG7J4MlP3JC6A+f96V7vRyX9It3nzT/GrjIU8edM7rMSnIi10f476lzbE1K7yEiEuWro0OJBguLCwDuFOJc1Na6sRWL/cCeMIwUN9ggSVbe3v/5/EgzTKWLvEAiBrYRUkgwNI2ZaFQNT75UDxEUEx97zYnzpmiLEmbaYCbNxYtFAb0/Z4AztgUrhyxuNgxPnhfHFDHz/vTgFWUQZxTRkkJhQ6YNdVUEPAfO6ZV5BRss6LcCVb7VaAma9giy0XJZBt9IQh42NY0NSdgbLIPlLUF6rEdrdt0CUCK1wsCbkcI3ZSLc7ZSwGLbmJXbPsNxnE5xilYKAobZ77LpGZ8TAIun+/iCKQoF71IxQDI3K2CCd+ARNvXg9sykBcnHAoCZG4u66hlDoQLe6QV4CRtFSxZQ+D0BwNO2jgdkzoGoah1nj3FVlSR19taTSYxI8QLut23U8dsgzqHulJNCQpcqBnpTALCuQ6NSYLHpmR5i42gZzuIdcrMMvMJbQlxe3jXxyZnLACl7ARm/FjPIDOY8ODtpM71sxwfcZpvBeUzKWmfNINM5AS+wO0Khh7dMqKccu4+qatarZjYAwDlgetzStHtEt+XedsBOQtU9XMrRgjg4KTnc5nr+dmqadit/4C4uLm8DuA9koJTj1TL7fI5nDL+qqoo/FLGAzL7dYT17PzvAcQONYSUQRxW/QMrHZVIyik0ZuQA2mzp+Ji8BW4YM3Mbzm9inaHkJCGfrUZZjujiYailfFwA8DHIy3acwUj4v9vUVa+SmgNsl5fuyDTKovW9/IAmfLV0Pi2UncA515kjYdrwC9i9rpuHiq3JwtAAAAABJRU5ErkJggg=="></a>
18
  <a style="display:inline-block; margin-left: .5em" href='https://github.com/liuyuan-pal/SyncDreamer'><img src='https://img.shields.io/github/stars/liuyuan-pal/SyncDreamer?style=social' /></a>
19
  </div>
20
+ Given a single-view image, SyncDreamer is able to generate multiview-consistent images, which enables direct 3D reconstruction with NeuS or NeRF without SDS loss
21
+
22
+ 1. Upload the image.
23
+ 2. Predict the mask for the foreground object.
24
+ 3. Crop the foreground object.
25
+ 4. Generate multiview images.
26
+ '''
27
  _USER_GUIDE0 = "Step0: Please upload an image in the block above (or choose an example above). We use alpha values as object masks if given."
28
  _USER_GUIDE1 = "Step1: Please select a crop size using the glider."
29
  _USER_GUIDE2 = "Step2: Please choose a suitable elevation angle and then click the Generate button."
30
  _USER_GUIDE3 = "Generated multiview images are shown below!"
31
 
32
+ deployed = True
33
 
34
  def mask_prediction(mask_predictor, image_in: Image.Image):
35
  if image_in.mode=='RGBA':
 
63
  elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
64
  data = {"input_image": image_input, "input_elevation": elevation_input}
65
  for k, v in data.items():
66
+ if deployed:
67
+ data[k] = v.unsqueeze(0).cuda()
68
+ else:
69
+ data[k] = v.unsqueeze(0)
70
  data[k] = torch.repeat_interleave(data[k], sample_num, dim=0)
71
 
72
+ if deployed:
73
+ x_sample = model.sample(data, cfg_scale, batch_view_num)
74
+ else:
75
+ x_sample = torch.zeros(sample_num, 16, 3, 256, 256)
76
 
77
  B, N, _, H, W = x_sample.shape
78
  x_sample = (torch.clamp(x_sample,max=1.0,min=-1.0) + 1) * 0.5
 
92
  ckpt = 'ckpt/syncdreamer-pretrain.ckpt'
93
  config = OmegaConf.load(cfg)
94
  # model = None
95
+ if deployed:
96
+ model = instantiate_from_config(config.model)
97
+ print(f'loading model from {ckpt} ...')
98
+ ckpt = torch.load(ckpt,map_location='cpu')
99
+ model.load_state_dict(ckpt['state_dict'], strict=True)
100
+ model = model.cuda().eval()
101
+ del ckpt
102
+ else:
103
+ model = None
104
 
105
  # init sam model
106
  mask_predictor = None # sam_init(device_idx)
 
136
  examples_per_page=40
137
  )
138
 
139
+
140
  with gr.Column(scale=1):
141
  sam_block = gr.Image(type='pil', image_mode='RGBA', label="SAM output", height=256, interactive=False)
142
  crop_size_slider = gr.Slider(120, 240, 200, step=10, label='Crop size', interactive=True)
143
  crop_btn = gr.Button('Crop the image', variant='primary', interactive=True)
144
+ fig0 = gr.Image(value=Image.open('assets/crop_size.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False)
145
 
146
  with gr.Column(scale=1):
147
  input_block = gr.Image(type='pil', image_mode='RGB', label="Input to SyncDreamer", height=256, interactive=False)
 
151
  # batch_view_num = gr.Slider(1, 16, 8, step=1, label='', interactive=True)
152
  seed = gr.Number(6033, label='Random seed', interactive=True)
153
  run_btn = gr.Button('Run Generation', variant='primary', interactive=True)
154
+ fig1 = gr.Image(value=Image.open('assets/elevation.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False)
155
 
156
  output_block = gr.Image(type='pil', image_mode='RGB', label="Outputs of SyncDreamer", height=256, interactive=False)
157