from diffusers.models.normalization import RMSNorm import torch from torch import nn, Tensor ### # Code from aredden/flux-fp8-api class Linear8(nn.Module): __constants__ = ['in_features', 'out_features'] def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None): factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.in_features = in_features self.out_features = out_features self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs)) if bias: self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) else: self.register_parameter('bias', None) self.does_fp8 = self.supports_fp8_compute() self.scale_weight = torch.ones(1, device=device, dtype=torch.float32) self.scale_input = torch.ones(1, device=device, dtype=torch.float32) def supports_fp8_compute(self, device=None): props = torch.cuda.get_device_properties(device) if props.major >= 9 or props.major == 8 and props.minor >= 9: return True return False # def __setattr__(self, key, value): # if isinstance(value, nn.Parameter): # pass def forward(self, x: Tensor) -> Tensor: if self.does_fp8 is False: return torch.nn.functional.linear(x, self.weight, self.bias) dims = x.shape[:-1] x = x.view(-1, self.in_features) # Requires torch 2.4. y = torch._scaled_mm(x.to(torch.float8_e4m3fn), torch.transpose(self.weight, 0, 1), scale_a=self.scale_input.to(device=x.device), scale_b=self.scale_weight.to(device=x.device), bias=self.bias.to(torch.bfloat16), out_dtype=self.weight.dtype, use_fast_accum=True)[0] return y.view(*dims, self.out_features).to(torch.bfloat16) ### # Code from sayakpaul # http://github.com/huggingface/diffusers/issues/6500 def replace_regular_linears(module, parent=''): for name, child in module.named_children(): if isinstance(child, torch.nn.Linear): in_features = child.in_features out_features = child.out_features device = child.weight.data.device dtype = child.weight.data.dtype has_bias = True if child.bias is not None else False new_layer = Linear8(in_features, out_features, has_bias, device, dtype) new_layer.load_state_dict(child.state_dict()) new_layer = new_layer.to(device) setattr(module, name, new_layer) elif isinstance(child, RMSNorm): # RMSNorm doesn't support float8. rsd = child.state_dict() if 'weight' in rsd: child.load_state_dict({'weight': rsd['weight'].to(torch.bfloat16)}, assign=True) else: # Recursively apply to child modules. if parent == '': replace_regular_linears(child, parent=name) else: replace_regular_linears(child, parent='.'.join([parent, name]))