Skip to content

Commit

Permalink
models: add support for Phi3 MoE
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale committed Dec 24, 2024
1 parent 032974a commit 201db10
Show file tree
Hide file tree
Showing 8 changed files with 748 additions and 82 deletions.
19 changes: 14 additions & 5 deletions aphrodite/modeling/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import functools
import json
import os
from typing import Any, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple

import torch
import triton
Expand Down Expand Up @@ -444,7 +444,8 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
rand_perm1: torch.Tensor,
rand_perm2: torch.Tensor,
topk: int,
renormalize: bool,
custom_routing_function: Optional[Callable] = None,
renormalize: bool = True,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -495,8 +496,12 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
E = w1.shape[0]
N = w2.shape[1] * 16

topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
if custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states, gating_output, topk, renormalize)

get_config_func = functools.partial(try_get_optimal_moe_config,
w1.shape,
Expand Down Expand Up @@ -691,6 +696,7 @@ def fused_moe(
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -738,9 +744,12 @@ def fused_moe(
topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
topk, renormalize,
num_expert_group, topk_group)
else:
elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states, gating_output, topk, renormalize)

return fused_experts(hidden_states,
w1,
Expand Down
90 changes: 56 additions & 34 deletions aphrodite/modeling/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import abstractmethod
from enum import Enum
from typing import List, Optional, Tuple
from typing import Callable, List, Optional, Tuple

import torch

Expand Down Expand Up @@ -59,15 +59,18 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None) -> torch.Tensor:
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:

return self.forward(x=x,
layer=layer,
Expand All @@ -76,17 +79,21 @@ def apply(self,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group)

def forward_cuda(self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None) -> torch.Tensor:
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)

def forward_cuda(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:

from aphrodite.modeling.layers.fused_moe.fused_moe import fused_experts

Expand All @@ -97,7 +104,8 @@ def forward_cuda(self,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group)
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)

return fused_experts(hidden_states=x,
w1=layer.w13_weight,
Expand All @@ -110,20 +118,24 @@ def forward_cpu(self, *args, **kwargs):
raise NotImplementedError(
"The CPU backend currently does not support MoE.")

def forward_tpu(self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None) -> torch.Tensor:
def forward_tpu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:

from aphrodite.modeling.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
assert custom_routing_function is None
return fused_moe(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
Expand Down Expand Up @@ -168,6 +180,7 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
):
super().__init__()

Expand All @@ -186,6 +199,7 @@ def __init__(
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
Expand Down Expand Up @@ -386,7 +400,8 @@ def select_experts(hidden_states: torch.Tensor,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None):
from aphrodite.modeling.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk)

Expand All @@ -401,11 +416,17 @@ def select_experts(hidden_states: torch.Tensor,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group)
else:
elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)

return topk_weights, topk_ids

Expand All @@ -422,7 +443,8 @@ def forward(self, hidden_states: torch.Tensor,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group)
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function)

if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
Expand Down
25 changes: 14 additions & 11 deletions aphrodite/modeling/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,8 @@ def __init__(
dtype: torch.dtype,
short_factor: List[float],
long_factor: List[float],
short_mscale: float = 1.0,
long_mscale: float = 1.0,
short_mscale: Optional[float] = None,
long_mscale: Optional[float] = None,
):
super().__init__()

Expand All @@ -528,18 +528,21 @@ def __init__(
self.base = base
self.short_factor = short_factor
self.long_factor = long_factor
self.short_mscale = short_mscale
self.long_mscale = long_mscale

scale = (self.max_position_embeddings /
self.original_max_position_embeddings)

scale = self.max_position_embeddings / \
self.original_max_position_embeddings
if scale <= 1.0:
self.scaling_factor = 1.0
scaling_factor = 1.0
else:
self.scaling_factor = math.sqrt(
scaling_factor = math.sqrt(
1 + math.log(scale) /
math.log(self.original_max_position_embeddings))
if short_mscale is None:
short_mscale = scaling_factor
if long_mscale is None:
long_mscale = scaling_factor
self.short_mscale = short_mscale
self.long_mscale = long_mscale

short_cache = self._compute_cos_sin_cache(
original_max_position_embeddings, short_factor, short_mscale)
Expand Down Expand Up @@ -576,8 +579,8 @@ def _compute_cos_sin_cache(
inv_freq = self._compute_inv_freq(rescale_factors)
t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() * mscale * self.scaling_factor
sin = freqs.sin() * mscale * self.scaling_factor
cos = freqs.cos() * mscale
sin = freqs.sin() * mscale
cache = torch.cat((cos, sin), dim=-1)
return cache

Expand Down
1 change: 1 addition & 0 deletions aphrodite/modeling/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
Expand Down
Loading

0 comments on commit 201db10

Please sign in to comment.