vllm.distributed.tpu_distributed_utils
MODULE_TYPE_TO_WRAPPING_FUNC module-attribute
¶
MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict(
[
(
"QKVParallelLinear",
partition_qkv_parallel_linear,
),
(
"ColumnParallelLinear",
partition_column_parallel_linear,
),
(
"RowParallelLinear",
partition_row_parallel_linear,
),
]
)
XlaQKVParallelLinear ¶
Bases: Module
Source code in vllm/distributed/tpu_distributed_utils.py
__init__ ¶
Source code in vllm/distributed/tpu_distributed_utils.py
_load_weights_from_qkv_linear ¶
_load_weights_from_qkv_linear(qkv_linear: Module)
Source code in vllm/distributed/tpu_distributed_utils.py
_shard_weight ¶
Source code in vllm/distributed/tpu_distributed_utils.py
forward ¶
Source code in vllm/distributed/tpu_distributed_utils.py
get_fqn ¶
partition_column_parallel_linear ¶
Source code in vllm/distributed/tpu_distributed_utils.py
partition_qkv_parallel_linear ¶
Source code in vllm/distributed/tpu_distributed_utils.py
partition_row_parallel_linear ¶
Source code in vllm/distributed/tpu_distributed_utils.py
shard_model ¶
shard_model(model: Module, mesh: Mesh) -> None
Recursively check a PyTorch model and apply appropriate sharding based on the MODULE_TYPE_TO_WRAPPING_FUNC mapping.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model | Module | torch.nn.Module to process | required |
mesh | Mesh | An XLA SPMD mesh object used for sharding | required |