import json import struct import torch import threading import warnings ### # Code from ljleb/sd-mecha/sd_mecha/streaming.py DTYPE_MAPPING = { 'F64': (torch.float64, 8), 'F32': (torch.float32, 4), 'F16': (torch.float16, 2), 'BF16': (torch.bfloat16, 2), 'I8': (torch.int8, 1), 'I64': (torch.int64, 8), 'I32': (torch.int32, 4), 'I16': (torch.int16, 2), "F8_E4M3": (torch.float8_e4m3fn, 1), "F8_E5M2": (torch.float8_e5m2, 1), } class InSafetensorsDict: def __init__(self, f, buffer_size): self.default_buffer_size = buffer_size self.file = f self.header_size, self.header = self._read_header() self.buffer = bytearray() self.buffer_start_offset = 8 + self.header_size self.lock = threading.Lock() def __del__(self): self.close() def __getitem__(self, key): if key not in self.header or key == "__metadata__": raise KeyError(key) return self._load_tensor(key) def __iter__(self): return iter(self.keys()) def __len__(self): return len(self.header) def close(self): self.file.close() self.buffer = None self.header = None def keys(self): return ( key for key in self.header.keys() if key != "__metadata__" ) def values(self): for key in self.keys(): yield self[key] def items(self): for key in self.keys(): yield key, self[key] def _read_header(self): header_size_bytes = self.file.read(8) header_size = struct.unpack(' self.buffer_start_offset + len(self.buffer): self.file.seek(start_pos) necessary_buffer_size = max(self.default_buffer_size, length) if len(self.buffer) < necessary_buffer_size: self.buffer = bytearray(necessary_buffer_size) else: self.buffer = self.buffer[:necessary_buffer_size] self.file.readinto(self.buffer) self.buffer_start_offset = start_pos def _load_tensor(self, tensor_name): tensor_info = self.header[tensor_name] offsets = tensor_info['data_offsets'] dtype, dtype_bytes = DTYPE_MAPPING[tensor_info['dtype']] shape = tensor_info['shape'] total_bytes = offsets[1] - offsets[0] absolute_start_pos = 8 + self.header_size + offsets[0] with warnings.catch_warnings(): warnings.simplefilter('ignore') with self.lock: self._ensure_buffer(absolute_start_pos, total_bytes) buffer_offset = absolute_start_pos - self.buffer_start_offset return torch.frombuffer(self.buffer, count=total_bytes // dtype_bytes, offset=buffer_offset, dtype=dtype).reshape(shape)