diff --git a/paddlenlp/peft/lora/lora_layers.py b/paddlenlp/peft/lora/lora_layers.py index ae38f47825e4..c90e96aab4e2 100644 --- a/paddlenlp/peft/lora/lora_layers.py +++ b/paddlenlp/peft/lora/lora_layers.py @@ -13,6 +13,7 @@ # limitations under the License. import math +import os from typing import List, Optional import paddle @@ -24,6 +25,12 @@ RowParallelLinear, ) +if "npu" in paddle.device.get_all_custom_device_type(): + from .mc2_lora_npu import MC2LoRaColumnParallelLinear, MC2LoRaRowParallelLinear +else: + MC2LoRaRowParallelLinear = None + MC2LoRaColumnParallelLinear = None + class LoRALinear(nn.Linear): # LoRA implemented in a dense layer @@ -188,14 +195,17 @@ def forward(self, x: paddle.Tensor): input_mp = x # x @ W : [bz, in_f / ws] ===> [bz, out_f] - result_mp = F.linear(x=input_mp, weight=self.weight, name=self.name) + if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")): + output = MC2LoRaRowParallelLinear.apply(input_mp, self.weight, self.model_parallel_group) + else: + result_mp = F.linear(x=input_mp, weight=self.weight, name=self.name) - output = mp_ops._mp_allreduce( - result_mp, - group=self.model_parallel_group, - use_calc_stream=True, - use_model_parallel=True, - ) + output = mp_ops._mp_allreduce( + result_mp, + group=self.model_parallel_group, + use_calc_stream=True, + use_model_parallel=True, + ) if not self.merged: # x @ A: [bz, in_f/ ws] ===> [bz, r] @@ -294,13 +304,21 @@ def eval(self): self.merged = True def forward(self, input: paddle.Tensor): - input_mp = mp_ops._c_identity(input, group=self.model_parallel_group) - result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name) + if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")): + res_mp = MC2LoRaColumnParallelLinear.apply(input, self.weight, self.model_parallel_group) + result_mp = res_mp + self.bias + else: + input_mp = mp_ops._c_identity(input, group=self.model_parallel_group) + result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name) if not self.merged: input_a = self.lora_dropout(input) @ self.lora_A - input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group) - delta_mp = (input_a_mp @ self.lora_B) * self.scaling + if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")): + tmp = MC2LoRaColumnParallelLinear.apply(input_a, self.lora_B, self.model_parallel_group) + delta_mp = tmp * self.scaling + else: + input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group) + delta_mp = (input_a_mp @ self.lora_B) * self.scaling result_mp += delta_mp if self.gather_output and self.is_mp: diff --git a/paddlenlp/peft/lora/mc2_lora_npu.py b/paddlenlp/peft/lora/mc2_lora_npu.py new file mode 100644 index 000000000000..c57e0bed2590 --- /dev/null +++ b/paddlenlp/peft/lora/mc2_lora_npu.py @@ -0,0 +1,76 @@ +# !/usr/bin/env python3 + +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" mc2(tp overlap) """ + +import paddle +import paddle_custom_device +from paddle.autograd import PyLayer + + +class MC2LoRaRowParallelLinear(PyLayer): + @staticmethod + def forward(ctx, input_, weight, group): + ctx.save_for_backward(input_, weight) + rank = paddle.distributed.get_rank() + hcom_name = group.process_group.get_comm_name(rank) + x = input_.reshape([-1, input_.shape[-1]]) + out = paddle_custom_device.npu.fused_mm_allreduce( + x, weight, bias=None, hcom=hcom_name, reduce_op="sum", comm_turn=0 + ) + output = out.reshape([input_.shape[0], input_.shape[1], weight.shape[1]]) + ctx.ring_id = group.id + return output + + @staticmethod + def backward(ctx, dy): + input_, weight = ctx.saved_tensor() + out_grad = dy + sub_grad = out_grad.reshape([-1, out_grad.shape[-1]]) + input_grad = paddle.matmul(sub_grad, weight.t()) + if weight.stop_gradient: + return input_grad.reshape(input_.shape) + else: + input_reshape = input_.reshape([-1, input_.shape[-1]]) + weight_grad = input_reshape.t().matmul(sub_grad) + return input_grad.reshape(input_.shape), weight_grad + + +class MC2LoRaColumnParallelLinear(PyLayer): + @staticmethod + def forward(ctx, input_, weight, group): + ctx.save_for_backward(input_, weight) + ctx.group = group + input_mp = input_ + result_mp = paddle.matmul(input_mp, weight) + return result_mp + + @staticmethod + def backward(ctx, dy): + input_, weight = ctx.saved_tensor() + sub_grad = dy.reshape([-1, dy.shape[-1]]) + rank = paddle.distributed.get_rank() + hcom_name = ctx.group.process_group.get_comm_name(rank) + + d_weight = input_.reshape([-1, input_.shape[-1]]).t().matmul(sub_grad) if not weight.stop_gradient else None + d_input = paddle_custom_device.npu.fused_mm_allreduce( + sub_grad, weight.t(), bias=None, hcom=hcom_name, reduce_op="sum", comm_turn=0 + ) + + if d_weight is not None: + return d_input.reshape(input_.shape), d_weight + else: + return d_input.reshape(input_.shape)