liuyuan-pal commited on
Commit
df916e6
1 Parent(s): a14768e
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  ckpt/* filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  ckpt/* filter=lfs diff=lfs merge=lfs -text
37
+ hf_demo/examples/* filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -7,7 +7,6 @@ import torch
7
  import os
8
  import fire
9
  from omegaconf import OmegaConf
10
- from rembg import remove
11
 
12
  from ldm.util import add_margin, instantiate_from_config
13
  from sam_utils import sam_init, sam_out_nosave
@@ -39,6 +38,28 @@ _USER_GUIDE3 = "Generated multiview images are shown below!"
39
 
40
  deployed = True
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def resize_inputs(image_input, crop_size):
43
  alpha_np = np.asarray(image_input)[:, :, 3]
44
  coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
@@ -95,9 +116,9 @@ def white_background(img):
95
  rgb = (rgb*255).astype(np.uint8)
96
  return Image.fromarray(rgb)
97
 
98
- def sam_predict(predictor, raw_im):
99
  raw_im.thumbnail([512, 512], Image.Resampling.LANCZOS)
100
- image_nobg = remove(raw_im.convert('RGBA'), alpha_matting=True)
101
  arr = np.asarray(image_nobg)[:, :, -1]
102
  x_nonzero = np.nonzero(arr.sum(axis=0))
103
  y_nonzero = np.nonzero(arr.sum(axis=1))
@@ -140,6 +161,7 @@ def run_demo():
140
 
141
  # init sam model
142
  mask_predictor = sam_init()
 
143
 
144
  # with open('instructions_12345.md', 'r') as f:
145
  # article = f.read()
@@ -192,7 +214,7 @@ def run_demo():
192
  output_block = gr.Image(type='pil', image_mode='RGB', label="Outputs of SyncDreamer", height=256, interactive=False)
193
 
194
  update_guide = lambda GUIDE_TEXT: gr.update(value=GUIDE_TEXT)
195
- image_block.change(fn=partial(sam_predict, mask_predictor), inputs=[image_block], outputs=[sam_block], queue=False)\
196
  .success(fn=partial(update_guide, _USER_GUIDE1), outputs=[guide_text], queue=False)
197
 
198
  crop_size_slider.change(fn=resize_inputs, inputs=[sam_block, crop_size_slider], outputs=[input_block], queue=False)\
 
7
  import os
8
  import fire
9
  from omegaconf import OmegaConf
 
10
 
11
  from ldm.util import add_margin, instantiate_from_config
12
  from sam_utils import sam_init, sam_out_nosave
 
38
 
39
  deployed = True
40
 
41
+ class BackgroundRemoval:
42
+ def __init__(self, device='cuda'):
43
+ from carvekit.api.high import HiInterface
44
+ self.interface = HiInterface(
45
+ object_type="object", # Can be "object" or "hairs-like".
46
+ batch_size_seg=5,
47
+ batch_size_matting=1,
48
+ device=device,
49
+ seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
50
+ matting_mask_size=2048,
51
+ trimap_prob_threshold=231,
52
+ trimap_dilation=30,
53
+ trimap_erosion_iters=5,
54
+ fp16=True,
55
+ )
56
+
57
+ @torch.no_grad()
58
+ def __call__(self, image):
59
+ # image: [H, W, 3] array in [0, 255].
60
+ image = self.interface([image])[0]
61
+ return image
62
+
63
  def resize_inputs(image_input, crop_size):
64
  alpha_np = np.asarray(image_input)[:, :, 3]
65
  coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
 
116
  rgb = (rgb*255).astype(np.uint8)
117
  return Image.fromarray(rgb)
118
 
119
+ def sam_predict(predictor, removal, raw_im):
120
  raw_im.thumbnail([512, 512], Image.Resampling.LANCZOS)
121
+ image_nobg = removal(raw_im.convert('RGB'))
122
  arr = np.asarray(image_nobg)[:, :, -1]
123
  x_nonzero = np.nonzero(arr.sum(axis=0))
124
  y_nonzero = np.nonzero(arr.sum(axis=1))
 
161
 
162
  # init sam model
163
  mask_predictor = sam_init()
164
+ removal = BackgroundRemoval()
165
 
166
  # with open('instructions_12345.md', 'r') as f:
167
  # article = f.read()
 
214
  output_block = gr.Image(type='pil', image_mode='RGB', label="Outputs of SyncDreamer", height=256, interactive=False)
215
 
216
  update_guide = lambda GUIDE_TEXT: gr.update(value=GUIDE_TEXT)
217
+ image_block.change(fn=partial(sam_predict, mask_predictor, removal), inputs=[image_block], outputs=[sam_block], queue=False)\
218
  .success(fn=partial(update_guide, _USER_GUIDE1), outputs=[guide_text], queue=False)
219
 
220
  crop_size_slider.change(fn=resize_inputs, inputs=[sam_block, crop_size_slider], outputs=[input_block], queue=False)\
detection_test.py CHANGED
@@ -1,18 +1,39 @@
 
1
  import numpy as np
2
  from PIL import Image
3
  from skimage.io import imsave
4
-
5
- from app import white_background
6
- from ldm.util import add_margin
7
  from sam_utils import sam_out_nosave, sam_init
8
- from rembg import remove
9
 
10
- raw_im = Image.open('hf_demo/examples/basket.png')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  predictor = sam_init()
12
 
13
  raw_im.thumbnail([512, 512], Image.Resampling.LANCZOS)
14
  width, height = raw_im.size
15
- image_nobg = remove(raw_im.convert('RGBA'), alpha_matting=True)
16
  arr = np.asarray(image_nobg)[:, :, -1]
17
  x_nonzero = np.nonzero(arr.sum(axis=0))
18
  y_nonzero = np.nonzero(arr.sum(axis=1))
@@ -20,16 +41,16 @@ x_min = int(x_nonzero[0].min())
20
  y_min = int(y_nonzero[0].min())
21
  x_max = int(x_nonzero[0].max())
22
  y_max = int(y_nonzero[0].max())
23
- # image_nobg.save('./nobg.png')
24
 
25
  image_nobg.thumbnail([512, 512], Image.Resampling.LANCZOS)
26
  image_sam = sam_out_nosave(predictor, image_nobg.convert("RGB"), (x_min, y_min, x_max, y_max))
27
 
28
- # imsave('./mask.png', np.asarray(image_sam)[:,:,3]*255)
29
  image_sam = np.asarray(image_sam, np.float32) / 255
30
  out_mask = image_sam[:, :, 3:]
31
  out_rgb = image_sam[:, :, :3] * out_mask + 1 - out_mask
32
  out_img = (np.concatenate([out_rgb, out_mask], 2) * 255).astype(np.uint8)
33
 
34
  image_sam = Image.fromarray(out_img, mode='RGBA')
35
- # image_sam.save('./output.png')
 
1
+ import torch
2
  import numpy as np
3
  from PIL import Image
4
  from skimage.io import imsave
 
 
 
5
  from sam_utils import sam_out_nosave, sam_init
 
6
 
7
+ class BackgroundRemoval:
8
+ def __init__(self, device='cuda'):
9
+ from carvekit.api.high import HiInterface
10
+ self.interface = HiInterface(
11
+ object_type="object", # Can be "object" or "hairs-like".
12
+ batch_size_seg=5,
13
+ batch_size_matting=1,
14
+ device=device,
15
+ seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
16
+ matting_mask_size=2048,
17
+ trimap_prob_threshold=231,
18
+ trimap_dilation=30,
19
+ trimap_erosion_iters=5,
20
+ fp16=True,
21
+ )
22
+
23
+ @torch.no_grad()
24
+ def __call__(self, image):
25
+ # image: [H, W, 3] array in [0, 255].
26
+ # image = Image.fromarray(image)
27
+ image = self.interface([image])[0]
28
+ # image = np.array(image)
29
+ return image
30
+
31
+ raw_im = Image.open('hf_demo/examples/flower.png')
32
  predictor = sam_init()
33
 
34
  raw_im.thumbnail([512, 512], Image.Resampling.LANCZOS)
35
  width, height = raw_im.size
36
+ image_nobg = BackgroundRemoval()(raw_im.convert('RGB'))
37
  arr = np.asarray(image_nobg)[:, :, -1]
38
  x_nonzero = np.nonzero(arr.sum(axis=0))
39
  y_nonzero = np.nonzero(arr.sum(axis=1))
 
41
  y_min = int(y_nonzero[0].min())
42
  x_max = int(x_nonzero[0].max())
43
  y_max = int(y_nonzero[0].max())
44
+ image_nobg.save('./nobg.png')
45
 
46
  image_nobg.thumbnail([512, 512], Image.Resampling.LANCZOS)
47
  image_sam = sam_out_nosave(predictor, image_nobg.convert("RGB"), (x_min, y_min, x_max, y_max))
48
 
49
+ imsave('./mask.png', np.asarray(image_sam)[:,:,3])
50
  image_sam = np.asarray(image_sam, np.float32) / 255
51
  out_mask = image_sam[:, :, 3:]
52
  out_rgb = image_sam[:, :, :3] * out_mask + 1 - out_mask
53
  out_img = (np.concatenate([out_rgb, out_mask], 2) * 255).astype(np.uint8)
54
 
55
  image_sam = Image.fromarray(out_img, mode='RGBA')
56
+ image_sam.save('./output.png')
hf_demo/examples/basket.png CHANGED

Git LFS Details

  • SHA256: 9b7d07f44e1b223b5f3f6e97bf1e64198dbc63a020e860d1ffc177a5a42e7bd9
  • Pointer size: 130 Bytes
  • Size of remote file: 46.1 kB
hf_demo/examples/cat.png ADDED

Git LFS Details

  • SHA256: 7a2138057a0299987b7d8efde6e5d66d5f38c1666911ec4d60e414e33aac35a1
  • Pointer size: 130 Bytes
  • Size of remote file: 66.2 kB
hf_demo/examples/crab.png ADDED

Git LFS Details

  • SHA256: f97097523d5ce1d4544742b2970c5c7d482603d3ed3bd3937464c291a72946a8
  • Pointer size: 130 Bytes
  • Size of remote file: 60.3 kB
hf_demo/examples/elephant.png ADDED

Git LFS Details

  • SHA256: f9fbf42ae34c12cd94b17d1c540762039576c8321c82be7e589353a4d9cd1bd2
  • Pointer size: 130 Bytes
  • Size of remote file: 73.5 kB
hf_demo/examples/flower.png ADDED

Git LFS Details

  • SHA256: 91cd1effc2454a8c81c4b6d9b0ee46bdf12c2879b1ff8b44d2323976a792339a
  • Pointer size: 130 Bytes
  • Size of remote file: 26.4 kB
hf_demo/examples/forest.png ADDED

Git LFS Details

  • SHA256: 3439271a725e03e466f6064440d4dada3c51a22da00f8ce6d12fae744d8583db
  • Pointer size: 130 Bytes
  • Size of remote file: 65.3 kB
hf_demo/examples/monkey.png CHANGED

Git LFS Details

  • SHA256: 688ae93e7ac3a7afa87ff1c25e29f50cdcfa6f0843f8599ba3a344319a2ff90f
  • Pointer size: 130 Bytes
  • Size of remote file: 62.8 kB
hf_demo/examples/teapot.png ADDED

Git LFS Details

  • SHA256: 368aa4fb501ba8c35f2650efa170fdf8d8e3435ccc6c07e218a6a607338451fd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB