Skip to content

vllm.model_executor.layers.rotary_embedding.mrope

MRotaryEmbedding

Bases: RotaryEmbedding

Rotary Embedding with Multimodal Sections.

Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
class MRotaryEmbedding(RotaryEmbedding):
    """Rotary Embedding with Multimodal Sections."""

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        base: float,
        is_neox_style: bool,
        dtype: torch.dtype,
        mrope_section: Optional[list[int]] = None,
    ) -> None:
        # In Qwen2.5-VL, the maximum index value is related to the duration of
        # the input video. We enlarge max_position_embeddings to 4 times to get
        # a larger the cos and sin cache.
        self.cache_max_position_num = max_position_embeddings * 4
        super().__init__(head_size, rotary_dim, self.cache_max_position_num,
                         base, is_neox_style, dtype)

        self.mrope_section = mrope_section
        if self.mrope_section:
            assert sum(self.mrope_section) == rotary_dim // 2

        self.use_triton = current_platform.is_cuda_alike()

    def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        """MRope forward.

        Args:
            positions:
                [num_tokens,] (text only) or
                [3, num_tokens] (T/H/W positions with multimodal inputs)
            query: [num_tokens, num_heads * head_size]
            key: [num_tokens, num_kv_heads * head_size]
        """
        if self.use_triton:
            return self.forward_cuda(positions, query, key)
        else:
            return self.forward_native(positions, query, key)

    def forward_native(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
        offsets: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        """PyTorch-native implementation equivalent to forward().

        Args:
            positions:
                [num_tokens,] (text only) or
                [3, num_tokens] (T/H/W positions with multimodal inputs)
            query: [num_tokens, num_heads * head_size]
            key: [num_tokens, num_kv_heads * head_size]
        """
        assert positions.ndim == 1 or positions.ndim == 2
        assert key is not None

        num_tokens = positions.shape[-1]
        cos_sin = self.cos_sin_cache[positions]
        cos, sin = cos_sin.chunk(2, dim=-1)
        if positions.ndim == 2:
            assert self.mrope_section

            cos = torch.cat([
                m[i]
                for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
            ],
                            dim=-1)
            sin = torch.cat([
                m[i]
                for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
            ],
                            dim=-1)

        query_shape = query.shape
        query = query.view(num_tokens, -1, self.head_size)
        query_rot = query[..., :self.rotary_dim]
        query_pass = query[..., self.rotary_dim:]
        query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin,
                                              self.is_neox_style)
        query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

        key_shape = key.shape
        key = key.view(num_tokens, -1, self.head_size)
        key_rot = key[..., :self.rotary_dim]
        key_pass = key[..., self.rotary_dim:]
        key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin,
                                            self.is_neox_style)
        key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
        return query, key

    def forward_cuda(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
        offsets: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:

        assert positions.ndim == 1 or positions.ndim == 2
        assert key is not None

        num_tokens = positions.shape[-1]
        cos_sin = self.cos_sin_cache[positions]
        cos, sin = cos_sin.chunk(2, dim=-1)
        query_shape = query.shape
        key_shape = key.shape
        if positions.ndim == 2:
            assert self.mrope_section

            q, k = triton_mrope(
                query,
                key,
                cos,
                sin,
                self.mrope_section,
                self.head_size,
                self.rotary_dim,
            )

            return q.reshape(query_shape), k.reshape(key_shape)

        query = query.view(num_tokens, -1, self.head_size)
        query_rot = query[..., :self.rotary_dim]
        query_pass = query[..., self.rotary_dim:]
        query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin,
                                              self.is_neox_style)
        query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

        key = key.view(num_tokens, -1, self.head_size)
        key_rot = key[..., :self.rotary_dim]
        key_pass = key[..., self.rotary_dim:]
        key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin,
                                            self.is_neox_style)
        key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
        return query, key

    @classmethod
    def get_input_positions(
        cls,
        input_tokens: list[int],
        hf_config: PretrainedConfig,
        image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
        video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
        second_per_grid_ts: Optional[list[float]],
        context_len: int = 0,
        seq_len: Optional[int] = None,
        audio_feature_lengths: Optional[torch.Tensor] = None,
        use_audio_in_video: bool = False,
    ) -> tuple[list[list[int]], int]:
        """Get mrope input positions and delta value."""

        image_grid_thw = [] if image_grid_thw is None else image_grid_thw
        video_grid_thw = [] if video_grid_thw is None else video_grid_thw
        second_per_grid_ts = [] if second_per_grid_ts is None else \
            second_per_grid_ts

        llm_positions, mrope_position_delta = \
            cls.get_input_positions_tensor(
                input_tokens=input_tokens,
                hf_config=hf_config,
                image_grid_thw=image_grid_thw,
                video_grid_thw=video_grid_thw,
                second_per_grid_ts=second_per_grid_ts,
                context_len=context_len,
                seq_len=seq_len,
                audio_feature_lengths=audio_feature_lengths,
                use_audio_in_video=use_audio_in_video,
            )

        return llm_positions.tolist(), mrope_position_delta

    @classmethod
    def get_input_positions_tensor(
        cls,
        input_tokens: list[int],
        hf_config: PretrainedConfig,
        image_grid_thw: Union[list[list[int]], torch.Tensor],
        video_grid_thw: Union[list[list[int]], torch.Tensor],
        second_per_grid_ts: list[float],
        context_len: int = 0,
        seq_len: Optional[int] = None,
        audio_feature_lengths: Optional[torch.Tensor] = None,
        use_audio_in_video: bool = False,
    ) -> tuple[torch.Tensor, int]:
        from vllm.transformers_utils.config import thinker_uses_mrope
        if thinker_uses_mrope(hf_config):
            return cls._omni_get_input_positions_tensor(
                input_tokens=input_tokens,
                hf_config=hf_config,
                image_grid_thw=image_grid_thw,
                video_grid_thw=video_grid_thw,
                second_per_grid_ts=second_per_grid_ts,
                context_len=context_len,
                seq_len=seq_len,
                audio_feature_lengths=audio_feature_lengths,
                use_audio_in_video=use_audio_in_video,
            )
        elif hf_config.model_type in ["glm4v", "glm4v_moe"]:
            return cls._glm4v_get_input_positions_tensor(
                input_tokens=input_tokens,
                hf_config=hf_config,
                image_grid_thw=image_grid_thw,
                video_grid_thw=video_grid_thw,
                context_len=context_len,
                seq_len=seq_len,
            )
        else:
            return cls._vl_get_input_positions_tensor(
                input_tokens=input_tokens,
                hf_config=hf_config,
                image_grid_thw=image_grid_thw,
                video_grid_thw=video_grid_thw,
                second_per_grid_ts=second_per_grid_ts,
                context_len=context_len,
                seq_len=seq_len,
            )

    @classmethod
    def _glm4v_get_input_positions_tensor(
        cls,
        input_tokens: list[int],
        hf_config: PretrainedConfig,
        image_grid_thw: Union[list[list[int]], torch.Tensor],
        video_grid_thw: Union[list[list[int]], torch.Tensor],
        context_len: int = 0,
        seq_len: Optional[int] = None,
    ) -> tuple[torch.Tensor, int]:
        """Get mrope input positions and delta value for GLM4V."""

        image_token_id = hf_config.image_token_id
        video_start_token_id = hf_config.video_start_token_id
        video_end_token_id = hf_config.video_end_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
        llm_pos_ids_list: list = []

        if not (image_grid_thw is None and video_grid_thw is None):
            if isinstance(image_grid_thw, torch.Tensor):
                image_grid_thw = image_grid_thw.tolist()

            input_token_type: list[str] = []
            video_check_flg = False
            for token in input_tokens:
                if token == video_start_token_id:
                    video_check_flg = True
                elif token == video_end_token_id:
                    video_check_flg = False

                if (token == image_token_id) and (video_check_flg is False):
                    input_token_type.append("image")
                elif (token == image_token_id) and (video_check_flg is True):
                    input_token_type.append("video")
                else:
                    input_token_type.append("text")

            input_type_group: list[tuple[str, int, int]] = []
            for key, group_iter in itertools.groupby(
                    enumerate(input_token_type), lambda x: x[1]):
                group_list = list(group_iter)
                start_index = group_list[0][0]
                end_index = group_list[-1][0] + 1
                input_type_group.append((key, start_index, end_index))

            video_frame_num = 1
            mm_data_idx = 0
            for modality_type, start_idx, end_idx in input_type_group:
                st_idx = llm_pos_ids_list[-1].max() + 1 if len(
                    llm_pos_ids_list) > 0 else 0
                if modality_type == "image":
                    t, h, w = (
                        image_grid_thw[mm_data_idx][0],
                        image_grid_thw[mm_data_idx][1],
                        image_grid_thw[mm_data_idx][2],
                    )
                    llm_grid_t, llm_grid_h, llm_grid_w = \
                        t, h // spatial_merge_size, w // spatial_merge_size

                    t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
                        -1, llm_grid_h * llm_grid_w).flatten()
                    h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
                        llm_grid_t, -1, llm_grid_w).flatten()
                    w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
                        llm_grid_t, llm_grid_h, -1).flatten()
                    llm_pos_ids_list.append(
                        torch.stack([t_index, h_index, w_index]) + st_idx)
                    mm_data_idx += 1

                elif modality_type == "video":
                    t, h, w = (
                        video_frame_num,
                        image_grid_thw[mm_data_idx][1],
                        image_grid_thw[mm_data_idx][2],
                    )
                    llm_grid_t, llm_grid_h, llm_grid_w = \
                        t, h // spatial_merge_size, w // spatial_merge_size

                    for t_idx in range(llm_grid_t):
                        t_index = torch.tensor(t_idx).view(-1, 1).expand(
                            -1, llm_grid_h * llm_grid_w).flatten()
                        h_index = torch.arange(llm_grid_h).view(
                            1, -1, 1).expand(1, -1, llm_grid_w).flatten()
                        w_index = torch.arange(llm_grid_w).view(
                            1, 1, -1).expand(1, llm_grid_h, -1).flatten()
                        llm_pos_ids_list.append(
                            torch.stack([t_index, h_index, w_index]) + st_idx)

                    mm_data_idx += 1
                    video_frame_num += 1

                else:
                    text_len = end_idx - start_idx
                    llm_pos_ids_list.append(
                        torch.arange(text_len).view(1, -1).expand(3, -1) +
                        st_idx)
                    video_frame_num = 1

        else:
            text_len = len(input_tokens)
            llm_pos_ids_list.append(
                torch.arange(text_len).view(1, -1).expand(3, -1))

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
        llm_positions = llm_positions[:, context_len:seq_len]
        mrope_position_delta = (llm_positions.max() + 1 -
                                len(input_tokens)).item()
        return llm_positions, mrope_position_delta

    @classmethod
    def _vl_get_input_positions_tensor(
        cls,
        input_tokens: list[int],
        hf_config: PretrainedConfig,
        image_grid_thw: Union[list[list[int]], torch.Tensor],
        video_grid_thw: Union[list[list[int]], torch.Tensor],
        second_per_grid_ts: list[float],
        context_len: int = 0,
        seq_len: Optional[int] = None,
    ) -> tuple[torch.Tensor, int]:
        """Get mrope input positions and delta value."""

        image_token_id = hf_config.image_token_id
        video_token_id = hf_config.video_token_id
        vision_start_token_id = hf_config.vision_start_token_id
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
        tokens_per_second = getattr(hf_config.vision_config,
                                    "tokens_per_second", 1.0)

        input_tokens_tensor = torch.tensor(input_tokens)
        vision_start_indices = torch.argwhere(
            input_tokens_tensor == vision_start_token_id).squeeze(1)
        vision_tokens = input_tokens_tensor[vision_start_indices + 1]
        image_nums = (vision_tokens == image_token_id).sum()
        video_nums = (vision_tokens == video_token_id).sum()
        llm_pos_ids_list: list = []

        st = 0
        remain_images, remain_videos = image_nums, video_nums

        image_index, video_index = 0, 0
        for _ in range(image_nums + video_nums):
            video_second_per_grid_t = 0.0
            if image_token_id in input_tokens and remain_images > 0:
                ed_image = input_tokens.index(image_token_id, st)
            else:
                ed_image = len(input_tokens) + 1
            if video_token_id in input_tokens and remain_videos > 0:
                ed_video = input_tokens.index(video_token_id, st)
            else:
                ed_video = len(input_tokens) + 1
            if ed_image < ed_video:
                t, h, w = (
                    image_grid_thw[image_index][0],
                    image_grid_thw[image_index][1],
                    image_grid_thw[image_index][2],
                )
                image_index += 1
                remain_images -= 1
                ed = ed_image
            else:
                t, h, w = (
                    video_grid_thw[video_index][0],
                    video_grid_thw[video_index][1],
                    video_grid_thw[video_index][2],
                )
                video_second_per_grid_t = 1.0
                if second_per_grid_ts:
                    video_second_per_grid_t = second_per_grid_ts[video_index]
                video_index += 1
                remain_videos -= 1
                ed = ed_video

            llm_grid_t, llm_grid_h, llm_grid_w = \
                t, h // spatial_merge_size, w // spatial_merge_size
            text_len = ed - st

            st_idx = llm_pos_ids_list[-1].max() + 1 if len(
                llm_pos_ids_list) > 0 else 0
            llm_pos_ids_list.append(
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

            t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
                -1, llm_grid_h * llm_grid_w) * video_second_per_grid_t *
                       tokens_per_second).long().flatten()

            h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
                llm_grid_t, -1, llm_grid_w).flatten()
            w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
                llm_grid_t, llm_grid_h, -1).flatten()
            llm_pos_ids_list.append(
                torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
            st = ed + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(
                llm_pos_ids_list) > 0 else 0
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
                torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

        llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
        mrope_position_delta = (llm_positions.max() + 1 -
                                len(input_tokens)).item()
        llm_positions = llm_positions[:, context_len:seq_len]

        return llm_positions, mrope_position_delta

    @classmethod
    def _omni_get_input_positions_tensor(
        cls,
        input_tokens: list[int],
        hf_config: PretrainedConfig,
        image_grid_thw: Union[list[list[int]], torch.Tensor],
        video_grid_thw: Union[list[list[int]], torch.Tensor],
        second_per_grid_ts: Optional[list[float]] = None,
        context_len: int = 0,
        seq_len: Optional[int] = None,
        audio_feature_lengths: Optional[torch.Tensor] = None,
        use_audio_in_video: bool = False,
    ) -> tuple[torch.Tensor, int]:
        """Get mrope input positions and delta value (Qwen2.5-Omni version).

        Differences from MRotaryEmbedding:
            1. Add audio support (and related `audio_feature_lengths`).
            2. Add `use_audio_in_video` option to read audio from video inputs.
                In this case, audio and vision position ids will be split into
                chunks and interleaved.

        Example:

            (V_i are vision position ids, A_i are audio position ids)

            |V_1 ...    V_n|A_1 ...   A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
            |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
        """

        # TODO(fyabc): refactor and share more code with
        #  _vl_get_input_positions_tensor.

        thinker_config = hf_config.thinker_config
        audio_token_id = thinker_config.audio_token_index
        image_token_id = thinker_config.image_token_index
        video_token_id = thinker_config.video_token_index
        audio_start_token_id = thinker_config.audio_start_token_id
        audio_end_token_id = thinker_config.audio_end_token_id
        vision_start_token_id = thinker_config.vision_start_token_id
        vision_end_token_id = thinker_config.vision_end_token_id
        seconds_per_chunk = thinker_config.seconds_per_chunk
        spatial_merge_size = thinker_config.vision_config.spatial_merge_size
        tokens_per_second = getattr(thinker_config.vision_config,
                                    "tokens_per_second", 25)

        if isinstance(image_grid_thw, list):
            image_grid_thw = torch.tensor(image_grid_thw)
        if isinstance(video_grid_thw, list):
            video_grid_thw = torch.tensor(video_grid_thw)

        src_item = input_tokens
        audio_seqlens = audio_feature_lengths
        if not second_per_grid_ts:
            second_per_grid_ts = [1] * video_grid_thw.shape[0]
        audio_idx = 0
        video_idx = 0
        image_idx = 0
        new_src_item: list[int] = []
        llm_pos_ids_list: list[torch.Tensor] = []

        idx = 0
        while idx < len(src_item):
            new_src_item_len = len(new_src_item)
            start_idx = llm_pos_ids_list[-1].max() + 1 if len(
                llm_pos_ids_list) > 0 else 0
            if src_item[idx] not in [
                    audio_token_id, video_token_id, image_token_id
            ]:
                if use_audio_in_video and idx > 0:
                    if src_item[idx] == vision_end_token_id and \
                        src_item[idx - 1] == audio_end_token_id:
                        # processing the <|audio_eos|> before <|vision_eos|>
                        start_idx -= 1
                    elif src_item[idx] == audio_start_token_id and \
                        src_item[idx - 1] == vision_start_token_id:
                        # processing the <|audio_bos|> after <|vision_eos|>
                        start_idx -= 1
                new_src_item.append(src_item[idx])
                llm_pos_ids = torch.tensor([start_idx],
                                           dtype=torch.long).expand(3, -1)
                llm_pos_ids_list.append(llm_pos_ids)
            elif src_item[idx] == audio_token_id:
                assert audio_seqlens is not None
                audio_seqlen = audio_seqlens[audio_idx]
                place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1)
                new_src_item.extend([audio_token_id] * place_num)
                llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx
                llm_pos_ids_list.append(llm_pos_ids)
                audio_idx += 1
            elif src_item[idx] == image_token_id:
                grid_t = image_grid_thw[image_idx][0]
                grid_hs = image_grid_thw[:, 1]
                grid_ws = image_grid_thw[:, 2]
                t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long()
                llm_pos_ids = cls._get_llm_pos_ids_for_vision(
                    start_idx, image_idx, spatial_merge_size, t_index, grid_hs,
                    grid_ws)
                llm_pos_ids_list.append(llm_pos_ids)
                vision_seqlen = image_grid_thw[image_idx].prod() // (
                    spatial_merge_size**2)
                new_src_item.extend([image_token_id] * vision_seqlen)
                image_idx += 1
            elif src_item[idx] == video_token_id and not use_audio_in_video:
                grid_t = video_grid_thw[video_idx][0]
                grid_hs = video_grid_thw[:, 1]
                grid_ws = video_grid_thw[:, 2]
                t_index = (torch.arange(grid_t) *
                           second_per_grid_ts[video_idx] *
                           tokens_per_second).long()
                llm_pos_ids = cls._get_llm_pos_ids_for_vision(
                    start_idx, video_idx, spatial_merge_size, t_index, grid_hs,
                    grid_ws)
                llm_pos_ids_list.append(llm_pos_ids)
                vision_seqlen = video_grid_thw[video_idx].prod() // (
                    spatial_merge_size**2)
                new_src_item.extend([video_token_id] * vision_seqlen)
                video_idx += 1
            else:
                # read audio from video
                assert audio_seqlens is not None
                audio_seqlen = audio_seqlens[audio_idx]
                vision_seqlen = video_grid_thw[video_idx].prod() // (
                    spatial_merge_size**2)
                grid_t = video_grid_thw[video_idx][0]
                grid_h = video_grid_thw[video_idx][1]
                grid_w = video_grid_thw[video_idx][2]
                grid_hs = video_grid_thw[:, 1]
                grid_ws = video_grid_thw[:, 2]
                t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
                t_index = (torch.arange(grid_t) *
                           second_per_grid_ts[video_idx] *
                           tokens_per_second).long()
                t_index_split_chunk = cls._split_list_into_ranges(
                    t_index, t_ntoken_per_chunk)
                place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
                pure_audio_len = place_num - 2
                added_audio_len = 0
                audio_llm_pos_ids_list: list[torch.Tensor] = []
                for t_chunk in t_index_split_chunk:
                    vision_ntoken_per_chunk = len(
                        t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
                    new_src_item.extend([video_token_id] *
                                        vision_ntoken_per_chunk)
                    vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision(
                        start_idx, video_idx, spatial_merge_size, t_chunk,
                        grid_hs, grid_ws).split(1, dim=1)
                    llm_pos_ids_list.extend(vision_llm_pos_ids_list)
                    new_src_item.extend(
                        min(t_ntoken_per_chunk, pure_audio_len -
                            added_audio_len) * [audio_token_id])
                    audio_start_idx = start_idx if len(
                        audio_llm_pos_ids_list
                    ) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1
                    if min(t_ntoken_per_chunk,
                           pure_audio_len - added_audio_len) > 0:
                        audio_llm_pos_ids_list = (torch.arange(
                            min(t_ntoken_per_chunk, pure_audio_len -
                                added_audio_len)).expand(3, -1) +
                                                  audio_start_idx).split(1,
                                                                         dim=1)
                    else:
                        audio_llm_pos_ids_list = []
                    added_audio_len += min(t_ntoken_per_chunk,
                                           pure_audio_len - added_audio_len)
                    llm_pos_ids_list.extend(audio_llm_pos_ids_list)
                if added_audio_len < pure_audio_len:
                    new_src_item.extend(
                        (pure_audio_len - added_audio_len) * [audio_token_id])
                    audio_llm_pos_ids_list = (
                        torch.arange(pure_audio_len - added_audio_len).expand(
                            3, -1) + llm_pos_ids_list[-1].max() + 1).split(
                                1, dim=1)
                    llm_pos_ids_list.extend(audio_llm_pos_ids_list)
                audio_idx += 1
                video_idx += 1
            # move to the next token
            idx += len(new_src_item) - new_src_item_len

        llm_positions = torch.cat(llm_pos_ids_list, dim=1)
        mrope_position_delta = torch.cat(llm_pos_ids_list,
                                         dim=1).max() + 1 - len(src_item)
        llm_positions = llm_positions[:, context_len:seq_len]

        return llm_positions, mrope_position_delta

    @staticmethod
    def _get_llm_pos_ids_for_vision(
        start_idx: int,
        vision_idx: int,
        spatial_merge_size: int,
        t_index: list[int],
        grid_hs: torch.Tensor,
        grid_ws: torch.Tensor,
    ) -> torch.Tensor:
        llm_pos_ids_list = []
        llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
        llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
        h_index = (torch.arange(llm_grid_h).view(1, -1, 1).expand(
            len(t_index), -1, llm_grid_w).flatten())
        w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand(
            len(t_index), llm_grid_h, -1).flatten())
        t_index_tensor = torch.Tensor(t_index).to(llm_grid_h.device).view(
            -1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten()
        _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index])
        llm_pos_ids_list.append(_llm_pos_ids + start_idx)
        llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
        return llm_pos_ids

    @staticmethod
    def _split_list_into_ranges(lst: torch.Tensor,
                                interval: int) -> list[list[int]]:
        ranges: list[list[int]] = [[]
                                   for _ in range((max(lst) // interval) + 1)]
        for num in lst:
            index = num // interval
            ranges[index].append(num)
        return ranges

    @staticmethod
    def get_next_input_positions(
        mrope_position_delta: int,
        context_len: int,
        seq_len: int,
    ) -> list[list[int]]:
        return [
            list(
                range(context_len + mrope_position_delta,
                      seq_len + mrope_position_delta)) for _ in range(3)
        ]

    @staticmethod
    def get_next_input_positions_tensor(out: np.ndarray, out_offset: int,
                                        mrope_position_delta: int,
                                        context_len: int, num_new_tokens: int):

        values = np.arange(mrope_position_delta + context_len,
                           mrope_position_delta + context_len + num_new_tokens,
                           dtype=out.dtype)
        out[:, out_offset:out_offset + num_new_tokens] = values

    @classmethod
    def omni_get_updates_use_audio_in_video(
        cls,
        thinker_config: PretrainedConfig,
        audio_len: int,
        video_grid_thw: Union[list[int], torch.Tensor],
        video_second_per_grid_t: float,
    ) -> list[int]:
        """Get video prompt updates when `use_audio_in_video` is True.

        In this case, audio and vision update ids will be split into
        chunks and interleaved (details in `_omni_get_input_positions_tensor`).

        <|video_bos|><|VIDEO|><|video_eos|> =>
        <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>
        """

        audio_token_id = thinker_config.audio_token_index
        video_token_id = thinker_config.video_token_index
        audio_start_token_id = thinker_config.audio_start_token_id
        audio_end_token_id = thinker_config.audio_end_token_id
        seconds_per_chunk = thinker_config.seconds_per_chunk
        spatial_merge_size = thinker_config.vision_config.spatial_merge_size
        tokens_per_second = getattr(thinker_config.vision_config,
                                    "tokens_per_second", 25)

        grid_t = video_grid_thw[0]
        grid_h = video_grid_thw[1]
        grid_w = video_grid_thw[2]
        t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
        t_index = (torch.arange(grid_t) * video_second_per_grid_t *
                   tokens_per_second).long()
        t_index_split_chunk = cls._split_list_into_ranges(
            t_index, t_ntoken_per_chunk)

        updates = [audio_start_token_id]
        added_audio_len = 0
        for t_chunk in t_index_split_chunk:
            vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // (
                spatial_merge_size**2)
            updates.extend([video_token_id] * vision_ntoken_per_chunk)

            audio_chunk_size = min(t_ntoken_per_chunk,
                                   audio_len - added_audio_len)
            updates.extend(audio_chunk_size * [audio_token_id])
            added_audio_len += audio_chunk_size
        if added_audio_len < audio_len:
            updates.extend((audio_len - added_audio_len) * [audio_token_id])
        updates.extend([audio_end_token_id])

        return updates

cache_max_position_num instance-attribute

cache_max_position_num = max_position_embeddings * 4

mrope_section instance-attribute

mrope_section = mrope_section

use_triton instance-attribute

use_triton = is_cuda_alike()

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: dtype,
    mrope_section: Optional[list[int]] = None,
) -> None
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: torch.dtype,
    mrope_section: Optional[list[int]] = None,
) -> None:
    # In Qwen2.5-VL, the maximum index value is related to the duration of
    # the input video. We enlarge max_position_embeddings to 4 times to get
    # a larger the cos and sin cache.
    self.cache_max_position_num = max_position_embeddings * 4
    super().__init__(head_size, rotary_dim, self.cache_max_position_num,
                     base, is_neox_style, dtype)

    self.mrope_section = mrope_section
    if self.mrope_section:
        assert sum(self.mrope_section) == rotary_dim // 2

    self.use_triton = current_platform.is_cuda_alike()

_get_llm_pos_ids_for_vision staticmethod

_get_llm_pos_ids_for_vision(
    start_idx: int,
    vision_idx: int,
    spatial_merge_size: int,
    t_index: list[int],
    grid_hs: Tensor,
    grid_ws: Tensor,
) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
@staticmethod
def _get_llm_pos_ids_for_vision(
    start_idx: int,
    vision_idx: int,
    spatial_merge_size: int,
    t_index: list[int],
    grid_hs: torch.Tensor,
    grid_ws: torch.Tensor,
) -> torch.Tensor:
    llm_pos_ids_list = []
    llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
    llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
    h_index = (torch.arange(llm_grid_h).view(1, -1, 1).expand(
        len(t_index), -1, llm_grid_w).flatten())
    w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand(
        len(t_index), llm_grid_h, -1).flatten())
    t_index_tensor = torch.Tensor(t_index).to(llm_grid_h.device).view(
        -1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten()
    _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index])
    llm_pos_ids_list.append(_llm_pos_ids + start_idx)
    llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
    return llm_pos_ids

_glm4v_get_input_positions_tensor classmethod

_glm4v_get_input_positions_tensor(
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Union[list[list[int]], Tensor],
    video_grid_thw: Union[list[list[int]], Tensor],
    context_len: int = 0,
    seq_len: Optional[int] = None,
) -> tuple[Tensor, int]

Get mrope input positions and delta value for GLM4V.

Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
@classmethod
def _glm4v_get_input_positions_tensor(
    cls,
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Union[list[list[int]], torch.Tensor],
    video_grid_thw: Union[list[list[int]], torch.Tensor],
    context_len: int = 0,
    seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
    """Get mrope input positions and delta value for GLM4V."""

    image_token_id = hf_config.image_token_id
    video_start_token_id = hf_config.video_start_token_id
    video_end_token_id = hf_config.video_end_token_id
    spatial_merge_size = hf_config.vision_config.spatial_merge_size
    llm_pos_ids_list: list = []

    if not (image_grid_thw is None and video_grid_thw is None):
        if isinstance(image_grid_thw, torch.Tensor):
            image_grid_thw = image_grid_thw.tolist()

        input_token_type: list[str] = []
        video_check_flg = False
        for token in input_tokens:
            if token == video_start_token_id:
                video_check_flg = True
            elif token == video_end_token_id:
                video_check_flg = False

            if (token == image_token_id) and (video_check_flg is False):
                input_token_type.append("image")
            elif (token == image_token_id) and (video_check_flg is True):
                input_token_type.append("video")
            else:
                input_token_type.append("text")

        input_type_group: list[tuple[str, int, int]] = []
        for key, group_iter in itertools.groupby(
                enumerate(input_token_type), lambda x: x[1]):
            group_list = list(group_iter)
            start_index = group_list[0][0]
            end_index = group_list[-1][0] + 1
            input_type_group.append((key, start_index, end_index))

        video_frame_num = 1
        mm_data_idx = 0
        for modality_type, start_idx, end_idx in input_type_group:
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(
                llm_pos_ids_list) > 0 else 0
            if modality_type == "image":
                t, h, w = (
                    image_grid_thw[mm_data_idx][0],
                    image_grid_thw[mm_data_idx][1],
                    image_grid_thw[mm_data_idx][2],
                )
                llm_grid_t, llm_grid_h, llm_grid_w = \
                    t, h // spatial_merge_size, w // spatial_merge_size

                t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
                    -1, llm_grid_h * llm_grid_w).flatten()
                h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
                    llm_grid_t, -1, llm_grid_w).flatten()
                w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
                    llm_grid_t, llm_grid_h, -1).flatten()
                llm_pos_ids_list.append(
                    torch.stack([t_index, h_index, w_index]) + st_idx)
                mm_data_idx += 1

            elif modality_type == "video":
                t, h, w = (
                    video_frame_num,
                    image_grid_thw[mm_data_idx][1],
                    image_grid_thw[mm_data_idx][2],
                )
                llm_grid_t, llm_grid_h, llm_grid_w = \
                    t, h // spatial_merge_size, w // spatial_merge_size

                for t_idx in range(llm_grid_t):
                    t_index = torch.tensor(t_idx).view(-1, 1).expand(
                        -1, llm_grid_h * llm_grid_w).flatten()
                    h_index = torch.arange(llm_grid_h).view(
                        1, -1, 1).expand(1, -1, llm_grid_w).flatten()
                    w_index = torch.arange(llm_grid_w).view(
                        1, 1, -1).expand(1, llm_grid_h, -1).flatten()
                    llm_pos_ids_list.append(
                        torch.stack([t_index, h_index, w_index]) + st_idx)

                mm_data_idx += 1
                video_frame_num += 1

            else:
                text_len = end_idx - start_idx
                llm_pos_ids_list.append(
                    torch.arange(text_len).view(1, -1).expand(3, -1) +
                    st_idx)
                video_frame_num = 1

    else:
        text_len = len(input_tokens)
        llm_pos_ids_list.append(
            torch.arange(text_len).view(1, -1).expand(3, -1))

    llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
    llm_positions = llm_positions[:, context_len:seq_len]
    mrope_position_delta = (llm_positions.max() + 1 -
                            len(input_tokens)).item()
    return llm_positions, mrope_position_delta

_omni_get_input_positions_tensor classmethod

_omni_get_input_positions_tensor(
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Union[list[list[int]], Tensor],
    video_grid_thw: Union[list[list[int]], Tensor],
    second_per_grid_ts: Optional[list[float]] = None,
    context_len: int = 0,
    seq_len: Optional[int] = None,
    audio_feature_lengths: Optional[Tensor] = None,
    use_audio_in_video: bool = False,
) -> tuple[Tensor, int]

Get mrope input positions and delta value (Qwen2.5-Omni version).

Differences from MRotaryEmbedding
  1. Add audio support (and related audio_feature_lengths).
  2. Add use_audio_in_video option to read audio from video inputs. In this case, audio and vision position ids will be split into chunks and interleaved.

Example:

(V_i are vision position ids, A_i are audio position ids)

|V_1 ...    V_n|A_1 ...   A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
|vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
@classmethod
def _omni_get_input_positions_tensor(
    cls,
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Union[list[list[int]], torch.Tensor],
    video_grid_thw: Union[list[list[int]], torch.Tensor],
    second_per_grid_ts: Optional[list[float]] = None,
    context_len: int = 0,
    seq_len: Optional[int] = None,
    audio_feature_lengths: Optional[torch.Tensor] = None,
    use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
    """Get mrope input positions and delta value (Qwen2.5-Omni version).

    Differences from MRotaryEmbedding:
        1. Add audio support (and related `audio_feature_lengths`).
        2. Add `use_audio_in_video` option to read audio from video inputs.
            In this case, audio and vision position ids will be split into
            chunks and interleaved.

    Example:

        (V_i are vision position ids, A_i are audio position ids)

        |V_1 ...    V_n|A_1 ...   A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
        |vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
    """

    # TODO(fyabc): refactor and share more code with
    #  _vl_get_input_positions_tensor.

    thinker_config = hf_config.thinker_config
    audio_token_id = thinker_config.audio_token_index
    image_token_id = thinker_config.image_token_index
    video_token_id = thinker_config.video_token_index
    audio_start_token_id = thinker_config.audio_start_token_id
    audio_end_token_id = thinker_config.audio_end_token_id
    vision_start_token_id = thinker_config.vision_start_token_id
    vision_end_token_id = thinker_config.vision_end_token_id
    seconds_per_chunk = thinker_config.seconds_per_chunk
    spatial_merge_size = thinker_config.vision_config.spatial_merge_size
    tokens_per_second = getattr(thinker_config.vision_config,
                                "tokens_per_second", 25)

    if isinstance(image_grid_thw, list):
        image_grid_thw = torch.tensor(image_grid_thw)
    if isinstance(video_grid_thw, list):
        video_grid_thw = torch.tensor(video_grid_thw)

    src_item = input_tokens
    audio_seqlens = audio_feature_lengths
    if not second_per_grid_ts:
        second_per_grid_ts = [1] * video_grid_thw.shape[0]
    audio_idx = 0
    video_idx = 0
    image_idx = 0
    new_src_item: list[int] = []
    llm_pos_ids_list: list[torch.Tensor] = []

    idx = 0
    while idx < len(src_item):
        new_src_item_len = len(new_src_item)
        start_idx = llm_pos_ids_list[-1].max() + 1 if len(
            llm_pos_ids_list) > 0 else 0
        if src_item[idx] not in [
                audio_token_id, video_token_id, image_token_id
        ]:
            if use_audio_in_video and idx > 0:
                if src_item[idx] == vision_end_token_id and \
                    src_item[idx - 1] == audio_end_token_id:
                    # processing the <|audio_eos|> before <|vision_eos|>
                    start_idx -= 1
                elif src_item[idx] == audio_start_token_id and \
                    src_item[idx - 1] == vision_start_token_id:
                    # processing the <|audio_bos|> after <|vision_eos|>
                    start_idx -= 1
            new_src_item.append(src_item[idx])
            llm_pos_ids = torch.tensor([start_idx],
                                       dtype=torch.long).expand(3, -1)
            llm_pos_ids_list.append(llm_pos_ids)
        elif src_item[idx] == audio_token_id:
            assert audio_seqlens is not None
            audio_seqlen = audio_seqlens[audio_idx]
            place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1)
            new_src_item.extend([audio_token_id] * place_num)
            llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx
            llm_pos_ids_list.append(llm_pos_ids)
            audio_idx += 1
        elif src_item[idx] == image_token_id:
            grid_t = image_grid_thw[image_idx][0]
            grid_hs = image_grid_thw[:, 1]
            grid_ws = image_grid_thw[:, 2]
            t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long()
            llm_pos_ids = cls._get_llm_pos_ids_for_vision(
                start_idx, image_idx, spatial_merge_size, t_index, grid_hs,
                grid_ws)
            llm_pos_ids_list.append(llm_pos_ids)
            vision_seqlen = image_grid_thw[image_idx].prod() // (
                spatial_merge_size**2)
            new_src_item.extend([image_token_id] * vision_seqlen)
            image_idx += 1
        elif src_item[idx] == video_token_id and not use_audio_in_video:
            grid_t = video_grid_thw[video_idx][0]
            grid_hs = video_grid_thw[:, 1]
            grid_ws = video_grid_thw[:, 2]
            t_index = (torch.arange(grid_t) *
                       second_per_grid_ts[video_idx] *
                       tokens_per_second).long()
            llm_pos_ids = cls._get_llm_pos_ids_for_vision(
                start_idx, video_idx, spatial_merge_size, t_index, grid_hs,
                grid_ws)
            llm_pos_ids_list.append(llm_pos_ids)
            vision_seqlen = video_grid_thw[video_idx].prod() // (
                spatial_merge_size**2)
            new_src_item.extend([video_token_id] * vision_seqlen)
            video_idx += 1
        else:
            # read audio from video
            assert audio_seqlens is not None
            audio_seqlen = audio_seqlens[audio_idx]
            vision_seqlen = video_grid_thw[video_idx].prod() // (
                spatial_merge_size**2)
            grid_t = video_grid_thw[video_idx][0]
            grid_h = video_grid_thw[video_idx][1]
            grid_w = video_grid_thw[video_idx][2]
            grid_hs = video_grid_thw[:, 1]
            grid_ws = video_grid_thw[:, 2]
            t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
            t_index = (torch.arange(grid_t) *
                       second_per_grid_ts[video_idx] *
                       tokens_per_second).long()
            t_index_split_chunk = cls._split_list_into_ranges(
                t_index, t_ntoken_per_chunk)
            place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
            pure_audio_len = place_num - 2
            added_audio_len = 0
            audio_llm_pos_ids_list: list[torch.Tensor] = []
            for t_chunk in t_index_split_chunk:
                vision_ntoken_per_chunk = len(
                    t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
                new_src_item.extend([video_token_id] *
                                    vision_ntoken_per_chunk)
                vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision(
                    start_idx, video_idx, spatial_merge_size, t_chunk,
                    grid_hs, grid_ws).split(1, dim=1)
                llm_pos_ids_list.extend(vision_llm_pos_ids_list)
                new_src_item.extend(
                    min(t_ntoken_per_chunk, pure_audio_len -
                        added_audio_len) * [audio_token_id])
                audio_start_idx = start_idx if len(
                    audio_llm_pos_ids_list
                ) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1
                if min(t_ntoken_per_chunk,
                       pure_audio_len - added_audio_len) > 0:
                    audio_llm_pos_ids_list = (torch.arange(
                        min(t_ntoken_per_chunk, pure_audio_len -
                            added_audio_len)).expand(3, -1) +
                                              audio_start_idx).split(1,
                                                                     dim=1)
                else:
                    audio_llm_pos_ids_list = []
                added_audio_len += min(t_ntoken_per_chunk,
                                       pure_audio_len - added_audio_len)
                llm_pos_ids_list.extend(audio_llm_pos_ids_list)
            if added_audio_len < pure_audio_len:
                new_src_item.extend(
                    (pure_audio_len - added_audio_len) * [audio_token_id])
                audio_llm_pos_ids_list = (
                    torch.arange(pure_audio_len - added_audio_len).expand(
                        3, -1) + llm_pos_ids_list[-1].max() + 1).split(
                            1, dim=1)
                llm_pos_ids_list.extend(audio_llm_pos_ids_list)
            audio_idx += 1
            video_idx += 1
        # move to the next token
        idx += len(new_src_item) - new_src_item_len

    llm_positions = torch.cat(llm_pos_ids_list, dim=1)
    mrope_position_delta = torch.cat(llm_pos_ids_list,
                                     dim=1).max() + 1 - len(src_item)
    llm_positions = llm_positions[:, context_len:seq_len]

    return llm_positions, mrope_position_delta

_split_list_into_ranges staticmethod

_split_list_into_ranges(
    lst: Tensor, interval: int
) -> list[list[int]]
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
@staticmethod
def _split_list_into_ranges(lst: torch.Tensor,
                            interval: int) -> list[list[int]]:
    ranges: list[list[int]] = [[]
                               for _ in range((max(lst) // interval) + 1)]
    for num in lst:
        index = num // interval
        ranges[index].append(num)
    return ranges

_vl_get_input_positions_tensor classmethod

_vl_get_input_positions_tensor(
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Union[list[list[int]], Tensor],
    video_grid_thw: Union[list[list[int]], Tensor],
    second_per_grid_ts: list[float],
    context_len: int = 0,
    seq_len: Optional[int] = None,
) -> tuple[Tensor, int]

Get mrope input positions and delta value.

Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
@classmethod
def _vl_get_input_positions_tensor(
    cls,
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Union[list[list[int]], torch.Tensor],
    video_grid_thw: Union[list[list[int]], torch.Tensor],
    second_per_grid_ts: list[float],
    context_len: int = 0,
    seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
    """Get mrope input positions and delta value."""

    image_token_id = hf_config.image_token_id
    video_token_id = hf_config.video_token_id
    vision_start_token_id = hf_config.vision_start_token_id
    spatial_merge_size = hf_config.vision_config.spatial_merge_size
    tokens_per_second = getattr(hf_config.vision_config,
                                "tokens_per_second", 1.0)

    input_tokens_tensor = torch.tensor(input_tokens)
    vision_start_indices = torch.argwhere(
        input_tokens_tensor == vision_start_token_id).squeeze(1)
    vision_tokens = input_tokens_tensor[vision_start_indices + 1]
    image_nums = (vision_tokens == image_token_id).sum()
    video_nums = (vision_tokens == video_token_id).sum()
    llm_pos_ids_list: list = []

    st = 0
    remain_images, remain_videos = image_nums, video_nums

    image_index, video_index = 0, 0
    for _ in range(image_nums + video_nums):
        video_second_per_grid_t = 0.0
        if image_token_id in input_tokens and remain_images > 0:
            ed_image = input_tokens.index(image_token_id, st)
        else:
            ed_image = len(input_tokens) + 1
        if video_token_id in input_tokens and remain_videos > 0:
            ed_video = input_tokens.index(video_token_id, st)
        else:
            ed_video = len(input_tokens) + 1
        if ed_image < ed_video:
            t, h, w = (
                image_grid_thw[image_index][0],
                image_grid_thw[image_index][1],
                image_grid_thw[image_index][2],
            )
            image_index += 1
            remain_images -= 1
            ed = ed_image
        else:
            t, h, w = (
                video_grid_thw[video_index][0],
                video_grid_thw[video_index][1],
                video_grid_thw[video_index][2],
            )
            video_second_per_grid_t = 1.0
            if second_per_grid_ts:
                video_second_per_grid_t = second_per_grid_ts[video_index]
            video_index += 1
            remain_videos -= 1
            ed = ed_video

        llm_grid_t, llm_grid_h, llm_grid_w = \
            t, h // spatial_merge_size, w // spatial_merge_size
        text_len = ed - st

        st_idx = llm_pos_ids_list[-1].max() + 1 if len(
            llm_pos_ids_list) > 0 else 0
        llm_pos_ids_list.append(
            torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

        t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
            -1, llm_grid_h * llm_grid_w) * video_second_per_grid_t *
                   tokens_per_second).long().flatten()

        h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
            llm_grid_t, -1, llm_grid_w).flatten()
        w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
            llm_grid_t, llm_grid_h, -1).flatten()
        llm_pos_ids_list.append(
            torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
        st = ed + llm_grid_t * llm_grid_h * llm_grid_w

    if st < len(input_tokens):
        st_idx = llm_pos_ids_list[-1].max() + 1 if len(
            llm_pos_ids_list) > 0 else 0
        text_len = len(input_tokens) - st
        llm_pos_ids_list.append(
            torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

    llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
    mrope_position_delta = (llm_positions.max() + 1 -
                            len(input_tokens)).item()
    llm_positions = llm_positions[:, context_len:seq_len]

    return llm_positions, mrope_position_delta

forward

forward(
    positions: Tensor,
    query: Tensor,
    key: Optional[Tensor] = None,
) -> tuple[Tensor, Optional[Tensor]]

MRope forward.

Parameters:

Name Type Description Default
positions Tensor

[num_tokens,] (text only) or [3, num_tokens] (T/H/W positions with multimodal inputs)

required
query Tensor

[num_tokens, num_heads * head_size]

required
key Optional[Tensor]

[num_tokens, num_kv_heads * head_size]

None
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
def forward(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    """MRope forward.

    Args:
        positions:
            [num_tokens,] (text only) or
            [3, num_tokens] (T/H/W positions with multimodal inputs)
        query: [num_tokens, num_heads * head_size]
        key: [num_tokens, num_kv_heads * head_size]
    """
    if self.use_triton:
        return self.forward_cuda(positions, query, key)
    else:
        return self.forward_native(positions, query, key)

forward_cuda

forward_cuda(
    positions: Tensor,
    query: Tensor,
    key: Optional[Tensor] = None,
    offsets: Optional[Tensor] = None,
) -> tuple[Tensor, Optional[Tensor]]
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
def forward_cuda(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: Optional[torch.Tensor] = None,
    offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:

    assert positions.ndim == 1 or positions.ndim == 2
    assert key is not None

    num_tokens = positions.shape[-1]
    cos_sin = self.cos_sin_cache[positions]
    cos, sin = cos_sin.chunk(2, dim=-1)
    query_shape = query.shape
    key_shape = key.shape
    if positions.ndim == 2:
        assert self.mrope_section

        q, k = triton_mrope(
            query,
            key,
            cos,
            sin,
            self.mrope_section,
            self.head_size,
            self.rotary_dim,
        )

        return q.reshape(query_shape), k.reshape(key_shape)

    query = query.view(num_tokens, -1, self.head_size)
    query_rot = query[..., :self.rotary_dim]
    query_pass = query[..., self.rotary_dim:]
    query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin,
                                          self.is_neox_style)
    query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

    key = key.view(num_tokens, -1, self.head_size)
    key_rot = key[..., :self.rotary_dim]
    key_pass = key[..., self.rotary_dim:]
    key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin,
                                        self.is_neox_style)
    key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
    return query, key

forward_native

forward_native(
    positions: Tensor,
    query: Tensor,
    key: Optional[Tensor] = None,
    offsets: Optional[Tensor] = None,
) -> tuple[Tensor, Optional[Tensor]]

PyTorch-native implementation equivalent to forward().

Parameters:

Name Type Description Default
positions Tensor

[num_tokens,] (text only) or [3, num_tokens] (T/H/W positions with multimodal inputs)

required
query Tensor

[num_tokens, num_heads * head_size]

required
key Optional[Tensor]

[num_tokens, num_kv_heads * head_size]

None
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
def forward_native(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: Optional[torch.Tensor] = None,
    offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    """PyTorch-native implementation equivalent to forward().

    Args:
        positions:
            [num_tokens,] (text only) or
            [3, num_tokens] (T/H/W positions with multimodal inputs)
        query: [num_tokens, num_heads * head_size]
        key: [num_tokens, num_kv_heads * head_size]
    """
    assert positions.ndim == 1 or positions.ndim == 2
    assert key is not None

    num_tokens = positions.shape[-1]
    cos_sin = self.cos_sin_cache[positions]
    cos, sin = cos_sin.chunk(2, dim=-1)
    if positions.ndim == 2:
        assert self.mrope_section

        cos = torch.cat([
            m[i]
            for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
        ],
                        dim=-1)
        sin = torch.cat([
            m[i]
            for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
        ],
                        dim=-1)

    query_shape = query.shape
    query = query.view(num_tokens, -1, self.head_size)
    query_rot = query[..., :self.rotary_dim]
    query_pass = query[..., self.rotary_dim:]
    query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin,
                                          self.is_neox_style)
    query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

    key_shape = key.shape
    key = key.view(num_tokens, -1, self.head_size)
    key_rot = key[..., :self.rotary_dim]
    key_pass = key[..., self.rotary_dim:]
    key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin,
                                        self.is_neox_style)
    key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
    return query, key

get_input_positions classmethod

get_input_positions(
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Optional[
        Union[list[list[int]], Tensor]
    ],
    video_grid_thw: Optional[
        Union[list[list[int]], Tensor]
    ],
    second_per_grid_ts: Optional[list[float]],
    context_len: int = 0,
    seq_len: Optional[int] = None,
    audio_feature_lengths: Optional[Tensor] = None,
    use_audio_in_video: bool = False,
) -> tuple[list[list[int]], int]

Get mrope input positions and delta value.

Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
@classmethod
def get_input_positions(
    cls,
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
    video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
    second_per_grid_ts: Optional[list[float]],
    context_len: int = 0,
    seq_len: Optional[int] = None,
    audio_feature_lengths: Optional[torch.Tensor] = None,
    use_audio_in_video: bool = False,
) -> tuple[list[list[int]], int]:
    """Get mrope input positions and delta value."""

    image_grid_thw = [] if image_grid_thw is None else image_grid_thw
    video_grid_thw = [] if video_grid_thw is None else video_grid_thw
    second_per_grid_ts = [] if second_per_grid_ts is None else \
        second_per_grid_ts

    llm_positions, mrope_position_delta = \
        cls.get_input_positions_tensor(
            input_tokens=input_tokens,
            hf_config=hf_config,
            image_grid_thw=image_grid_thw,
            video_grid_thw=video_grid_thw,
            second_per_grid_ts=second_per_grid_ts,
            context_len=context_len,
            seq_len=seq_len,
            audio_feature_lengths=audio_feature_lengths,
            use_audio_in_video=use_audio_in_video,
        )

    return llm_positions.tolist(), mrope_position_delta

get_input_positions_tensor classmethod

get_input_positions_tensor(
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Union[list[list[int]], Tensor],
    video_grid_thw: Union[list[list[int]], Tensor],
    second_per_grid_ts: list[float],
    context_len: int = 0,
    seq_len: Optional[int] = None,
    audio_feature_lengths: Optional[Tensor] = None,
    use_audio_in_video: bool = False,
) -> tuple[Tensor, int]
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
@classmethod
def get_input_positions_tensor(
    cls,
    input_tokens: list[int],
    hf_config: PretrainedConfig,
    image_grid_thw: Union[list[list[int]], torch.Tensor],
    video_grid_thw: Union[list[list[int]], torch.Tensor],
    second_per_grid_ts: list[float],
    context_len: int = 0,
    seq_len: Optional[int] = None,
    audio_feature_lengths: Optional[torch.Tensor] = None,
    use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
    from vllm.transformers_utils.config import thinker_uses_mrope
    if thinker_uses_mrope(hf_config):
        return cls._omni_get_input_positions_tensor(
            input_tokens=input_tokens,
            hf_config=hf_config,
            image_grid_thw=image_grid_thw,
            video_grid_thw=video_grid_thw,
            second_per_grid_ts=second_per_grid_ts,
            context_len=context_len,
            seq_len=seq_len,
            audio_feature_lengths=audio_feature_lengths,
            use_audio_in_video=use_audio_in_video,
        )
    elif hf_config.model_type in ["glm4v", "glm4v_moe"]:
        return cls._glm4v_get_input_positions_tensor(
            input_tokens=input_tokens,
            hf_config=hf_config,
            image_grid_thw=image_grid_thw,
            video_grid_thw=video_grid_thw,
            context_len=context_len,
            seq_len=seq_len,
        )
    else:
        return cls._vl_get_input_positions_tensor(
            input_tokens=input_tokens,
            hf_config=hf_config,
            image_grid_thw=image_grid_thw,
            video_grid_thw=video_grid_thw,
            second_per_grid_ts=second_per_grid_ts,
            context_len=context_len,
            seq_len=seq_len,
        )

get_next_input_positions staticmethod

get_next_input_positions(
    mrope_position_delta: int,
    context_len: int,
    seq_len: int,
) -> list[list[int]]
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
@staticmethod
def get_next_input_positions(
    mrope_position_delta: int,
    context_len: int,
    seq_len: int,
) -> list[list[int]]:
    return [
        list(
            range(context_len + mrope_position_delta,
                  seq_len + mrope_position_delta)) for _ in range(3)
    ]

get_next_input_positions_tensor staticmethod

get_next_input_positions_tensor(
    out: ndarray,
    out_offset: int,
    mrope_position_delta: int,
    context_len: int,
    num_new_tokens: int,
)
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
@staticmethod
def get_next_input_positions_tensor(out: np.ndarray, out_offset: int,
                                    mrope_position_delta: int,
                                    context_len: int, num_new_tokens: int):

    values = np.arange(mrope_position_delta + context_len,
                       mrope_position_delta + context_len + num_new_tokens,
                       dtype=out.dtype)
    out[:, out_offset:out_offset + num_new_tokens] = values

omni_get_updates_use_audio_in_video classmethod

omni_get_updates_use_audio_in_video(
    thinker_config: PretrainedConfig,
    audio_len: int,
    video_grid_thw: Union[list[int], Tensor],
    video_second_per_grid_t: float,
) -> list[int]

Get video prompt updates when use_audio_in_video is True.

In this case, audio and vision update ids will be split into chunks and interleaved (details in _omni_get_input_positions_tensor).

<|video_bos|><|VIDEO|><|video_eos|> => <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>

Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
@classmethod
def omni_get_updates_use_audio_in_video(
    cls,
    thinker_config: PretrainedConfig,
    audio_len: int,
    video_grid_thw: Union[list[int], torch.Tensor],
    video_second_per_grid_t: float,
) -> list[int]:
    """Get video prompt updates when `use_audio_in_video` is True.

    In this case, audio and vision update ids will be split into
    chunks and interleaved (details in `_omni_get_input_positions_tensor`).

    <|video_bos|><|VIDEO|><|video_eos|> =>
    <|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>
    """

    audio_token_id = thinker_config.audio_token_index
    video_token_id = thinker_config.video_token_index
    audio_start_token_id = thinker_config.audio_start_token_id
    audio_end_token_id = thinker_config.audio_end_token_id
    seconds_per_chunk = thinker_config.seconds_per_chunk
    spatial_merge_size = thinker_config.vision_config.spatial_merge_size
    tokens_per_second = getattr(thinker_config.vision_config,
                                "tokens_per_second", 25)

    grid_t = video_grid_thw[0]
    grid_h = video_grid_thw[1]
    grid_w = video_grid_thw[2]
    t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
    t_index = (torch.arange(grid_t) * video_second_per_grid_t *
               tokens_per_second).long()
    t_index_split_chunk = cls._split_list_into_ranges(
        t_index, t_ntoken_per_chunk)

    updates = [audio_start_token_id]
    added_audio_len = 0
    for t_chunk in t_index_split_chunk:
        vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // (
            spatial_merge_size**2)
        updates.extend([video_token_id] * vision_ntoken_per_chunk)

        audio_chunk_size = min(t_ntoken_per_chunk,
                               audio_len - added_audio_len)
        updates.extend(audio_chunk_size * [audio_token_id])
        added_audio_len += audio_chunk_size
    if added_audio_len < audio_len:
        updates.extend((audio_len - added_audio_len) * [audio_token_id])
    updates.extend([audio_end_token_id])

    return updates

_triton_qwen2vl_mrope_forward

_triton_qwen2vl_mrope_forward(
    q_ptr,
    k_ptr,
    cos,
    sin,
    num_tokens,
    n_qh: constexpr,
    n_kh: constexpr,
    hd: constexpr,
    rd: constexpr,
    pad_n_qh: constexpr,
    pad_n_kh: constexpr,
    pad_hd: constexpr,
    mrope_section_t: constexpr,
    mrope_section_h: constexpr,
)
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
@triton.jit
def _triton_qwen2vl_mrope_forward(
    q_ptr,
    k_ptr,
    cos,
    sin,
    num_tokens,
    n_qh: tl.constexpr,
    n_kh: tl.constexpr,
    hd: tl.constexpr,
    rd: tl.constexpr,
    pad_n_qh: tl.constexpr,
    pad_n_kh: tl.constexpr,
    pad_hd: tl.constexpr,
    mrope_section_t: tl.constexpr,
    mrope_section_h: tl.constexpr,
):
    # Adapted from
    # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
    # This version supports flatten input tensors from vllm
    # and supports cos and sin cache with shape (3, num_tokens, head_dim // 2)
    # instead of (3, bsz, seq_len, head_dim)
    pid = tl.program_id(0)
    # locate start address
    q_ptr = q_ptr + pid * (n_qh * hd)
    k_ptr = k_ptr + pid * (n_kh * hd)

    # ####################################################################
    # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
    # m of this program instance
    # ####################################################################
    # Note: cos and sin now have shape (3, num_tokens, head_dim // 2)

    t_end = mrope_section_t
    h_end = t_end + mrope_section_h

    # Updated stride calculation for half head_dim
    half_rd = rd // 2
    t_cos = cos + pid * half_rd
    h_cos = t_cos + num_tokens * half_rd
    w_cos = h_cos + num_tokens * half_rd
    t_sin = sin + pid * half_rd
    h_sin = t_sin + num_tokens * half_rd
    w_sin = h_sin + num_tokens * half_rd

    # Updated offsets for half head_dim
    cos_offsets = tl.arange(0, pad_hd // 2)
    t_mask = cos_offsets < t_end
    h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
    w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)

    t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
    h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
    w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0)
    t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0)
    h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0)
    w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0)

    cos_row = t_cos_row + h_cos_row + w_cos_row
    sin_row = t_sin_row + h_sin_row + w_sin_row

    # ####################################################################
    # Load the left and right half of q and k for the current
    # program instance (i.e. for the current token) separately
    # ####################################################################
    # left half of the head
    first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(
        0, pad_hd // 2)[None, :]
    first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(
        0, pad_hd // 2)[None, :]
    first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(
        0, pad_hd // 2)[None, :] < rd // 2)
    first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(
        0, pad_hd // 2)[None, :] < rd // 2)

    q_tile_1 = tl.load(q_ptr + first_half_q_offsets,
                       mask=first_q_mask,
                       other=0).to(sin_row.dtype)
    k_tile_1 = tl.load(k_ptr + first_half_k_offsets,
                       mask=first_k_mask,
                       other=0).to(sin_row.dtype)

    # right half of the head
    second_half_q_offsets = first_half_q_offsets + (rd // 2)
    second_half_k_offsets = first_half_k_offsets + (rd // 2)
    second_q_mask = first_q_mask
    second_k_mask = first_k_mask

    q_tile_2 = tl.load(q_ptr + second_half_q_offsets,
                       mask=second_q_mask,
                       other=0).to(sin_row.dtype)
    k_tile_2 = tl.load(k_ptr + second_half_k_offsets,
                       mask=second_k_mask,
                       other=0).to(sin_row.dtype)

    # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
    # Since cos and sin are now half-size,
    # we use the same cos_row and sin_row for both halves
    new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
    tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
    new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
    tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)

    new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
    tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
    new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
    tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)

triton_mrope

triton_mrope(
    q: Tensor,
    k: Tensor,
    cos: Tensor,
    sin: Tensor,
    mrope_section: list[int],
    head_size: int,
    rotary_dim: int,
) -> tuple[Tensor, Tensor]

Qwen2VL mrope kernel.

Parameters:

Name Type Description Default
query

[num_tokens, num_heads * head_size]

required
key

[num_tokens, num_kv_heads * head_size]

required
cos Tensor

[3, num_tokens, head_size //2 ] (T/H/W positions with multimodal inputs)

required
sin Tensor

[3, num_tokens, head_size //2 ] (T/H/W positions with multimodal inputs)

required
mrope_section list[int]

[t, h, w]

required
head_size int

int

required
Source code in vllm/model_executor/layers/rotary_embedding/mrope.py
def triton_mrope(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
    mrope_section: list[int],
    head_size: int,
    rotary_dim: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Qwen2VL mrope kernel.

    Args:
        query: [num_tokens, num_heads * head_size]
        key: [num_tokens, num_kv_heads * head_size]
        cos: [3, num_tokens, head_size //2 ]
            (T/H/W positions with multimodal inputs)
        sin: [3, num_tokens, head_size //2 ]
            (T/H/W positions with multimodal inputs)
        mrope_section: [t, h, w]
        head_size: int
    """
    n_row, n_q_head_head_dim = q.shape
    n_q_head = n_q_head_head_dim // head_size
    n_kv_head = k.shape[1] // head_size
    pad_hd = triton.next_power_of_2(head_size)
    pad_n_q_head = triton.next_power_of_2(n_q_head)
    pad_n_kv_head = triton.next_power_of_2(n_kv_head)

    # ensure tensors passed into the kernel are contiguous.
    # It will be no-op if they are already contiguous
    q = q.contiguous()
    k = k.contiguous()
    cos = cos.contiguous()
    sin = sin.contiguous()

    _triton_qwen2vl_mrope_forward[(n_row, )](
        q,
        k,
        cos,
        sin,
        n_row,
        n_q_head,
        n_kv_head,
        head_size,
        rotary_dim,
        pad_n_q_head,
        pad_n_kv_head,
        pad_hd,
        mrope_section[0],
        mrope_section[1],
    )
    return q, k