Skip to content

Commit

Permalink
fix: InternLM2 model with Tensor Parallel (#980)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale authored Dec 24, 2024
1 parent 495435a commit 7632f91
Showing 1 changed file with 32 additions and 13 deletions.
45 changes: 32 additions & 13 deletions aphrodite/modeling/models/internlm2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
Expand All @@ -9,7 +10,10 @@
from aphrodite.common.config import CacheConfig
from aphrodite.common.sequence import IntermediateTensors
from aphrodite.distributed import (get_pp_group,
get_tensor_model_parallel_world_size)
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather)
from aphrodite.modeling.layers.activation import SiluAndMul
from aphrodite.modeling.layers.layernorm import RMSNorm
from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
Expand Down Expand Up @@ -73,20 +77,21 @@ def __init__(
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
assert self.total_num_heads % self.tp_size == 0
self.num_heads = self.total_num_heads // self.tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
if self.total_num_kv_heads >= self.tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
assert self.total_num_kv_heads % self.tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
assert self.tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
Expand Down Expand Up @@ -125,11 +130,25 @@ def __init__(
quant_config=quant_config)

def split_qkv(self, qkv: torch.Tensor):
qkv = qkv.view(-1, self.num_kv_heads, self.key_value_groups + 2, 128)
q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=2)
q = q.reshape(-1, self.q_size)
k = k.reshape(-1, self.kv_size)
v = v.reshape(-1, self.kv_size)
seq_len = qkv.shape[0]
if self.tp_size > 1:
qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
qkv = tensor_model_parallel_all_gather(qkv)
qkv = torch.split(qkv, qkv_map, dim=-1)
qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
qkv = torch.cat(qkv, dim=-1)
qkv = qkv.view(seq_len, self.total_num_kv_heads,
self.key_value_groups + 2, self.head_dim)
q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
q = q.reshape(seq_len, self.q_size * self.tp_size)
k = k.reshape(seq_len, self.kv_size * self.tp_size)
v = v.reshape(seq_len, self.kv_size * self.tp_size)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
v = splitter(v)[self.tp_rank]
return q, k, v

def forward(
Expand Down

0 comments on commit 7632f91

Please sign in to comment.