liuyuan-pal commited on
Commit
979cb09
1 Parent(s): 36a325d
Files changed (2) hide show
  1. app.py +9 -2
  2. sam_utils.py +3 -2
app.py CHANGED
@@ -82,6 +82,14 @@ def generate(model, batch_view_num, sample_num, cfg_scale, seed, image_input, el
82
  results = np.concatenate(results, 0)
83
  return Image.fromarray(results)
84
 
 
 
 
 
 
 
 
 
85
  def run_demo():
86
  # device = f"cuda:0" if torch.cuda.is_available() else "cpu"
87
  # models = None # init_model(device, os.path.join(code_dir, ckpt))
@@ -101,7 +109,6 @@ def run_demo():
101
 
102
  # init sam model
103
  mask_predictor = sam_init()
104
- mask_predict_fn = lambda x: sam_out_nosave(mask_predictor, x)
105
 
106
  # with open('instructions_12345.md', 'r') as f:
107
  # article = f.read()
@@ -154,7 +161,7 @@ def run_demo():
154
  output_block = gr.Image(type='pil', image_mode='RGB', label="Outputs of SyncDreamer", height=256, interactive=False)
155
 
156
  update_guide = lambda GUIDE_TEXT: gr.update(value=GUIDE_TEXT)
157
- image_block.change(fn=mask_predict_fn, inputs=[image_block], outputs=[sam_block], queue=False)\
158
  .success(fn=partial(update_guide, _USER_GUIDE1), outputs=[guide_text], queue=False)
159
 
160
  crop_size_slider.change(fn=resize_inputs, inputs=[sam_block, crop_size_slider], outputs=[input_block], queue=False)\
 
82
  results = np.concatenate(results, 0)
83
  return Image.fromarray(results)
84
 
85
+ def sam_predict(predictor, raw_im):
86
+ h, w = raw_im.height, raw_im.width
87
+ add_margin(raw_im, size=max(h, w))
88
+ raw_im.thumbnail([512, 512], Image.Resampling.LANCZOS)
89
+ image_sam = sam_out_nosave(predictor, raw_im.convert("RGB"))
90
+ torch.cuda.empty_cache()
91
+ return image_sam
92
+
93
  def run_demo():
94
  # device = f"cuda:0" if torch.cuda.is_available() else "cpu"
95
  # models = None # init_model(device, os.path.join(code_dir, ckpt))
 
109
 
110
  # init sam model
111
  mask_predictor = sam_init()
 
112
 
113
  # with open('instructions_12345.md', 'r') as f:
114
  # article = f.read()
 
161
  output_block = gr.Image(type='pil', image_mode='RGB', label="Outputs of SyncDreamer", height=256, interactive=False)
162
 
163
  update_guide = lambda GUIDE_TEXT: gr.update(value=GUIDE_TEXT)
164
+ image_block.change(fn=partial(sam_predict, mask_predictor), inputs=[image_block], outputs=[sam_block], queue=False)\
165
  .success(fn=partial(update_guide, _USER_GUIDE1), outputs=[guide_text], queue=False)
166
 
167
  crop_size_slider.change(fn=resize_inputs, inputs=[sam_block, crop_size_slider], outputs=[input_block], queue=False)\
sam_utils.py CHANGED
@@ -16,9 +16,10 @@ def sam_init(device_id=0):
16
  predictor = SamPredictor(sam)
17
  return predictor
18
 
19
- def sam_out_nosave(predictor, input_image, bbox_sliders=(0,0,255,255)):
20
- bbox = np.array(bbox_sliders)
21
  image = np.asarray(input_image)
 
 
22
 
23
  start_time = time.time()
24
  predictor.set_image(image)
 
16
  predictor = SamPredictor(sam)
17
  return predictor
18
 
19
+ def sam_out_nosave(predictor, input_image, ):
 
20
  image = np.asarray(input_image)
21
+ h, w, _ = image.shape
22
+ bbox = np.array([0, 0, w, h])
23
 
24
  start_time = time.time()
25
  predictor.set_image(image)