liuyuan-pal commited on
Commit
0fa63ef
1 Parent(s): 959adf1
Files changed (2) hide show
  1. app.py +25 -12
  2. requirements.txt +2 -1
app.py CHANGED
@@ -6,9 +6,9 @@ import gradio as gr
6
  import torch
7
  import os
8
  import fire
 
9
 
10
- from generate import load_model
11
- from ldm.util import add_margin
12
 
13
  _TITLE = '''SyncDreamer: Generating Multiview-consistent Images from a Single-view Image'''
14
  _DESCRIPTION = '''
@@ -21,6 +21,7 @@ Given a single-view image, SyncDreamer is able to generate multiview-consistent
21
  _USER_GUIDE0 = "Step0: Please upload an image in the block above (or choose an example above). We use alpha values as object masks if given."
22
  _USER_GUIDE1 = "Step1: Please select a crop size using the glider."
23
  _USER_GUIDE2 = "Step2: Please choose a suitable elevation angle and then click the Generate button."
 
24
 
25
 
26
  def mask_prediction(mask_predictor, image_in: Image.Image):
@@ -42,24 +43,24 @@ def resize_inputs(image_input, crop_size):
42
  results = add_margin(ref_img_, size=256)
43
  return results
44
 
45
- def generate(model, seed, batch_view_num, sample_num, cfg_scale, image_input, elevation_input):
 
46
  torch.random.manual_seed(seed)
47
  np.random.seed(seed)
48
 
49
  # prepare data
50
  image_input = np.asarray(image_input)
51
  image_input = image_input.astype(np.float32) / 255.0
52
- ref_mask = image_input[:, :, 3:]
53
- image_input[:, :, :3] = image_input[:, :, :3] * ref_mask + 1 - ref_mask # white background
54
  image_input = image_input[:, :, :3] * 2.0 - 1.0
55
  image_input = torch.from_numpy(image_input.astype(np.float32))
56
  elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
57
  data = {"input_image": image_input, "input_elevation": elevation_input}
58
  for k, v in data.items():
59
- data[k] = v.unsqueeze(0).cuda()
60
  data[k] = torch.repeat_interleave(data[k], sample_num, dim=0)
61
 
62
  x_sample = model.sample(data, cfg_scale, batch_view_num)
 
63
 
64
  B, N, _, H, W = x_sample.shape
65
  x_sample = (torch.clamp(x_sample,max=1.0,min=-1.0) + 1) * 0.5
@@ -68,14 +69,23 @@ def generate(model, seed, batch_view_num, sample_num, cfg_scale, image_input, e
68
 
69
  results = []
70
  for bi in range(B):
71
- results.append(torch.concat([x_sample[bi,ni] for ni in range(N)], 1))
72
- results = torch.concat(results, 0)
73
  return Image.fromarray(results)
74
 
75
  def run_demo():
76
  # device = f"cuda:0" if torch.cuda.is_available() else "cpu"
77
  # models = None # init_model(device, os.path.join(code_dir, ckpt))
78
- model = load_model('configs/syncdreamer', 'ckpt/syncdreamer-pretrain.ckpt', strict=True)
 
 
 
 
 
 
 
 
 
79
 
80
  # init sam model
81
  mask_predictor = None # sam_init(device_idx)
@@ -114,6 +124,7 @@ def run_demo():
114
  with gr.Column(scale=1):
115
  sam_block = gr.Image(type='pil', image_mode='RGBA', label="SAM output", height=256, interactive=False)
116
  crop_size_slider = gr.Slider(120, 240, 200, step=10, label='Crop size', interactive=True)
 
117
 
118
  with gr.Column(scale=1):
119
  input_block = gr.Image(type='pil', image_mode='RGB', label="Input to SyncDreamer", height=256, interactive=False)
@@ -122,7 +133,7 @@ def run_demo():
122
  # sample_num = gr.Slider(1, 2, 2, step=1, label='Sample Num', interactive=True, info='How many instance (16 images per instance)')
123
  # batch_view_num = gr.Slider(1, 16, 8, step=1, label='', interactive=True)
124
  seed = gr.Number(6033, label='Random seed', interactive=True)
125
- run_btn = gr.Button('Run Generation', variant='primary', interactive=False)
126
 
127
  output_block = gr.Image(type='pil', image_mode='RGB', label="Outputs of SyncDreamer", height=256, interactive=False)
128
 
@@ -132,9 +143,11 @@ def run_demo():
132
 
133
  crop_size_slider.change(fn=resize_inputs, inputs=[sam_block, crop_size_slider], outputs=[input_block], queue=False)\
134
  .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
 
 
135
 
136
- run_btn.click(partial(generate, model, seed, 16, 1, cfg_scale, input_block, elevation), outputs=[output_block])\
137
- .success(fn=partial(update_guide, _USER_GUIDE0), outputs=[guide_text], queue=False)
138
 
139
  demo.queue().launch(share=False, max_threads=80) # auth=("admin", os.environ['PASSWD'])
140
 
 
6
  import torch
7
  import os
8
  import fire
9
+ from omegaconf import OmegaConf
10
 
11
+ from ldm.util import add_margin, instantiate_from_config
 
12
 
13
  _TITLE = '''SyncDreamer: Generating Multiview-consistent Images from a Single-view Image'''
14
  _DESCRIPTION = '''
 
21
  _USER_GUIDE0 = "Step0: Please upload an image in the block above (or choose an example above). We use alpha values as object masks if given."
22
  _USER_GUIDE1 = "Step1: Please select a crop size using the glider."
23
  _USER_GUIDE2 = "Step2: Please choose a suitable elevation angle and then click the Generate button."
24
+ _USER_GUIDE3 = "Generated multiview images are shown below!"
25
 
26
 
27
  def mask_prediction(mask_predictor, image_in: Image.Image):
 
43
  results = add_margin(ref_img_, size=256)
44
  return results
45
 
46
+ def generate(model, batch_view_num, sample_num, cfg_scale, seed, image_input, elevation_input):
47
+ seed=int(seed)
48
  torch.random.manual_seed(seed)
49
  np.random.seed(seed)
50
 
51
  # prepare data
52
  image_input = np.asarray(image_input)
53
  image_input = image_input.astype(np.float32) / 255.0
 
 
54
  image_input = image_input[:, :, :3] * 2.0 - 1.0
55
  image_input = torch.from_numpy(image_input.astype(np.float32))
56
  elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
57
  data = {"input_image": image_input, "input_elevation": elevation_input}
58
  for k, v in data.items():
59
+ data[k] = v.unsqueeze(0)#.cuda()
60
  data[k] = torch.repeat_interleave(data[k], sample_num, dim=0)
61
 
62
  x_sample = model.sample(data, cfg_scale, batch_view_num)
63
+ # x_sample = torch.zeros(sample_num, 16, 3, 256, 256)
64
 
65
  B, N, _, H, W = x_sample.shape
66
  x_sample = (torch.clamp(x_sample,max=1.0,min=-1.0) + 1) * 0.5
 
69
 
70
  results = []
71
  for bi in range(B):
72
+ results.append(np.concatenate([x_sample[bi,ni] for ni in range(N)], 1))
73
+ results = np.concatenate(results, 0)
74
  return Image.fromarray(results)
75
 
76
  def run_demo():
77
  # device = f"cuda:0" if torch.cuda.is_available() else "cpu"
78
  # models = None # init_model(device, os.path.join(code_dir, ckpt))
79
+ cfg = 'configs/syncdreamer.yaml'
80
+ ckpt = 'ckpt/syncdreamer-pretrain.ckpt'
81
+ config = OmegaConf.load(cfg)
82
+ # model = None
83
+ model = instantiate_from_config(config.model)
84
+ print(f'loading model from {ckpt} ...')
85
+ ckpt = torch.load(ckpt,map_location='cpu')
86
+ model.load_state_dict(ckpt['state_dict'], strict=True)
87
+ model = model.cuda().eval()
88
+ del ckpt
89
 
90
  # init sam model
91
  mask_predictor = None # sam_init(device_idx)
 
124
  with gr.Column(scale=1):
125
  sam_block = gr.Image(type='pil', image_mode='RGBA', label="SAM output", height=256, interactive=False)
126
  crop_size_slider = gr.Slider(120, 240, 200, step=10, label='Crop size', interactive=True)
127
+ crop_btn = gr.Button('Crop the image', variant='primary', interactive=True)
128
 
129
  with gr.Column(scale=1):
130
  input_block = gr.Image(type='pil', image_mode='RGB', label="Input to SyncDreamer", height=256, interactive=False)
 
133
  # sample_num = gr.Slider(1, 2, 2, step=1, label='Sample Num', interactive=True, info='How many instance (16 images per instance)')
134
  # batch_view_num = gr.Slider(1, 16, 8, step=1, label='', interactive=True)
135
  seed = gr.Number(6033, label='Random seed', interactive=True)
136
+ run_btn = gr.Button('Run Generation', variant='primary', interactive=True)
137
 
138
  output_block = gr.Image(type='pil', image_mode='RGB', label="Outputs of SyncDreamer", height=256, interactive=False)
139
 
 
143
 
144
  crop_size_slider.change(fn=resize_inputs, inputs=[sam_block, crop_size_slider], outputs=[input_block], queue=False)\
145
  .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
146
+ crop_btn.click(fn=resize_inputs, inputs=[sam_block, crop_size_slider], outputs=[input_block], queue=False)\
147
+ .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
148
 
149
+ run_btn.click(partial(generate, model, 16, 1), inputs=[cfg_scale, seed, input_block, elevation], outputs=[output_block], queue=False)\
150
+ .success(fn=partial(update_guide, _USER_GUIDE3), outputs=[guide_text], queue=False)
151
 
152
  demo.queue().launch(share=False, max_threads=80) # auth=("admin", os.environ['PASSWD'])
153
 
requirements.txt CHANGED
@@ -19,4 +19,5 @@ trimesh
19
  easydict
20
  nerfacc
21
  imageio-ffmpeg==0.4.7
22
- git+https://github.com/openai/CLIP.git
 
 
19
  easydict
20
  nerfacc
21
  imageio-ffmpeg==0.4.7
22
+ fire
23
+ git+https://github.com/openai/CLIP.git