Skip to content

vllm.model_executor.models.minimax_m2

Inference-only MiniMaxM2 model.

MiniMaxM2Attention

Bases: Module

Source code in vllm/model_executor/models/minimax_m2.py
class MiniMaxM2Attention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rotary_dim: int,
        rope_theta: float = 10000,
        rope_scaling: dict[str, Any] | None = None,
        attn_window_size: int | None = None,
        max_position_embeddings: int = 8192,
        head_dim: int | None = None,
        rms_norm_eps: float = 1e-06,
        qkv_bias: bool = False,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = head_dim or (hidden_size // self.total_num_heads)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=qkv_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            per_layer_sliding_window=attn_window_size,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )

        self.q_norm = MiniMaxText01RMSNormTP(
            self.head_dim * self.total_num_heads, eps=rms_norm_eps
        )
        self.k_norm = MiniMaxText01RMSNormTP(
            self.head_dim * self.total_num_kv_heads, eps=rms_norm_eps
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q = self.q_norm(q)
        k = self.k_norm(k)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output

attn instance-attribute

attn = Attention(
    num_heads,
    head_dim,
    scaling,
    num_kv_heads=num_kv_heads,
    per_layer_sliding_window=attn_window_size,
    cache_config=cache_config,
    quant_config=quant_config,
    prefix=f"{prefix}.attn",
)

head_dim instance-attribute

head_dim = head_dim or hidden_size // total_num_heads

hidden_size instance-attribute

hidden_size = hidden_size

k_norm instance-attribute

k_norm = MiniMaxText01RMSNormTP(
    head_dim * total_num_kv_heads, eps=rms_norm_eps
)

kv_size instance-attribute

kv_size = num_kv_heads * head_dim

max_position_embeddings instance-attribute

max_position_embeddings = max_position_embeddings

num_heads instance-attribute

num_heads = total_num_heads // tp_size

num_kv_heads instance-attribute

num_kv_heads = max(1, total_num_kv_heads // tp_size)

o_proj instance-attribute

o_proj = RowParallelLinear(
    total_num_heads * head_dim,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.o_proj",
)

q_norm instance-attribute

q_norm = MiniMaxText01RMSNormTP(
    head_dim * total_num_heads, eps=rms_norm_eps
)

q_size instance-attribute

q_size = num_heads * head_dim

qkv_proj instance-attribute

qkv_proj = QKVParallelLinear(
    hidden_size,
    head_dim,
    total_num_heads,
    total_num_kv_heads,
    bias=qkv_bias,
    quant_config=quant_config,
    prefix=f"{prefix}.qkv_proj",
)

rope_theta instance-attribute

rope_theta = rope_theta

rotary_emb instance-attribute

rotary_emb = get_rope(
    head_dim,
    rotary_dim=rotary_dim,
    max_position=max_position_embeddings,
    base=rope_theta,
    rope_scaling=rope_scaling,
)

scaling instance-attribute

scaling = head_dim ** -0.5

total_num_heads instance-attribute

total_num_heads = num_heads

total_num_kv_heads instance-attribute

total_num_kv_heads = num_kv_heads

__init__

__init__(
    hidden_size: int,
    num_heads: int,
    num_kv_heads: int,
    rotary_dim: int,
    rope_theta: float = 10000,
    rope_scaling: dict[str, Any] | None = None,
    attn_window_size: int | None = None,
    max_position_embeddings: int = 8192,
    head_dim: int | None = None,
    rms_norm_eps: float = 1e-06,
    qkv_bias: bool = False,
    cache_config: CacheConfig | None = None,
    quant_config: QuantizationConfig | None = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/minimax_m2.py
def __init__(
    self,
    hidden_size: int,
    num_heads: int,
    num_kv_heads: int,
    rotary_dim: int,
    rope_theta: float = 10000,
    rope_scaling: dict[str, Any] | None = None,
    attn_window_size: int | None = None,
    max_position_embeddings: int = 8192,
    head_dim: int | None = None,
    rms_norm_eps: float = 1e-06,
    qkv_bias: bool = False,
    cache_config: CacheConfig | None = None,
    quant_config: QuantizationConfig | None = None,
    prefix: str = "",
) -> None:
    super().__init__()
    self.hidden_size = hidden_size
    tp_size = get_tensor_model_parallel_world_size()
    self.total_num_heads = num_heads
    assert self.total_num_heads % tp_size == 0
    self.num_heads = self.total_num_heads // tp_size
    self.total_num_kv_heads = num_kv_heads
    if self.total_num_kv_heads >= tp_size:
        # Number of KV heads is greater than TP size, so we partition
        # the KV heads across multiple tensor parallel GPUs.
        assert self.total_num_kv_heads % tp_size == 0
    else:
        # Number of KV heads is less than TP size, so we replicate
        # the KV heads across multiple tensor parallel GPUs.
        assert tp_size % self.total_num_kv_heads == 0
    self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
    self.head_dim = head_dim or (hidden_size // self.total_num_heads)
    self.q_size = self.num_heads * self.head_dim
    self.kv_size = self.num_kv_heads * self.head_dim
    self.scaling = self.head_dim**-0.5
    self.rope_theta = rope_theta
    self.max_position_embeddings = max_position_embeddings

    self.qkv_proj = QKVParallelLinear(
        hidden_size,
        self.head_dim,
        self.total_num_heads,
        self.total_num_kv_heads,
        bias=qkv_bias,
        quant_config=quant_config,
        prefix=f"{prefix}.qkv_proj",
    )

    self.o_proj = RowParallelLinear(
        self.total_num_heads * self.head_dim,
        hidden_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.o_proj",
    )

    self.rotary_emb = get_rope(
        self.head_dim,
        rotary_dim=rotary_dim,
        max_position=max_position_embeddings,
        base=rope_theta,
        rope_scaling=rope_scaling,
    )
    self.attn = Attention(
        self.num_heads,
        self.head_dim,
        self.scaling,
        num_kv_heads=self.num_kv_heads,
        per_layer_sliding_window=attn_window_size,
        cache_config=cache_config,
        quant_config=quant_config,
        prefix=f"{prefix}.attn",
    )

    self.q_norm = MiniMaxText01RMSNormTP(
        self.head_dim * self.total_num_heads, eps=rms_norm_eps
    )
    self.k_norm = MiniMaxText01RMSNormTP(
        self.head_dim * self.total_num_kv_heads, eps=rms_norm_eps
    )

forward

forward(positions: Tensor, hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/minimax_m2.py
def forward(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
) -> torch.Tensor:
    qkv, _ = self.qkv_proj(hidden_states)
    q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
    q = self.q_norm(q)
    k = self.k_norm(k)
    q, k = self.rotary_emb(positions, q, k)
    attn_output = self.attn(q, k, v)
    output, _ = self.o_proj(attn_output)
    return output

MiniMaxM2DecoderLayer

Bases: Module

Source code in vllm/model_executor/models/minimax_m2.py
class MiniMaxM2DecoderLayer(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
        prefix: str,
        model_config: ModelConfig,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
        if hasattr(config, "max_model_len") and isinstance(config.max_model_len, int):
            max_position_embeddings = max(
                config.max_position_embeddings, config.max_model_len
            )
        # DecoderLayers are created with `make_layers` which passes the prefix
        # with the layer's index.
        layer_idx = int(prefix.split(sep=".")[-1])

        # TODO: support MTP
        attn_window_size = getattr(config, "attn_window_size", None)
        if attn_window_size is not None:
            if isinstance(attn_window_size, list):
                attn_window_size = attn_window_size[layer_idx]
            elif isinstance(attn_window_size, int):
                attn_window_size = attn_window_size
            else:
                raise ValueError(f"Invalid attn_window_size: {attn_window_size}")
            attn_window_size = None if attn_window_size <= 0 else attn_window_size

        # different rope theta for full layer and swa layer
        swa_rope_theta = getattr(config, "swa_rope_theta", -1)
        # default to full rope theta
        swa_rope_theta = rope_theta if swa_rope_theta <= 0 else swa_rope_theta
        rope_theta = swa_rope_theta if attn_window_size is not None else rope_theta

        self.layer_idx = layer_idx
        self.self_attn = MiniMaxM2Attention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rotary_dim=config.rotary_dim,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            attn_window_size=attn_window_size,
            max_position_embeddings=max_position_embeddings,
            rms_norm_eps=config.rms_norm_eps,
            qkv_bias=getattr(config, "attention_bias", False),
            head_dim=getattr(config, "head_dim", None),
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )

        self.block_sparse_moe = MiniMaxM2MoE(
            config=config,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: torch.Tensor | None,
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)

        hidden_states = self.block_sparse_moe(hidden_states)

        return hidden_states, residual

block_sparse_moe instance-attribute

block_sparse_moe = MiniMaxM2MoE(
    config=config,
    quant_config=quant_config,
    prefix=f"{prefix}.mlp",
)

hidden_size instance-attribute

hidden_size = hidden_size

input_layernorm instance-attribute

input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)

layer_idx instance-attribute

layer_idx = layer_idx

post_attention_layernorm instance-attribute

post_attention_layernorm = RMSNorm(
    hidden_size, eps=rms_norm_eps
)

self_attn instance-attribute

self_attn = MiniMaxM2Attention(
    hidden_size=hidden_size,
    num_heads=num_attention_heads,
    num_kv_heads=num_key_value_heads,
    rotary_dim=rotary_dim,
    rope_theta=rope_theta,
    rope_scaling=rope_scaling,
    attn_window_size=attn_window_size,
    max_position_embeddings=max_position_embeddings,
    rms_norm_eps=rms_norm_eps,
    qkv_bias=getattr(config, "attention_bias", False),
    head_dim=getattr(config, "head_dim", None),
    cache_config=cache_config,
    quant_config=quant_config,
    prefix=f"{prefix}.self_attn",
)

__init__

__init__(
    config: PretrainedConfig,
    prefix: str,
    model_config: ModelConfig,
    cache_config: CacheConfig | None = None,
    quant_config: QuantizationConfig | None = None,
) -> None
Source code in vllm/model_executor/models/minimax_m2.py
def __init__(
    self,
    config: PretrainedConfig,
    prefix: str,
    model_config: ModelConfig,
    cache_config: CacheConfig | None = None,
    quant_config: QuantizationConfig | None = None,
) -> None:
    super().__init__()
    self.hidden_size = config.hidden_size
    rope_theta = getattr(config, "rope_theta", 10000)
    rope_scaling = getattr(config, "rope_scaling", None)
    max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
    if hasattr(config, "max_model_len") and isinstance(config.max_model_len, int):
        max_position_embeddings = max(
            config.max_position_embeddings, config.max_model_len
        )
    # DecoderLayers are created with `make_layers` which passes the prefix
    # with the layer's index.
    layer_idx = int(prefix.split(sep=".")[-1])

    # TODO: support MTP
    attn_window_size = getattr(config, "attn_window_size", None)
    if attn_window_size is not None:
        if isinstance(attn_window_size, list):
            attn_window_size = attn_window_size[layer_idx]
        elif isinstance(attn_window_size, int):
            attn_window_size = attn_window_size
        else:
            raise ValueError(f"Invalid attn_window_size: {attn_window_size}")
        attn_window_size = None if attn_window_size <= 0 else attn_window_size

    # different rope theta for full layer and swa layer
    swa_rope_theta = getattr(config, "swa_rope_theta", -1)
    # default to full rope theta
    swa_rope_theta = rope_theta if swa_rope_theta <= 0 else swa_rope_theta
    rope_theta = swa_rope_theta if attn_window_size is not None else rope_theta

    self.layer_idx = layer_idx
    self.self_attn = MiniMaxM2Attention(
        hidden_size=self.hidden_size,
        num_heads=config.num_attention_heads,
        num_kv_heads=config.num_key_value_heads,
        rotary_dim=config.rotary_dim,
        rope_theta=rope_theta,
        rope_scaling=rope_scaling,
        attn_window_size=attn_window_size,
        max_position_embeddings=max_position_embeddings,
        rms_norm_eps=config.rms_norm_eps,
        qkv_bias=getattr(config, "attention_bias", False),
        head_dim=getattr(config, "head_dim", None),
        cache_config=cache_config,
        quant_config=quant_config,
        prefix=f"{prefix}.self_attn",
    )

    self.block_sparse_moe = MiniMaxM2MoE(
        config=config,
        quant_config=quant_config,
        prefix=f"{prefix}.mlp",
    )
    self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
    self.post_attention_layernorm = RMSNorm(
        config.hidden_size, eps=config.rms_norm_eps
    )

forward

forward(
    positions: Tensor,
    hidden_states: Tensor,
    residual: Tensor | None,
) -> Tensor
Source code in vllm/model_executor/models/minimax_m2.py
def forward(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
    residual: torch.Tensor | None,
) -> torch.Tensor:
    # Self Attention
    if residual is None:
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
    else:
        hidden_states, residual = self.input_layernorm(hidden_states, residual)
    hidden_states = self.self_attn(
        positions=positions,
        hidden_states=hidden_states,
    )

    # Fully Connected
    hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)

    hidden_states = self.block_sparse_moe(hidden_states)

    return hidden_states, residual

MiniMaxM2ForCausalLM

Bases: Module, SupportsPP

Source code in vllm/model_executor/models/minimax_m2.py
class MiniMaxM2ForCausalLM(nn.Module, SupportsPP):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
        if hasattr(vllm_config.model_config, "max_model_len"):
            self.config.max_model_len = vllm_config.model_config.max_model_len
        self.model = MiniMaxM2Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
                config.vocab_size, config.hidden_size, quant_config=None
            )
        else:
            self.lm_head = PPMissingLayer()
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors
        )

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs,
    ) -> torch.Tensor | IntermediateTensors:
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        logits = self.logits_processor(self.lm_head, hidden_states)
        return logits

    def make_empty_intermediate_tensors(
        self, batch_size: int, dtype: torch.dtype, device: torch.device
    ) -> IntermediateTensors:
        return IntermediateTensors(
            {
                "hidden_states": torch.zeros(
                    (batch_size, self.config.hidden_size), dtype=dtype, device=device
                ),
                "residual": torch.zeros(
                    (batch_size, self.config.hidden_size), dtype=dtype, device=device
                ),
            }
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()

config instance-attribute

config = config

lm_head instance-attribute

lm_head = ParallelLMHead(
    vocab_size, hidden_size, quant_config=None
)

logits_processor instance-attribute

logits_processor = LogitsProcessor(vocab_size)

model instance-attribute

model = MiniMaxM2Model(
    vllm_config=vllm_config,
    prefix=maybe_prefix(prefix, "model"),
)

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/minimax_m2.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()
    config = vllm_config.model_config.hf_config
    quant_config = vllm_config.quant_config
    self.config = config
    self.quant_config = quant_config
    if hasattr(vllm_config.model_config, "max_model_len"):
        self.config.max_model_len = vllm_config.model_config.max_model_len
    self.model = MiniMaxM2Model(
        vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
    )
    if get_pp_group().is_last_rank:
        self.lm_head = ParallelLMHead(
            config.vocab_size, config.hidden_size, quant_config=None
        )
    else:
        self.lm_head = PPMissingLayer()
    self.logits_processor = LogitsProcessor(config.vocab_size)
    self.make_empty_intermediate_tensors = (
        self.model.make_empty_intermediate_tensors
    )

compute_logits

compute_logits(hidden_states: Tensor) -> Tensor | None
Source code in vllm/model_executor/models/minimax_m2.py
def compute_logits(
    self,
    hidden_states: torch.Tensor,
) -> torch.Tensor | None:
    logits = self.logits_processor(self.lm_head, hidden_states)
    return logits

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: Tensor | None = None,
    **kwargs,
) -> Tensor | IntermediateTensors
Source code in vllm/model_executor/models/minimax_m2.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: torch.Tensor | None = None,
    **kwargs,
) -> torch.Tensor | IntermediateTensors:
    hidden_states = self.model(
        input_ids, positions, intermediate_tensors, inputs_embeds
    )
    return hidden_states

get_expert_mapping

get_expert_mapping() -> list[tuple[str, str, int, str]]
Source code in vllm/model_executor/models/minimax_m2.py
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
    return self.model.get_expert_mapping()

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/minimax_m2.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.model.get_input_embeddings(input_ids)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/minimax_m2.py
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
    loader = AutoWeightsLoader(self)
    return loader.load_weights(weights)

make_empty_intermediate_tensors

make_empty_intermediate_tensors(
    batch_size: int, dtype: dtype, device: device
) -> IntermediateTensors
Source code in vllm/model_executor/models/minimax_m2.py
def make_empty_intermediate_tensors(
    self, batch_size: int, dtype: torch.dtype, device: torch.device
) -> IntermediateTensors:
    return IntermediateTensors(
        {
            "hidden_states": torch.zeros(
                (batch_size, self.config.hidden_size), dtype=dtype, device=device
            ),
            "residual": torch.zeros(
                (batch_size, self.config.hidden_size), dtype=dtype, device=device
            ),
        }
    )

MiniMaxM2MoE

Bases: Module

Source code in vllm/model_executor/models/minimax_m2.py
class MiniMaxM2MoE(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()

        if self.tp_size > config.num_local_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
                f"the number of experts {config.num_local_experts}."
            )
        self.use_routing_bias = getattr(config, "use_routing_bias", False)
        if self.use_routing_bias:
            self.e_score_correction_bias = nn.Parameter(
                torch.empty(config.num_local_experts, dtype=torch.float32)
            )
            self.e_score_correction_bias.weight_loader = (
                MiniMaxM2MoE.ebias_weight_loader
            )
        else:
            self.e_score_correction_bias = None

        self.experts = FusedMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            scoring_func=config.scoring_func,
            use_grouped_topk=True,
            num_expert_group=1,
            topk_group=1,
            e_score_correction_bias=self.e_score_correction_bias,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            reduce_results=False,
            renormalize=True,
            quant_config=quant_config,
            prefix=f"{prefix}.experts",
        )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.num_local_experts,
            bias=False,
            params_dtype=torch.float32,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )

    @staticmethod
    def ebias_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor) -> None:
        assert param.size() == loaded_weight.size()
        param.data.copy_(loaded_weight.to(torch.float32))

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)

        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states.to(torch.float32))
        final_hidden_states = self.experts(
            hidden_states=hidden_states, router_logits=router_logits
        )
        final_hidden_states = final_hidden_states
        if self.tp_size > 1:
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)

        return final_hidden_states.view(num_tokens, hidden_dim)

e_score_correction_bias instance-attribute

e_score_correction_bias = Parameter(
    empty(num_local_experts, dtype=float32)
)

experts instance-attribute

experts = FusedMoE(
    num_experts=num_local_experts,
    top_k=num_experts_per_tok,
    scoring_func=scoring_func,
    use_grouped_topk=True,
    num_expert_group=1,
    topk_group=1,
    e_score_correction_bias=e_score_correction_bias,
    hidden_size=hidden_size,
    intermediate_size=intermediate_size,
    reduce_results=False,
    renormalize=True,
    quant_config=quant_config,
    prefix=f"{prefix}.experts",
)

gate instance-attribute

gate = ReplicatedLinear(
    hidden_size,
    num_local_experts,
    bias=False,
    params_dtype=float32,
    quant_config=None,
    prefix=f"{prefix}.gate",
)

tp_size instance-attribute

use_routing_bias instance-attribute

use_routing_bias = getattr(
    config, "use_routing_bias", False
)

__init__

__init__(
    config: PretrainedConfig,
    quant_config: QuantizationConfig | None = None,
    prefix: str = "",
)
Source code in vllm/model_executor/models/minimax_m2.py
def __init__(
    self,
    config: PretrainedConfig,
    quant_config: QuantizationConfig | None = None,
    prefix: str = "",
):
    super().__init__()
    self.tp_size = get_tensor_model_parallel_world_size()

    if self.tp_size > config.num_local_experts:
        raise ValueError(
            f"Tensor parallel size {self.tp_size} is greater than "
            f"the number of experts {config.num_local_experts}."
        )
    self.use_routing_bias = getattr(config, "use_routing_bias", False)
    if self.use_routing_bias:
        self.e_score_correction_bias = nn.Parameter(
            torch.empty(config.num_local_experts, dtype=torch.float32)
        )
        self.e_score_correction_bias.weight_loader = (
            MiniMaxM2MoE.ebias_weight_loader
        )
    else:
        self.e_score_correction_bias = None

    self.experts = FusedMoE(
        num_experts=config.num_local_experts,
        top_k=config.num_experts_per_tok,
        scoring_func=config.scoring_func,
        use_grouped_topk=True,
        num_expert_group=1,
        topk_group=1,
        e_score_correction_bias=self.e_score_correction_bias,
        hidden_size=config.hidden_size,
        intermediate_size=config.intermediate_size,
        reduce_results=False,
        renormalize=True,
        quant_config=quant_config,
        prefix=f"{prefix}.experts",
    )

    self.gate = ReplicatedLinear(
        config.hidden_size,
        config.num_local_experts,
        bias=False,
        params_dtype=torch.float32,
        quant_config=None,
        prefix=f"{prefix}.gate",
    )

ebias_weight_loader staticmethod

ebias_weight_loader(
    param: Parameter, loaded_weight: Tensor
) -> None
Source code in vllm/model_executor/models/minimax_m2.py
@staticmethod
def ebias_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor) -> None:
    assert param.size() == loaded_weight.size()
    param.data.copy_(loaded_weight.to(torch.float32))

forward

forward(hidden_states: Tensor) -> Tensor
Source code in vllm/model_executor/models/minimax_m2.py
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    num_tokens, hidden_dim = hidden_states.shape
    hidden_states = hidden_states.view(-1, hidden_dim)

    # router_logits: (num_tokens, n_experts)
    router_logits, _ = self.gate(hidden_states.to(torch.float32))
    final_hidden_states = self.experts(
        hidden_states=hidden_states, router_logits=router_logits
    )
    final_hidden_states = final_hidden_states
    if self.tp_size > 1:
        final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)

    return final_hidden_states.view(num_tokens, hidden_dim)

MiniMaxM2Model

Bases: Module

Source code in vllm/model_executor/models/minimax_m2.py
@support_torch_compile
class MiniMaxM2Model(nn.Module):
    fall_back_to_pt_during_load = False

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        self.config = config

        self.vocab_size = config.vocab_size

        if get_pp_group().is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
                quant_config=None,
                prefix=f"{prefix}.embed_tokens",
            )
        else:
            self.embed_tokens = PPMissingLayer()

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: MiniMaxM2DecoderLayer(
                config,
                prefix,
                model_config=model_config,
                cache_config=cache_config,
                quant_config=quant_config,
            ),
            prefix=f"{prefix}.layers",
        )

        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        for layer in self.layers[self.start_layer : self.end_layer]:
            hidden_states, residual = layer(positions, hidden_states, residual)

        if not get_pp_group().is_last_rank:
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="w1",
            ckpt_down_proj_name="w2",
            ckpt_up_proj_name="w3",
            num_experts=self.config.num_local_experts,
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        expert_params_mapping = self.get_expert_mapping()

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
            if spec_layer is not None:
                continue  # skip spec decode layers for main model

            for param_name, weight_name, shard_id in stacked_params_mapping:
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if ("mlp.experts." in name) and name not in params_dict:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)

                    if is_pp_missing_parameter(name, self):
                        continue

                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue

                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

                    if is_pp_missing_parameter(name, self):
                        continue

                    param = params_dict[name]
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

config instance-attribute

config = config

embed_tokens instance-attribute

embed_tokens = VocabParallelEmbedding(
    vocab_size,
    hidden_size,
    quant_config=None,
    prefix=f"{prefix}.embed_tokens",
)

fall_back_to_pt_during_load class-attribute instance-attribute

fall_back_to_pt_during_load = False

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors_factory(
        ["hidden_states", "residual"], hidden_size
    )
)

norm instance-attribute

norm = RMSNorm(hidden_size, eps=rms_norm_eps)

vocab_size instance-attribute

vocab_size = vocab_size

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/minimax_m2.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()

    config = vllm_config.model_config.hf_config
    model_config = vllm_config.model_config
    cache_config = vllm_config.cache_config
    quant_config = vllm_config.quant_config
    self.config = config

    self.vocab_size = config.vocab_size

    if get_pp_group().is_first_rank:
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
            quant_config=None,
            prefix=f"{prefix}.embed_tokens",
        )
    else:
        self.embed_tokens = PPMissingLayer()

    self.start_layer, self.end_layer, self.layers = make_layers(
        config.num_hidden_layers,
        lambda prefix: MiniMaxM2DecoderLayer(
            config,
            prefix,
            model_config=model_config,
            cache_config=cache_config,
            quant_config=quant_config,
        ),
        prefix=f"{prefix}.layers",
    )

    if get_pp_group().is_last_rank:
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
    else:
        self.norm = PPMissingLayer()
    self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
        ["hidden_states", "residual"], config.hidden_size
    )

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: IntermediateTensors | None,
    inputs_embeds: Tensor | None = None,
) -> Tensor | IntermediateTensors
Source code in vllm/model_executor/models/minimax_m2.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: IntermediateTensors | None,
    inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
    if get_pp_group().is_first_rank:
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.get_input_embeddings(input_ids)
        residual = None
    else:
        assert intermediate_tensors is not None
        hidden_states = intermediate_tensors["hidden_states"]
        residual = intermediate_tensors["residual"]

    for layer in self.layers[self.start_layer : self.end_layer]:
        hidden_states, residual = layer(positions, hidden_states, residual)

    if not get_pp_group().is_last_rank:
        return IntermediateTensors(
            {"hidden_states": hidden_states, "residual": residual}
        )
    hidden_states, _ = self.norm(hidden_states, residual)
    return hidden_states

get_expert_mapping

get_expert_mapping() -> list[tuple[str, str, int, str]]
Source code in vllm/model_executor/models/minimax_m2.py
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
    return FusedMoE.make_expert_params_mapping(
        ckpt_gate_proj_name="w1",
        ckpt_down_proj_name="w2",
        ckpt_up_proj_name="w3",
        num_experts=self.config.num_local_experts,
    )

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/minimax_m2.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.embed_tokens(input_ids)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/minimax_m2.py
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
    stacked_params_mapping = [
        # (param_name, shard_name, shard_id)
        ("qkv_proj", "q_proj", "q"),
        ("qkv_proj", "k_proj", "k"),
        ("qkv_proj", "v_proj", "v"),
    ]

    # Params for weights, fp8 weight scales, fp8 activation scales
    # (param_name, weight_name, expert_id, shard_id)
    expert_params_mapping = self.get_expert_mapping()

    params_dict = dict(self.named_parameters())
    loaded_params: set[str] = set()
    for name, loaded_weight in weights:
        if "rotary_emb.inv_freq" in name:
            continue

        spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
        if spec_layer is not None:
            continue  # skip spec decode layers for main model

        for param_name, weight_name, shard_id in stacked_params_mapping:
            # Skip non-stacked layers and experts (experts handled below).
            if weight_name not in name:
                continue
            # We have mlp.experts[0].gate_proj in the checkpoint.
            # Since we handle the experts below in expert_params_mapping,
            # we need to skip here BEFORE we update the name, otherwise
            # name will be updated to mlp.experts[0].gate_up_proj, which
            # will then be updated below in expert_params_mapping
            # for mlp.experts[0].gate_gate_up_proj, which breaks load.
            if ("mlp.experts." in name) and name not in params_dict:
                continue
            name = name.replace(weight_name, param_name)
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue

            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)
            break
        else:
            for mapping in expert_params_mapping:
                param_name, weight_name, expert_id, shard_id = mapping
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(
                    param,
                    loaded_weight,
                    name,
                    shard_id=shard_id,
                    expert_id=expert_id,
                )
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = getattr(
                    param, "weight_loader", default_weight_loader
                )
                weight_loader(param, loaded_weight)
        loaded_params.add(name)
    return loaded_params

get_spec_layer_idx_from_weight_name

get_spec_layer_idx_from_weight_name(
    config: PretrainedConfig, weight_name: str
) -> int | None
Source code in vllm/model_executor/models/minimax_m2.py
def get_spec_layer_idx_from_weight_name(
    config: PretrainedConfig, weight_name: str
) -> int | None:
    if hasattr(config, "num_mtp_modules") and (config.num_mtp_modules > 0):
        layer_idx = config.num_hidden_layers
        for i in range(config.num_mtp_modules):
            if weight_name.startswith(f"model.layers.{layer_idx + i}."):
                return layer_idx + i
    return None