Skip to content

vllm.distributed.device_communicators.symm_mem

logger module-attribute

logger = init_logger(__name__)

symm_mem_available module-attribute

symm_mem_available = True

SymmMemCommunicator

Source code in vllm/distributed/device_communicators/symm_mem.py
class SymmMemCommunicator:
    _WORLD_SIZES_MULTIMEM = {
        "9.0": [4, 6, 8],
        "10.0": [6, 8],
    }

    def __init__(self, group: ProcessGroup, device: Union[int, str,
                                                          torch.device]):
        self.disabled = True

        if not symm_mem_available:
            return

        if not current_platform.is_cuda():
            logger.warning("SymmMemCommunicator: symmetric "
                           "memory is not available.")
            return
        if isinstance(device, int):
            device = torch.device(f"cuda:{device}")
        elif isinstance(device, str):
            device = torch.device(device)
        torch.cuda.set_device(device)
        self.dtype = torch.bfloat16
        self.device = device
        self.group = group
        self.world_size = dist.get_world_size(self.group)
        self.device_capability = current_platform.get_device_capability(
        ).as_version_str()
        if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
            logger.warning(
                "SymmMemCommunicator: Device capability %s not supported, "
                "communicator is not available.",
                self.device_capability,
            )
            return
        if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[
                self.device_capability]:
            logger.warning(
                "SymmMemCommunicator: World size %d not supported, "
                "communicator is not available.",
                self.world_size,
            )
            return
        self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
            self.world_size]
        self.buffer = torch_symm_mem.empty(
            self.max_size // self.dtype.itemsize,
            device=self.device,
            dtype=self.dtype,
        )
        handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
        if handle.multicast_ptr == 0:
            logger.warning("SymmMemCommunicator: symmetric memory "
                           "multicast operations are not supported.")
            return
        self.disabled = False

    def should_use_symm_mem(self, inp: torch.Tensor):
        if self.disabled:
            return False
        if inp.dtype != self.dtype:
            return False
        inp_size = inp.numel() * inp.element_size()
        if inp_size % 4 != 0:
            return False
        return inp_size < self.max_size

    def all_reduce(
            self,
            inp: torch.Tensor,
            *,
            out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
        if not self.should_use_symm_mem(inp):
            return None
        if out is None:
            out = torch.empty_like(inp)
        self.buffer[:inp.numel()].copy_(inp.view(-1))
        if self.world_size in self._WORLD_SIZES_MULTIMEM[
                self.device_capability]:
            torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
                                                    "sum",
                                                    self.group.group_name)
        else:
            torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
                                                    "sum",
                                                    self.group.group_name)
        out.copy_(self.buffer[:inp.numel()].view(out.shape))
        return out

_WORLD_SIZES_MULTIMEM class-attribute instance-attribute

_WORLD_SIZES_MULTIMEM = {'9.0': [4, 6, 8], '10.0': [6, 8]}

buffer instance-attribute

buffer = empty(
    max_size // itemsize, device=device, dtype=dtype
)

device instance-attribute

device = device

device_capability instance-attribute

device_capability = as_version_str()

disabled instance-attribute

disabled = False

dtype instance-attribute

dtype = bfloat16

group instance-attribute

group = group

max_size instance-attribute

max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[device_capability][
    world_size
]

world_size instance-attribute

world_size = get_world_size(group)

__init__

__init__(
    group: ProcessGroup, device: Union[int, str, device]
)
Source code in vllm/distributed/device_communicators/symm_mem.py
def __init__(self, group: ProcessGroup, device: Union[int, str,
                                                      torch.device]):
    self.disabled = True

    if not symm_mem_available:
        return

    if not current_platform.is_cuda():
        logger.warning("SymmMemCommunicator: symmetric "
                       "memory is not available.")
        return
    if isinstance(device, int):
        device = torch.device(f"cuda:{device}")
    elif isinstance(device, str):
        device = torch.device(device)
    torch.cuda.set_device(device)
    self.dtype = torch.bfloat16
    self.device = device
    self.group = group
    self.world_size = dist.get_world_size(self.group)
    self.device_capability = current_platform.get_device_capability(
    ).as_version_str()
    if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
        logger.warning(
            "SymmMemCommunicator: Device capability %s not supported, "
            "communicator is not available.",
            self.device_capability,
        )
        return
    if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[
            self.device_capability]:
        logger.warning(
            "SymmMemCommunicator: World size %d not supported, "
            "communicator is not available.",
            self.world_size,
        )
        return
    self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
        self.world_size]
    self.buffer = torch_symm_mem.empty(
        self.max_size // self.dtype.itemsize,
        device=self.device,
        dtype=self.dtype,
    )
    handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
    if handle.multicast_ptr == 0:
        logger.warning("SymmMemCommunicator: symmetric memory "
                       "multicast operations are not supported.")
        return
    self.disabled = False

all_reduce

all_reduce(
    inp: Tensor, *, out: Optional[Tensor] = None
) -> Optional[Tensor]
Source code in vllm/distributed/device_communicators/symm_mem.py
def all_reduce(
        self,
        inp: torch.Tensor,
        *,
        out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
    if not self.should_use_symm_mem(inp):
        return None
    if out is None:
        out = torch.empty_like(inp)
    self.buffer[:inp.numel()].copy_(inp.view(-1))
    if self.world_size in self._WORLD_SIZES_MULTIMEM[
            self.device_capability]:
        torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
                                                "sum",
                                                self.group.group_name)
    else:
        torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
                                                "sum",
                                                self.group.group_name)
    out.copy_(self.buffer[:inp.numel()].view(out.shape))
    return out

should_use_symm_mem

should_use_symm_mem(inp: Tensor)
Source code in vllm/distributed/device_communicators/symm_mem.py
def should_use_symm_mem(self, inp: torch.Tensor):
    if self.disabled:
        return False
    if inp.dtype != self.dtype:
        return False
    inp_size = inp.numel() * inp.element_size()
    if inp_size % 4 != 0:
        return False
    return inp_size < self.max_size