Define KV connector functionality mixin for model runners.
KVConnectorModelRunnerMixin
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
| class KVConnectorModelRunnerMixin:
@staticmethod
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)
# Background KV cache transfers happen here.
# These transfers are designed to be async and the requests
# involved may be disjoint from the running requests.
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())
@staticmethod
def maybe_wait_for_kv_save() -> None:
if has_kv_transfer_group():
get_kv_transfer_group().wait_for_save()
@staticmethod
def get_finished_kv_transfers(
scheduler_output: "SchedulerOutput",
) -> tuple[Optional[set[str]], Optional[set[str]]]:
if has_kv_transfer_group():
return get_kv_transfer_group().get_finished(
scheduler_output.finished_req_ids)
return None, None
@staticmethod
def kv_connector_no_forward(scheduler_output: "SchedulerOutput",
vllm_config: VllmConfig) -> ModelRunnerOutput:
# KV send/recv even if no work to do.
with set_forward_context(
None, vllm_config
), KVConnectorModelRunnerMixin._get_kv_connector_output(
scheduler_output, wait_for_save=False) as kv_connector_output:
pass
if (not kv_connector_output.finished_sending
and not kv_connector_output.finished_recving):
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
return output
@staticmethod
def maybe_get_kv_connector_output(
scheduler_output: "SchedulerOutput"
) -> AbstractContextManager[Optional[KVConnectorOutput]]:
return KVConnectorModelRunnerMixin._get_kv_connector_output(
scheduler_output) if has_kv_transfer_group() else nullcontext()
# This context manager must be used within an active forward context.
# It encapsulates the entire KV conector lifecycle within execute_model
@staticmethod
@contextmanager
def _get_kv_connector_output(
scheduler_output: "SchedulerOutput",
wait_for_save: bool = True
) -> Generator[KVConnectorOutput, None, None]:
output = KVConnectorOutput()
# Update KVConnector with the KVConnector metadata forward().
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)
# Background KV cache transfers happen here.
# These transfers are designed to be async and the requests
# involved may be disjoint from the running requests.
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())
try:
yield output
finally:
if wait_for_save:
kv_connector.wait_for_save()
output.finished_sending, output.finished_recving = (
kv_connector.get_finished(scheduler_output.finished_req_ids))
kv_connector.clear_connector_metadata()
|
_get_kv_connector_output staticmethod
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
| @staticmethod
@contextmanager
def _get_kv_connector_output(
scheduler_output: "SchedulerOutput",
wait_for_save: bool = True
) -> Generator[KVConnectorOutput, None, None]:
output = KVConnectorOutput()
# Update KVConnector with the KVConnector metadata forward().
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)
# Background KV cache transfers happen here.
# These transfers are designed to be async and the requests
# involved may be disjoint from the running requests.
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())
try:
yield output
finally:
if wait_for_save:
kv_connector.wait_for_save()
output.finished_sending, output.finished_recving = (
kv_connector.get_finished(scheduler_output.finished_req_ids))
kv_connector.clear_connector_metadata()
|
get_finished_kv_transfers staticmethod
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
| @staticmethod
def get_finished_kv_transfers(
scheduler_output: "SchedulerOutput",
) -> tuple[Optional[set[str]], Optional[set[str]]]:
if has_kv_transfer_group():
return get_kv_transfer_group().get_finished(
scheduler_output.finished_req_ids)
return None, None
|
kv_connector_no_forward staticmethod
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
| @staticmethod
def kv_connector_no_forward(scheduler_output: "SchedulerOutput",
vllm_config: VllmConfig) -> ModelRunnerOutput:
# KV send/recv even if no work to do.
with set_forward_context(
None, vllm_config
), KVConnectorModelRunnerMixin._get_kv_connector_output(
scheduler_output, wait_for_save=False) as kv_connector_output:
pass
if (not kv_connector_output.finished_sending
and not kv_connector_output.finished_recving):
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
return output
|
maybe_get_kv_connector_output staticmethod
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
| @staticmethod
def maybe_get_kv_connector_output(
scheduler_output: "SchedulerOutput"
) -> AbstractContextManager[Optional[KVConnectorOutput]]:
return KVConnectorModelRunnerMixin._get_kv_connector_output(
scheduler_output) if has_kv_transfer_group() else nullcontext()
|
maybe_setup_kv_connector staticmethod
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
| @staticmethod
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)
# Background KV cache transfers happen here.
# These transfers are designed to be async and the requests
# involved may be disjoint from the running requests.
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())
|
maybe_wait_for_kv_save staticmethod
maybe_wait_for_kv_save() -> None
Source code in vllm/v1/worker/kv_connector_model_runner_mixin.py
| @staticmethod
def maybe_wait_for_kv_save() -> None:
if has_kv_transfer_group():
get_kv_transfer_group().wait_for_save()
|