karthikmohan409 commited on
Commit
311f85c
1 Parent(s): b3faf0d

Upload 35 files

Browse files
Files changed (35) hide show
  1. app/__pycache__/genai.cpython-311.pyc +0 -0
  2. app/app.py +71 -0
  3. app/gen_mask.py +63 -0
  4. app/genai.py +167 -0
  5. app/huggingface-cloth-segmentation/LICENSE +21 -0
  6. app/huggingface-cloth-segmentation/README.md +38 -0
  7. app/huggingface-cloth-segmentation/__pycache__/network.cpython-311.pyc +0 -0
  8. app/huggingface-cloth-segmentation/__pycache__/options.cpython-311.pyc +0 -0
  9. app/huggingface-cloth-segmentation/__pycache__/process.cpython-311.pyc +0 -0
  10. app/huggingface-cloth-segmentation/app.py +39 -0
  11. app/huggingface-cloth-segmentation/assets/1.png +0 -0
  12. app/huggingface-cloth-segmentation/assets/2.png +0 -0
  13. app/huggingface-cloth-segmentation/input/03615_00.jpg +0 -0
  14. app/huggingface-cloth-segmentation/input/08909_00.jpg +0 -0
  15. app/huggingface-cloth-segmentation/model/cloth_segm.pth +3 -0
  16. app/huggingface-cloth-segmentation/network.py +560 -0
  17. app/huggingface-cloth-segmentation/options.py +12 -0
  18. app/huggingface-cloth-segmentation/output/alpha/1.png +0 -0
  19. app/huggingface-cloth-segmentation/output/alpha/3.png +0 -0
  20. app/huggingface-cloth-segmentation/output/cloth_seg/final_seg.png +0 -0
  21. app/huggingface-cloth-segmentation/process.py +190 -0
  22. app/huggingface-cloth-segmentation/requirements.txt +7 -0
  23. app/main.py +71 -0
  24. app/model/cloth_segm.pth +3 -0
  25. app/output/alpha/1.png +0 -0
  26. app/output/alpha/2.png +0 -0
  27. app/output/alpha/3.png +0 -0
  28. app/output/cloth_seg/final_seg.png +0 -0
  29. app/output_image.jpg +0 -0
  30. app/output_image_1.jpg +0 -0
  31. app/output_image_2.jpg +0 -0
  32. app/output_image_3.jpg +0 -0
  33. app/output_image_4.jpg +0 -0
  34. app/processed_images/output_image.jpg +0 -0
  35. app/processed_images/output_image_1.jpg +0 -0
app/__pycache__/genai.cpython-311.pyc ADDED
Binary file (5.97 kB). View file
 
app/app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, send_from_directory
2
+ from flask_cors import CORS
3
+ from genai import gen_vton
4
+ from werkzeug.utils import secure_filename
5
+ import os
6
+ import tempfile
7
+
8
+ #app = Flask(__name__)
9
+
10
+ app = Flask(__name__, static_folder='processed_images')
11
+
12
+ CORS(app, supports_credentials=True)
13
+ #CORS(app, supports_credentials=True, resources={r"/*": {"origins": "*"}}) # Allow requests from any originorigins=["http://localhost:3000"])
14
+
15
+ #CORS(app, resources={r"/proc": {"origins": "http://localhost:3000"}}, supports_credentials=True)
16
+ #@app.route("/proc")
17
+ @app.route('/proc', methods=['POST'])
18
+ def process_images():
19
+ # Retrieve images from the request
20
+ print("Request came here")
21
+ print(request)
22
+ print(request.headers)
23
+ print(request.files)
24
+
25
+
26
+ user_image_t = request.files.get('userImage')
27
+ dress_image_t = request.files.get('dressImage')
28
+ #print(dress_image_t.filename)
29
+ print(user_image_t.filename)
30
+ #file = request.files['file']
31
+ if dress_image_t:
32
+ # Save the file to a temporary file
33
+ temp_dir = tempfile.gettempdir()
34
+ filename = secure_filename(dress_image_t.filename)
35
+ temp_path = os.path.join(temp_dir, filename)
36
+ dress_image_t.save(temp_path)
37
+ dress_image = temp_path
38
+ if user_image_t:
39
+ temp_dir = tempfile.gettempdir()
40
+ filename = secure_filename(user_image_t.filename)
41
+ temp_path_1 = os.path.join(temp_dir, filename)
42
+ user_image_t.save(temp_path_1)
43
+ user_image = temp_path_1
44
+
45
+ gen_vton(user_image, dress_image)
46
+ processed_image_1_path = './processed_images/output_image.jpg'
47
+ processed_image_2_path = './processed_images/output_image_1.jpg'
48
+
49
+ # Save your images using the paths above...
50
+
51
+ # Return the URL for the saved images
52
+ url_to_processed_image_1 = request.host_url + processed_image_1_path
53
+ url_to_processed_image_2 = request.host_url + processed_image_2_path
54
+ # Process images...
55
+ # For the sake of this example, let's say the processing function returns two image URLs
56
+ processed_image_urls = [url_to_processed_image_1, url_to_processed_image_2]
57
+ os.remove(temp_path)
58
+ os.remove(temp_path_1)
59
+ return jsonify({'processedImages': processed_image_urls})
60
+
61
+ @app.route('/processed_images/<filename>')
62
+ def processed_images(filename):
63
+ print("request_came_here")
64
+ return send_from_directory(app.static_folder, filename)
65
+ # Example of generating a unique filename for the output
66
+
67
+
68
+ #
69
+
70
+ if __name__ == '__main__':
71
+ app.run()
app/gen_mask.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
2
+ from PIL import Image
3
+ import requests
4
+ import matplotlib.pyplot as plt
5
+ import torch.nn as nn
6
+
7
+ processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
8
+ model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
9
+
10
+ url = "https://plus.unsplash.com/premium_photo-1673210886161-bfcc40f54d1f?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8cGVyc29uJTIwc3RhbmRpbmd8ZW58MHx8MHx8&w=1000&q=80"
11
+
12
+ #image = Image.open(requests.get(url, stream=True).raw)
13
+ image_path = "C:/Users/Admin/Downloads/dress1.jpg"
14
+ image = Image.open(image_path)
15
+
16
+ inputs = processor(images=image, return_tensors="pt")
17
+
18
+ outputs = model(**inputs)
19
+ logits = outputs.logits.cpu()
20
+ print("here")
21
+ upsampled_logits = nn.functional.interpolate(
22
+ logits,
23
+ size=image.size[::-1],
24
+ mode="bilinear",
25
+ align_corners=False,
26
+ )
27
+ print(upsampled_logits.argmax(dim=1))
28
+
29
+ pred_seg = upsampled_logits.argmax(dim=1)[0]
30
+ plt.imshow(pred_seg)
31
+ import matplotlib as mpl
32
+ label_names = list(model.config.id2label)
33
+ # Create a color map with the same number of colors as your labels
34
+ # Use the updated method to get the colormap
35
+ cmap = mpl.colormaps['tab20']
36
+
37
+ # Create the figure and axes for the plot and the colorbar
38
+ fig, ax = plt.subplots()
39
+
40
+ # Display the segmentation
41
+ im = ax.imshow(pred_seg, cmap=cmap)
42
+
43
+ # Create a colorbar
44
+ cbar = fig.colorbar(im, ax=ax, ticks=range(len(label_names)))
45
+ cbar.ax.set_yticklabels(label_names)
46
+
47
+ plt.show()
48
+
49
+ # Get the number of labels
50
+ n_labels = len(label_names)
51
+
52
+ # Extract RGB values for each color in the colormap
53
+ colors = cmap.colors[:n_labels]
54
+
55
+ # Convert RGBA to RGB by omitting the Alpha value
56
+ rgb_colors = [color[:3] for color in colors]
57
+
58
+ # Create a dictionary mapping labels to RGB colors
59
+ label_to_color = dict(zip(label_names, rgb_colors))
60
+
61
+ # Display the mapping
62
+ for label, color in label_to_color.items():
63
+ print(f"{label}: {color}")
app/genai.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ # Or save the image
4
+ #output_image.save("output_image.jpg")
5
+ from os import device_encoding
6
+ from diffusers import StableDiffusionInpaintPipeline
7
+ from PIL import Image
8
+ import torch
9
+ import numpy as np
10
+ import torch
11
+ import gc
12
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipelineLegacy, DDIMScheduler, AutoencoderKL
13
+ from PIL import Image
14
+ #import pose_estimation as pe
15
+ import requests
16
+ from rembg import remove
17
+ from transformers import BlipProcessor, BlipForConditionalGeneration
18
+ import sys
19
+ import os
20
+ import subprocess
21
+ sys.path.append(
22
+ os.path.join(os.path.dirname(__file__), "huggingface-cloth-segmentation"))
23
+
24
+ from process import load_seg_model, get_palette, generate_mask
25
+
26
+
27
+ device = 'cpu'
28
+
29
+
30
+
31
+ def initialize_and_load_models():
32
+
33
+ checkpoint_path = 'model/cloth_segm.pth'
34
+ net = load_seg_model(checkpoint_path, device=device)
35
+
36
+ return net
37
+
38
+ net = initialize_and_load_models()
39
+ palette = get_palette(4)
40
+
41
+
42
+ def run(img):
43
+
44
+ cloth_seg = generate_mask(img, net=net, palette=palette, device=device)
45
+ return cloth_seg
46
+
47
+ def image_caption(image_path, img_type):
48
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
49
+ processor = BlipProcessor.from_pretrained("noamrot/FuseCap")
50
+ model = BlipForConditionalGeneration.from_pretrained("noamrot/FuseCap").to(device)
51
+
52
+ raw_image = Image.open(image_path).convert('RGB')
53
+ if img_type == "dress":
54
+ raw_image = remove(raw_image)
55
+ print("bg removed")
56
+ raw_image.show
57
+ #raw_image = img_np_no_bg
58
+
59
+ text = "a picture of "
60
+ inputs = processor(raw_image, text, return_tensors="pt").to(device)
61
+
62
+ out = model.generate(**inputs, num_beams = 3)
63
+ print(processor.decode(out[0], skip_special_tokens=True))
64
+ caption = processor.decode(out[0], skip_special_tokens=True)
65
+ return caption
66
+
67
+ def gen_vton(image_input, dress_input):
68
+ # Load the pre-trained model
69
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
70
+ "runwayml/stable-diffusion-inpainting",
71
+ #revision="fp16", # Or "full" to disable
72
+ torch_dtype=torch.float32, # Or torch.float32
73
+ )
74
+ image_path = image_input
75
+ #submodule_path = os.path.join(os.path.dirname(__file__), "huggingface-cloth-segmentation/process.py")
76
+
77
+ img_open = Image.open(image_path)
78
+ #
79
+ run(img_open)
80
+ gen_mask_1 = "./huggingface-cloth-segmentation/output/alpha/1.png"
81
+ gen_mask_2 = "./huggingface-cloth-segmentation/output/alpha/2.png"
82
+ gen_mask_3 = "./huggingface-cloth-segmentation/output/alpha/3.png"
83
+ print("mask_generated")
84
+ if gen_mask_1:
85
+ mask_path = gen_mask_1
86
+ elif gen_mask_2:
87
+ mask_path = gen_mask_2
88
+ else:
89
+ mask_path = gen_mask_3
90
+
91
+ dress_path = dress_input
92
+
93
+ image = Image.open(image_path)
94
+ mask = Image.open(mask_path) # Convert mask to grayscale
95
+ #image = Image.open("/content/drive/MyDrive/train1/train/image/000025.jpg")
96
+ #mask = Image.open("/content/drive/MyDrive/train1/train/image/000014.jpg")# Convert mask to grayscale
97
+ #image = download_image(img_url).resize((512, 512))
98
+ #mask = download_image(mask_url).resize((512, 512))
99
+
100
+ #image = Image.open(image_path)
101
+ #mask_image = Image.open(mask_path)
102
+ image = image.resize((512, 512))
103
+ mask = mask.resize((512, 512))
104
+ # Define your prompt (text input)
105
+
106
+ user_caption = image_caption(image_path, "user")
107
+ dress_caption = image_caption(dress_path, "dress")
108
+ print(user_caption)
109
+ print(dress_caption)
110
+ prompt = " a human wearing a {dress_caption} "
111
+ neg_prompt = "{user_caption}"
112
+
113
+ # Note: `image` and `mask_image` should be PIL images.
114
+ # The mask structure is white for inpainting and black for keeping as is.
115
+ # Replace `image` and `mask_image` with your actual images.
116
+
117
+ guidance_scale=7.5
118
+ denoising_strength=0.9
119
+ num_samples = 2
120
+ generator = torch.Generator(device="cpu") # Explicitly create a CPU generator
121
+
122
+
123
+
124
+
125
+ images = pipe(
126
+ prompt=prompt,
127
+ negative_prompt=neg_prompt,
128
+ image=image,
129
+ mask_image=mask,
130
+ guidance_scale=guidance_scale,
131
+ denoising_strength=denoising_strength,
132
+ generator=generator,
133
+ num_images_per_prompt=num_samples,
134
+ ).images
135
+
136
+ #Image_1 = pipe(prompt=prompt, image=image, mask_image=mask).images[0]
137
+
138
+
139
+ #images[0] # Display the image
140
+
141
+ #img = Image.open(images[0])
142
+ #img.show()
143
+ #img = Image.open(images[1])
144
+ #img.show()
145
+
146
+ #images[2].show
147
+ # Or save the image
148
+ images[0].save("./processed_images/output_image.jpg")
149
+ images[1].save("./processed_images/output_image_1.jpg")
150
+
151
+ #images[2].save("output_image_2.jpg")
152
+ #images[3].save("output_image_3.jpg")
153
+ #images[3].save("output_image_4.jpg")
154
+
155
+
156
+ #if app == "__main__":
157
+ #gen_vton()
158
+ #user_image = "C:/Users/Admin/Downloads/woman.jpg"
159
+ #dress_image = "C:/Users/Admin/Downloads/dress1.jpg"
160
+ #gen_vton(user_image, dress_image)
161
+
162
+ def predict(dict, prompt):
163
+ image = dict['image'].convert("RGB").resize((512, 512))
164
+ mask_image = dict['mask'].convert("RGB").resize((512, 512))
165
+ #images = pipe(prompt=prompt, image=image, mask_image=mask_image).images
166
+ return(images[0])
167
+
app/huggingface-cloth-segmentation/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Alok Pandey
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app/huggingface-cloth-segmentation/README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Huggingface cloth segmentation using U2NET
2
+
3
+ ![Python 3.8](https://img.shields.io/badge/python-3.8-green.svg)
4
+ [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT)
5
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1LGgLiHiWcmpQalgazLgq4uQuVUm9ZM4M?usp=sharing)
6
+
7
+ This repo contains inference code and gradio demo script using pre-trained U2NET model for Cloths Parsing from human portrait.</br>
8
+ Here clothes are parsed into 3 category: Upper body(red), Lower body(green) and Full body(yellow). The provided script also generates alpha images for each class.
9
+
10
+
11
+ # Inference
12
+ - clone the repo `git clone https://github.com/wildoctopus/huggingface-cloth-segmentation.git`.
13
+ - Install dependencies `pip install -r requirements.txt`
14
+ - Run `python process.py --image 'input/03615_00.jpg'` . **Script will automatically download the pretrained model**.
15
+ - Outputs will be saved in `output` folder.
16
+ - `output/alpha/..` contains alpha images corresponding to each class.
17
+ - `output/cloth_seg` contains final segmentation.
18
+ -
19
+
20
+ # Gradio Demo
21
+ - Run `python app.py`
22
+ - Navigate to local or public url provided by app on successfull execution.
23
+ ### OR
24
+ - Inference in colab from here [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1LGgLiHiWcmpQalgazLgq4uQuVUm9ZM4M?usp=sharing)
25
+
26
+ # Huggingface Demo
27
+ - Check gradio demo on Huggingface space from here [huggingface-cloth-segmentation](https://huggingface.co/spaces/wildoctopus/cloth-segmentation).
28
+
29
+ # Output samples
30
+ ![Sample 000](assets/1.png)
31
+ ![Sample 024](assets/2.png)
32
+
33
+
34
+ This model works well with any background and almost all poses.
35
+
36
+ # Acknowledgements
37
+ - U2net model is from original [u2net repo](https://github.com/xuebinqin/U-2-Net). Thanks to Xuebin Qin for amazing repo.
38
+ - Most of the code is taken and modified from [levindabhi/cloth-segmentation](https://github.com/levindabhi/cloth-segmentation)
app/huggingface-cloth-segmentation/__pycache__/network.cpython-311.pyc ADDED
Binary file (27.3 kB). View file
 
app/huggingface-cloth-segmentation/__pycache__/options.cpython-311.pyc ADDED
Binary file (779 Bytes). View file
 
app/huggingface-cloth-segmentation/__pycache__/process.cpython-311.pyc ADDED
Binary file (10.4 kB). View file
 
app/huggingface-cloth-segmentation/app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import torch
3
+ import gradio as gr
4
+ from process import load_seg_model, get_palette, generate_mask
5
+
6
+
7
+ device = 'cpu'
8
+
9
+
10
+
11
+ def initialize_and_load_models():
12
+
13
+ checkpoint_path = 'model/cloth_segm.pth'
14
+ net = load_seg_model(checkpoint_path, device=device)
15
+
16
+ return net
17
+
18
+ net = initialize_and_load_models()
19
+ palette = get_palette(4)
20
+
21
+
22
+ def run(img):
23
+
24
+ cloth_seg = generate_mask(img, net=net, palette=palette, device=device)
25
+ return cloth_seg
26
+
27
+ # Define input and output interfaces
28
+ input_image = gr.inputs.Image(label="Input Image", type="pil")
29
+
30
+ # Define the Gradio interface
31
+ cloth_seg_image = gr.outputs.Image(label="Cloth Segmentation", type="pil")
32
+
33
+ title = "Demo for Cloth Segmentation"
34
+ description = "An app for Cloth Segmentation"
35
+ inputs = [input_image]
36
+ outputs = [cloth_seg_image]
37
+
38
+
39
+ gr.Interface(fn=run, inputs=inputs, outputs=outputs, title=title, description=description).launch(share=True)
app/huggingface-cloth-segmentation/assets/1.png ADDED
app/huggingface-cloth-segmentation/assets/2.png ADDED
app/huggingface-cloth-segmentation/input/03615_00.jpg ADDED
app/huggingface-cloth-segmentation/input/08909_00.jpg ADDED
app/huggingface-cloth-segmentation/model/cloth_segm.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f71fad2bc11789a996acc507d1a5a1602ae0edefc2b9aba1cd198be5cc9c1a44
3
+ size 176625341
app/huggingface-cloth-segmentation/network.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class REBNCONV(nn.Module):
7
+ def __init__(self, in_ch=3, out_ch=3, dirate=1):
8
+ super(REBNCONV, self).__init__()
9
+
10
+ self.conv_s1 = nn.Conv2d(
11
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
12
+ )
13
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
14
+ self.relu_s1 = nn.ReLU(inplace=True)
15
+
16
+ def forward(self, x):
17
+
18
+ hx = x
19
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
20
+
21
+ return xout
22
+
23
+
24
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
25
+ def _upsample_like(src, tar):
26
+
27
+ src = F.upsample(src, size=tar.shape[2:], mode="bilinear")
28
+
29
+ return src
30
+
31
+
32
+ ### RSU-7 ###
33
+ class RSU7(nn.Module): # UNet07DRES(nn.Module):
34
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
35
+ super(RSU7, self).__init__()
36
+
37
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
38
+
39
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
40
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
41
+
42
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
43
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
44
+
45
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
46
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
47
+
48
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
49
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
50
+
51
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
52
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
53
+
54
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
55
+
56
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
57
+
58
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
59
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
60
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
61
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
62
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
63
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
64
+
65
+ def forward(self, x):
66
+
67
+ hx = x
68
+ hxin = self.rebnconvin(hx)
69
+
70
+ hx1 = self.rebnconv1(hxin)
71
+ hx = self.pool1(hx1)
72
+
73
+ hx2 = self.rebnconv2(hx)
74
+ hx = self.pool2(hx2)
75
+
76
+ hx3 = self.rebnconv3(hx)
77
+ hx = self.pool3(hx3)
78
+
79
+ hx4 = self.rebnconv4(hx)
80
+ hx = self.pool4(hx4)
81
+
82
+ hx5 = self.rebnconv5(hx)
83
+ hx = self.pool5(hx5)
84
+
85
+ hx6 = self.rebnconv6(hx)
86
+
87
+ hx7 = self.rebnconv7(hx6)
88
+
89
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
90
+ hx6dup = _upsample_like(hx6d, hx5)
91
+
92
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
93
+ hx5dup = _upsample_like(hx5d, hx4)
94
+
95
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
96
+ hx4dup = _upsample_like(hx4d, hx3)
97
+
98
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
99
+ hx3dup = _upsample_like(hx3d, hx2)
100
+
101
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
102
+ hx2dup = _upsample_like(hx2d, hx1)
103
+
104
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
105
+
106
+ """
107
+ del hx1, hx2, hx3, hx4, hx5, hx6, hx7
108
+ del hx6d, hx5d, hx3d, hx2d
109
+ del hx2dup, hx3dup, hx4dup, hx5dup, hx6dup
110
+ """
111
+
112
+ return hx1d + hxin
113
+
114
+
115
+ ### RSU-6 ###
116
+ class RSU6(nn.Module): # UNet06DRES(nn.Module):
117
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
118
+ super(RSU6, self).__init__()
119
+
120
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
121
+
122
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
123
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
124
+
125
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
126
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
127
+
128
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
129
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
130
+
131
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
132
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
133
+
134
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
135
+
136
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
137
+
138
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
139
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
140
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
141
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
142
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
143
+
144
+ def forward(self, x):
145
+
146
+ hx = x
147
+
148
+ hxin = self.rebnconvin(hx)
149
+
150
+ hx1 = self.rebnconv1(hxin)
151
+ hx = self.pool1(hx1)
152
+
153
+ hx2 = self.rebnconv2(hx)
154
+ hx = self.pool2(hx2)
155
+
156
+ hx3 = self.rebnconv3(hx)
157
+ hx = self.pool3(hx3)
158
+
159
+ hx4 = self.rebnconv4(hx)
160
+ hx = self.pool4(hx4)
161
+
162
+ hx5 = self.rebnconv5(hx)
163
+
164
+ hx6 = self.rebnconv6(hx5)
165
+
166
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
167
+ hx5dup = _upsample_like(hx5d, hx4)
168
+
169
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
170
+ hx4dup = _upsample_like(hx4d, hx3)
171
+
172
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
173
+ hx3dup = _upsample_like(hx3d, hx2)
174
+
175
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
176
+ hx2dup = _upsample_like(hx2d, hx1)
177
+
178
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
179
+
180
+ """
181
+ del hx1, hx2, hx3, hx4, hx5, hx6
182
+ del hx5d, hx4d, hx3d, hx2d
183
+ del hx2dup, hx3dup, hx4dup, hx5dup
184
+ """
185
+
186
+ return hx1d + hxin
187
+
188
+
189
+ ### RSU-5 ###
190
+ class RSU5(nn.Module): # UNet05DRES(nn.Module):
191
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
192
+ super(RSU5, self).__init__()
193
+
194
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
195
+
196
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
197
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
198
+
199
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
200
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
201
+
202
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
203
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
204
+
205
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
206
+
207
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
208
+
209
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
210
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
211
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
212
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
213
+
214
+ def forward(self, x):
215
+
216
+ hx = x
217
+
218
+ hxin = self.rebnconvin(hx)
219
+
220
+ hx1 = self.rebnconv1(hxin)
221
+ hx = self.pool1(hx1)
222
+
223
+ hx2 = self.rebnconv2(hx)
224
+ hx = self.pool2(hx2)
225
+
226
+ hx3 = self.rebnconv3(hx)
227
+ hx = self.pool3(hx3)
228
+
229
+ hx4 = self.rebnconv4(hx)
230
+
231
+ hx5 = self.rebnconv5(hx4)
232
+
233
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
234
+ hx4dup = _upsample_like(hx4d, hx3)
235
+
236
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
237
+ hx3dup = _upsample_like(hx3d, hx2)
238
+
239
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
240
+ hx2dup = _upsample_like(hx2d, hx1)
241
+
242
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
243
+
244
+ """
245
+ del hx1, hx2, hx3, hx4, hx5
246
+ del hx4d, hx3d, hx2d
247
+ del hx2dup, hx3dup, hx4dup
248
+ """
249
+
250
+ return hx1d + hxin
251
+
252
+
253
+ ### RSU-4 ###
254
+ class RSU4(nn.Module): # UNet04DRES(nn.Module):
255
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
256
+ super(RSU4, self).__init__()
257
+
258
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
259
+
260
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
261
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
262
+
263
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
264
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
265
+
266
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
267
+
268
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
269
+
270
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
271
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
272
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
273
+
274
+ def forward(self, x):
275
+
276
+ hx = x
277
+
278
+ hxin = self.rebnconvin(hx)
279
+
280
+ hx1 = self.rebnconv1(hxin)
281
+ hx = self.pool1(hx1)
282
+
283
+ hx2 = self.rebnconv2(hx)
284
+ hx = self.pool2(hx2)
285
+
286
+ hx3 = self.rebnconv3(hx)
287
+
288
+ hx4 = self.rebnconv4(hx3)
289
+
290
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
291
+ hx3dup = _upsample_like(hx3d, hx2)
292
+
293
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
294
+ hx2dup = _upsample_like(hx2d, hx1)
295
+
296
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
297
+
298
+ """
299
+ del hx1, hx2, hx3, hx4
300
+ del hx3d, hx2d
301
+ del hx2dup, hx3dup
302
+ """
303
+
304
+ return hx1d + hxin
305
+
306
+
307
+ ### RSU-4F ###
308
+ class RSU4F(nn.Module): # UNet04FRES(nn.Module):
309
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
310
+ super(RSU4F, self).__init__()
311
+
312
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
313
+
314
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
315
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
316
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
317
+
318
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
319
+
320
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
321
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
322
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
323
+
324
+ def forward(self, x):
325
+
326
+ hx = x
327
+
328
+ hxin = self.rebnconvin(hx)
329
+
330
+ hx1 = self.rebnconv1(hxin)
331
+ hx2 = self.rebnconv2(hx1)
332
+ hx3 = self.rebnconv3(hx2)
333
+
334
+ hx4 = self.rebnconv4(hx3)
335
+
336
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
337
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
338
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
339
+
340
+ """
341
+ del hx1, hx2, hx3, hx4
342
+ del hx3d, hx2d
343
+ """
344
+
345
+ return hx1d + hxin
346
+
347
+
348
+ ##### U^2-Net ####
349
+ class U2NET(nn.Module):
350
+ def __init__(self, in_ch=3, out_ch=1):
351
+ super(U2NET, self).__init__()
352
+
353
+ self.stage1 = RSU7(in_ch, 32, 64)
354
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
355
+
356
+ self.stage2 = RSU6(64, 32, 128)
357
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
358
+
359
+ self.stage3 = RSU5(128, 64, 256)
360
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
361
+
362
+ self.stage4 = RSU4(256, 128, 512)
363
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
364
+
365
+ self.stage5 = RSU4F(512, 256, 512)
366
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
367
+
368
+ self.stage6 = RSU4F(512, 256, 512)
369
+
370
+ # decoder
371
+ self.stage5d = RSU4F(1024, 256, 512)
372
+ self.stage4d = RSU4(1024, 128, 256)
373
+ self.stage3d = RSU5(512, 64, 128)
374
+ self.stage2d = RSU6(256, 32, 64)
375
+ self.stage1d = RSU7(128, 16, 64)
376
+
377
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
378
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
379
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
380
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
381
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
382
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
383
+
384
+ self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
385
+
386
+ def forward(self, x):
387
+
388
+ hx = x
389
+
390
+ # stage 1
391
+ hx1 = self.stage1(hx)
392
+ hx = self.pool12(hx1)
393
+
394
+ # stage 2
395
+ hx2 = self.stage2(hx)
396
+ hx = self.pool23(hx2)
397
+
398
+ # stage 3
399
+ hx3 = self.stage3(hx)
400
+ hx = self.pool34(hx3)
401
+
402
+ # stage 4
403
+ hx4 = self.stage4(hx)
404
+ hx = self.pool45(hx4)
405
+
406
+ # stage 5
407
+ hx5 = self.stage5(hx)
408
+ hx = self.pool56(hx5)
409
+
410
+ # stage 6
411
+ hx6 = self.stage6(hx)
412
+ hx6up = _upsample_like(hx6, hx5)
413
+
414
+ # -------------------- decoder --------------------
415
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
416
+ hx5dup = _upsample_like(hx5d, hx4)
417
+
418
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
419
+ hx4dup = _upsample_like(hx4d, hx3)
420
+
421
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
422
+ hx3dup = _upsample_like(hx3d, hx2)
423
+
424
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
425
+ hx2dup = _upsample_like(hx2d, hx1)
426
+
427
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
428
+
429
+ # side output
430
+ d1 = self.side1(hx1d)
431
+
432
+ d2 = self.side2(hx2d)
433
+ d2 = _upsample_like(d2, d1)
434
+
435
+ d3 = self.side3(hx3d)
436
+ d3 = _upsample_like(d3, d1)
437
+
438
+ d4 = self.side4(hx4d)
439
+ d4 = _upsample_like(d4, d1)
440
+
441
+ d5 = self.side5(hx5d)
442
+ d5 = _upsample_like(d5, d1)
443
+
444
+ d6 = self.side6(hx6)
445
+ d6 = _upsample_like(d6, d1)
446
+
447
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
448
+
449
+ """
450
+ del hx1, hx2, hx3, hx4, hx5, hx6
451
+ del hx5d, hx4d, hx3d, hx2d, hx1d
452
+ del hx6up, hx5dup, hx4dup, hx3dup, hx2dup
453
+ """
454
+
455
+ return d0, d1, d2, d3, d4, d5, d6
456
+
457
+
458
+ ### U^2-Net small ###
459
+ class U2NETP(nn.Module):
460
+ def __init__(self, in_ch=3, out_ch=1):
461
+ super(U2NETP, self).__init__()
462
+
463
+ self.stage1 = RSU7(in_ch, 16, 64)
464
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
465
+
466
+ self.stage2 = RSU6(64, 16, 64)
467
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
468
+
469
+ self.stage3 = RSU5(64, 16, 64)
470
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
471
+
472
+ self.stage4 = RSU4(64, 16, 64)
473
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
474
+
475
+ self.stage5 = RSU4F(64, 16, 64)
476
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
477
+
478
+ self.stage6 = RSU4F(64, 16, 64)
479
+
480
+ # decoder
481
+ self.stage5d = RSU4F(128, 16, 64)
482
+ self.stage4d = RSU4(128, 16, 64)
483
+ self.stage3d = RSU5(128, 16, 64)
484
+ self.stage2d = RSU6(128, 16, 64)
485
+ self.stage1d = RSU7(128, 16, 64)
486
+
487
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
488
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
489
+ self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
490
+ self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
491
+ self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
492
+ self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
493
+
494
+ self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
495
+
496
+ def forward(self, x):
497
+
498
+ hx = x
499
+
500
+ # stage 1
501
+ hx1 = self.stage1(hx)
502
+ hx = self.pool12(hx1)
503
+
504
+ # stage 2
505
+ hx2 = self.stage2(hx)
506
+ hx = self.pool23(hx2)
507
+
508
+ # stage 3
509
+ hx3 = self.stage3(hx)
510
+ hx = self.pool34(hx3)
511
+
512
+ # stage 4
513
+ hx4 = self.stage4(hx)
514
+ hx = self.pool45(hx4)
515
+
516
+ # stage 5
517
+ hx5 = self.stage5(hx)
518
+ hx = self.pool56(hx5)
519
+
520
+ # stage 6
521
+ hx6 = self.stage6(hx)
522
+ hx6up = _upsample_like(hx6, hx5)
523
+
524
+ # decoder
525
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
526
+ hx5dup = _upsample_like(hx5d, hx4)
527
+
528
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
529
+ hx4dup = _upsample_like(hx4d, hx3)
530
+
531
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
532
+ hx3dup = _upsample_like(hx3d, hx2)
533
+
534
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
535
+ hx2dup = _upsample_like(hx2d, hx1)
536
+
537
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
538
+
539
+ # side output
540
+ d1 = self.side1(hx1d)
541
+
542
+ d2 = self.side2(hx2d)
543
+ d2 = _upsample_like(d2, d1)
544
+
545
+ d3 = self.side3(hx3d)
546
+ d3 = _upsample_like(d3, d1)
547
+
548
+ d4 = self.side4(hx4d)
549
+ d4 = _upsample_like(d4, d1)
550
+
551
+ d5 = self.side5(hx5d)
552
+ d5 = _upsample_like(d5, d1)
553
+
554
+ d6 = self.side6(hx6)
555
+ d6 = _upsample_like(d6, d1)
556
+
557
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
558
+
559
+
560
+ return d0, d1, d2, d3, d4, d5, d6
app/huggingface-cloth-segmentation/options.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import os
3
+
4
+
5
+ class parser(object):
6
+ def __init__(self):
7
+
8
+ self.output = "./output" # output image folder path
9
+ self.logs_dir = './logs'
10
+ self.device = 'cuda:0'
11
+
12
+ opt = parser()
app/huggingface-cloth-segmentation/output/alpha/1.png ADDED
app/huggingface-cloth-segmentation/output/alpha/3.png ADDED
app/huggingface-cloth-segmentation/output/cloth_seg/final_seg.png ADDED
app/huggingface-cloth-segmentation/process.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from network import U2NET
2
+
3
+ import os
4
+ from PIL import Image
5
+ import cv2
6
+ import gdown
7
+ import argparse
8
+ import numpy as np
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torchvision.transforms as transforms
13
+
14
+ from collections import OrderedDict
15
+ from options import opt
16
+
17
+
18
+ def load_checkpoint(model, checkpoint_path):
19
+ if not os.path.exists(checkpoint_path):
20
+ print("----No checkpoints at given path----")
21
+ return
22
+ model_state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))
23
+ new_state_dict = OrderedDict()
24
+ for k, v in model_state_dict.items():
25
+ name = k[7:] # remove `module.`
26
+ new_state_dict[name] = v
27
+
28
+ model.load_state_dict(new_state_dict)
29
+ print("----checkpoints loaded from path: {}----".format(checkpoint_path))
30
+ return model
31
+
32
+
33
+ def get_palette(num_cls):
34
+ """ Returns the color map for visualizing the segmentation mask.
35
+ Args:
36
+ num_cls: Number of classes
37
+ Returns:
38
+ The color map
39
+ """
40
+ n = num_cls
41
+ palette = [0] * (n * 3)
42
+ for j in range(0, n):
43
+ lab = j
44
+ palette[j * 3 + 0] = 0
45
+ palette[j * 3 + 1] = 0
46
+ palette[j * 3 + 2] = 0
47
+ i = 0
48
+ while lab:
49
+ palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
50
+ palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
51
+ palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
52
+ i += 1
53
+ lab >>= 3
54
+ return palette
55
+
56
+
57
+ class Normalize_image(object):
58
+ """Normalize given tensor into given mean and standard dev
59
+
60
+ Args:
61
+ mean (float): Desired mean to substract from tensors
62
+ std (float): Desired std to divide from tensors
63
+ """
64
+
65
+ def __init__(self, mean, std):
66
+ assert isinstance(mean, (float))
67
+ if isinstance(mean, float):
68
+ self.mean = mean
69
+
70
+ if isinstance(std, float):
71
+ self.std = std
72
+
73
+ self.normalize_1 = transforms.Normalize(self.mean, self.std)
74
+ self.normalize_3 = transforms.Normalize([self.mean] * 3, [self.std] * 3)
75
+ self.normalize_18 = transforms.Normalize([self.mean] * 18, [self.std] * 18)
76
+
77
+ def __call__(self, image_tensor):
78
+ if image_tensor.shape[0] == 1:
79
+ return self.normalize_1(image_tensor)
80
+
81
+ elif image_tensor.shape[0] == 3:
82
+ return self.normalize_3(image_tensor)
83
+
84
+ elif image_tensor.shape[0] == 18:
85
+ return self.normalize_18(image_tensor)
86
+
87
+ else:
88
+ assert "Please set proper channels! Normlization implemented only for 1, 3 and 18"
89
+
90
+
91
+
92
+
93
+ def apply_transform(img):
94
+ transforms_list = []
95
+ transforms_list += [transforms.ToTensor()]
96
+ transforms_list += [Normalize_image(0.5, 0.5)]
97
+ transform_rgb = transforms.Compose(transforms_list)
98
+ return transform_rgb(img)
99
+
100
+
101
+
102
+ def generate_mask(input_image, net, palette, device = 'cpu'):
103
+
104
+ #img = Image.open(input_image).convert('RGB')
105
+ img = input_image
106
+ img_size = img.size
107
+ img = img.resize((768, 768), Image.BICUBIC)
108
+ image_tensor = apply_transform(img)
109
+ image_tensor = torch.unsqueeze(image_tensor, 0)
110
+
111
+ alpha_out_dir = os.path.join(opt.output,'alpha')
112
+ cloth_seg_out_dir = os.path.join(opt.output,'cloth_seg')
113
+
114
+ os.makedirs(alpha_out_dir, exist_ok=True)
115
+ os.makedirs(cloth_seg_out_dir, exist_ok=True)
116
+
117
+ with torch.no_grad():
118
+ output_tensor = net(image_tensor.to(device))
119
+ output_tensor = F.log_softmax(output_tensor[0], dim=1)
120
+ output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
121
+ output_tensor = torch.squeeze(output_tensor, dim=0)
122
+ output_arr = output_tensor.cpu().numpy()
123
+
124
+ classes_to_save = []
125
+
126
+ # Check which classes are present in the image
127
+ for cls in range(1, 4): # Exclude background class (0)
128
+ if np.any(output_arr == cls):
129
+ classes_to_save.append(cls)
130
+
131
+ # Save alpha masks
132
+ for cls in classes_to_save:
133
+ alpha_mask = (output_arr == cls).astype(np.uint8) * 255
134
+ alpha_mask = alpha_mask[0] # Selecting the first channel to make it 2D
135
+ alpha_mask_img = Image.fromarray(alpha_mask, mode='L')
136
+ alpha_mask_img = alpha_mask_img.resize(img_size, Image.BICUBIC)
137
+ alpha_mask_img.save(os.path.join(alpha_out_dir, f'{cls}.png'))
138
+
139
+ # Save final cloth segmentations
140
+ cloth_seg = Image.fromarray(output_arr[0].astype(np.uint8), mode='P')
141
+ cloth_seg.putpalette(palette)
142
+ cloth_seg = cloth_seg.resize(img_size, Image.BICUBIC)
143
+ cloth_seg.save(os.path.join(cloth_seg_out_dir, 'final_seg.png'))
144
+ return cloth_seg
145
+
146
+
147
+
148
+ def check_or_download_model(file_path):
149
+ if not os.path.exists(file_path):
150
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
151
+ url = "https://drive.google.com/uc?id=11xTBALOeUkyuaK3l60CpkYHLTmv7k3dY"
152
+ gdown.download(url, file_path, quiet=False)
153
+ print("Model downloaded successfully.")
154
+ else:
155
+ print("Model already exists.")
156
+
157
+
158
+ def load_seg_model(checkpoint_path, device='cpu'):
159
+ net = U2NET(in_ch=3, out_ch=4)
160
+ check_or_download_model(checkpoint_path)
161
+ net = load_checkpoint(net, checkpoint_path)
162
+ net = net.to(device)
163
+ net = net.eval()
164
+
165
+ return net
166
+
167
+
168
+ def main(args):
169
+
170
+ device = 'cuda:0' if args.cuda else 'cpu'
171
+
172
+ # Create an instance of your model
173
+ model = load_seg_model(args.checkpoint_path, device=device)
174
+
175
+ palette = get_palette(4)
176
+
177
+ img = Image.open(args.image).convert('RGB')
178
+
179
+ cloth_seg = generate_mask(img, net=model, palette=palette, device=device)
180
+
181
+
182
+
183
+ if __name__ == '__main__':
184
+ parser = argparse.ArgumentParser(description='Help to set arguments for Cloth Segmentation.')
185
+ parser.add_argument('--image', type=str, help='Path to the input image')
186
+ parser.add_argument('--cuda', action='store_true', help='Enable CUDA (default: False)')
187
+ parser.add_argument('--checkpoint_path', type=str, default='model/cloth_segm.pth', help='Path to the checkpoint file')
188
+ args = parser.parse_args()
189
+
190
+ main(args)
app/huggingface-cloth-segmentation/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ gdown
5
+ Pillow
6
+ opencv-python
7
+ numpy
app/main.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, send_from_directory
2
+ from flask_cors import CORS
3
+ from genai import gen_vton
4
+ from werkzeug.utils import secure_filename
5
+ import os
6
+ import tempfile
7
+
8
+ #app = Flask(__name__)
9
+
10
+ app = Flask(__name__, static_folder='processed_images')
11
+
12
+ CORS(app, supports_credentials=True)
13
+ #CORS(app, supports_credentials=True, resources={r"/*": {"origins": "*"}}) # Allow requests from any originorigins=["http://localhost:3000"])
14
+
15
+ #CORS(app, resources={r"/proc": {"origins": "http://localhost:3000"}}, supports_credentials=True)
16
+ #@app.route("/proc")
17
+ @app.route('/proc', methods=['POST'])
18
+ def process_images():
19
+ # Retrieve images from the request
20
+ print("Request came here")
21
+ print(request)
22
+ print(request.headers)
23
+ print(request.files)
24
+
25
+
26
+ user_image_t = request.files.get('userImage')
27
+ dress_image_t = request.files.get('dressImage')
28
+ #print(dress_image_t.filename)
29
+ print(user_image_t.filename)
30
+ #file = request.files['file']
31
+ if dress_image_t:
32
+ # Save the file to a temporary file
33
+ temp_dir = tempfile.gettempdir()
34
+ filename = secure_filename(dress_image_t.filename)
35
+ temp_path = os.path.join(temp_dir, filename)
36
+ dress_image_t.save(temp_path)
37
+ dress_image = temp_path
38
+ if user_image_t:
39
+ temp_dir = tempfile.gettempdir()
40
+ filename = secure_filename(user_image_t.filename)
41
+ temp_path_1 = os.path.join(temp_dir, filename)
42
+ user_image_t.save(temp_path_1)
43
+ user_image = temp_path_1
44
+
45
+ gen_vton(user_image, dress_image)
46
+ processed_image_1_path = './processed_images/output_image.jpg'
47
+ processed_image_2_path = './processed_images/output_image_1.jpg'
48
+
49
+ # Save your images using the paths above...
50
+
51
+ # Return the URL for the saved images
52
+ url_to_processed_image_1 = request.host_url + processed_image_1_path
53
+ url_to_processed_image_2 = request.host_url + processed_image_2_path
54
+ # Process images...
55
+ # For the sake of this example, let's say the processing function returns two image URLs
56
+ processed_image_urls = [url_to_processed_image_1, url_to_processed_image_2]
57
+ os.remove(temp_path)
58
+ os.remove(temp_path_1)
59
+ return jsonify({'processedImages': processed_image_urls})
60
+
61
+ @app.route('/processed_images/<filename>')
62
+ def processed_images(filename):
63
+ print("request_came_here")
64
+ return send_from_directory(app.static_folder, filename)
65
+ # Example of generating a unique filename for the output
66
+
67
+
68
+ #
69
+
70
+ if __name__ == '__main__':
71
+ app.run(debug=True, host='0.0.0.0')
app/model/cloth_segm.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f71fad2bc11789a996acc507d1a5a1602ae0edefc2b9aba1cd198be5cc9c1a44
3
+ size 176625341
app/output/alpha/1.png ADDED
app/output/alpha/2.png ADDED
app/output/alpha/3.png ADDED
app/output/cloth_seg/final_seg.png ADDED
app/output_image.jpg ADDED
app/output_image_1.jpg ADDED
app/output_image_2.jpg ADDED
app/output_image_3.jpg ADDED
app/output_image_4.jpg ADDED
app/processed_images/output_image.jpg ADDED
app/processed_images/output_image_1.jpg ADDED