def prepare_moe_fp8_layer_for_marlin(
layer: torch.nn.Module,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
w13_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
input_dtype: torch.dtype | None = None,
) -> tuple[
torch.Tensor, # workspace
torch.Tensor, # w13_weight
torch.Tensor, # w2_weight
torch.Tensor, # w13_weight_scale
torch.Tensor, # w2_weight_scale
]:
logger.warning_once(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
if input_dtype is not None and input_dtype.itemsize == 1:
raise NotImplementedError("Marlin W8A8 is not supported.")
e = layer.num_experts
k = layer.hidden_size
n = layer.intermediate_size_per_partition
weight_block_size = getattr(layer, "weight_block_size", None)
# WORKSPACE
device = layer.w13_weight.device
workspace = marlin_make_workspace_new(device, 4)
perm = torch.empty(0, dtype=torch.int, device=device)
# WEIGHT
# Repack weights to marlin format
def repack_weight(name: str, weight: torch.Tensor) -> torch.Tensor:
tensor_list = []
if "w13" in name:
size_n, size_k = n * 2, k
else:
size_n, size_k = k, n
assert weight.shape == (e, size_n, size_k)
for i in range(e):
qweight = pack_fp8_to_int32(weight[i], size_k_first=False)
qweight = qweight.T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=8
)
tensor_list.append(marlin_qweight)
return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
w13_weight = repack_weight("w13", w13_weight)
w2_weight = repack_weight("w2", w2_weight)
# WEIGHT SCALES
# Permute scales
group_size = -1 if weight_block_size is None else weight_block_size[1]
def permute_scales(scales: torch.Tensor, name: str) -> torch.Tensor:
scales = scales.to(layer.orig_dtype)
tensor_list = []
if "w13" in name:
size_n, size_k = n * 2, k
else:
size_n, size_k = k, n
# marlin kernel only support channel-wise and group-wise quantization
# we need to convert the scales
if weight_block_size is None:
if scales.nelement() == e:
# tensor-wise quantization -> channel-wise quantization
# (e, 1, 1) =>(repeat)=> (e, 1, size_n)
scales = scales.view(e, 1, 1).repeat_interleave(size_n, 2)
elif scales.nelement() > e and scales.nelement() != e * size_n:
assert (e * size_n) % scales.nelement() == 0
s_size = scales.nelement() // e
# tensor-wise quantization (for gate-up proj)
# -> channel-wise quantization
# (e, 1, s_size) =>(repeat)=> (e, 1, size_n)
scales = scales.view(e, 1, s_size)
scales = scales.repeat_interleave(size_n // s_size, 2)
else:
# channel-wise quantization
# (e, 1, size_n)
scales = scales.view(e, 1, size_n)
else:
# block-wise quantization -> group-wise quantization
# (e, size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (e, size_k // block_size[1], size_n)
scales = scales.permute(0, 2, 1)
block_n = weight_block_size[0]
scales = scales.repeat_interleave(block_n, 2)
# size_n may not divisible by block_size[0]
scales = scales[..., :size_n].contiguous()
for i in range(e):
marlin_scales = marlin_permute_scales(
s=scales[i], size_k=size_k, size_n=size_n, group_size=group_size
)
tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
if input_dtype != torch.float8_e4m3fn:
scales = fp8_fused_exponent_bias_into_scales(scales)
return scales
w13_weight_scale = permute_scales(w13_weight_scale, "w13")
w2_weight_scale = permute_scales(w2_weight_scale, "w2")
return (
workspace,
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
)