Skip to content

vllm.distributed.eplb

Expert parallelism load balancer (EPLB).

Modules:

Name Description
eplb_state

Expert parallelism load balancer (EPLB) metrics and states.

rebalance_algo

Expert parallelism load balancer (EPLB) for vLLM.

rebalance_execute

The actual execution of the rearrangement.

logger module-attribute

logger = init_logger(__name__)

EplbState dataclass

EPLB metrics.

Source code in vllm/distributed/eplb/eplb_state.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
@dataclass
class EplbState:
    """EPLB metrics."""

    physical_to_logical_map: torch.Tensor
    """
    Mapping from physical experts to logical experts.

    Shape: (num_moe_layers, num_physical_experts)

    # Example

    For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
    EP ranks, the mapping could look like this:

    ```
    [[0, 1, 2, 3, 0, 1],
     [0, 2, 0, 1, 0, 3]]
    ```
    """
    logical_to_physical_map: torch.Tensor
    """
    Mapping from logical experts to physical experts.

    This is a sparse matrix, where -1 indicates no mapping.

    Shape: (num_moe_layers, num_logical_experts, num_redundant_experts + 1)

    # Example

    For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
    EP ranks, the mapping could look like this:

    ```
    [[[0, 4, -1],
      [1, 5, -1],
      [2, -1, -1],
      [3, -1, -1]],
     [[0, 2, 4],
      [3, -1, -1],
      [1, -1, -1],
      [5, -1, -1]]]
    ```
    """
    logical_replica_count: torch.Tensor
    """
    Number of replicas for each logical expert.
    This is exactly the non-`-1` count in the `logical_to_physical_map`.

    Shape: (num_moe_layers, num_logical_experts)

    # Example
    For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3
    EP ranks, the count could look like this:

    ```
    [[2, 2, 1, 1],
     [3, 1, 1, 1]]
    """

    expert_load_pass: torch.Tensor
    """
    Expert load during this forward pass. 
    We use the token count each expert processes as the load.

    Shape: (num_moe_layers, num_physical_experts)
    """
    expert_load_window: torch.Tensor
    """
    A sliding window of expert load.

    Shape: (window_size, num_moe_layers, num_physical_experts)

    NOTE: The expert_load_view now records load for all physical experts
    rather than just local experts. This ensures consistent load statistics
    across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels).
    The recorded load will be multiplied by dp_size when using naive all-to-all
    due to each DP rank contributing the same token set to the calculation.
    See:
    https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856
    """
    expert_load_window_step: int = 0
    """
    Current step in the sliding window.

    Different from `expert_rearrangement_step`, each EP rank may have its own
    `expert_load_window_step`.
    """
    expert_load_window_size: int = 0
    """
    Size of the expert load sliding window.
    This is a constant and is taken from the config.
    """

    expert_rearrangement_step: int = 0
    """
    Steps after last rearrangement.
    Will trigger a rearrangement if it exceeds the threshold.

    NOTE: Keep in mind that all EP ranks need to have the same
    `expert_rearrangement_step` value to ensure synchronization.
    Otherwise, the rearrangement will hang at collective
    communication calls.
    """
    expert_rearrangement_step_interval: int = 0
    """
    Interval for expert rearrangement steps.
    This is a constant and is taken from the config.
    """

    @staticmethod
    def build_initial_global_physical_to_logical_map(
        num_routed_experts: int,
        num_redundant_experts: int,
    ) -> Sequence[int]:
        """
        Build an initial expert arrangement using the following structure:
        [original routed experts, redundant experts]

        Returns:
            physical_to_logical_map (Sequence[int]): A list of integers,
                where each integer is the index of the logical expert
                that the corresponding physical expert maps to.
        """
        global_physical_to_logical_map = list(range(num_routed_experts))
        global_physical_to_logical_map += [
            i % num_routed_experts for i in range(num_redundant_experts)
        ]
        return global_physical_to_logical_map

    @classmethod
    def build(
        cls,
        model: MixtureOfExperts,
        device: torch.device,
        parallel_config: ParallelConfig,
        global_expert_load: Optional[torch.Tensor] = None,
        old_global_expert_indices: Optional[torch.Tensor] = None,
        rank_mapping: Optional[dict[int, int]] = None,
    ) -> "EplbState":
        """
        Build the initial EPLB state.
        """
        physical_to_logical_map_list = (
            cls.build_initial_global_physical_to_logical_map(
                model.num_routed_experts,
                model.num_redundant_experts,
            ))
        physical_to_logical_map = torch.tensor(
            physical_to_logical_map_list,
            device=device,
        )
        # Assuming 8 GPUs per node, this supports up to
        # (1023 + 1) / 8 = 128 nodes for now.
        # TODO(rui): make this configurable
        MAX_EXPERT_REDUNDANCY = 1023
        assert model.num_redundant_experts <= MAX_EXPERT_REDUNDANCY, (
            f"num_redundant_experts {model.num_redundant_experts} "
            f"must be less than or equal to {MAX_EXPERT_REDUNDANCY}")
        max_slots_per_logical_expert = MAX_EXPERT_REDUNDANCY + 1
        logical_to_physical_map = torch.full(
            (model.num_logical_experts, max_slots_per_logical_expert),
            -1,
            device=device,
        )
        logical_replica_count = torch.zeros(
            (model.num_logical_experts, ),
            device=device,
            dtype=torch.long,
        )

        for i in range(model.num_physical_experts):
            logical_idx = physical_to_logical_map[i]
            logical_to_physical_map[logical_idx,
                                    logical_replica_count[logical_idx]] = i
            logical_replica_count[logical_idx] += 1

        # Duplicate initial mapping for all layers
        physical_to_logical_map = physical_to_logical_map.unsqueeze(0).expand(
            model.num_moe_layers,
            -1,
        ).contiguous()
        logical_to_physical_map = logical_to_physical_map.unsqueeze(0).expand(
            model.num_moe_layers,
            -1,
            -1,
        ).contiguous()
        logical_replica_count = logical_replica_count.unsqueeze(0).expand(
            model.num_moe_layers,
            -1,
        ).contiguous()

        expert_load_pass = torch.zeros(
            (model.num_moe_layers, model.num_physical_experts),
            dtype=torch.int32,
            device=device,
        )
        expert_load_window_size = parallel_config.eplb_config.window_size
        expert_load_window = torch.zeros(
            (expert_load_window_size, model.num_moe_layers,
             model.num_physical_experts),
            dtype=torch.int32,
            device=device,
        )

        # Set the initial progress of rearrangement to 3/4
        eplb_step_interval = parallel_config.eplb_config.step_interval
        expert_rearrangement_step = max(
            0, eplb_step_interval - eplb_step_interval // 4)

        if global_expert_load is not None:
            ep_group = get_ep_group().device_group
            assert global_expert_load.shape == (model.num_moe_layers,
                                                model.num_logical_experts)
            assert global_expert_load.dtype == torch.int64

            num_replicas = model.num_physical_experts
            num_groups = model.num_expert_groups
            num_nodes = get_node_count()
            num_gpus = ep_group.size()

            if num_gpus % num_nodes != 0:
                num_nodes = 1
                logger.warning_once(
                    f"num_gpus % num_nodes != 0, "
                    "not using hierarchical rearrangement algorithm.\n"
                    f"{num_gpus=}, {num_nodes=}")

            # Get new expert mappings
            (
                new_physical_to_logical_map,
                new_logical_to_physical_map,
                new_logical_replica_count,
            ) = (rebalance_experts(
                global_expert_load,
                num_replicas,
                num_groups,
                num_nodes,
                num_gpus,
            ))

            max_physical_slots = new_logical_to_physical_map.shape[-1]
            assert max_physical_slots <= logical_to_physical_map.shape[-1]
            new_logical_to_physical_map = torch.nn.functional.pad(
                new_logical_to_physical_map,
                (0, logical_to_physical_map.shape[-1] - max_physical_slots),
                value=-1,
            )
            physical_to_logical_map = new_physical_to_logical_map.to(device)
            logical_to_physical_map.copy_(new_logical_to_physical_map)
            logical_replica_count.copy_(new_logical_replica_count)

        model.set_eplb_state(
            expert_load_pass,
            logical_to_physical_map,
            logical_replica_count,
        )
        if global_expert_load is not None:
            rearrange_expert_weights_inplace(
                old_global_expert_indices,
                new_physical_to_logical_map,
                model.expert_weights,
                ep_group,
                False,
                rank_mapping,
            )
            expert_rearrangement_step = 0

        return cls(
            physical_to_logical_map,
            logical_to_physical_map,
            logical_replica_count,
            expert_load_pass,
            expert_load_window,
            expert_load_window_size=expert_load_window_size,
            expert_rearrangement_step=expert_rearrangement_step,
            expert_rearrangement_step_interval=eplb_step_interval,
        )

    def step(self,
             model: MixtureOfExperts,
             is_dummy: bool = False,
             is_profile: bool = False,
             log_stats: bool = False) -> None:
        """
        Step the EPLB state.

        Args:
            model (MixtureOfExperts): The MoE model.
            is_dummy (bool): If `True`, this is a dummy step and the load
              metrics recorded in this forward pass will not count. Defaults
              to `False`.
            is_profile (bool): If `True`, perform a dummy rearrangement
              with maximum communication cost. This is used in `profile_run`
              to reserve enough memory for the communication buffer.
            log_stats (bool): If `True`, log the expert load metrics.

        # Stats
            The metrics are all summed up across layers.
            - `avg_tokens`: The average load across ranks.
            - `max_tokens`: The maximum load across ranks.
            - `balancedness`: The ratio of average load to maximum load.
        """

        if is_profile:
            self.rearrange(model, is_profile=True)
            return

        if is_dummy:
            # Do not record load metrics for dummy steps
            self.expert_load_pass.zero_()

        if log_stats:
            # total_expert_load_pass: (num_moe_layers, num_physical_experts)
            total_expert_load_pass = self.expert_load_pass.clone()

            # Collect load metrics from all ranks
            ep_group = get_ep_group().device_group
            all_reduce(total_expert_load_pass, group=ep_group)

            # num_tokens_per_rank: (num_moe_layers, num_ranks)
            num_tokens_per_rank = total_expert_load_pass.reshape(
                total_expert_load_pass.shape[0], ep_group.size(),
                -1).sum(dim=-1).float()

            # Compute balancedness ratio:
            # for each layer:
            #   (mean load across ranks) / (max load across ranks)
            avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0)
            max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum(
                dim=0)

            # Just to make type checker happy
            tokens_tensors: list[float] = torch.stack(
                [avg_tokens_tensor, max_tokens_tensor]).tolist()
            avg_tokens, max_tokens = tokens_tensors
            balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0

            if ep_group.rank() == 0:
                logger.info(
                    "EPLB step: avg_tokens=%.2f, max_tokens=%d, "
                    "balancedness=%.4f", avg_tokens, max_tokens, balancedness)

        # Update the expert load sliding window
        if not is_dummy:
            self.expert_load_window[self.expert_load_window_step] = (
                self.expert_load_pass.clone())
            self.expert_load_window_step += 1
            if self.expert_load_window_step >= self.expert_load_window_size:
                self.expert_load_window_step = 0
            self.expert_load_pass.zero_()

        # Step the expert rearrangement step
        # Note that even if this is a dummy step, we still increment the
        # rearrangement step and perform rearrangement to ensure all ranks are
        # performing collective communication.
        self.expert_rearrangement_step += 1
        if (self.expert_rearrangement_step
                >= self.expert_rearrangement_step_interval):
            self.expert_rearrangement_step = 0
            self.rearrange(model)

    def rearrange(self,
                  model: MixtureOfExperts,
                  is_profile: bool = False,
                  execute_shuffle: bool = True,
                  global_expert_load: Optional[torch.Tensor] = None,
                  rank_mapping: Optional[dict[int, int]] = None) -> None:
        """
        Rearrange the experts according to the current load.
        """

        ep_group = get_ep_group().device_group
        ep_rank = ep_group.rank()

        time_start = None
        is_main_rank = ep_rank == 0
        if is_main_rank:
            torch.cuda.synchronize()
            time_start = time.perf_counter()
            logger.info("Rearranging experts %s...",
                        "(profile)" if is_profile else "")

        if global_expert_load is None:
            # Map the physical expert load to global logical experts
            logical_expert_load_window = torch.zeros(
                self.expert_load_window_size,
                model.num_moe_layers,
                model.num_logical_experts,
                dtype=self.expert_load_window.dtype,
                device=self.expert_load_window.device,
            )
            logical_expert_load_window.scatter_add_(
                dim=-1,
                index=self.physical_to_logical_map.unsqueeze(0).expand_as(
                    self.expert_load_window).long(),
                src=self.expert_load_window,
            )

            if not execute_shuffle:
                metadata = torch.tensor(
                    [
                        model.num_moe_layers, model.num_logical_experts,
                        self.physical_to_logical_map.shape[1]
                    ],
                    dtype=torch.int32,
                    device="cpu",
                )
                torch.distributed.broadcast(metadata,
                                            group=get_ep_group().cpu_group,
                                            group_src=0)

            # Perform all-reduce to get the expert load across all ranks
            global_expert_load_window = logical_expert_load_window.sum(dim=0)
            all_reduce(global_expert_load_window, group=ep_group)

            if not execute_shuffle:
                # (num_moe_layers, old_num_physical_experts)
                old_global_expert_indices = self.physical_to_logical_map
                torch.distributed.broadcast(old_global_expert_indices,
                                            group=ep_group,
                                            group_src=0)
                return global_expert_load_window
        else:
            assert execute_shuffle
            global_expert_load_window = global_expert_load

        # TODO(bowen): Treat differently for prefill and decode nodes
        num_replicas = model.num_physical_experts
        num_groups = model.num_expert_groups
        if rank_mapping is not None and len(rank_mapping) == ep_group.size():
            # NOTE(yongji): scale down, we need to rebalance the experts on
            # remaining GPUs, transfer the experts while we haven't shutdown
            # the GPUs to be released.
            cpu_group = get_ep_group().cpu_group
            num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping)
            num_gpus = sum(new_rank != -1
                           for new_rank in rank_mapping.values())
            num_replicas = num_replicas // ep_group.size(
            ) * num_gpus  # handle num replicas change
        else:
            num_nodes = get_node_count()
            num_gpus = ep_group.size()

        if num_gpus % num_nodes != 0:
            self.num_nodes = 1
            logger.warning_once(
                f"num_gpus % num_nodes != 0, "
                "not using hierarchical rearrangement algorithm.\n"
                f"{num_gpus=}, {num_nodes=}")

        # Get new expert mappings
        (
            new_physical_to_logical_map,
            new_logical_to_physical_map,
            new_logical_replica_count,
        ) = (rebalance_experts(
            global_expert_load_window,
            num_replicas,
            num_groups,
            num_nodes,
            num_gpus,
        ))

        # Update expert weights
        rearrange_expert_weights_inplace(
            self.physical_to_logical_map,
            new_physical_to_logical_map,
            model.expert_weights,
            ep_group,
            is_profile,
            rank_mapping,
        )

        if not is_profile:
            if self.physical_to_logical_map.shape[
                    1] != new_physical_to_logical_map.shape[1]:
                self.physical_to_logical_map = new_physical_to_logical_map.to(
                    self.physical_to_logical_map.device)
            else:
                self.physical_to_logical_map.copy_(new_physical_to_logical_map)
            max_physical_slots = new_logical_to_physical_map.shape[-1]
            assert max_physical_slots <= self.logical_to_physical_map.shape[-1]
            new_logical_to_physical_map = torch.nn.functional.pad(
                new_logical_to_physical_map,
                (0,
                 self.logical_to_physical_map.shape[-1] - max_physical_slots),
                value=-1,
            )
            self.logical_to_physical_map.copy_(new_logical_to_physical_map)
            self.logical_replica_count.copy_(new_logical_replica_count)

        if is_main_rank:
            assert time_start is not None
            torch.cuda.synchronize()
            time_end = time.perf_counter()
            logger.info(
                "Rearranged experts%sin %.2f seconds.",
                " (profile) " if is_profile else " ",
                time_end - time_start,
            )

    @staticmethod
    def recv_state() -> tuple[torch.Tensor, torch.Tensor]:
        """
        Receive the expert load and old placement from the master rank.
        """
        ep_group = get_ep_group()
        metadata = torch.empty(3, dtype=torch.int32, device="cpu")
        torch.distributed.broadcast(metadata,
                                    group=ep_group.cpu_group,
                                    group_src=0)
        num_moe_layers, num_logical_experts, num_old_physical_experts = (
            metadata.tolist())
        global_expert_load = torch.zeros(
            (num_moe_layers, num_logical_experts),
            dtype=torch.int64,
            device=ep_group.device,
        )
        all_reduce(global_expert_load, group=ep_group.device_group)
        old_global_expert_indices = torch.empty(
            (num_moe_layers, num_old_physical_experts),
            dtype=torch.int64,
            device=ep_group.device,
        )
        torch.distributed.broadcast(old_global_expert_indices,
                                    group=ep_group.device_group,
                                    group_src=0)

        return global_expert_load, old_global_expert_indices

expert_load_pass instance-attribute

expert_load_pass: Tensor

Expert load during this forward pass. We use the token count each expert processes as the load.

Shape: (num_moe_layers, num_physical_experts)

expert_load_window instance-attribute

expert_load_window: Tensor

A sliding window of expert load.

Shape: (window_size, num_moe_layers, num_physical_experts)

NOTE: The expert_load_view now records load for all physical experts rather than just local experts. This ensures consistent load statistics across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels). The recorded load will be multiplied by dp_size when using naive all-to-all due to each DP rank contributing the same token set to the calculation. See: https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856

expert_load_window_size class-attribute instance-attribute

expert_load_window_size: int = 0

Size of the expert load sliding window. This is a constant and is taken from the config.

expert_load_window_step class-attribute instance-attribute

expert_load_window_step: int = 0

Current step in the sliding window.

Different from expert_rearrangement_step, each EP rank may have its own expert_load_window_step.

expert_rearrangement_step class-attribute instance-attribute

expert_rearrangement_step: int = 0

Steps after last rearrangement. Will trigger a rearrangement if it exceeds the threshold.

NOTE: Keep in mind that all EP ranks need to have the same expert_rearrangement_step value to ensure synchronization. Otherwise, the rearrangement will hang at collective communication calls.

expert_rearrangement_step_interval class-attribute instance-attribute

expert_rearrangement_step_interval: int = 0

Interval for expert rearrangement steps. This is a constant and is taken from the config.

logical_replica_count instance-attribute

logical_replica_count: Tensor

Number of replicas for each logical expert. This is exactly the non--1 count in the logical_to_physical_map.

Shape: (num_moe_layers, num_logical_experts)

Example

For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3 EP ranks, the count could look like this:

``` [[2, 2, 1, 1], [3, 1, 1, 1]]

logical_to_physical_map instance-attribute

logical_to_physical_map: Tensor

Mapping from logical experts to physical experts.

This is a sparse matrix, where -1 indicates no mapping.

Shape: (num_moe_layers, num_logical_experts, num_redundant_experts + 1)

Example

For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3 EP ranks, the mapping could look like this:

[[[0, 4, -1],
  [1, 5, -1],
  [2, -1, -1],
  [3, -1, -1]],
 [[0, 2, 4],
  [3, -1, -1],
  [1, -1, -1],
  [5, -1, -1]]]

physical_to_logical_map instance-attribute

physical_to_logical_map: Tensor

Mapping from physical experts to logical experts.

Shape: (num_moe_layers, num_physical_experts)

Example

For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3 EP ranks, the mapping could look like this:

[[0, 1, 2, 3, 0, 1],
 [0, 2, 0, 1, 0, 3]]

__init__

__init__(
    physical_to_logical_map: Tensor,
    logical_to_physical_map: Tensor,
    logical_replica_count: Tensor,
    expert_load_pass: Tensor,
    expert_load_window: Tensor,
    expert_load_window_step: int = 0,
    expert_load_window_size: int = 0,
    expert_rearrangement_step: int = 0,
    expert_rearrangement_step_interval: int = 0,
) -> None

build classmethod

build(
    model: MixtureOfExperts,
    device: device,
    parallel_config: ParallelConfig,
    global_expert_load: Optional[Tensor] = None,
    old_global_expert_indices: Optional[Tensor] = None,
    rank_mapping: Optional[dict[int, int]] = None,
) -> EplbState

Build the initial EPLB state.

Source code in vllm/distributed/eplb/eplb_state.py
@classmethod
def build(
    cls,
    model: MixtureOfExperts,
    device: torch.device,
    parallel_config: ParallelConfig,
    global_expert_load: Optional[torch.Tensor] = None,
    old_global_expert_indices: Optional[torch.Tensor] = None,
    rank_mapping: Optional[dict[int, int]] = None,
) -> "EplbState":
    """
    Build the initial EPLB state.
    """
    physical_to_logical_map_list = (
        cls.build_initial_global_physical_to_logical_map(
            model.num_routed_experts,
            model.num_redundant_experts,
        ))
    physical_to_logical_map = torch.tensor(
        physical_to_logical_map_list,
        device=device,
    )
    # Assuming 8 GPUs per node, this supports up to
    # (1023 + 1) / 8 = 128 nodes for now.
    # TODO(rui): make this configurable
    MAX_EXPERT_REDUNDANCY = 1023
    assert model.num_redundant_experts <= MAX_EXPERT_REDUNDANCY, (
        f"num_redundant_experts {model.num_redundant_experts} "
        f"must be less than or equal to {MAX_EXPERT_REDUNDANCY}")
    max_slots_per_logical_expert = MAX_EXPERT_REDUNDANCY + 1
    logical_to_physical_map = torch.full(
        (model.num_logical_experts, max_slots_per_logical_expert),
        -1,
        device=device,
    )
    logical_replica_count = torch.zeros(
        (model.num_logical_experts, ),
        device=device,
        dtype=torch.long,
    )

    for i in range(model.num_physical_experts):
        logical_idx = physical_to_logical_map[i]
        logical_to_physical_map[logical_idx,
                                logical_replica_count[logical_idx]] = i
        logical_replica_count[logical_idx] += 1

    # Duplicate initial mapping for all layers
    physical_to_logical_map = physical_to_logical_map.unsqueeze(0).expand(
        model.num_moe_layers,
        -1,
    ).contiguous()
    logical_to_physical_map = logical_to_physical_map.unsqueeze(0).expand(
        model.num_moe_layers,
        -1,
        -1,
    ).contiguous()
    logical_replica_count = logical_replica_count.unsqueeze(0).expand(
        model.num_moe_layers,
        -1,
    ).contiguous()

    expert_load_pass = torch.zeros(
        (model.num_moe_layers, model.num_physical_experts),
        dtype=torch.int32,
        device=device,
    )
    expert_load_window_size = parallel_config.eplb_config.window_size
    expert_load_window = torch.zeros(
        (expert_load_window_size, model.num_moe_layers,
         model.num_physical_experts),
        dtype=torch.int32,
        device=device,
    )

    # Set the initial progress of rearrangement to 3/4
    eplb_step_interval = parallel_config.eplb_config.step_interval
    expert_rearrangement_step = max(
        0, eplb_step_interval - eplb_step_interval // 4)

    if global_expert_load is not None:
        ep_group = get_ep_group().device_group
        assert global_expert_load.shape == (model.num_moe_layers,
                                            model.num_logical_experts)
        assert global_expert_load.dtype == torch.int64

        num_replicas = model.num_physical_experts
        num_groups = model.num_expert_groups
        num_nodes = get_node_count()
        num_gpus = ep_group.size()

        if num_gpus % num_nodes != 0:
            num_nodes = 1
            logger.warning_once(
                f"num_gpus % num_nodes != 0, "
                "not using hierarchical rearrangement algorithm.\n"
                f"{num_gpus=}, {num_nodes=}")

        # Get new expert mappings
        (
            new_physical_to_logical_map,
            new_logical_to_physical_map,
            new_logical_replica_count,
        ) = (rebalance_experts(
            global_expert_load,
            num_replicas,
            num_groups,
            num_nodes,
            num_gpus,
        ))

        max_physical_slots = new_logical_to_physical_map.shape[-1]
        assert max_physical_slots <= logical_to_physical_map.shape[-1]
        new_logical_to_physical_map = torch.nn.functional.pad(
            new_logical_to_physical_map,
            (0, logical_to_physical_map.shape[-1] - max_physical_slots),
            value=-1,
        )
        physical_to_logical_map = new_physical_to_logical_map.to(device)
        logical_to_physical_map.copy_(new_logical_to_physical_map)
        logical_replica_count.copy_(new_logical_replica_count)

    model.set_eplb_state(
        expert_load_pass,
        logical_to_physical_map,
        logical_replica_count,
    )
    if global_expert_load is not None:
        rearrange_expert_weights_inplace(
            old_global_expert_indices,
            new_physical_to_logical_map,
            model.expert_weights,
            ep_group,
            False,
            rank_mapping,
        )
        expert_rearrangement_step = 0

    return cls(
        physical_to_logical_map,
        logical_to_physical_map,
        logical_replica_count,
        expert_load_pass,
        expert_load_window,
        expert_load_window_size=expert_load_window_size,
        expert_rearrangement_step=expert_rearrangement_step,
        expert_rearrangement_step_interval=eplb_step_interval,
    )

build_initial_global_physical_to_logical_map staticmethod

build_initial_global_physical_to_logical_map(
    num_routed_experts: int, num_redundant_experts: int
) -> Sequence[int]

Build an initial expert arrangement using the following structure: [original routed experts, redundant experts]

Returns:

Name Type Description
physical_to_logical_map Sequence[int]

A list of integers, where each integer is the index of the logical expert that the corresponding physical expert maps to.

Source code in vllm/distributed/eplb/eplb_state.py
@staticmethod
def build_initial_global_physical_to_logical_map(
    num_routed_experts: int,
    num_redundant_experts: int,
) -> Sequence[int]:
    """
    Build an initial expert arrangement using the following structure:
    [original routed experts, redundant experts]

    Returns:
        physical_to_logical_map (Sequence[int]): A list of integers,
            where each integer is the index of the logical expert
            that the corresponding physical expert maps to.
    """
    global_physical_to_logical_map = list(range(num_routed_experts))
    global_physical_to_logical_map += [
        i % num_routed_experts for i in range(num_redundant_experts)
    ]
    return global_physical_to_logical_map

rearrange

rearrange(
    model: MixtureOfExperts,
    is_profile: bool = False,
    execute_shuffle: bool = True,
    global_expert_load: Optional[Tensor] = None,
    rank_mapping: Optional[dict[int, int]] = None,
) -> None

Rearrange the experts according to the current load.

Source code in vllm/distributed/eplb/eplb_state.py
def rearrange(self,
              model: MixtureOfExperts,
              is_profile: bool = False,
              execute_shuffle: bool = True,
              global_expert_load: Optional[torch.Tensor] = None,
              rank_mapping: Optional[dict[int, int]] = None) -> None:
    """
    Rearrange the experts according to the current load.
    """

    ep_group = get_ep_group().device_group
    ep_rank = ep_group.rank()

    time_start = None
    is_main_rank = ep_rank == 0
    if is_main_rank:
        torch.cuda.synchronize()
        time_start = time.perf_counter()
        logger.info("Rearranging experts %s...",
                    "(profile)" if is_profile else "")

    if global_expert_load is None:
        # Map the physical expert load to global logical experts
        logical_expert_load_window = torch.zeros(
            self.expert_load_window_size,
            model.num_moe_layers,
            model.num_logical_experts,
            dtype=self.expert_load_window.dtype,
            device=self.expert_load_window.device,
        )
        logical_expert_load_window.scatter_add_(
            dim=-1,
            index=self.physical_to_logical_map.unsqueeze(0).expand_as(
                self.expert_load_window).long(),
            src=self.expert_load_window,
        )

        if not execute_shuffle:
            metadata = torch.tensor(
                [
                    model.num_moe_layers, model.num_logical_experts,
                    self.physical_to_logical_map.shape[1]
                ],
                dtype=torch.int32,
                device="cpu",
            )
            torch.distributed.broadcast(metadata,
                                        group=get_ep_group().cpu_group,
                                        group_src=0)

        # Perform all-reduce to get the expert load across all ranks
        global_expert_load_window = logical_expert_load_window.sum(dim=0)
        all_reduce(global_expert_load_window, group=ep_group)

        if not execute_shuffle:
            # (num_moe_layers, old_num_physical_experts)
            old_global_expert_indices = self.physical_to_logical_map
            torch.distributed.broadcast(old_global_expert_indices,
                                        group=ep_group,
                                        group_src=0)
            return global_expert_load_window
    else:
        assert execute_shuffle
        global_expert_load_window = global_expert_load

    # TODO(bowen): Treat differently for prefill and decode nodes
    num_replicas = model.num_physical_experts
    num_groups = model.num_expert_groups
    if rank_mapping is not None and len(rank_mapping) == ep_group.size():
        # NOTE(yongji): scale down, we need to rebalance the experts on
        # remaining GPUs, transfer the experts while we haven't shutdown
        # the GPUs to be released.
        cpu_group = get_ep_group().cpu_group
        num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping)
        num_gpus = sum(new_rank != -1
                       for new_rank in rank_mapping.values())
        num_replicas = num_replicas // ep_group.size(
        ) * num_gpus  # handle num replicas change
    else:
        num_nodes = get_node_count()
        num_gpus = ep_group.size()

    if num_gpus % num_nodes != 0:
        self.num_nodes = 1
        logger.warning_once(
            f"num_gpus % num_nodes != 0, "
            "not using hierarchical rearrangement algorithm.\n"
            f"{num_gpus=}, {num_nodes=}")

    # Get new expert mappings
    (
        new_physical_to_logical_map,
        new_logical_to_physical_map,
        new_logical_replica_count,
    ) = (rebalance_experts(
        global_expert_load_window,
        num_replicas,
        num_groups,
        num_nodes,
        num_gpus,
    ))

    # Update expert weights
    rearrange_expert_weights_inplace(
        self.physical_to_logical_map,
        new_physical_to_logical_map,
        model.expert_weights,
        ep_group,
        is_profile,
        rank_mapping,
    )

    if not is_profile:
        if self.physical_to_logical_map.shape[
                1] != new_physical_to_logical_map.shape[1]:
            self.physical_to_logical_map = new_physical_to_logical_map.to(
                self.physical_to_logical_map.device)
        else:
            self.physical_to_logical_map.copy_(new_physical_to_logical_map)
        max_physical_slots = new_logical_to_physical_map.shape[-1]
        assert max_physical_slots <= self.logical_to_physical_map.shape[-1]
        new_logical_to_physical_map = torch.nn.functional.pad(
            new_logical_to_physical_map,
            (0,
             self.logical_to_physical_map.shape[-1] - max_physical_slots),
            value=-1,
        )
        self.logical_to_physical_map.copy_(new_logical_to_physical_map)
        self.logical_replica_count.copy_(new_logical_replica_count)

    if is_main_rank:
        assert time_start is not None
        torch.cuda.synchronize()
        time_end = time.perf_counter()
        logger.info(
            "Rearranged experts%sin %.2f seconds.",
            " (profile) " if is_profile else " ",
            time_end - time_start,
        )

recv_state staticmethod

recv_state() -> tuple[Tensor, Tensor]

Receive the expert load and old placement from the master rank.

Source code in vllm/distributed/eplb/eplb_state.py
@staticmethod
def recv_state() -> tuple[torch.Tensor, torch.Tensor]:
    """
    Receive the expert load and old placement from the master rank.
    """
    ep_group = get_ep_group()
    metadata = torch.empty(3, dtype=torch.int32, device="cpu")
    torch.distributed.broadcast(metadata,
                                group=ep_group.cpu_group,
                                group_src=0)
    num_moe_layers, num_logical_experts, num_old_physical_experts = (
        metadata.tolist())
    global_expert_load = torch.zeros(
        (num_moe_layers, num_logical_experts),
        dtype=torch.int64,
        device=ep_group.device,
    )
    all_reduce(global_expert_load, group=ep_group.device_group)
    old_global_expert_indices = torch.empty(
        (num_moe_layers, num_old_physical_experts),
        dtype=torch.int64,
        device=ep_group.device,
    )
    torch.distributed.broadcast(old_global_expert_indices,
                                group=ep_group.device_group,
                                group_src=0)

    return global_expert_load, old_global_expert_indices

step

step(
    model: MixtureOfExperts,
    is_dummy: bool = False,
    is_profile: bool = False,
    log_stats: bool = False,
) -> None

Step the EPLB state.

Parameters:

Name Type Description Default
model MixtureOfExperts

The MoE model.

required
is_dummy bool

If True, this is a dummy step and the load metrics recorded in this forward pass will not count. Defaults to False.

False
is_profile bool

If True, perform a dummy rearrangement with maximum communication cost. This is used in profile_run to reserve enough memory for the communication buffer.

False
log_stats bool

If True, log the expert load metrics.

False

Stats

The metrics are all summed up across layers.
- `avg_tokens`: The average load across ranks.
- `max_tokens`: The maximum load across ranks.
- `balancedness`: The ratio of average load to maximum load.
Source code in vllm/distributed/eplb/eplb_state.py
def step(self,
         model: MixtureOfExperts,
         is_dummy: bool = False,
         is_profile: bool = False,
         log_stats: bool = False) -> None:
    """
    Step the EPLB state.

    Args:
        model (MixtureOfExperts): The MoE model.
        is_dummy (bool): If `True`, this is a dummy step and the load
          metrics recorded in this forward pass will not count. Defaults
          to `False`.
        is_profile (bool): If `True`, perform a dummy rearrangement
          with maximum communication cost. This is used in `profile_run`
          to reserve enough memory for the communication buffer.
        log_stats (bool): If `True`, log the expert load metrics.

    # Stats
        The metrics are all summed up across layers.
        - `avg_tokens`: The average load across ranks.
        - `max_tokens`: The maximum load across ranks.
        - `balancedness`: The ratio of average load to maximum load.
    """

    if is_profile:
        self.rearrange(model, is_profile=True)
        return

    if is_dummy:
        # Do not record load metrics for dummy steps
        self.expert_load_pass.zero_()

    if log_stats:
        # total_expert_load_pass: (num_moe_layers, num_physical_experts)
        total_expert_load_pass = self.expert_load_pass.clone()

        # Collect load metrics from all ranks
        ep_group = get_ep_group().device_group
        all_reduce(total_expert_load_pass, group=ep_group)

        # num_tokens_per_rank: (num_moe_layers, num_ranks)
        num_tokens_per_rank = total_expert_load_pass.reshape(
            total_expert_load_pass.shape[0], ep_group.size(),
            -1).sum(dim=-1).float()

        # Compute balancedness ratio:
        # for each layer:
        #   (mean load across ranks) / (max load across ranks)
        avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0)
        max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum(
            dim=0)

        # Just to make type checker happy
        tokens_tensors: list[float] = torch.stack(
            [avg_tokens_tensor, max_tokens_tensor]).tolist()
        avg_tokens, max_tokens = tokens_tensors
        balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0

        if ep_group.rank() == 0:
            logger.info(
                "EPLB step: avg_tokens=%.2f, max_tokens=%d, "
                "balancedness=%.4f", avg_tokens, max_tokens, balancedness)

    # Update the expert load sliding window
    if not is_dummy:
        self.expert_load_window[self.expert_load_window_step] = (
            self.expert_load_pass.clone())
        self.expert_load_window_step += 1
        if self.expert_load_window_step >= self.expert_load_window_size:
            self.expert_load_window_step = 0
        self.expert_load_pass.zero_()

    # Step the expert rearrangement step
    # Note that even if this is a dummy step, we still increment the
    # rearrangement step and perform rearrangement to ensure all ranks are
    # performing collective communication.
    self.expert_rearrangement_step += 1
    if (self.expert_rearrangement_step
            >= self.expert_rearrangement_step_interval):
        self.expert_rearrangement_step = 0
        self.rearrange(model)

MixtureOfExperts

Bases: Protocol

Check if the model is a mixture of experts (MoE) model.

Source code in vllm/model_executor/models/interfaces.py
@runtime_checkable
class MixtureOfExperts(Protocol):
    """
    Check if the model is a mixture of experts (MoE) model.
    """

    expert_weights: MutableSequence[Iterable[Tensor]]
    """
    Expert weights saved in this rank.

    The first dimension is the layer, and the second dimension is different
    parameters in the layer, e.g. up/down projection weights.
    """

    num_moe_layers: int
    """Number of MoE layers in this model."""

    num_expert_groups: int
    """Number of expert groups in this model."""

    num_logical_experts: int
    """Number of logical experts in this model."""

    num_physical_experts: int
    """Number of physical experts in this model."""

    num_local_physical_experts: int
    """Number of local physical experts in this model."""

    num_routed_experts: int
    """Number of routed experts in this model."""

    num_shared_experts: int
    """Number of shared experts in this model."""

    num_redundant_experts: int
    """Number of redundant experts in this model."""

    def set_eplb_state(
        self,
        expert_load_view: Tensor,
        logical_to_physical_map: Tensor,
        logical_replica_count: Tensor,
    ) -> None:
        """
        Register the EPLB state in the MoE model.

        Since these are views of the actual EPLB state, any changes made by
        the EPLB algorithm are automatically reflected in the model's behavior
        without requiring additional method calls to set new states.

        You should also collect model's `expert_weights` here instead of in
        the weight loader, since after initial weight loading, further
        processing like quantization may be applied to the weights.

        Args:
            expert_load_view: A view of the expert load metrics tensor.
            logical_to_physical_map: Mapping from logical to physical experts.
            logical_replica_count: Count of replicas for each logical expert.
        """
        ...

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        ...

expert_weights instance-attribute

expert_weights: MutableSequence[Iterable[Tensor]]

Expert weights saved in this rank.

The first dimension is the layer, and the second dimension is different parameters in the layer, e.g. up/down projection weights.

num_expert_groups instance-attribute

num_expert_groups: int

Number of expert groups in this model.

num_local_physical_experts instance-attribute

num_local_physical_experts: int

Number of local physical experts in this model.

num_logical_experts instance-attribute

num_logical_experts: int

Number of logical experts in this model.

num_moe_layers instance-attribute

num_moe_layers: int

Number of MoE layers in this model.

num_physical_experts instance-attribute

num_physical_experts: int

Number of physical experts in this model.

num_redundant_experts instance-attribute

num_redundant_experts: int

Number of redundant experts in this model.

num_routed_experts instance-attribute

num_routed_experts: int

Number of routed experts in this model.

num_shared_experts instance-attribute

num_shared_experts: int

Number of shared experts in this model.

set_eplb_state

set_eplb_state(
    expert_load_view: Tensor,
    logical_to_physical_map: Tensor,
    logical_replica_count: Tensor,
) -> None

Register the EPLB state in the MoE model.

Since these are views of the actual EPLB state, any changes made by the EPLB algorithm are automatically reflected in the model's behavior without requiring additional method calls to set new states.

You should also collect model's expert_weights here instead of in the weight loader, since after initial weight loading, further processing like quantization may be applied to the weights.

Parameters:

Name Type Description Default
expert_load_view Tensor

A view of the expert load metrics tensor.

required
logical_to_physical_map Tensor

Mapping from logical to physical experts.

required
logical_replica_count Tensor

Count of replicas for each logical expert.

required
Source code in vllm/model_executor/models/interfaces.py
def set_eplb_state(
    self,
    expert_load_view: Tensor,
    logical_to_physical_map: Tensor,
    logical_replica_count: Tensor,
) -> None:
    """
    Register the EPLB state in the MoE model.

    Since these are views of the actual EPLB state, any changes made by
    the EPLB algorithm are automatically reflected in the model's behavior
    without requiring additional method calls to set new states.

    You should also collect model's `expert_weights` here instead of in
    the weight loader, since after initial weight loading, further
    processing like quantization may be applied to the weights.

    Args:
        expert_load_view: A view of the expert load metrics tensor.
        logical_to_physical_map: Mapping from logical to physical experts.
        logical_replica_count: Count of replicas for each logical expert.
    """
    ...

update_physical_experts_metadata

update_physical_experts_metadata(
    num_physical_experts: int,
    num_local_physical_experts: int,
) -> None
Source code in vllm/model_executor/models/interfaces.py
def update_physical_experts_metadata(
    self,
    num_physical_experts: int,
    num_local_physical_experts: int,
) -> None:
    ...

ParallelConfig

Configuration for the distributed execution.

Source code in vllm/config/parallel.py
@config
@dataclass
class ParallelConfig:
    """Configuration for the distributed execution."""

    pipeline_parallel_size: int = 1
    """Number of pipeline parallel groups."""
    tensor_parallel_size: int = 1
    """Number of tensor parallel groups."""
    data_parallel_size: int = 1
    """Number of data parallel groups. MoE layers will be sharded according to
    the product of the tensor parallel size and data parallel size."""
    data_parallel_size_local: int = 1
    """Number of local data parallel groups."""
    data_parallel_rank: int = 0
    """Rank of the data parallel group."""
    data_parallel_rank_local: Optional[int] = None
    """Local rank of the data parallel group,
    set only in SPMD mode."""
    data_parallel_master_ip: str = "127.0.0.1"
    """IP of the data parallel master."""
    data_parallel_rpc_port: int = 29550
    """Port for data parallel messaging."""
    data_parallel_master_port: int = 29500
    """Port of the data parallel master."""
    data_parallel_backend: str = "mp"
    """Backend to use for data parallel, either "mp" or "ray"."""
    data_parallel_external_lb: bool = False
    """Whether to use "external" DP LB mode. Applies only to online serving
    and when data_parallel_size > 0. This is useful for a "one-pod-per-rank"
    wide-EP setup in Kuberentes. Set implicitly when --data-parallel-rank
    is provided explicitly to vllm serve."""
    data_parallel_hybrid_lb: bool = False
    """Whether to use "hybrid" DP LB mode. Applies only to online serving
    and when data_parallel_size > 0. Enables running an AsyncLLM
    and API server on a "per-node" basis where vLLM load balances
    between local data parallel ranks, but an external LB balances
    between vLLM nodes/replicas. Set explicitly in conjunction with
    --data-parallel-start-rank."""
    enable_expert_parallel: bool = False
    """Use expert parallelism instead of tensor parallelism for MoE layers."""
    enable_eplb: bool = False
    """Enable expert parallelism load balancing for MoE layers."""
    eplb_config: EPLBConfig = field(default_factory=EPLBConfig)
    """Expert parallelism configuration."""
    num_redundant_experts: Optional[int] = None
    """`num_redundant_experts` is deprecated and has been replaced with
    `eplb_config.num_redundant_experts`. This will be removed in v0.12.0.
    Please use `eplb_config.num_redundant_experts` instead."""
    eplb_window_size: Optional[int] = None
    """`eplb_window_size` is deprecated and has been replaced with
    `eplb_config.window_size`. This will be removed in v0.12.0.
    Please use `eplb_config.window_size` instead."""
    eplb_step_interval: Optional[int] = None
    """`eplb_step_interval` is deprecated and has been replaced with
    `eplb_config.step_interval`. This will be removed in v0.12.0.
    Please use `eplb_config.step_interval` instead."""
    eplb_log_balancedness: Optional[bool] = None
    """`eplb_log_balancedness` is deprecated and has been replaced with
    `eplb_config.log_balancedness`. This will be removed in v0.12.0.
    Please use `eplb_config.log_balancedness` instead."""

    max_parallel_loading_workers: Optional[int] = None
    """Maximum number of parallel loading workers when loading model
    sequentially in multiple batches. To avoid RAM OOM when using tensor
    parallel and large models."""

    disable_custom_all_reduce: bool = False
    """Disable the custom all-reduce kernel and fall back to NCCL."""

    ray_workers_use_nsight: bool = False
    """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""

    ray_runtime_env: Optional[RuntimeEnv] = None
    """Ray runtime environment to pass to distributed workers."""

    placement_group: Optional[PlacementGroup] = None
    """ray distributed model workers placement group."""

    distributed_executor_backend: Optional[Union[str,
                                                 DistributedExecutorBackend,
                                                 type[ExecutorBase]]] = None
    """Backend to use for distributed model
    workers, either "ray" or "mp" (multiprocessing). If the product
    of pipeline_parallel_size and tensor_parallel_size is less than
    or equal to the number of GPUs available, "mp" will be used to
    keep processing on a single host. Otherwise, this will default
    to "ray" if Ray is installed and fail otherwise. Note that tpu
    only support Ray for distributed inference."""

    worker_cls: str = "auto"
    """The full name of the worker class to use. If "auto", the worker class
    will be determined based on the platform."""
    sd_worker_cls: str = "auto"
    """The full name of the worker class to use for speculative decoding.
    If "auto", the worker class will be determined based on the platform."""
    worker_extension_cls: str = ""
    """The full name of the worker extension class to use. The worker extension
    class is dynamically inherited by the worker class. This is used to inject
    new attributes and methods to the worker class for use in collective_rpc
    calls."""

    world_size: int = field(init=False)
    """world_size is TPxPP, it affects the number of workers we create."""

    rank: int = 0
    """Global rank in distributed setup."""

    _data_parallel_master_port_list: list[int] = field(default_factory=list)
    """List of open port auto-queried for data parallel messaging.
    Set to be private as it's not intended to be configured by users.
    """

    @property
    def world_size_across_dp(self) -> int:
        """world_size_across_dp is TPxPPxDP, it is the size of the world
        including data parallelism."""
        return self.world_size * self.data_parallel_size

    def get_next_dp_init_port(self) -> int:
        """
        We might need to initialize process groups in multiple
        processes that is related to data parallelism,
        e.g. both in the worker and in the engine, which
        can live in different processes. To avoid port conflicts, we
        pop a new port from the prepared port list each time we need to
        initialize a new process group related to data parallelism.
        """
        if self._data_parallel_master_port_list:
            answer = self._data_parallel_master_port_list.pop()
        else:
            answer = self.data_parallel_master_port
            self.data_parallel_master_port += 1

        return answer

    def stateless_init_dp_group(self) -> ProcessGroup:
        # NOTE: In high-concurrency scenarios multiple processes
        # can pick the same (currently free) port through a race
        # condition when calling `get_open_port()`. When the first
        # process binds the port the others will subsequently fail
        # with `torch.distributed.DistNetworkError: EADDRINUSE`.
        # To make the initialization more robust we retry a few times
        # with a fresh port whenever this specific error is observed.
        from torch.distributed import DistNetworkError

        from vllm.distributed.utils import (
            stateless_init_torch_distributed_process_group)

        max_retries = 5
        last_exc: Optional[Exception] = None
        for _ in range(max_retries):
            try:
                # use gloo since the engine process might not have cuda device
                return stateless_init_torch_distributed_process_group(
                    self.data_parallel_master_ip,
                    self.get_next_dp_init_port(),
                    self.data_parallel_rank,
                    self.data_parallel_size,
                    backend="gloo")
            except DistNetworkError as e:
                # We only want to retry when the root cause is EADDRINUSE.
                if "EADDRINUSE" in str(e):
                    logger.warning(
                        "Address already in use. Retrying with a new port.")
                    last_exc = e
                    continue  # try again with a new port
                raise e

        # If we get here all retries have failed.
        assert last_exc is not None
        raise last_exc

    @staticmethod
    def has_unfinished_dp(dp_group: ProcessGroup,
                          has_unfinished: bool) -> bool:
        tensor = torch.tensor([has_unfinished],
                              dtype=torch.int32,
                              device="cpu")
        # dp rank 0: has_unfinished_seqs=True
        # dp rank 1: has_unfinished_seqs=False
        # aggregated: has_unfinished_seqs=True
        # so this is an OR operation, i.e. MAX in integers
        torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group)
        aggregated_has_unfinished = bool(tensor.item())
        return aggregated_has_unfinished

    @staticmethod
    def sync_kv_cache_memory_size(dp_group: ProcessGroup,
                                  kv_cache_memory: int) -> int:
        if kv_cache_memory == -1:
            kv_cache_memory = torch.iinfo(torch.int64).max
        tensor = torch.tensor([kv_cache_memory],
                              dtype=torch.int64,
                              device="cpu")
        # we cannot use broadcast for stateless dp group since it depends
        # on global rank
        torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group)
        return tensor.item()

    def compute_hash(self):
        """
        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        factors: list[Any] = []
        factors.append(self.pipeline_parallel_size)
        factors.append(self.tensor_parallel_size)
        factors.append(self.enable_expert_parallel)
        factors.append(self.data_parallel_size)
        factors.append(envs.VLLM_ALL2ALL_BACKEND)
        return hashlib.sha256(str(factors).encode()).hexdigest()

    def __post_init__(self) -> None:
        # Forward deprecated fields to their new location
        if self.num_redundant_experts is not None:
            self.eplb_config.num_redundant_experts = (
                self.num_redundant_experts)
            logger.warning_once(
                "num_redundant_experts is deprecated and has been replaced "
                "with eplb_config.num_redundant_experts. This will be removed "
                "in v0.12.0. Changing this field after initialization will "
                "have no effect.")
        if self.eplb_window_size is not None:
            self.eplb_config.window_size = self.eplb_window_size
            logger.warning_once(
                "eplb_window_size is deprecated and has been replaced "
                "with eplb_config.window_size. This will be removed "
                "in v0.12.0. Changing this field after initialization will "
                "have no effect.")
        if self.eplb_step_interval is not None:
            self.eplb_config.step_interval = self.eplb_step_interval
            logger.warning_once(
                "eplb_step_interval is deprecated and has been replaced "
                "with eplb_config.step_interval. This will be removed "
                "in v0.12.0. Changing this field after initialization will "
                "have no effect.")
        if self.eplb_log_balancedness is not None:
            self.eplb_config.log_balancedness = self.eplb_log_balancedness
            logger.warning_once(
                "eplb_log_balancedness is deprecated and has been replaced "
                "with eplb_config.log_balancedness. This will be removed "
                "in v0.12.0. Changing this field after initialization will "
                "have no effect.")

        # Continue with the rest of the initialization
        self.world_size = self.pipeline_parallel_size * \
            self.tensor_parallel_size

        if self.data_parallel_size_local > self.data_parallel_size:
            raise ValueError(
                f"data_parallel_size_local ({self.data_parallel_size_local}) "
                f"must be <= data_parallel_size ({self.data_parallel_size})")

        if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
            # Data parallel was specified in the engine args.
            if not self._data_parallel_master_port_list:
                self._data_parallel_master_port_list = get_open_ports_list(5)
            self.data_parallel_master_port = \
                self._data_parallel_master_port_list.pop()

            if not (0 <= self.data_parallel_rank < self.data_parallel_size):
                raise ValueError(
                    f"data_parallel_rank ({self.data_parallel_rank})"
                    f" must be in the range [0, {self.data_parallel_size})")
        else:
            # Otherwise fall back to env vars (e.g. for offline SPMD case).
            self.data_parallel_size = envs.VLLM_DP_SIZE
            self.data_parallel_rank = envs.VLLM_DP_RANK
            self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL
            self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
            self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT

            if self.data_parallel_external_lb:
                raise ValueError("data_parallel_external_lb can only "
                                 "be set when data_parallel_size > 1")

        if self.distributed_executor_backend == "external_launcher":
            import os
            os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
            logger.info("Disabling V1 multiprocessing for external launcher.")

        if self.enable_eplb:
            if not current_platform.is_cuda():
                raise ValueError(
                    "Expert parallelism load balancing is only supported on "
                    "CUDA devices now.")
            if self.eplb_config.num_redundant_experts < 0:
                raise ValueError(
                    "num_redundant_experts must be non-negative, but got "
                    f"{self.eplb_config.num_redundant_experts}.")
            if not self.enable_expert_parallel:
                raise ValueError(
                    "enable_expert_parallel must be True to use EPLB.")
            if self.tensor_parallel_size * self.data_parallel_size <= 1:
                raise ValueError(
                    "EPLB requires tensor_parallel_size or data_parallel_size "
                    f"to be greater than 1, but got "
                    f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
                )
        else:
            if self.eplb_config.num_redundant_experts != 0:
                raise ValueError(
                    "num_redundant_experts should be used with EPLB."
                    f"{self.eplb_config.num_redundant_experts}.")
        if self.distributed_executor_backend is None and self.world_size > 1:
            # We use multiprocessing by default if world_size fits on the
            # current node and we aren't in a ray placement group.

            from vllm.executor import ray_utils
            backend: DistributedExecutorBackend = "mp"
            ray_found = ray_utils.ray_is_available()
            if current_platform.is_neuron():
                # neuron uses single process to control multiple devices
                backend = "uni"
            elif current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
                backend = "uni"
            elif (current_platform.is_cuda()
                  and cuda_device_count_stateless() < self.world_size):
                if not ray_found:
                    raise ValueError("Unable to load Ray: "
                                     f"{ray_utils.ray_import_err}. Ray is "
                                     "required for multi-node inference, "
                                     "please install Ray with `pip install "
                                     "ray`.")
                backend = "ray"
            elif self.data_parallel_backend == "ray":
                logger.info("Using ray distributed inference because "
                            "data_parallel_backend is ray")
                backend = "ray"
            elif ray_found:
                if self.placement_group:
                    backend = "ray"
                else:
                    from ray import is_initialized as ray_is_initialized
                    if ray_is_initialized():
                        from ray.util import get_current_placement_group
                        if get_current_placement_group():
                            backend = "ray"
            self.distributed_executor_backend = backend
            logger.debug("Defaulting to use %s for distributed inference",
                         backend)

        if self.distributed_executor_backend is None and self.world_size == 1:
            self.distributed_executor_backend = "uni"

    @property
    def use_ray(self) -> bool:
        return self.distributed_executor_backend == "ray" or (
            isinstance(self.distributed_executor_backend, type)
            and getattr(self.distributed_executor_backend, "uses_ray", False))

    @model_validator(mode='after')
    def _verify_args(self) -> Self:
        # Lazy import to avoid circular import
        from vllm.executor.executor_base import ExecutorBase
        from vllm.platforms import current_platform
        if self.distributed_executor_backend is not None and not isinstance(
                self.distributed_executor_backend, str) and not (isinstance(
                    self.distributed_executor_backend, type) and issubclass(
                        self.distributed_executor_backend, ExecutorBase)):
            raise ValueError(
                "Unrecognized distributed executor backend "
                f"{self.distributed_executor_backend}. Supported "
                "values are 'ray', 'mp' 'uni', 'external_launcher', "
                " custom ExecutorBase subclass or its import path.")
        if self.use_ray:
            from vllm.executor import ray_utils
            ray_utils.assert_ray_available()

        if not current_platform.use_custom_allreduce():
            self.disable_custom_all_reduce = True
            logger.debug(
                "Disabled the custom all-reduce kernel because it is not "
                "supported on current platform.")
        if self.ray_workers_use_nsight and not self.use_ray:
            raise ValueError("Unable to use nsight profiling unless workers "
                             "run with Ray.")

        return self

_data_parallel_master_port_list class-attribute instance-attribute

_data_parallel_master_port_list: list[int] = field(
    default_factory=list
)

List of open port auto-queried for data parallel messaging. Set to be private as it's not intended to be configured by users.

data_parallel_backend class-attribute instance-attribute

data_parallel_backend: str = 'mp'

Backend to use for data parallel, either "mp" or "ray".

data_parallel_external_lb class-attribute instance-attribute

data_parallel_external_lb: bool = False

Whether to use "external" DP LB mode. Applies only to online serving and when data_parallel_size > 0. This is useful for a "one-pod-per-rank" wide-EP setup in Kuberentes. Set implicitly when --data-parallel-rank is provided explicitly to vllm serve.

data_parallel_hybrid_lb class-attribute instance-attribute

data_parallel_hybrid_lb: bool = False

Whether to use "hybrid" DP LB mode. Applies only to online serving and when data_parallel_size > 0. Enables running an AsyncLLM and API server on a "per-node" basis where vLLM load balances between local data parallel ranks, but an external LB balances between vLLM nodes/replicas. Set explicitly in conjunction with --data-parallel-start-rank.

data_parallel_master_ip class-attribute instance-attribute

data_parallel_master_ip: str = '127.0.0.1'

IP of the data parallel master.

data_parallel_master_port class-attribute instance-attribute

data_parallel_master_port: int = 29500

Port of the data parallel master.

data_parallel_rank class-attribute instance-attribute

data_parallel_rank: int = 0

Rank of the data parallel group.

data_parallel_rank_local class-attribute instance-attribute

data_parallel_rank_local: Optional[int] = None

Local rank of the data parallel group, set only in SPMD mode.

data_parallel_rpc_port class-attribute instance-attribute

data_parallel_rpc_port: int = 29550

Port for data parallel messaging.

data_parallel_size class-attribute instance-attribute

data_parallel_size: int = 1

Number of data parallel groups. MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.

data_parallel_size_local class-attribute instance-attribute

data_parallel_size_local: int = 1

Number of local data parallel groups.

disable_custom_all_reduce class-attribute instance-attribute

disable_custom_all_reduce: bool = False

Disable the custom all-reduce kernel and fall back to NCCL.

distributed_executor_backend class-attribute instance-attribute

distributed_executor_backend: Optional[
    Union[
        str, DistributedExecutorBackend, type[ExecutorBase]
    ]
] = None

Backend to use for distributed model workers, either "ray" or "mp" (multiprocessing). If the product of pipeline_parallel_size and tensor_parallel_size is less than or equal to the number of GPUs available, "mp" will be used to keep processing on a single host. Otherwise, this will default to "ray" if Ray is installed and fail otherwise. Note that tpu only support Ray for distributed inference.

enable_eplb class-attribute instance-attribute

enable_eplb: bool = False

Enable expert parallelism load balancing for MoE layers.

enable_expert_parallel class-attribute instance-attribute

enable_expert_parallel: bool = False

Use expert parallelism instead of tensor parallelism for MoE layers.

eplb_config class-attribute instance-attribute

eplb_config: EPLBConfig = field(default_factory=EPLBConfig)

Expert parallelism configuration.

eplb_log_balancedness class-attribute instance-attribute

eplb_log_balancedness: Optional[bool] = None

eplb_log_balancedness is deprecated and has been replaced with eplb_config.log_balancedness. This will be removed in v0.12.0. Please use eplb_config.log_balancedness instead.

eplb_step_interval class-attribute instance-attribute

eplb_step_interval: Optional[int] = None

eplb_step_interval is deprecated and has been replaced with eplb_config.step_interval. This will be removed in v0.12.0. Please use eplb_config.step_interval instead.

eplb_window_size class-attribute instance-attribute

eplb_window_size: Optional[int] = None

eplb_window_size is deprecated and has been replaced with eplb_config.window_size. This will be removed in v0.12.0. Please use eplb_config.window_size instead.

max_parallel_loading_workers class-attribute instance-attribute

max_parallel_loading_workers: Optional[int] = None

Maximum number of parallel loading workers when loading model sequentially in multiple batches. To avoid RAM OOM when using tensor parallel and large models.

num_redundant_experts class-attribute instance-attribute

num_redundant_experts: Optional[int] = None

num_redundant_experts is deprecated and has been replaced with eplb_config.num_redundant_experts. This will be removed in v0.12.0. Please use eplb_config.num_redundant_experts instead.

pipeline_parallel_size class-attribute instance-attribute

pipeline_parallel_size: int = 1

Number of pipeline parallel groups.

placement_group class-attribute instance-attribute

placement_group: Optional[PlacementGroup] = None

ray distributed model workers placement group.

rank class-attribute instance-attribute

rank: int = 0

Global rank in distributed setup.

ray_runtime_env class-attribute instance-attribute

ray_runtime_env: Optional[RuntimeEnv] = None

Ray runtime environment to pass to distributed workers.

ray_workers_use_nsight class-attribute instance-attribute

ray_workers_use_nsight: bool = False

Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.

sd_worker_cls class-attribute instance-attribute

sd_worker_cls: str = 'auto'

The full name of the worker class to use for speculative decoding. If "auto", the worker class will be determined based on the platform.

tensor_parallel_size class-attribute instance-attribute

tensor_parallel_size: int = 1

Number of tensor parallel groups.

use_ray property

use_ray: bool

worker_cls class-attribute instance-attribute

worker_cls: str = 'auto'

The full name of the worker class to use. If "auto", the worker class will be determined based on the platform.

worker_extension_cls class-attribute instance-attribute

worker_extension_cls: str = ''

The full name of the worker extension class to use. The worker extension class is dynamically inherited by the worker class. This is used to inject new attributes and methods to the worker class for use in collective_rpc calls.

world_size class-attribute instance-attribute

world_size: int = field(init=False)

world_size is TPxPP, it affects the number of workers we create.

world_size_across_dp property

world_size_across_dp: int

world_size_across_dp is TPxPPxDP, it is the size of the world including data parallelism.

__post_init__

__post_init__() -> None
Source code in vllm/config/parallel.py
def __post_init__(self) -> None:
    # Forward deprecated fields to their new location
    if self.num_redundant_experts is not None:
        self.eplb_config.num_redundant_experts = (
            self.num_redundant_experts)
        logger.warning_once(
            "num_redundant_experts is deprecated and has been replaced "
            "with eplb_config.num_redundant_experts. This will be removed "
            "in v0.12.0. Changing this field after initialization will "
            "have no effect.")
    if self.eplb_window_size is not None:
        self.eplb_config.window_size = self.eplb_window_size
        logger.warning_once(
            "eplb_window_size is deprecated and has been replaced "
            "with eplb_config.window_size. This will be removed "
            "in v0.12.0. Changing this field after initialization will "
            "have no effect.")
    if self.eplb_step_interval is not None:
        self.eplb_config.step_interval = self.eplb_step_interval
        logger.warning_once(
            "eplb_step_interval is deprecated and has been replaced "
            "with eplb_config.step_interval. This will be removed "
            "in v0.12.0. Changing this field after initialization will "
            "have no effect.")
    if self.eplb_log_balancedness is not None:
        self.eplb_config.log_balancedness = self.eplb_log_balancedness
        logger.warning_once(
            "eplb_log_balancedness is deprecated and has been replaced "
            "with eplb_config.log_balancedness. This will be removed "
            "in v0.12.0. Changing this field after initialization will "
            "have no effect.")

    # Continue with the rest of the initialization
    self.world_size = self.pipeline_parallel_size * \
        self.tensor_parallel_size

    if self.data_parallel_size_local > self.data_parallel_size:
        raise ValueError(
            f"data_parallel_size_local ({self.data_parallel_size_local}) "
            f"must be <= data_parallel_size ({self.data_parallel_size})")

    if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
        # Data parallel was specified in the engine args.
        if not self._data_parallel_master_port_list:
            self._data_parallel_master_port_list = get_open_ports_list(5)
        self.data_parallel_master_port = \
            self._data_parallel_master_port_list.pop()

        if not (0 <= self.data_parallel_rank < self.data_parallel_size):
            raise ValueError(
                f"data_parallel_rank ({self.data_parallel_rank})"
                f" must be in the range [0, {self.data_parallel_size})")
    else:
        # Otherwise fall back to env vars (e.g. for offline SPMD case).
        self.data_parallel_size = envs.VLLM_DP_SIZE
        self.data_parallel_rank = envs.VLLM_DP_RANK
        self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL
        self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
        self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT

        if self.data_parallel_external_lb:
            raise ValueError("data_parallel_external_lb can only "
                             "be set when data_parallel_size > 1")

    if self.distributed_executor_backend == "external_launcher":
        import os
        os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
        logger.info("Disabling V1 multiprocessing for external launcher.")

    if self.enable_eplb:
        if not current_platform.is_cuda():
            raise ValueError(
                "Expert parallelism load balancing is only supported on "
                "CUDA devices now.")
        if self.eplb_config.num_redundant_experts < 0:
            raise ValueError(
                "num_redundant_experts must be non-negative, but got "
                f"{self.eplb_config.num_redundant_experts}.")
        if not self.enable_expert_parallel:
            raise ValueError(
                "enable_expert_parallel must be True to use EPLB.")
        if self.tensor_parallel_size * self.data_parallel_size <= 1:
            raise ValueError(
                "EPLB requires tensor_parallel_size or data_parallel_size "
                f"to be greater than 1, but got "
                f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
            )
    else:
        if self.eplb_config.num_redundant_experts != 0:
            raise ValueError(
                "num_redundant_experts should be used with EPLB."
                f"{self.eplb_config.num_redundant_experts}.")
    if self.distributed_executor_backend is None and self.world_size > 1:
        # We use multiprocessing by default if world_size fits on the
        # current node and we aren't in a ray placement group.

        from vllm.executor import ray_utils
        backend: DistributedExecutorBackend = "mp"
        ray_found = ray_utils.ray_is_available()
        if current_platform.is_neuron():
            # neuron uses single process to control multiple devices
            backend = "uni"
        elif current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
            backend = "uni"
        elif (current_platform.is_cuda()
              and cuda_device_count_stateless() < self.world_size):
            if not ray_found:
                raise ValueError("Unable to load Ray: "
                                 f"{ray_utils.ray_import_err}. Ray is "
                                 "required for multi-node inference, "
                                 "please install Ray with `pip install "
                                 "ray`.")
            backend = "ray"
        elif self.data_parallel_backend == "ray":
            logger.info("Using ray distributed inference because "
                        "data_parallel_backend is ray")
            backend = "ray"
        elif ray_found:
            if self.placement_group:
                backend = "ray"
            else:
                from ray import is_initialized as ray_is_initialized
                if ray_is_initialized():
                    from ray.util import get_current_placement_group
                    if get_current_placement_group():
                        backend = "ray"
        self.distributed_executor_backend = backend
        logger.debug("Defaulting to use %s for distributed inference",
                     backend)

    if self.distributed_executor_backend is None and self.world_size == 1:
        self.distributed_executor_backend = "uni"

_verify_args

_verify_args() -> Self
Source code in vllm/config/parallel.py
@model_validator(mode='after')
def _verify_args(self) -> Self:
    # Lazy import to avoid circular import
    from vllm.executor.executor_base import ExecutorBase
    from vllm.platforms import current_platform
    if self.distributed_executor_backend is not None and not isinstance(
            self.distributed_executor_backend, str) and not (isinstance(
                self.distributed_executor_backend, type) and issubclass(
                    self.distributed_executor_backend, ExecutorBase)):
        raise ValueError(
            "Unrecognized distributed executor backend "
            f"{self.distributed_executor_backend}. Supported "
            "values are 'ray', 'mp' 'uni', 'external_launcher', "
            " custom ExecutorBase subclass or its import path.")
    if self.use_ray:
        from vllm.executor import ray_utils
        ray_utils.assert_ray_available()

    if not current_platform.use_custom_allreduce():
        self.disable_custom_all_reduce = True
        logger.debug(
            "Disabled the custom all-reduce kernel because it is not "
            "supported on current platform.")
    if self.ray_workers_use_nsight and not self.use_ray:
        raise ValueError("Unable to use nsight profiling unless workers "
                         "run with Ray.")

    return self

compute_hash

compute_hash()

Provide a hash that uniquely identifies all the configs that affect the structure of the computation graph from input ids/embeddings to the final hidden states, excluding anything before input ids/embeddings and after the final hidden states.

Source code in vllm/config/parallel.py
def compute_hash(self):
    """
    Provide a hash that uniquely identifies all the configs
    that affect the structure of the computation
    graph from input ids/embeddings to the final hidden states,
    excluding anything before input ids/embeddings and after
    the final hidden states.
    """
    factors: list[Any] = []
    factors.append(self.pipeline_parallel_size)
    factors.append(self.tensor_parallel_size)
    factors.append(self.enable_expert_parallel)
    factors.append(self.data_parallel_size)
    factors.append(envs.VLLM_ALL2ALL_BACKEND)
    return hashlib.sha256(str(factors).encode()).hexdigest()

get_next_dp_init_port

get_next_dp_init_port() -> int

We might need to initialize process groups in multiple processes that is related to data parallelism, e.g. both in the worker and in the engine, which can live in different processes. To avoid port conflicts, we pop a new port from the prepared port list each time we need to initialize a new process group related to data parallelism.

Source code in vllm/config/parallel.py
def get_next_dp_init_port(self) -> int:
    """
    We might need to initialize process groups in multiple
    processes that is related to data parallelism,
    e.g. both in the worker and in the engine, which
    can live in different processes. To avoid port conflicts, we
    pop a new port from the prepared port list each time we need to
    initialize a new process group related to data parallelism.
    """
    if self._data_parallel_master_port_list:
        answer = self._data_parallel_master_port_list.pop()
    else:
        answer = self.data_parallel_master_port
        self.data_parallel_master_port += 1

    return answer

has_unfinished_dp staticmethod

has_unfinished_dp(
    dp_group: ProcessGroup, has_unfinished: bool
) -> bool
Source code in vllm/config/parallel.py
@staticmethod
def has_unfinished_dp(dp_group: ProcessGroup,
                      has_unfinished: bool) -> bool:
    tensor = torch.tensor([has_unfinished],
                          dtype=torch.int32,
                          device="cpu")
    # dp rank 0: has_unfinished_seqs=True
    # dp rank 1: has_unfinished_seqs=False
    # aggregated: has_unfinished_seqs=True
    # so this is an OR operation, i.e. MAX in integers
    torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group)
    aggregated_has_unfinished = bool(tensor.item())
    return aggregated_has_unfinished

stateless_init_dp_group

stateless_init_dp_group() -> ProcessGroup
Source code in vllm/config/parallel.py
def stateless_init_dp_group(self) -> ProcessGroup:
    # NOTE: In high-concurrency scenarios multiple processes
    # can pick the same (currently free) port through a race
    # condition when calling `get_open_port()`. When the first
    # process binds the port the others will subsequently fail
    # with `torch.distributed.DistNetworkError: EADDRINUSE`.
    # To make the initialization more robust we retry a few times
    # with a fresh port whenever this specific error is observed.
    from torch.distributed import DistNetworkError

    from vllm.distributed.utils import (
        stateless_init_torch_distributed_process_group)

    max_retries = 5
    last_exc: Optional[Exception] = None
    for _ in range(max_retries):
        try:
            # use gloo since the engine process might not have cuda device
            return stateless_init_torch_distributed_process_group(
                self.data_parallel_master_ip,
                self.get_next_dp_init_port(),
                self.data_parallel_rank,
                self.data_parallel_size,
                backend="gloo")
        except DistNetworkError as e:
            # We only want to retry when the root cause is EADDRINUSE.
            if "EADDRINUSE" in str(e):
                logger.warning(
                    "Address already in use. Retrying with a new port.")
                last_exc = e
                continue  # try again with a new port
            raise e

    # If we get here all retries have failed.
    assert last_exc is not None
    raise last_exc

sync_kv_cache_memory_size staticmethod

sync_kv_cache_memory_size(
    dp_group: ProcessGroup, kv_cache_memory: int
) -> int
Source code in vllm/config/parallel.py
@staticmethod
def sync_kv_cache_memory_size(dp_group: ProcessGroup,
                              kv_cache_memory: int) -> int:
    if kv_cache_memory == -1:
        kv_cache_memory = torch.iinfo(torch.int64).max
    tensor = torch.tensor([kv_cache_memory],
                          dtype=torch.int64,
                          device="cpu")
    # we cannot use broadcast for stateless dp group since it depends
    # on global rank
    torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group)
    return tensor.item()

StatelessProcessGroup dataclass

A dataclass to hold a metadata store, and the rank, world_size of the group. Only use it to communicate metadata between processes. For data-plane communication, create NCCL-related objects.

Source code in vllm/distributed/utils.py
@dataclasses.dataclass
class StatelessProcessGroup:
    """A dataclass to hold a metadata store, and the rank, world_size of the
    group. Only use it to communicate metadata between processes.
    For data-plane communication, create NCCL-related objects.
    """
    rank: int
    world_size: int
    store: torch._C._distributed_c10d.Store

    # stores a reference to the socket so that the file descriptor stays alive
    socket: Optional[socket.socket]

    data_expiration_seconds: int = 3600  # 1 hour

    # dst rank -> counter
    send_dst_counter: dict[int, int] = dataclasses.field(default_factory=dict)
    # src rank -> counter
    recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict)
    broadcast_send_counter: int = 0
    broadcast_recv_src_counter: dict[int, int] = dataclasses.field(
        default_factory=dict)

    # A deque to store the data entries, with key and timestamp.
    entries: deque[tuple[str,
                         float]] = dataclasses.field(default_factory=deque)

    def __post_init__(self):
        assert self.rank < self.world_size
        self.send_dst_counter = {i: 0 for i in range(self.world_size)}
        self.recv_src_counter = {i: 0 for i in range(self.world_size)}
        self.broadcast_recv_src_counter = {
            i: 0
            for i in range(self.world_size)
        }

    def send_obj(self, obj: Any, dst: int):
        """Send an object to a destination rank."""
        self.expire_data()
        key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
        self.store.set(key, pickle.dumps(obj))
        self.send_dst_counter[dst] += 1
        self.entries.append((key, time.time()))

    def expire_data(self):
        """Expire data that is older than `data_expiration_seconds` seconds."""
        while self.entries:
            # check the oldest entry
            key, timestamp = self.entries[0]
            if time.time() - timestamp > self.data_expiration_seconds:
                self.store.delete_key(key)
                self.entries.popleft()
            else:
                break

    def recv_obj(self, src: int) -> Any:
        """Receive an object from a source rank."""
        obj = pickle.loads(
            self.store.get(
                f"send_to/{self.rank}/{self.recv_src_counter[src]}"))
        self.recv_src_counter[src] += 1
        return obj

    def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:
        """Broadcast an object from a source rank to all other ranks.
        It does not clean up after all ranks have received the object.
        Use it for limited times, e.g., for initialization.
        """
        if self.rank == src:
            self.expire_data()
            key = (f"broadcast_from/{src}/"
                   f"{self.broadcast_send_counter}")
            self.store.set(key, pickle.dumps(obj))
            self.broadcast_send_counter += 1
            self.entries.append((key, time.time()))
            return obj
        else:
            key = (f"broadcast_from/{src}/"
                   f"{self.broadcast_recv_src_counter[src]}")
            recv_obj = pickle.loads(self.store.get(key))
            self.broadcast_recv_src_counter[src] += 1
            return recv_obj

    def all_gather_obj(self, obj: Any) -> list[Any]:
        """All gather an object from all ranks."""
        gathered_objs = []
        for i in range(self.world_size):
            if i == self.rank:
                gathered_objs.append(obj)
                self.broadcast_obj(obj, src=self.rank)
            else:
                recv_obj = self.broadcast_obj(None, src=i)
                gathered_objs.append(recv_obj)
        return gathered_objs

    def barrier(self, timeout: float = 30.0):
        """A robust barrier to synchronize all ranks.


        Uses a multi-phase approach to ensure all processes reach the barrier
        before proceeding:

        1. Each process signals it has reached the barrier

        2. Each process signals that it has confirmed the arrival of all other
        ranks.

        3. Rank 0 waits for all other ranks to signal their departure to ensure
        that all ranks have departed the barrier first.

        Args:
            timeout: Maximum time in seconds to wait for each phase (in seconds)


        Raises:
            RuntimeError: If coordination fails or times out
        """
        # Generate a barrier ID that is globally unique
        try:
            if self.rank == 0:
                barrier_id = f"barrier_{uuid.uuid4()}"
                self.broadcast_obj(barrier_id, src=0)
            else:
                barrier_id = self.broadcast_obj(None, src=0)
        except Exception as e:
            raise RuntimeError("Failed to broadcast barrier_id") from e

        # Phase 1: Signal arrival at barrier
        # Wait for all processes to arrive
        # We need all ranks to confirm the arrival of all other ranks.
        # This is the key synchronization point.
        arrival_key = f"arrival_{barrier_id}_{self.rank}"
        try:
            self.store.set(arrival_key, b"1")
        except Exception as e:
            raise RuntimeError("Failed to signal barrier arrival") from e

        start_time = time.time()
        processes_arrived: set[int] = set()

        while len(processes_arrived) < self.world_size:
            # Check for timeout
            cur_time = time.time()
            if cur_time - start_time > timeout:
                raise RuntimeError("Barrier timed out after %f seconds",
                                   timeout)

            # Check for each process
            for i in range(self.world_size):
                if i in processes_arrived:
                    continue

                key = f"arrival_{barrier_id}_{i}"
                try:
                    # Try to get the key - if it exists, we'll get a value
                    # If it doesn't exist, it will throw an exception
                    self.store.get(key)
                    processes_arrived.add(i)
                except KeyError:
                    # Key doesn't exist yet
                    pass
                except Exception as check_e:
                    logger.debug("Error checking key existence: %s", check_e)
                    sched_yield()

            # Short sleep to avoid tight polling
            if len(processes_arrived) < self.world_size:
                sched_yield()

        # Phase 2: Signal departure from barrier
        # We only care to block at this stage in rank 0, which runs the
        # server side of the TCPStore. We want to make sure that all
        # clients have departed the barrier before rank 0 in case the
        # next thing after the barrier is a shutdown, including tearing
        # down the TCPStore. Other ranks can exit the barrier immediately
        # after signaling their departure.
        departure_key = f"departure_{barrier_id}_{self.rank}"
        try:
            self.store.set(departure_key, b"1")
        except Exception as e:
            raise RuntimeError("Failed to signal barrier departure") from e

        if self.rank != 0:
            return

        # Make rank 0 wait for all processes to signal departure
        start_time = time.time()
        processes_departed: set[int] = set()

        while len(processes_departed) < self.world_size:
            # Check for timeout
            if time.time() - start_time > timeout:
                raise RuntimeError("Barrier departure timed out after %f s",
                                   timeout)

            # Check for each process
            for i in range(self.world_size):
                if i in processes_departed:
                    continue

                key = f"departure_{barrier_id}_{i}"
                try:
                    # Try to get the key - if it exists, we'll get a value
                    # If it doesn't exist, it will throw an exception
                    self.store.get(key)
                    processes_departed.add(i)
                except KeyError:
                    # Key doesn't exist yet
                    pass
                except Exception as check_e:
                    logger.debug("Error checking key existence: %s", check_e)
                    sched_yield()

            # Short sleep to avoid tight polling
            if len(processes_departed) < self.world_size:
                sched_yield()

        # Clean up keys to avoid leaking memory in the store
        for i in range(self.world_size):
            try:
                self.store.delete_key(f"arrival_{barrier_id}_{i}")
            except Exception:
                logger.debug("Error deleting key: %s",
                             f'arrival_{barrier_id}_{i}')

            try:
                self.store.delete_key(f"departure_{barrier_id}_{i}")
            except Exception:
                logger.debug("Error deleting key: %s",
                             f'departure_{barrier_id}_{i}')

    @staticmethod
    def create(
        host: str,
        port: int,
        rank: int,
        world_size: int,
        data_expiration_seconds: int = 3600,
        store_timeout: int = 300,
    ) -> "StatelessProcessGroup":
        """A replacement for `torch.distributed.init_process_group` that does not
        pollute the global state.

        If we have process A and process B called `torch.distributed.init_process_group`
        to form a group, and then we want to form another group with process A, B, C,
        D, it is not possible in PyTorch, because process A and process B have already
        formed a group, and process C and process D cannot join that group. This
        function is a workaround for this issue.

        `torch.distributed.init_process_group` is a global call, while this function
        is a stateless call. It will return a `StatelessProcessGroup` object that can be
        used for exchanging metadata. With this function, process A and process B
        can call `StatelessProcessGroup.create` to form a group, and then process A, B,
        C, and D can call `StatelessProcessGroup.create` to form another group.
        """ # noqa
        launch_server = rank == 0
        if launch_server:
            # listen on the specified interface (instead of 0.0.0.0)
            listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            listen_socket.bind((host, port))
            listen_socket.listen()
            listen_fd = listen_socket.fileno()
        else:
            listen_socket = None
            listen_fd = None

        store = TCPStore(
            host_name=host,
            port=port,
            world_size=world_size,
            is_master=launch_server,
            timeout=timedelta(seconds=store_timeout),
            use_libuv=False,  # for now: github.com/pytorch/pytorch/pull/150215
            master_listen_fd=listen_fd,
        )

        return StatelessProcessGroup(
            rank=rank,
            world_size=world_size,
            store=store,
            socket=listen_socket,
            data_expiration_seconds=data_expiration_seconds)

broadcast_recv_src_counter class-attribute instance-attribute

broadcast_recv_src_counter: dict[int, int] = field(
    default_factory=dict
)

broadcast_send_counter class-attribute instance-attribute

broadcast_send_counter: int = 0

data_expiration_seconds class-attribute instance-attribute

data_expiration_seconds: int = 3600

entries class-attribute instance-attribute

entries: deque[tuple[str, float]] = field(
    default_factory=deque
)

rank instance-attribute

rank: int

recv_src_counter class-attribute instance-attribute

recv_src_counter: dict[int, int] = field(
    default_factory=dict
)

send_dst_counter class-attribute instance-attribute

send_dst_counter: dict[int, int] = field(
    default_factory=dict
)

socket instance-attribute

socket: Optional[socket]

store instance-attribute

store: Store

world_size instance-attribute

world_size: int

__init__

__init__(
    rank: int,
    world_size: int,
    store: Store,
    socket: Optional[socket],
    data_expiration_seconds: int = 3600,
    send_dst_counter: dict[int, int] = dict(),
    recv_src_counter: dict[int, int] = dict(),
    broadcast_send_counter: int = 0,
    broadcast_recv_src_counter: dict[int, int] = dict(),
    entries: deque[tuple[str, float]] = deque(),
) -> None

__post_init__

__post_init__()
Source code in vllm/distributed/utils.py
def __post_init__(self):
    assert self.rank < self.world_size
    self.send_dst_counter = {i: 0 for i in range(self.world_size)}
    self.recv_src_counter = {i: 0 for i in range(self.world_size)}
    self.broadcast_recv_src_counter = {
        i: 0
        for i in range(self.world_size)
    }

all_gather_obj

all_gather_obj(obj: Any) -> list[Any]

All gather an object from all ranks.

Source code in vllm/distributed/utils.py
def all_gather_obj(self, obj: Any) -> list[Any]:
    """All gather an object from all ranks."""
    gathered_objs = []
    for i in range(self.world_size):
        if i == self.rank:
            gathered_objs.append(obj)
            self.broadcast_obj(obj, src=self.rank)
        else:
            recv_obj = self.broadcast_obj(None, src=i)
            gathered_objs.append(recv_obj)
    return gathered_objs

barrier

barrier(timeout: float = 30.0)

A robust barrier to synchronize all ranks.

Uses a multi-phase approach to ensure all processes reach the barrier before proceeding:

  1. Each process signals it has reached the barrier

  2. Each process signals that it has confirmed the arrival of all other ranks.

  3. Rank 0 waits for all other ranks to signal their departure to ensure that all ranks have departed the barrier first.

Parameters:

Name Type Description Default
timeout float

Maximum time in seconds to wait for each phase (in seconds)

30.0

Raises:

Type Description
RuntimeError

If coordination fails or times out

Source code in vllm/distributed/utils.py
def barrier(self, timeout: float = 30.0):
    """A robust barrier to synchronize all ranks.


    Uses a multi-phase approach to ensure all processes reach the barrier
    before proceeding:

    1. Each process signals it has reached the barrier

    2. Each process signals that it has confirmed the arrival of all other
    ranks.

    3. Rank 0 waits for all other ranks to signal their departure to ensure
    that all ranks have departed the barrier first.

    Args:
        timeout: Maximum time in seconds to wait for each phase (in seconds)


    Raises:
        RuntimeError: If coordination fails or times out
    """
    # Generate a barrier ID that is globally unique
    try:
        if self.rank == 0:
            barrier_id = f"barrier_{uuid.uuid4()}"
            self.broadcast_obj(barrier_id, src=0)
        else:
            barrier_id = self.broadcast_obj(None, src=0)
    except Exception as e:
        raise RuntimeError("Failed to broadcast barrier_id") from e

    # Phase 1: Signal arrival at barrier
    # Wait for all processes to arrive
    # We need all ranks to confirm the arrival of all other ranks.
    # This is the key synchronization point.
    arrival_key = f"arrival_{barrier_id}_{self.rank}"
    try:
        self.store.set(arrival_key, b"1")
    except Exception as e:
        raise RuntimeError("Failed to signal barrier arrival") from e

    start_time = time.time()
    processes_arrived: set[int] = set()

    while len(processes_arrived) < self.world_size:
        # Check for timeout
        cur_time = time.time()
        if cur_time - start_time > timeout:
            raise RuntimeError("Barrier timed out after %f seconds",
                               timeout)

        # Check for each process
        for i in range(self.world_size):
            if i in processes_arrived:
                continue

            key = f"arrival_{barrier_id}_{i}"
            try:
                # Try to get the key - if it exists, we'll get a value
                # If it doesn't exist, it will throw an exception
                self.store.get(key)
                processes_arrived.add(i)
            except KeyError:
                # Key doesn't exist yet
                pass
            except Exception as check_e:
                logger.debug("Error checking key existence: %s", check_e)
                sched_yield()

        # Short sleep to avoid tight polling
        if len(processes_arrived) < self.world_size:
            sched_yield()

    # Phase 2: Signal departure from barrier
    # We only care to block at this stage in rank 0, which runs the
    # server side of the TCPStore. We want to make sure that all
    # clients have departed the barrier before rank 0 in case the
    # next thing after the barrier is a shutdown, including tearing
    # down the TCPStore. Other ranks can exit the barrier immediately
    # after signaling their departure.
    departure_key = f"departure_{barrier_id}_{self.rank}"
    try:
        self.store.set(departure_key, b"1")
    except Exception as e:
        raise RuntimeError("Failed to signal barrier departure") from e

    if self.rank != 0:
        return

    # Make rank 0 wait for all processes to signal departure
    start_time = time.time()
    processes_departed: set[int] = set()

    while len(processes_departed) < self.world_size:
        # Check for timeout
        if time.time() - start_time > timeout:
            raise RuntimeError("Barrier departure timed out after %f s",
                               timeout)

        # Check for each process
        for i in range(self.world_size):
            if i in processes_departed:
                continue

            key = f"departure_{barrier_id}_{i}"
            try:
                # Try to get the key - if it exists, we'll get a value
                # If it doesn't exist, it will throw an exception
                self.store.get(key)
                processes_departed.add(i)
            except KeyError:
                # Key doesn't exist yet
                pass
            except Exception as check_e:
                logger.debug("Error checking key existence: %s", check_e)
                sched_yield()

        # Short sleep to avoid tight polling
        if len(processes_departed) < self.world_size:
            sched_yield()

    # Clean up keys to avoid leaking memory in the store
    for i in range(self.world_size):
        try:
            self.store.delete_key(f"arrival_{barrier_id}_{i}")
        except Exception:
            logger.debug("Error deleting key: %s",
                         f'arrival_{barrier_id}_{i}')

        try:
            self.store.delete_key(f"departure_{barrier_id}_{i}")
        except Exception:
            logger.debug("Error deleting key: %s",
                         f'departure_{barrier_id}_{i}')

broadcast_obj

broadcast_obj(obj: Optional[Any], src: int) -> Any

Broadcast an object from a source rank to all other ranks. It does not clean up after all ranks have received the object. Use it for limited times, e.g., for initialization.

Source code in vllm/distributed/utils.py
def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:
    """Broadcast an object from a source rank to all other ranks.
    It does not clean up after all ranks have received the object.
    Use it for limited times, e.g., for initialization.
    """
    if self.rank == src:
        self.expire_data()
        key = (f"broadcast_from/{src}/"
               f"{self.broadcast_send_counter}")
        self.store.set(key, pickle.dumps(obj))
        self.broadcast_send_counter += 1
        self.entries.append((key, time.time()))
        return obj
    else:
        key = (f"broadcast_from/{src}/"
               f"{self.broadcast_recv_src_counter[src]}")
        recv_obj = pickle.loads(self.store.get(key))
        self.broadcast_recv_src_counter[src] += 1
        return recv_obj

create staticmethod

create(
    host: str,
    port: int,
    rank: int,
    world_size: int,
    data_expiration_seconds: int = 3600,
    store_timeout: int = 300,
) -> StatelessProcessGroup

A replacement for torch.distributed.init_process_group that does not pollute the global state.

If we have process A and process B called torch.distributed.init_process_group to form a group, and then we want to form another group with process A, B, C, D, it is not possible in PyTorch, because process A and process B have already formed a group, and process C and process D cannot join that group. This function is a workaround for this issue.

torch.distributed.init_process_group is a global call, while this function is a stateless call. It will return a StatelessProcessGroup object that can be used for exchanging metadata. With this function, process A and process B can call StatelessProcessGroup.create to form a group, and then process A, B, C, and D can call StatelessProcessGroup.create to form another group.

Source code in vllm/distributed/utils.py
@staticmethod
def create(
    host: str,
    port: int,
    rank: int,
    world_size: int,
    data_expiration_seconds: int = 3600,
    store_timeout: int = 300,
) -> "StatelessProcessGroup":
    """A replacement for `torch.distributed.init_process_group` that does not
    pollute the global state.

    If we have process A and process B called `torch.distributed.init_process_group`
    to form a group, and then we want to form another group with process A, B, C,
    D, it is not possible in PyTorch, because process A and process B have already
    formed a group, and process C and process D cannot join that group. This
    function is a workaround for this issue.

    `torch.distributed.init_process_group` is a global call, while this function
    is a stateless call. It will return a `StatelessProcessGroup` object that can be
    used for exchanging metadata. With this function, process A and process B
    can call `StatelessProcessGroup.create` to form a group, and then process A, B,
    C, and D can call `StatelessProcessGroup.create` to form another group.
    """ # noqa
    launch_server = rank == 0
    if launch_server:
        # listen on the specified interface (instead of 0.0.0.0)
        listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        listen_socket.bind((host, port))
        listen_socket.listen()
        listen_fd = listen_socket.fileno()
    else:
        listen_socket = None
        listen_fd = None

    store = TCPStore(
        host_name=host,
        port=port,
        world_size=world_size,
        is_master=launch_server,
        timeout=timedelta(seconds=store_timeout),
        use_libuv=False,  # for now: github.com/pytorch/pytorch/pull/150215
        master_listen_fd=listen_fd,
    )

    return StatelessProcessGroup(
        rank=rank,
        world_size=world_size,
        store=store,
        socket=listen_socket,
        data_expiration_seconds=data_expiration_seconds)

expire_data

expire_data()

Expire data that is older than data_expiration_seconds seconds.

Source code in vllm/distributed/utils.py
def expire_data(self):
    """Expire data that is older than `data_expiration_seconds` seconds."""
    while self.entries:
        # check the oldest entry
        key, timestamp = self.entries[0]
        if time.time() - timestamp > self.data_expiration_seconds:
            self.store.delete_key(key)
            self.entries.popleft()
        else:
            break

recv_obj

recv_obj(src: int) -> Any

Receive an object from a source rank.

Source code in vllm/distributed/utils.py
def recv_obj(self, src: int) -> Any:
    """Receive an object from a source rank."""
    obj = pickle.loads(
        self.store.get(
            f"send_to/{self.rank}/{self.recv_src_counter[src]}"))
    self.recv_src_counter[src] += 1
    return obj

send_obj

send_obj(obj: Any, dst: int)

Send an object to a destination rank.

Source code in vllm/distributed/utils.py
def send_obj(self, obj: Any, dst: int):
    """Send an object to a destination rank."""
    self.expire_data()
    key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
    self.store.set(key, pickle.dumps(obj))
    self.send_dst_counter[dst] += 1
    self.entries.append((key, time.time()))

get_ep_group

get_ep_group() -> GroupCoordinator
Source code in vllm/distributed/parallel_state.py
def get_ep_group() -> GroupCoordinator:
    assert _EP is not None, ("expert parallel group is not initialized")
    return _EP

get_node_count

get_node_count() -> int

Return the total number of nodes in the distributed environment.

Source code in vllm/distributed/parallel_state.py
def get_node_count() -> int:
    """Return the total number of nodes in the distributed environment. """
    assert _NODE_COUNT is not None, (
        "distributed environment is not initialized")
    return _NODE_COUNT

in_the_same_node_as

in_the_same_node_as(
    pg: Union[ProcessGroup, StatelessProcessGroup],
    source_rank: int = 0,
) -> list[bool]

This is a collective operation that returns if each rank is in the same node as the source rank. It tests if processes are attached to the same memory system (shared access to shared memory).

Source code in vllm/distributed/parallel_state.py
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
                        source_rank: int = 0) -> list[bool]:
    """
    This is a collective operation that returns if each rank is in the same node
    as the source rank. It tests if processes are attached to the same
    memory system (shared access to shared memory).
    """
    if isinstance(pg, ProcessGroup):
        assert torch.distributed.get_backend(
            pg) != torch.distributed.Backend.NCCL, (
                "in_the_same_node_as should be tested with a non-NCCL group.")
        # local rank inside the group
        rank = torch.distributed.get_rank(group=pg)
        world_size = torch.distributed.get_world_size(group=pg)

        # global ranks of the processes in the group
        ranks = torch.distributed.get_process_group_ranks(pg)
    else:
        rank = pg.rank
        world_size = pg.world_size
        ranks = list(range(world_size))

    # local tensor in each process to store the result
    is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)

    magic_message = b"magic_message"
    shm = None

    try:
        with contextlib.suppress(OSError):
            if rank == source_rank:
                # create a shared memory segment
                shm = shared_memory.SharedMemory(create=True, size=128)
                shm.buf[:len(magic_message)] = magic_message
                if isinstance(pg, ProcessGroup):
                    torch.distributed.broadcast_object_list(
                        [shm.name], src=ranks[source_rank], group=pg)
                else:
                    pg.broadcast_obj(shm.name, src=source_rank)
                is_in_the_same_node[rank] = 1
            else:
                # try to open the shared memory segment
                if isinstance(pg, ProcessGroup):
                    recv = [None]
                    torch.distributed.broadcast_object_list(
                        recv, src=ranks[source_rank], group=pg)
                    name = recv[0]
                else:
                    name = pg.broadcast_obj(None, src=source_rank)
                # fix to https://stackoverflow.com/q/62748654/9191338
                # Python incorrectly tracks shared memory even if it is not
                # created by the process. The following patch is a workaround.
                with patch("multiprocessing.resource_tracker.register",
                           lambda *args, **kwargs: None):
                    shm = shared_memory.SharedMemory(name=name)
                if shm.buf[:len(magic_message)] == magic_message:
                    is_in_the_same_node[rank] = 1
    except Exception as e:
        logger.error("Error ignored in is_in_the_same_node: %s", e)
    finally:
        if shm:
            shm.close()

    if isinstance(pg, ProcessGroup):
        torch.distributed.barrier(group=pg)
    else:
        pg.barrier()

    # clean up the shared memory segment
    with contextlib.suppress(OSError):
        if rank == source_rank and shm:
            shm.unlink()

    if isinstance(pg, ProcessGroup):
        torch.distributed.all_reduce(is_in_the_same_node, group=pg)
        aggregated_data = is_in_the_same_node
    else:
        aggregated_data = torch.zeros_like(is_in_the_same_node)
        for i in range(world_size):
            rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
            aggregated_data += rank_data

    return [x == 1 for x in aggregated_data.tolist()]

init_logger

init_logger(name: str) -> _VllmLogger

The main purpose of this function is to ensure that loggers are retrieved in such a way that we can be sure the root vllm logger has already been configured.

Source code in vllm/logger.py
def init_logger(name: str) -> _VllmLogger:
    """The main purpose of this function is to ensure that loggers are
    retrieved in such a way that we can be sure the root vllm logger has
    already been configured."""

    logger = logging.getLogger(name)

    for method_name, method in _METHODS_TO_PATCH.items():
        setattr(logger, method_name, MethodType(method, logger))

    return cast(_VllmLogger, logger)

rearrange_expert_weights_inplace

rearrange_expert_weights_inplace(
    old_global_expert_indices: Tensor,
    new_global_expert_indices: Tensor,
    expert_weights: Sequence[Iterable[Tensor]],
    ep_group: ProcessGroup,
    is_profile: bool = False,
    rank_mapping: Optional[dict[int, int]] = None,
) -> None

Rearranges the expert weights in place according to the new expert indices.

The value of the indices arguments are logical indices of the experts, while keys are physical.

Parameters:

Name Type Description Default
old_global_expert_indices Tensor

Shape (num_moe_layers, num_physical_experts).

required
new_global_expert_indices Tensor

Shape (num_moe_layers, num_physical_experts).

required
expert_weights Sequence[Iterable[Tensor]]

A sequence of shape (num_moe_layers)(weight_count) of tensors of shape (num_local_physical_experts, hidden_size_i). For example, a linear layer may have up and down projection, so weight_count = 2. Each weight's hidden size can be different.

required
ep_group ProcessGroup

The device process group for expert parallelism.

required
is_profile bool

If True, do not perform any actual weight copy. This is used during profile run, where we only perform dummy communications to reserve enough memory for the buffers.

False
rank_mapping Optional[dict[int, int]]

A dictionary mapping old rank to new rank.

None
Source code in vllm/distributed/eplb/rebalance_execute.py
def rearrange_expert_weights_inplace(
    old_global_expert_indices: torch.Tensor,
    new_global_expert_indices: torch.Tensor,
    expert_weights: Sequence[Iterable[torch.Tensor]],
    ep_group: ProcessGroup,
    is_profile: bool = False,
    rank_mapping: Optional[dict[int, int]] = None,
) -> None:
    """
    Rearranges the expert weights in place according to the new expert indices.

    The value of the indices arguments are logical indices of the experts,
    while keys are physical.

    Args:
        old_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
        new_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
        expert_weights: A sequence of shape (num_moe_layers)(weight_count)
            of tensors of shape (num_local_physical_experts, hidden_size_i).
            For example, a linear layer may have up and down projection,
            so weight_count = 2. Each weight's hidden size can be different.
        ep_group: The device process group for expert parallelism.
        is_profile (bool): If `True`, do not perform any actual weight copy.
            This is used during profile run, where we only perform dummy
            communications to reserve enough memory for the buffers.
        rank_mapping: A dictionary mapping old rank to new rank.
    """
    if rank_mapping is not None:
        if len(rank_mapping) == ep_group.size():
            # scale down
            new_global_expert_indices = \
                _map_new_expert_indices_with_rank_mapping(
                new_global_expert_indices,
                rank_mapping,
            )
        else:
            # scale up
            old_global_expert_indices = \
                _map_old_expert_indices_with_rank_mapping(
                old_global_expert_indices,
                rank_mapping,
                ep_group.size(),
            )

    assert old_global_expert_indices.shape[
        1] == new_global_expert_indices.shape[1]

    num_moe_layers, num_physical_experts = old_global_expert_indices.shape
    assert len(expert_weights) == num_moe_layers

    num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
    assert new_global_expert_indices.shape == (num_moe_layers,
                                               num_physical_experts)

    ep_rank = ep_group.rank()
    ep_size = ep_group.size()
    assert num_physical_experts == ep_size * num_local_physical_experts

    # A buffer to hold the expert weights in one layer during the exchange.
    # NOTE: Currently we assume the same weights across different layers
    # have the same shape.
    expert_weights_buffer = [torch.empty_like(w) for w in expert_weights[0]]

    if is_profile:
        # Maximum send size is to send all local experts to all ranks,
        # So we use a dummy `all_gather` to reserve enough communication buffer
        for weight, buffer in zip(expert_weights[0], expert_weights_buffer):
            # A `/dev/null`-like buffer to avoid real memory allocation
            dummy_recv_buffer = [buffer for _ in range(ep_size)]
            # NOTE(bowen): Needed this barrier to avoid OOM during actual
            # execution. I'm not very sure why this is needed
            torch.distributed.barrier()
            all_gather(
                dummy_recv_buffer,
                weight,
                group=ep_group,
            )
        return

    for layer in range(num_moe_layers):
        # NOTE(bowen): We need this synchronize to run, but I don't know why.
        # If you figure out the reason, please let me know -- thank you!
        torch.cuda.synchronize()
        shuffle_layer(
            num_local_physical_experts,
            ep_rank,
            old_global_expert_indices[layer].tolist(),
            new_global_expert_indices[layer].tolist(),
            expert_weights[layer],
            expert_weights_buffer,
            ep_group,
        )

rebalance_experts

rebalance_experts(
    weight: Tensor,
    num_replicas: int,
    num_groups: int,
    num_nodes: int,
    num_gpus: int,
) -> tuple[Tensor, Tensor, Tensor]

Entry point for expert-parallelism load balancer.

Parameters:

Name Type Description Default
weight Tensor

[layers, num_logical_experts], the load statistics for all logical experts

required
num_replicas int

number of physical experts, must be a multiple of num_gpus

required
num_groups int

number of expert groups

required
num_nodes int

number of server nodes, where the intra-node network (e.g, NVLink) is faster

required
num_gpus int

number of GPUs, must be a multiple of num_nodes

required

Returns:

Name Type Description
physical_to_logical_map Tensor

[layers, num_replicas], the expert index of each replica

logical_to_physical_map Tensor

[layers, num_logical_experts, X], the replica indices for each expert

expert_count Tensor

[layers, num_logical_experts], number of physical replicas for each logical expert

Source code in vllm/distributed/eplb/rebalance_algo.py
def rebalance_experts(
    weight: torch.Tensor,
    num_replicas: int,
    num_groups: int,
    num_nodes: int,
    num_gpus: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Entry point for expert-parallelism load balancer.

    Parameters:
        weight: [layers, num_logical_experts], the load statistics for all
            logical experts
        num_replicas: number of physical experts, must be a multiple of
            `num_gpus`
        num_groups: number of expert groups
        num_nodes: number of server nodes, where the intra-node network
            (e.g, NVLink) is faster
        num_gpus: number of GPUs, must be a multiple of `num_nodes`

    Returns:
        physical_to_logical_map: [layers, num_replicas], the expert index of
            each replica
        logical_to_physical_map: [layers, num_logical_experts, X], the replica
            indices for each expert
        expert_count: [layers, num_logical_experts], number of physical
            replicas for each logical expert
    """
    num_layers, num_logical_experts = weight.shape
    weight = weight.float().cpu()
    if num_groups % num_nodes == 0:
        # use hierarchical load-balance policy
        phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
            weight, num_replicas, num_groups, num_nodes, num_gpus)
    else:
        # use global load-balance policy
        phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
            weight, num_replicas, 1, 1, num_gpus)
    num_redundant_experts = num_replicas - num_logical_experts
    maxlogcnt = num_redundant_experts + 1
    log2phy: torch.Tensor = torch.full(
        (num_layers, num_logical_experts, maxlogcnt),
        -1,
        dtype=torch.int64,
        device=logcnt.device,
    )
    log2phy.view(num_layers, -1).scatter_(
        -1,
        phy2log * maxlogcnt + phyrank,
        torch.arange(num_replicas, dtype=torch.int64,
                     device=log2phy.device).expand(num_layers, -1),
    )
    return phy2log, log2phy, logcnt