Skip to content

Commit

Permalink
support mc2 for mp lora. (PaddlePaddle#8162)
Browse files Browse the repository at this point in the history
  • Loading branch information
wuhuachaocoding authored Mar 21, 2024
1 parent 18072a2 commit 0c65a47
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 11 deletions.
40 changes: 29 additions & 11 deletions paddlenlp/peft/lora/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import math
import os
from typing import List, Optional

import paddle
Expand All @@ -24,6 +25,12 @@
RowParallelLinear,
)

if "npu" in paddle.device.get_all_custom_device_type():
from .mc2_lora_npu import MC2LoRaColumnParallelLinear, MC2LoRaRowParallelLinear
else:
MC2LoRaRowParallelLinear = None
MC2LoRaColumnParallelLinear = None


class LoRALinear(nn.Linear):
# LoRA implemented in a dense layer
Expand Down Expand Up @@ -188,14 +195,17 @@ def forward(self, x: paddle.Tensor):
input_mp = x

# x @ W : [bz, in_f / ws] ===> [bz, out_f]
result_mp = F.linear(x=input_mp, weight=self.weight, name=self.name)
if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")):
output = MC2LoRaRowParallelLinear.apply(input_mp, self.weight, self.model_parallel_group)
else:
result_mp = F.linear(x=input_mp, weight=self.weight, name=self.name)

output = mp_ops._mp_allreduce(
result_mp,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)
output = mp_ops._mp_allreduce(
result_mp,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)

if not self.merged:
# x @ A: [bz, in_f/ ws] ===> [bz, r]
Expand Down Expand Up @@ -294,13 +304,21 @@ def eval(self):
self.merged = True

def forward(self, input: paddle.Tensor):
input_mp = mp_ops._c_identity(input, group=self.model_parallel_group)
result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name)
if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")):
res_mp = MC2LoRaColumnParallelLinear.apply(input, self.weight, self.model_parallel_group)
result_mp = res_mp + self.bias
else:
input_mp = mp_ops._c_identity(input, group=self.model_parallel_group)
result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name)

if not self.merged:
input_a = self.lora_dropout(input) @ self.lora_A
input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group)
delta_mp = (input_a_mp @ self.lora_B) * self.scaling
if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")):
tmp = MC2LoRaColumnParallelLinear.apply(input_a, self.lora_B, self.model_parallel_group)
delta_mp = tmp * self.scaling
else:
input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group)
delta_mp = (input_a_mp @ self.lora_B) * self.scaling
result_mp += delta_mp

if self.gather_output and self.is_mp:
Expand Down
76 changes: 76 additions & 0 deletions paddlenlp/peft/lora/mc2_lora_npu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# !/usr/bin/env python3

# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" mc2(tp overlap) """

import paddle
import paddle_custom_device
from paddle.autograd import PyLayer


class MC2LoRaRowParallelLinear(PyLayer):
@staticmethod
def forward(ctx, input_, weight, group):
ctx.save_for_backward(input_, weight)
rank = paddle.distributed.get_rank()
hcom_name = group.process_group.get_comm_name(rank)
x = input_.reshape([-1, input_.shape[-1]])
out = paddle_custom_device.npu.fused_mm_allreduce(
x, weight, bias=None, hcom=hcom_name, reduce_op="sum", comm_turn=0
)
output = out.reshape([input_.shape[0], input_.shape[1], weight.shape[1]])
ctx.ring_id = group.id
return output

@staticmethod
def backward(ctx, dy):
input_, weight = ctx.saved_tensor()
out_grad = dy
sub_grad = out_grad.reshape([-1, out_grad.shape[-1]])
input_grad = paddle.matmul(sub_grad, weight.t())
if weight.stop_gradient:
return input_grad.reshape(input_.shape)
else:
input_reshape = input_.reshape([-1, input_.shape[-1]])
weight_grad = input_reshape.t().matmul(sub_grad)
return input_grad.reshape(input_.shape), weight_grad


class MC2LoRaColumnParallelLinear(PyLayer):
@staticmethod
def forward(ctx, input_, weight, group):
ctx.save_for_backward(input_, weight)
ctx.group = group
input_mp = input_
result_mp = paddle.matmul(input_mp, weight)
return result_mp

@staticmethod
def backward(ctx, dy):
input_, weight = ctx.saved_tensor()
sub_grad = dy.reshape([-1, dy.shape[-1]])
rank = paddle.distributed.get_rank()
hcom_name = ctx.group.process_group.get_comm_name(rank)

d_weight = input_.reshape([-1, input_.shape[-1]]).t().matmul(sub_grad) if not weight.stop_gradient else None
d_input = paddle_custom_device.npu.fused_mm_allreduce(
sub_grad, weight.t(), bias=None, hcom=hcom_name, reduce_op="sum", comm_turn=0
)

if d_weight is not None:
return d_input.reshape(input_.shape), d_weight
else:
return d_input.reshape(input_.shape)

0 comments on commit 0c65a47

Please sign in to comment.