Skip to content

vllm.distributed.kv_transfer.kv_connector.utils

KV cache helper for store.

logger module-attribute

logger = init_logger(__name__)

KVOutputAggregator

Utility class to aggregate the output of all workers into a single output corresponding to Rank 0 for scheduler.

Source code in vllm/distributed/kv_transfer/kv_connector/utils.py
class KVOutputAggregator:
    """Utility class to aggregate the output of all workers into a single 
    output corresponding to Rank 0 for scheduler."""

    def __init__(self, world_size: int):
        # Complete transfer tracker. Used to track finished requests
        # [req_id -> n_remaining_workers]
        self._recv_remaining_count = defaultdict[str, int](lambda: world_size)
        self._send_remaining_count = defaultdict[str, int](lambda: world_size)

    def aggregate(self,
                  outputs: list[ModelRunnerOutput],
                  output_rank: int = 0) -> ModelRunnerOutput:
        # aggregate kv_connector_output from all workers

        def update_finished_set(req_ids: Optional[set[str]],
                                remaining_count_dict: dict[str, int],
                                finished_set: set[str]) -> None:
            for req_id in req_ids or ():
                remaining_count_dict[req_id] -= 1
                if remaining_count_dict[req_id] == 0:
                    finished_set.add(req_id)
                    del remaining_count_dict[req_id]

        finished_sending = set[str]()
        finished_recving = set[str]()
        for output in outputs:
            output = output.kv_connector_output
            if not output:
                continue
            update_finished_set(output.finished_sending,
                                self._send_remaining_count, finished_sending)
            update_finished_set(output.finished_recving,
                                self._recv_remaining_count, finished_recving)

        # select output of the worker specified by output_rank
        output = outputs[output_rank]

        output.kv_connector_output = KVConnectorOutput(
            finished_sending=finished_sending or None,
            finished_recving=finished_recving or None,
        )

        return output

    def async_aggregate(self,
                        output_futures: Sequence[Future[ModelRunnerOutput]],
                        output_rank: int = 0) -> Future[ModelRunnerOutput]:
        """Takes a list of futures and returns a single future which resolves
        to the respective list of outputs."""
        result_future: Future[ModelRunnerOutput] = Future()

        outputs: list[Optional[ModelRunnerOutput]] = [None
                                                      ] * len(output_futures)

        def make_callback(idx):

            def callback(fut):
                if result_future.done():
                    return

                try:
                    outputs[idx] = fut.result()
                except CancelledError:
                    result_future.cancel()
                except Exception as e:
                    result_future.set_exception(e)

                # this check assumes io_thread_pool uses a single thread
                if all(outputs):
                    result_future.set_result(
                        self.aggregate(cast(list[ModelRunnerOutput], outputs),
                                       output_rank))

            return callback

        for i, output_future in enumerate(output_futures):
            output_future.add_done_callback(make_callback(i))

        return result_future

_recv_remaining_count instance-attribute

_recv_remaining_count = defaultdict[str, int](
    lambda: world_size
)

_send_remaining_count instance-attribute

_send_remaining_count = defaultdict[str, int](
    lambda: world_size
)

__init__

__init__(world_size: int)
Source code in vllm/distributed/kv_transfer/kv_connector/utils.py
def __init__(self, world_size: int):
    # Complete transfer tracker. Used to track finished requests
    # [req_id -> n_remaining_workers]
    self._recv_remaining_count = defaultdict[str, int](lambda: world_size)
    self._send_remaining_count = defaultdict[str, int](lambda: world_size)

aggregate

aggregate(
    outputs: list[ModelRunnerOutput], output_rank: int = 0
) -> ModelRunnerOutput
Source code in vllm/distributed/kv_transfer/kv_connector/utils.py
def aggregate(self,
              outputs: list[ModelRunnerOutput],
              output_rank: int = 0) -> ModelRunnerOutput:
    # aggregate kv_connector_output from all workers

    def update_finished_set(req_ids: Optional[set[str]],
                            remaining_count_dict: dict[str, int],
                            finished_set: set[str]) -> None:
        for req_id in req_ids or ():
            remaining_count_dict[req_id] -= 1
            if remaining_count_dict[req_id] == 0:
                finished_set.add(req_id)
                del remaining_count_dict[req_id]

    finished_sending = set[str]()
    finished_recving = set[str]()
    for output in outputs:
        output = output.kv_connector_output
        if not output:
            continue
        update_finished_set(output.finished_sending,
                            self._send_remaining_count, finished_sending)
        update_finished_set(output.finished_recving,
                            self._recv_remaining_count, finished_recving)

    # select output of the worker specified by output_rank
    output = outputs[output_rank]

    output.kv_connector_output = KVConnectorOutput(
        finished_sending=finished_sending or None,
        finished_recving=finished_recving or None,
    )

    return output

async_aggregate

async_aggregate(
    output_futures: Sequence[Future[ModelRunnerOutput]],
    output_rank: int = 0,
) -> Future[ModelRunnerOutput]

Takes a list of futures and returns a single future which resolves to the respective list of outputs.

Source code in vllm/distributed/kv_transfer/kv_connector/utils.py
def async_aggregate(self,
                    output_futures: Sequence[Future[ModelRunnerOutput]],
                    output_rank: int = 0) -> Future[ModelRunnerOutput]:
    """Takes a list of futures and returns a single future which resolves
    to the respective list of outputs."""
    result_future: Future[ModelRunnerOutput] = Future()

    outputs: list[Optional[ModelRunnerOutput]] = [None
                                                  ] * len(output_futures)

    def make_callback(idx):

        def callback(fut):
            if result_future.done():
                return

            try:
                outputs[idx] = fut.result()
            except CancelledError:
                result_future.cancel()
            except Exception as e:
                result_future.set_exception(e)

            # this check assumes io_thread_pool uses a single thread
            if all(outputs):
                result_future.set_result(
                    self.aggregate(cast(list[ModelRunnerOutput], outputs),
                                   output_rank))

        return callback

    for i, output_future in enumerate(output_futures):
        output_future.add_done_callback(make_callback(i))

    return result_future

model_aware_kv_ops_helper

Source code in vllm/distributed/kv_transfer/kv_connector/utils.py
class model_aware_kv_ops_helper:

    def __init__(self, config: VllmConfig):
        self.is_deepseek_mla = config.model_config.is_deepseek_mla
        self.use_mla_opt = not envs.VLLM_MLA_DISABLE
        self.tp_size = config.parallel_config.tensor_parallel_size

    def get_model_args(self, model_executable: torch.nn.Module):

        model_config = model_executable.model.config
        self.model_executable = model_executable
        num_heads = int(model_config.num_key_value_heads / self.tp_size)
        hidden_size = model_config.hidden_size
        num_attention_heads = model_config.num_attention_heads

        # Deepseek's MLA (Multi-head Latent Attention) uses two different
        # kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
        # When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
        # resulting in a kv_cache shape of [num_blks, blk_size, 1,
        # kv_lora_rank + qk_rope_head_dim].
        # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
        # to a kv_cache shape of [2, num_blks, blk_size,
        # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
        # For more details, see vllm/attention/backends/mla/common.py.
        if self.is_deepseek_mla and self.use_mla_opt:
            head_size = model_config.kv_lora_rank + \
                model_config.qk_rope_head_dim
            num_heads = 1
        elif self.is_deepseek_mla and not self.use_mla_opt:
            head_size = model_config.qk_nope_head_dim + \
                model_config.qk_rope_head_dim
        else:
            head_size = getattr(model_config, "head_dim", None)
            if head_size is None:
                head_size = int(hidden_size // num_attention_heads)

        return num_heads, head_size

    def get_kv_from_cache(self, kv_cache, num_heads, head_size):
        if self.is_deepseek_mla and self.use_mla_opt:
            key_cache = kv_cache.reshape(-1, num_heads, head_size)
            value_cache = kv_cache.reshape(-1, num_heads, head_size)
        else:
            key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
            value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
        return key_cache, value_cache

    def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values,
                        layer, kv_cache, slot_mapping, start_pos, end_pos):

        model_config = model_executable.model.config

        if self.is_deepseek_mla and self.use_mla_opt:
            layer.self_attn.attn = layer.self_attn.mla_attn
            k_c_normed_k_pe = keys.squeeze(1)
            k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank]
            k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:]
            ops.concat_and_cache_mla(
                k_c_normed.to(kv_cache.device),
                k_pe.to(kv_cache.device),
                kv_cache,
                slot_mapping[start_pos:end_pos],
                layer.self_attn.attn.kv_cache_dtype,
                layer.self_attn.attn._k_scale,
            )
        else:
            key_cache, value_cache = kv_cache[0], kv_cache[1]
            ops.reshape_and_cache_flash(
                keys.to(key_cache.device),
                values.to(value_cache.device),
                key_cache,
                value_cache,
                slot_mapping[start_pos:end_pos],
                layer.self_attn.attn.kv_cache_dtype,
                layer.self_attn.attn._k_scale,
                layer.self_attn.attn._v_scale,
            )

is_deepseek_mla instance-attribute

is_deepseek_mla = is_deepseek_mla

tp_size instance-attribute

tp_size = tensor_parallel_size

use_mla_opt instance-attribute

use_mla_opt = not VLLM_MLA_DISABLE

__init__

__init__(config: VllmConfig)
Source code in vllm/distributed/kv_transfer/kv_connector/utils.py
def __init__(self, config: VllmConfig):
    self.is_deepseek_mla = config.model_config.is_deepseek_mla
    self.use_mla_opt = not envs.VLLM_MLA_DISABLE
    self.tp_size = config.parallel_config.tensor_parallel_size

get_kv_from_cache

get_kv_from_cache(kv_cache, num_heads, head_size)
Source code in vllm/distributed/kv_transfer/kv_connector/utils.py
def get_kv_from_cache(self, kv_cache, num_heads, head_size):
    if self.is_deepseek_mla and self.use_mla_opt:
        key_cache = kv_cache.reshape(-1, num_heads, head_size)
        value_cache = kv_cache.reshape(-1, num_heads, head_size)
    else:
        key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
        value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
    return key_cache, value_cache

get_model_args

get_model_args(model_executable: Module)
Source code in vllm/distributed/kv_transfer/kv_connector/utils.py
def get_model_args(self, model_executable: torch.nn.Module):

    model_config = model_executable.model.config
    self.model_executable = model_executable
    num_heads = int(model_config.num_key_value_heads / self.tp_size)
    hidden_size = model_config.hidden_size
    num_attention_heads = model_config.num_attention_heads

    # Deepseek's MLA (Multi-head Latent Attention) uses two different
    # kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
    # When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
    # resulting in a kv_cache shape of [num_blks, blk_size, 1,
    # kv_lora_rank + qk_rope_head_dim].
    # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
    # to a kv_cache shape of [2, num_blks, blk_size,
    # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
    # For more details, see vllm/attention/backends/mla/common.py.
    if self.is_deepseek_mla and self.use_mla_opt:
        head_size = model_config.kv_lora_rank + \
            model_config.qk_rope_head_dim
        num_heads = 1
    elif self.is_deepseek_mla and not self.use_mla_opt:
        head_size = model_config.qk_nope_head_dim + \
            model_config.qk_rope_head_dim
    else:
        head_size = getattr(model_config, "head_dim", None)
        if head_size is None:
            head_size = int(hidden_size // num_attention_heads)

    return num_heads, head_size

put_kv_to_cache

put_kv_to_cache(
    model_executable: Module,
    keys,
    values,
    layer,
    kv_cache,
    slot_mapping,
    start_pos,
    end_pos,
)
Source code in vllm/distributed/kv_transfer/kv_connector/utils.py
def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values,
                    layer, kv_cache, slot_mapping, start_pos, end_pos):

    model_config = model_executable.model.config

    if self.is_deepseek_mla and self.use_mla_opt:
        layer.self_attn.attn = layer.self_attn.mla_attn
        k_c_normed_k_pe = keys.squeeze(1)
        k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank]
        k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:]
        ops.concat_and_cache_mla(
            k_c_normed.to(kv_cache.device),
            k_pe.to(kv_cache.device),
            kv_cache,
            slot_mapping[start_pos:end_pos],
            layer.self_attn.attn.kv_cache_dtype,
            layer.self_attn.attn._k_scale,
        )
    else:
        key_cache, value_cache = kv_cache[0], kv_cache[1]
        ops.reshape_and_cache_flash(
            keys.to(key_cache.device),
            values.to(value_cache.device),
            key_cache,
            value_cache,
            slot_mapping[start_pos:end_pos],
            layer.self_attn.attn.kv_cache_dtype,
            layer.self_attn.attn._k_scale,
            layer.self_attn.attn._v_scale,
        )

get_kv_connector_cache_layout

get_kv_connector_cache_layout()
Source code in vllm/distributed/kv_transfer/kv_connector/utils.py
def get_kv_connector_cache_layout():
    # NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
    # used for faster transfer.
    vllm_config = get_current_vllm_config()
    kv_config = vllm_config.kv_transfer_config
    if kv_config is not None:
        connector_cls = KVConnectorFactory.get_connector_class(kv_config)
        required_kvcache_layout = connector_cls.get_required_kvcache_layout(
            vllm_config)
        if required_kvcache_layout is not None:
            return required_kvcache_layout
        logger.info_once("Connectors do not specify a " \
                         "kv cache layout, defaulting to NHD.")
    return "NHD"