This file contains ops for ViT attention to be compatible with torch.compile as there are operations here not supported by torch.compile (for instance, .item() in flash attention)
Using these ops and wrapping vision blocks with torch.compile can speed up throughput in vision models by ~5% relative on H100, and improve token latencies by ~7% (see qwen2_5_vl for example usage)
To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0)
apply_sdpa
Input shape: (batch_size x seq_len x num_heads x head_size)
Source code in vllm/attention/ops/vit_attn_wrappers.py
| def apply_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""
Input shape:
(batch_size x seq_len x num_heads x head_size)
"""
q, k, v = (einops.rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0)
output = einops.rearrange(output, "b h s d -> b s h d ")
return output
|
flash_attn_maxseqlen_wrapper
Source code in vllm/attention/ops/vit_attn_wrappers.py
| def flash_attn_maxseqlen_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
fa_version: int | None,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor:
kwargs = {}
if is_rocm_aiter:
from aiter import flash_attn_varlen_func
else:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
if not current_platform.is_rocm() and fa_version is not None:
kwargs["fa_version"] = fa_version
q_len = q.size(1)
if cu_seqlens is None:
cu_seqlens = torch.arange(
0, (batch_size + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device
)
max_seqlen = q_len if max_seqlen is None else max_seqlen.item()
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False,
**kwargs,
)
context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
return context_layer
|
flash_attn_maxseqlen_wrapper_fake
Source code in vllm/attention/ops/vit_attn_wrappers.py
| def flash_attn_maxseqlen_wrapper_fake(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
fa_version: int | None,
) -> torch.Tensor:
return torch.empty_like(q)
|
torch_sdpa_wrapper
Source code in vllm/attention/ops/vit_attn_wrappers.py
| def torch_sdpa_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
# Never remove the contiguous logic for ROCm
# Without it, hallucinations occur with the backend
if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
if cu_seqlens is None:
return apply_sdpa(q, k, v)
outputs = []
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
q_chunks = torch.split(q, lens, dim=1)
k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
output_i = apply_sdpa(q_i, k_i, v_i)
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
return context_layer
|
torch_sdpa_wrapper_fake
Source code in vllm/attention/ops/vit_attn_wrappers.py
| def torch_sdpa_wrapper_fake(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
return torch.empty_like(q)
|
vit_flash_attn_wrapper
Source code in vllm/attention/ops/vit_attn_wrappers.py
| def vit_flash_attn_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
batch_size: int,
is_rocm_aiter: bool,
fa_version: int | None,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
q,
k,
v,
batch_size,
is_rocm_aiter,
fa_version,
cu_seqlens,
max_seqlen,
)
|
vit_torch_sdpa_wrapper
Source code in vllm/attention/ops/vit_attn_wrappers.py
| def vit_torch_sdpa_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens)
|