Skip to content

vllm.v1.spec_decode.metadata

SpecDecodeMetadata dataclass

Source code in vllm/v1/spec_decode/metadata.py
@dataclass
class SpecDecodeMetadata:
    # [num_tokens]
    draft_token_ids: torch.Tensor
    # [batch_size]
    num_draft_tokens: list[int]
    # [batch_size]
    cu_num_draft_tokens: torch.Tensor
    # [batch_size]
    cu_num_sampled_tokens: torch.Tensor
    # [num_tokens]
    target_logits_indices: torch.Tensor
    # [batch_size]
    bonus_logits_indices: torch.Tensor
    # [num_tokens + batch_size]
    logits_indices: torch.Tensor

    def __post_init__(self):
        self.max_spec_len = max(self.num_draft_tokens)

    @classmethod
    def make_dummy(
        cls,
        draft_token_ids: list[list[int]],
        device: torch.device,
    ) -> "SpecDecodeMetadata":
        batch_size = len(draft_token_ids)
        num_draft_tokens = [len(ids) for ids in draft_token_ids]
        num_sampled_tokens = [len(ids) + 1 for ids in draft_token_ids]
        flattened_draft_token_ids = sum(draft_token_ids, [])
        num_tokens = len(flattened_draft_token_ids)

        draft_token_ids_tensor = torch.tensor(
            flattened_draft_token_ids, dtype=torch.int32, device=device
        )
        cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
        cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(device)
        cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
        cu_num_sampled_tokens_tensor = torch.from_numpy(cu_num_sampled_tokens).to(
            device
        )

        target_logits_indices = torch.zeros(
            num_tokens, dtype=torch.int32, device=device
        )
        bonus_logits_indices = torch.zeros(batch_size, dtype=torch.int32, device=device)
        logits_indices = torch.zeros(
            num_tokens + batch_size, dtype=torch.int32, device=device
        )
        return cls(
            draft_token_ids=draft_token_ids_tensor,
            num_draft_tokens=num_draft_tokens,
            cu_num_draft_tokens=cu_num_draft_tokens_tensor,
            cu_num_sampled_tokens=cu_num_sampled_tokens_tensor,
            target_logits_indices=target_logits_indices,
            bonus_logits_indices=bonus_logits_indices,
            logits_indices=logits_indices,
        )

bonus_logits_indices instance-attribute

bonus_logits_indices: Tensor

cu_num_draft_tokens instance-attribute

cu_num_draft_tokens: Tensor

cu_num_sampled_tokens instance-attribute

cu_num_sampled_tokens: Tensor

draft_token_ids instance-attribute

draft_token_ids: Tensor

logits_indices instance-attribute

logits_indices: Tensor

num_draft_tokens instance-attribute

num_draft_tokens: list[int]

target_logits_indices instance-attribute

target_logits_indices: Tensor

__init__

__init__(
    draft_token_ids: Tensor,
    num_draft_tokens: list[int],
    cu_num_draft_tokens: Tensor,
    cu_num_sampled_tokens: Tensor,
    target_logits_indices: Tensor,
    bonus_logits_indices: Tensor,
    logits_indices: Tensor,
) -> None

__post_init__

__post_init__()
Source code in vllm/v1/spec_decode/metadata.py
def __post_init__(self):
    self.max_spec_len = max(self.num_draft_tokens)

make_dummy classmethod

make_dummy(
    draft_token_ids: list[list[int]], device: device
) -> SpecDecodeMetadata
Source code in vllm/v1/spec_decode/metadata.py
@classmethod
def make_dummy(
    cls,
    draft_token_ids: list[list[int]],
    device: torch.device,
) -> "SpecDecodeMetadata":
    batch_size = len(draft_token_ids)
    num_draft_tokens = [len(ids) for ids in draft_token_ids]
    num_sampled_tokens = [len(ids) + 1 for ids in draft_token_ids]
    flattened_draft_token_ids = sum(draft_token_ids, [])
    num_tokens = len(flattened_draft_token_ids)

    draft_token_ids_tensor = torch.tensor(
        flattened_draft_token_ids, dtype=torch.int32, device=device
    )
    cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
    cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(device)
    cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
    cu_num_sampled_tokens_tensor = torch.from_numpy(cu_num_sampled_tokens).to(
        device
    )

    target_logits_indices = torch.zeros(
        num_tokens, dtype=torch.int32, device=device
    )
    bonus_logits_indices = torch.zeros(batch_size, dtype=torch.int32, device=device)
    logits_indices = torch.zeros(
        num_tokens + batch_size, dtype=torch.int32, device=device
    )
    return cls(
        draft_token_ids=draft_token_ids_tensor,
        num_draft_tokens=num_draft_tokens,
        cu_num_draft_tokens=cu_num_draft_tokens_tensor,
        cu_num_sampled_tokens=cu_num_sampled_tokens_tensor,
        target_logits_indices=target_logits_indices,
        bonus_logits_indices=bonus_logits_indices,
        logits_indices=logits_indices,
    )