twodgirl commited on
Commit
5d2a4a6
1 Parent(s): 6a55d9b

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +272 -0
  2. convert_diffusers_to_sdxl.py +106 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from convert_diffusers_to_sdxl import convert_unet_state_dict
2
+ from huggingface_hub import hf_hub_download, hf_hub_url, HfApi, HfFileSystem
3
+ import gradio
4
+ import gguf
5
+ import os
6
+ import requests
7
+ from safetensors.torch import load_file
8
+ import shutil
9
+ import time
10
+ import urllib
11
+ from urllib.parse import urlparse, parse_qs, unquote
12
+ import urllib.request
13
+
14
+ def convert(intro, url, api_key, arch):
15
+ path = urllib.parse.urlparse(url).path
16
+ components = path.split('/')
17
+ filename = components[-1]
18
+ output_file = 'locked_model.safetensors'
19
+ print('Step 1/3')
20
+ lock = Filelock(output_file)
21
+ if not os.path.exists(output_file):
22
+ if len(url.split('/')) == 2:
23
+ if not check_hf_safety(url):
24
+ raise Exception('Unexpected error ;)')
25
+ if not lock.acquire():
26
+ raise Exception('Wait your time in the queue.')
27
+ print('Download safetensors from {}.'.format(url))
28
+ try:
29
+ # We won't download the file by hf_hub_download, urllib.request,
30
+ # but access it remotely.
31
+ fs = HfFileSystem()
32
+ with fs.open('{}/unet/diffusion_pytorch_model.safetensors'.format(url)), 'r') as f:
33
+ byte_data = f.read()
34
+ sd_fp16 = load_transformer_by_diffuser_checkpoint(sd=safetensors.torch.load(byte_data))
35
+ except:
36
+ lock.release()
37
+ raise
38
+ else:
39
+ if not check_model_safety(filename):
40
+ raise Exception('Unexpected error ;)')
41
+ if not lock.acquire():
42
+ raise Exception('Wait your time in the queue.')
43
+ print('Download model by id {}.'.format(filename))
44
+ try:
45
+ # Save a hf copy of the remote file, then access it remotely.
46
+ fs = HfFileSystem()
47
+ copy_path = 'twodgirl/wild-sdxl/civit/{}.safetensors'
48
+ with fs.open(copy_path, 'wb') as f:
49
+ download_file(url, f, api_key)
50
+ with fs.open(copy_path, 'r') as f:
51
+ byte_data = f.read()
52
+ sd_fp16 = load_transformer_by_original_checkpoint(sd=safetensors.torch.load(byte_data))
53
+ except:
54
+ lock.release()
55
+ raise
56
+ print('Step 2/3')
57
+ os.remove(output_file) # Free hugging space runs out of free space.
58
+ write('locked_model.gguf', output_file, arch, sd_fp16)
59
+ print('Step 3/3')
60
+ api = HfApi()
61
+ api.upload_file(path_or_fileobj='locked_model.gguf',
62
+ path_in_repo=filename + '.comfyui.Q8.gguf',
63
+ repo_id='twodgirl/wild-sdxl',
64
+ repo_type='model')
65
+ lock.release()
66
+ gradio.Info('Download the file from twodgirl/wild-sdxl/{}'.format(filename + '.comfyui.Q8.gguf'))
67
+ print(output_file)
68
+
69
+ def download_file(url: str, f, token: str):
70
+ ###
71
+ # Code from ashleykleynhans/civitai-downloader.
72
+ USER_AGENT = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) Gecko/20100101 Firefox/119.0'
73
+
74
+ headers = {
75
+ 'Authorization': f'Bearer {token}',
76
+ 'User-Agent': USER_AGENT,
77
+ }
78
+
79
+ # Disable automatic redirect handling
80
+ class NoRedirection(urllib.request.HTTPErrorProcessor):
81
+ def http_response(self, request, response):
82
+ return response
83
+ https_response = http_response
84
+
85
+ request = urllib.request.Request(url, headers=headers)
86
+ opener = urllib.request.build_opener(NoRedirection)
87
+ response = opener.open(request)
88
+
89
+ if response.status in [301, 302, 303, 307, 308]:
90
+ redirect_url = response.getheader('Location')
91
+
92
+ # Extract filename from the redirect URL
93
+ parsed_url = urlparse(redirect_url)
94
+ query_params = parse_qs(parsed_url.query)
95
+ content_disposition = query_params.get('response-content-disposition', [None])[0]
96
+
97
+ if content_disposition:
98
+ filename = unquote(content_disposition.split('filename=')[1].strip('"'))
99
+ else:
100
+ raise Exception('Unable to determine filename')
101
+
102
+ response = urllib.request.urlopen(redirect_url)
103
+ elif response.status == 404:
104
+ raise Exception('File not found')
105
+ else:
106
+ raise Exception('No redirect found, something went wrong')
107
+
108
+ total_size = response.getheader('Content-Length')
109
+
110
+ if total_size is not None:
111
+ total_size = int(total_size)
112
+
113
+ # With file pointer.
114
+ downloaded = 0
115
+ start_time = time.time()
116
+
117
+ CHUNK_SIZE = 1638400
118
+ while True:
119
+ chunk_start_time = time.time()
120
+ buffer = response.read(CHUNK_SIZE)
121
+ chunk_end_time = time.time()
122
+
123
+ if not buffer:
124
+ break
125
+
126
+ downloaded += len(buffer)
127
+ f.write(buffer)
128
+ chunk_time = chunk_end_time - chunk_start_time
129
+
130
+ if chunk_time > 0:
131
+ speed = len(buffer) / chunk_time / (1024 ** 2) # Speed in MB/s
132
+
133
+ if total_size is not None:
134
+ progress = downloaded / total_size
135
+ # sys.stdout.write(f'\rDownloading: {filename} [{progress*100:.2f}%] - {speed:.2f} MB/s')
136
+ # sys.stdout.flush()
137
+
138
+ end_time = time.time()
139
+ time_taken = end_time - start_time
140
+ hours, remainder = divmod(time_taken, 3600)
141
+ minutes, seconds = divmod(remainder, 60)
142
+
143
+ if hours > 0:
144
+ time_str = f'{int(hours)}h {int(minutes)}m {int(seconds)}s'
145
+ elif minutes > 0:
146
+ time_str = f'{int(minutes)}m {int(seconds)}s'
147
+ else:
148
+ time_str = f'{int(seconds)}s'
149
+
150
+ # sys.stdout.write('\n')
151
+ print(f'Download completed. File saved as: {filename}')
152
+ print(f'Downloaded in {time_str}')
153
+
154
+ ###
155
+ # huggingface/twodgirl.
156
+ # License: apache-2.0
157
+
158
+ class Filelock:
159
+ def __init__(self, file_path):
160
+ self.file_path = file_path
161
+ self.lock_path = "{}.lock".format(file_path)
162
+ self.lock_file = None
163
+
164
+ def acquire(self):
165
+ if os.path.exists(self.lock_path):
166
+ lock_stat = os.stat(self.lock_path)
167
+ if time.time() - lock_stat.st_mtime > 900: # 15 minutes
168
+ os.remove(self.lock_path)
169
+ if not os.path.exists(self.lock_path):
170
+ try:
171
+ self.lock_file = open(self.lock_path, 'w')
172
+ self.lock_file.write(str(os.getpid()))
173
+ self.lock_file.flush()
174
+ return True
175
+ except IOError:
176
+ return False
177
+ return False
178
+
179
+ def release(self):
180
+ if self.lock_file:
181
+ self.lock_file.close()
182
+ os.remove(self.lock_path)
183
+ self.lock_file = None
184
+
185
+ def check_hf_safety(repo_id):
186
+ return 'porn' not in repo_id
187
+
188
+ def check_model_safety(model_id):
189
+ url = f"https://civitai.com/api/v1/model-versions/{model_id}"
190
+ response = requests.get(url)
191
+ data = response.json()
192
+
193
+ model_id = data.get('model_id')
194
+
195
+ if model_id:
196
+ url = f"https://civitai.com/api/v1/models/{model_id}"
197
+ response = requests.get(url)
198
+ data = response.json()
199
+
200
+ tags = data.get('tags', [])
201
+ if 'porn' in tags:
202
+ return False
203
+ else:
204
+ return True
205
+ else:
206
+ return True
207
+
208
+ def load_transformer_by_diffuser_checkpoint(filepath=None, sd=None):
209
+ if sd is None:
210
+ sd = load_file(filepath)
211
+ unet_state_dict = convert_unet_state_dict(sd)
212
+ sd_copy = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
213
+
214
+ return sd_copy
215
+
216
+ def load_transformer_by_original_checkpoint(ckpt_path=None, sd=None):
217
+ if sd is None:
218
+ sd = load_file(ckpt_path)
219
+ sd_copy = {}
220
+ for key in sd.keys():
221
+ if key.startswith('model.diffusion_model.'):
222
+ sd_copy[key] = sd3[key]
223
+
224
+ return sd_copy
225
+
226
+ def write(target_path, checkpoint_path, arch, sd_fp16):
227
+ writer = gguf.GGUFWriter(target_path, arch=arch)
228
+ target_quant = gguf.GGMLQuantizationType.Q8_0
229
+ writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
230
+ writer.add_file_type(target_quant)
231
+ sd = {}
232
+ for key in sd_fp16.keys():
233
+ tensor = sd_fp16[key]
234
+ if len(tensor.shape) == 1 or len(tensor.shape) == 4:
235
+ q = gguf.GGMLQuantizationType.F16
236
+ else:
237
+ q = target_quant
238
+ sd[key] = gguf.quants.quantize(tensor.numpy(), q)
239
+ writer.add_tensor(key, sd[key], raw_dtype=q)
240
+ writer.write_header_to_file(target_path)
241
+ writer.write_kv_data_to_file()
242
+ writer.write_tensors_to_file()
243
+ writer.close()
244
+
245
+ intro = gradio.Markdown("""
246
+ ## Convert a SDXL model to GGUF
247
+
248
+ Convert a Pony/SDXL model's UNet to GGUF (Q8).
249
+
250
+ The question is whether I can automate tasks to the extent that would allow me to spend more time with my cat at home.
251
+
252
+ This space takes a diffusers file from 🤗, then converts it to [name your UI] compatible* format. The result should be avail in 10 minutes in the twodgirl/wild-sdxl model directory.
253
+
254
+ *That's an overstatement, as I only test it with my own comfy-gguf node.
255
+
256
+ The url format must follow:
257
+
258
+ *[hf-username]/[sdxl-repo-name]* which must lead to the /unet/diffusion_pytorch_model.safetensors.
259
+
260
+ ### Disclaimer
261
+
262
+ Use of this code requires citation and attribution to the author via a link to their Hugging Face profile in all resulting work.
263
+ """)
264
+ url = gradio.Textbox(label='Download url')
265
+ api_key = gradio.Textbox(label='API key')
266
+ arch = gradio.Textbox(label='Architecture', value='sdxl')
267
+
268
+ if __name__ == '__main__':
269
+ demo = gradio.Interface(convert,
270
+ [intro, url, api_key, arch],
271
+ outputs=None)
272
+ demo.queue().launch()
convert_diffusers_to_sdxl.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import re
3
+ import torch
4
+ from safetensors.torch import load_file, save_file
5
+
6
+ ###
7
+ # Code from huggingface/diffusers/scripts/convert_diffusers_to_original_sdxl.py.
8
+
9
+ unet_conversion_map = [
10
+ # (stable-diffusion, HF Diffusers)
11
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
12
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
13
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
14
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
15
+ ("input_blocks.0.0.weight", "conv_in.weight"),
16
+ ("input_blocks.0.0.bias", "conv_in.bias"),
17
+ ("out.0.weight", "conv_norm_out.weight"),
18
+ ("out.0.bias", "conv_norm_out.bias"),
19
+ ("out.2.weight", "conv_out.weight"),
20
+ ("out.2.bias", "conv_out.bias"),
21
+ # the following are for sdxl
22
+ ("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
23
+ ("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
24
+ ("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
25
+ ("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
26
+ ]
27
+
28
+ unet_conversion_map_resnet = [
29
+ # (stable-diffusion, HF Diffusers)
30
+ ("in_layers.0", "norm1"),
31
+ ("in_layers.2", "conv1"),
32
+ ("out_layers.0", "norm2"),
33
+ ("out_layers.3", "conv2"),
34
+ ("emb_layers.1", "time_emb_proj"),
35
+ ("skip_connection", "conv_shortcut"),
36
+ ]
37
+
38
+ unet_conversion_map_layer = []
39
+ # hardcoded number of downblocks and resnets/attentions...
40
+ # would need smarter logic for other networks.
41
+ for i in range(3):
42
+ # loop over downblocks/upblocks
43
+
44
+ for j in range(2):
45
+ # loop over resnets/attentions for downblocks
46
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
47
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
48
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
49
+
50
+ if i > 0:
51
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
52
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
53
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
54
+
55
+ for j in range(4):
56
+ # loop over resnets/attentions for upblocks
57
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
58
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
59
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
60
+
61
+ if i < 2:
62
+ # no attention layers in up_blocks.0
63
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
64
+ sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
65
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
66
+
67
+ if i < 3:
68
+ # no downsample in down_blocks.3
69
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
70
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
71
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
72
+
73
+ # no upsample in up_blocks.3
74
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
75
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
76
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
77
+ unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv."))
78
+
79
+ hf_mid_atn_prefix = "mid_block.attentions.0."
80
+ sd_mid_atn_prefix = "middle_block.1."
81
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
82
+ for j in range(2):
83
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
84
+ sd_mid_res_prefix = f"middle_block.{2*j}."
85
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
86
+
87
+
88
+ def convert_unet_state_dict(unet_state_dict):
89
+ # buyer beware: this is a *brittle* function,
90
+ # and correct output requires that all of these pieces interact in
91
+ # the exact order in which I have arranged them.
92
+ mapping = {k: k for k in unet_state_dict.keys()}
93
+ for sd_name, hf_name in unet_conversion_map:
94
+ mapping[hf_name] = sd_name
95
+ for k, v in mapping.items():
96
+ if "resnets" in k:
97
+ for sd_part, hf_part in unet_conversion_map_resnet:
98
+ v = v.replace(hf_part, sd_part)
99
+ mapping[k] = v
100
+ for k, v in mapping.items():
101
+ for sd_part, hf_part in unet_conversion_map_layer:
102
+ v = v.replace(hf_part, sd_part)
103
+ mapping[k] = v
104
+ new_state_dict = {sd_name: unet_state_dict[hf_name] for hf_name, sd_name in mapping.items()}
105
+
106
+ return new_state_dict
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ -e git+https://github.com/ggerganov/llama.cpp.git@master#egg=gguf&subdirectory=gguf-py
2
+ diffusers
3
+ sentencepiece
4
+ torch