Skip to content

vllm.model_executor.layers.rotary_embedding.common

logger module-attribute

logger = init_logger(__name__)

ApplyRotaryEmb

Bases: CustomOp

Source code in vllm/model_executor/layers/rotary_embedding/common.py
@CustomOp.register("apply_rotary_emb")
class ApplyRotaryEmb(CustomOp):
    def __init__(
        self,
        enforce_enable: bool = False,
        is_neox_style: bool = True,
        enable_fp32_compute: bool = False,
    ) -> None:
        super().__init__(enforce_enable)
        self.is_neox_style = is_neox_style
        self.enable_fp32_compute = enable_fp32_compute

        self.apply_rotary_emb_flash_attn = None
        if find_spec("flash_attn") is not None:
            from flash_attn.ops.triton.rotary import apply_rotary

            self.apply_rotary_emb_flash_attn = apply_rotary

    @staticmethod
    def forward_static(
        x: torch.Tensor,
        cos: torch.Tensor,
        sin: torch.Tensor,
        is_neox_style: bool = True,
        enable_fp32_compute: bool = False,
    ) -> torch.Tensor:
        """
        Args:
            x: [batch_size (optional), seq_len, num_heads, head_size]
            cos: [seq_len, head_size // 2]
            sin: [seq_len, head_size // 2]
            is_neox_style: Whether to use the Neox-style or GPT-J-style.
            enable_fp32_compute: Temporarily convert x, cos, sin to FP32 dtype
                                 for higher accuracy.
        """
        origin_dtype = x.dtype
        if enable_fp32_compute:
            x = x.float()

        cos = cos.unsqueeze(-2).to(x.dtype)
        sin = sin.unsqueeze(-2).to(x.dtype)

        if is_neox_style:
            x1, x2 = torch.chunk(x, 2, dim=-1)
        else:
            x1 = x[..., ::2]
            x2 = x[..., 1::2]

        o1 = x1 * cos - x2 * sin
        o2 = x2 * cos + x1 * sin

        if is_neox_style:
            output = torch.cat((o1, o2), dim=-1)
        else:
            output = torch.stack((o1, o2), dim=-1).flatten(-2)

        if enable_fp32_compute:
            output = output.to(origin_dtype)
        return output

    def forward_native(
        self,
        x: torch.Tensor,
        cos: torch.Tensor,
        sin: torch.Tensor,
    ) -> torch.Tensor:
        output = self.forward_static(
            x, cos, sin, self.is_neox_style, self.enable_fp32_compute
        )
        return output

    def forward_cuda(
        self,
        x: torch.Tensor,
        cos: torch.Tensor,
        sin: torch.Tensor,
    ) -> torch.Tensor:
        from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb

        origin_dtype = x.dtype
        if self.enable_fp32_compute:
            x = x.float()
            cos = cos.float()
            sin = sin.float()

        origin_shape = x.shape
        if len(origin_shape) == 3:
            # x: [seq_len, num_heads, head_size]
            x = x.unsqueeze(0)

        """
        Arguments of apply_rotary_emb() in vllm_flash_attn:
            x: [batch_size, seq_len, nheads, headdim]
            cos, sin: [seqlen_rotary, rotary_dim / 2]
            interleaved: defalut as False (Neox-style).
            ...
        """
        interleaved = not self.is_neox_style
        output = apply_rotary_emb(x, cos, sin, interleaved)

        if len(origin_shape) == 3:
            output = output.squeeze(0)
        if self.enable_fp32_compute:
            output = output.to(origin_dtype)
        return output

    def forward_hip(
        self,
        x: torch.Tensor,
        cos: torch.Tensor,
        sin: torch.Tensor,
    ) -> torch.Tensor:
        if self.apply_rotary_emb_flash_attn is not None:
            origin_dtype = x.dtype
            if self.enable_fp32_compute:
                x = x.float()
                cos = cos.float()
                sin = sin.float()

            origin_shape = x.shape
            if len(origin_shape) == 3:
                # x: [seq_len, num_heads, head_size]
                x = x.unsqueeze(0)

            """
            Arguments of apply_rotary() in flash_attn:
                x: [batch_size, seq_len, nheads, headdim]
                cos, sin: [seqlen_rotary, rotary_dim / 2]
                interleaved: defalut as False (Neox-style).
                ...
            """
            interleaved = not self.is_neox_style
            output = self.apply_rotary_emb_flash_attn(
                x, cos, sin, interleaved=interleaved
            ).type_as(x)

            if len(origin_shape) == 3:
                output = output.squeeze(0)
            if self.enable_fp32_compute:
                output = output.to(origin_dtype)
        else:
            # Falling back to PyTorch native implementation.
            output = self.forward_native(x, cos, sin)

        return output

    def extra_repr(self) -> str:
        s = f"is_neox_style={self.is_neox_style}"
        s += f"enable_fp32_compute={self.enable_fp32_compute}"
        return s

apply_rotary_emb_flash_attn instance-attribute

apply_rotary_emb_flash_attn = None

enable_fp32_compute instance-attribute

enable_fp32_compute = enable_fp32_compute

is_neox_style instance-attribute

is_neox_style = is_neox_style

__init__

__init__(
    enforce_enable: bool = False,
    is_neox_style: bool = True,
    enable_fp32_compute: bool = False,
) -> None
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def __init__(
    self,
    enforce_enable: bool = False,
    is_neox_style: bool = True,
    enable_fp32_compute: bool = False,
) -> None:
    super().__init__(enforce_enable)
    self.is_neox_style = is_neox_style
    self.enable_fp32_compute = enable_fp32_compute

    self.apply_rotary_emb_flash_attn = None
    if find_spec("flash_attn") is not None:
        from flash_attn.ops.triton.rotary import apply_rotary

        self.apply_rotary_emb_flash_attn = apply_rotary

extra_repr

extra_repr() -> str
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def extra_repr(self) -> str:
    s = f"is_neox_style={self.is_neox_style}"
    s += f"enable_fp32_compute={self.enable_fp32_compute}"
    return s

forward_cuda

forward_cuda(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def forward_cuda(
    self,
    x: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
) -> torch.Tensor:
    from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb

    origin_dtype = x.dtype
    if self.enable_fp32_compute:
        x = x.float()
        cos = cos.float()
        sin = sin.float()

    origin_shape = x.shape
    if len(origin_shape) == 3:
        # x: [seq_len, num_heads, head_size]
        x = x.unsqueeze(0)

    """
    Arguments of apply_rotary_emb() in vllm_flash_attn:
        x: [batch_size, seq_len, nheads, headdim]
        cos, sin: [seqlen_rotary, rotary_dim / 2]
        interleaved: defalut as False (Neox-style).
        ...
    """
    interleaved = not self.is_neox_style
    output = apply_rotary_emb(x, cos, sin, interleaved)

    if len(origin_shape) == 3:
        output = output.squeeze(0)
    if self.enable_fp32_compute:
        output = output.to(origin_dtype)
    return output

forward_hip

forward_hip(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def forward_hip(
    self,
    x: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
) -> torch.Tensor:
    if self.apply_rotary_emb_flash_attn is not None:
        origin_dtype = x.dtype
        if self.enable_fp32_compute:
            x = x.float()
            cos = cos.float()
            sin = sin.float()

        origin_shape = x.shape
        if len(origin_shape) == 3:
            # x: [seq_len, num_heads, head_size]
            x = x.unsqueeze(0)

        """
        Arguments of apply_rotary() in flash_attn:
            x: [batch_size, seq_len, nheads, headdim]
            cos, sin: [seqlen_rotary, rotary_dim / 2]
            interleaved: defalut as False (Neox-style).
            ...
        """
        interleaved = not self.is_neox_style
        output = self.apply_rotary_emb_flash_attn(
            x, cos, sin, interleaved=interleaved
        ).type_as(x)

        if len(origin_shape) == 3:
            output = output.squeeze(0)
        if self.enable_fp32_compute:
            output = output.to(origin_dtype)
    else:
        # Falling back to PyTorch native implementation.
        output = self.forward_native(x, cos, sin)

    return output

forward_native

forward_native(
    x: Tensor, cos: Tensor, sin: Tensor
) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def forward_native(
    self,
    x: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
) -> torch.Tensor:
    output = self.forward_static(
        x, cos, sin, self.is_neox_style, self.enable_fp32_compute
    )
    return output

forward_static staticmethod

forward_static(
    x: Tensor,
    cos: Tensor,
    sin: Tensor,
    is_neox_style: bool = True,
    enable_fp32_compute: bool = False,
) -> Tensor

Parameters:

Name Type Description Default
x Tensor

[batch_size (optional), seq_len, num_heads, head_size]

required
cos Tensor

[seq_len, head_size // 2]

required
sin Tensor

[seq_len, head_size // 2]

required
is_neox_style bool

Whether to use the Neox-style or GPT-J-style.

True
enable_fp32_compute bool

Temporarily convert x, cos, sin to FP32 dtype for higher accuracy.

False
Source code in vllm/model_executor/layers/rotary_embedding/common.py
@staticmethod
def forward_static(
    x: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
    is_neox_style: bool = True,
    enable_fp32_compute: bool = False,
) -> torch.Tensor:
    """
    Args:
        x: [batch_size (optional), seq_len, num_heads, head_size]
        cos: [seq_len, head_size // 2]
        sin: [seq_len, head_size // 2]
        is_neox_style: Whether to use the Neox-style or GPT-J-style.
        enable_fp32_compute: Temporarily convert x, cos, sin to FP32 dtype
                             for higher accuracy.
    """
    origin_dtype = x.dtype
    if enable_fp32_compute:
        x = x.float()

    cos = cos.unsqueeze(-2).to(x.dtype)
    sin = sin.unsqueeze(-2).to(x.dtype)

    if is_neox_style:
        x1, x2 = torch.chunk(x, 2, dim=-1)
    else:
        x1 = x[..., ::2]
        x2 = x[..., 1::2]

    o1 = x1 * cos - x2 * sin
    o2 = x2 * cos + x1 * sin

    if is_neox_style:
        output = torch.cat((o1, o2), dim=-1)
    else:
        output = torch.stack((o1, o2), dim=-1).flatten(-2)

    if enable_fp32_compute:
        output = output.to(origin_dtype)
    return output

_flashinfer_rotary_embedding

_flashinfer_rotary_embedding(
    positions: Tensor,
    query: Tensor,
    key: Tensor,
    head_size: int,
    cos_sin_cache: Tensor,
    is_neox: bool,
) -> None

Custom op wrapper for flashinfer's rotary embedding.

This is an in-place operation that modifies query and key tensors directly.

Source code in vllm/model_executor/layers/rotary_embedding/common.py
def _flashinfer_rotary_embedding(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    head_size: int,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
) -> None:
    """Custom op wrapper for flashinfer's rotary embedding.

    This is an in-place operation that modifies query and key tensors directly.
    """
    from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace

    apply_rope_with_cos_sin_cache_inplace(
        positions=positions,
        query=query,
        key=key,
        head_size=head_size,
        cos_sin_cache=cos_sin_cache,
        is_neox=is_neox,
    )

_flashinfer_rotary_embedding_fake

_flashinfer_rotary_embedding_fake(
    positions: Tensor,
    query: Tensor,
    key: Tensor,
    head_size: int,
    cos_sin_cache: Tensor,
    is_neox: bool,
) -> None
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def _flashinfer_rotary_embedding_fake(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    head_size: int,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
) -> None:
    return

rotate_gptj

rotate_gptj(x: Tensor) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    x = torch.stack((-x2, x1), dim=-1)
    return x.flatten(-2)

rotate_neox

rotate_neox(x: Tensor) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

yarn_find_correction_dim

yarn_find_correction_dim(
    num_rotations: int,
    dim: int,
    base: float = 10000,
    max_position_embeddings: int = 2048,
) -> float
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def yarn_find_correction_dim(
    num_rotations: int,
    dim: int,
    base: float = 10000,
    max_position_embeddings: int = 2048,
) -> float:
    return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
        2 * math.log(base)
    )

yarn_find_correction_range

yarn_find_correction_range(
    low_rot: int,
    high_rot: int,
    dim: int,
    base: float = 10000,
    max_position_embeddings: int = 2048,
    truncate: bool = True,
) -> tuple[float | int, float | int]
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def yarn_find_correction_range(
    low_rot: int,
    high_rot: int,
    dim: int,
    base: float = 10000,
    max_position_embeddings: int = 2048,
    truncate: bool = True,
) -> tuple[float | int, float | int]:
    low = yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
    high = yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
    if truncate:
        low = math.floor(low)
        high = math.ceil(high)
    return max(low, 0), min(high, dim - 1)  # Clamp values just in case

yarn_get_mscale

yarn_get_mscale(scale: float = 1) -> float
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def yarn_get_mscale(scale: float = 1) -> float:
    if scale <= 1:
        return 1.0
    return 0.1 * math.log(scale) + 1.0

yarn_linear_ramp_mask

yarn_linear_ramp_mask(
    low: float, high: float, dim: int, dtype: dtype
) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def yarn_linear_ramp_mask(
    low: float, high: float, dim: int, dtype: torch.dtype
) -> torch.Tensor:
    if low == high:
        high += 0.001  # Prevent singularity

    linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
    ramp_func = torch.clamp(linear_func, 0, 1)
    return ramp_func