Skip to content

Commit

Permalink
[PIR][oneDNN] Add quantization pattern for Concatence (#70430)
Browse files Browse the repository at this point in the history
* add quantize pattern for Concat

* fix typo

* fix the order of pass

* fix format

* fix format & add check

* correct header name
  • Loading branch information
LLee233 authored Dec 26, 2024
1 parent fef1013 commit 1e9f2b3
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 0 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ const std::vector<std::string> kPirMkldnnBf16Passes{
"cpu_bfloat16_placement_pass",
"cpu_bfloat16_pass",
"cpu_bfloat16_type_placement_pass",
"cpu_special_ops_bf16_pass",
"cpu_bf16_quantize_squash_pass",
};

Expand Down
166 changes: 166 additions & 0 deletions paddle/fluid/pir/transforms/onednn/cpu_special_ops_bf16_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/transforms/onednn/cpu_special_ops_bf16_pass.h"

#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h"
#include "paddle/fluid/pir/utils/general_functions.h"

#include "paddle/pir/include/core/builtin_op.h"
#include "paddle/pir/include/pass/pass.h"
#include "paddle/pir/include/pass/pass_registry.h"

namespace {

template <class IrType1, class IrType2>
static pir::Type create_type(pir::Type type,
pir::Type out_dtype,
pir::IrContext *ctx) {
auto input_type = type.dyn_cast<IrType1>();
return IrType2::get(ctx,
out_dtype,
input_type.dims(),
input_type.data_layout(),
input_type.lod(),
input_type.offset());
}

// For ops like conv and concat, their input is sometimes packed as VectorType,
// hence current quantization doesn't work. Here we deal with them specifically.
class ConcatBf16QuantizePattern
: public pir::OpRewritePattern<paddle::onednn::dialect::ConcatOp> {
public:
using pir::OpRewritePattern<
paddle::onednn::dialect::ConcatOp>::OpRewritePattern;
bool MatchAndRewrite(
paddle::onednn::dialect::ConcatOp op,
pir::PatternRewriter &rewriter) const override { // NOLINT
// The input should come from combine.
pir::CombineOp pre_op =
pir::GetDefiningOpForInput(op, 0)->dyn_cast<pir::CombineOp>();
if (!pre_op) return false;
if (!pre_op.out().HasOneUse()) return false;

auto op_attributes = op->attributes();
auto onednn_data_type = op_attributes.at("mkldnn_data_type")
.dyn_cast<pir::StrAttribute>()
.AsString();
if (onednn_data_type == "bfloat16") return false;
op_attributes["mkldnn_data_type"] = rewriter.str_attr("bfloat16");

auto combine_inputs = pre_op.inputs();

for (size_t idx = 0; idx < combine_inputs.size(); idx++) {
auto type = pre_op->operand_type(idx);
// Currently we only process case where elements are all DenseTensor(s)
if (!type.isa<pir::DenseTensorType>()) return false;
// All Tensors should be fp32
auto dtype = pir::GetDataTypeFromValue(pre_op->operand_source(idx));
if (!dtype.isa<pir::Float32Type>()) return false;
}

pir::IrContext *ctx = rewriter.ir_context();

std::unordered_map<std::string, pir::Attribute> q_attributes;
q_attributes["scale"] = rewriter.float_attr(1.0f);
q_attributes["shift"] = rewriter.float_attr(0.0f);
q_attributes["is_negative_input"] = rewriter.bool_attr(false);
q_attributes["output_format"] = rewriter.str_attr("NCHW");
q_attributes["bfloat16"] = rewriter.bool_attr(true);

// Insert quantize before combine
std::vector<pir::Value> new_combine_inputs(combine_inputs.size());
for (size_t idx = 0; idx < combine_inputs.size(); idx++) {
paddle::onednn::dialect::QuantizeOp quant_op =
rewriter.Build<paddle::onednn::dialect::QuantizeOp>(
combine_inputs[idx], q_attributes);
auto type = quant_op->result_type(0);
pir::Type new_type =
create_type<pir::DenseTensorType, paddle::dialect::DenseTensorType>(
type, pir::BFloat16Type::get(ctx), ctx);
quant_op->result(0).set_type(new_type);
new_combine_inputs[idx] = quant_op.output();
}
// Create new combine
pir::CombineOp new_combine =
rewriter.Build<pir::CombineOp>(new_combine_inputs);
rewriter.ReplaceAllUsesWith(pre_op.out(), new_combine.out());
rewriter.EraseOp(pre_op);

// Create new concat
auto concat_info =
ctx->GetRegisteredOpInfo(paddle::onednn::dialect::ConcatOp::name());
if (!concat_info) return false;

std::vector<pir::Type> op_item_inner_output_types;
auto type = op->result_type(0);
pir::Type new_type =
create_type<pir::DenseTensorType, paddle::dialect::DenseTensorType>(
type, pir::BFloat16Type::get(ctx), ctx);
op_item_inner_output_types.push_back(new_type);

paddle::onednn::dialect::ConcatOp new_concat =
rewriter
.Build({new_combine.out(), op.axis()},
op_attributes,
op_item_inner_output_types,
concat_info)
->dyn_cast<paddle::onednn::dialect::ConcatOp>();

// Insert dequant op under concat
std::unordered_map<std::string, pir::Attribute> dq_attributes;
dq_attributes["scale"] = rewriter.float_attr(1.0f);
dq_attributes["shift"] = rewriter.float_attr(0.0f);
paddle::onednn::dialect::DequantizeOp dequant_op =
rewriter.Build<paddle::onednn::dialect::DequantizeOp>(new_concat.out(),
dq_attributes);

rewriter.ReplaceAllUsesWith(op.out(), dequant_op.output());
rewriter.EraseOp(op);
return true;
}
};

class CPUSpecialOpsBf16Pass : public pir::PatternRewritePass {
public:
CPUSpecialOpsBf16Pass()
: pir::PatternRewritePass("cpu_special_ops_bf16_pass", 2) {}

pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override {
pir::RewritePatternSet ps(context);
uint32_t benefit = 100;

auto concat_bf16_quant_pattern =
std::make_unique<ConcatBf16QuantizePattern>(
context, benefit--, std::vector<std::string>{});
ps.Add(std::move(concat_bf16_quant_pattern));

return ps;
}
};

} // namespace

namespace pir {

std::unique_ptr<Pass> CreateCPUSpecialOpsBf16Pass() {
return std::make_unique<CPUSpecialOpsBf16Pass>();
}

} // namespace pir

REGISTER_IR_PASS(cpu_special_ops_bf16_pass, CPUSpecialOpsBf16Pass);
26 changes: 26 additions & 0 deletions paddle/fluid/pir/transforms/onednn/cpu_special_ops_bf16_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include "paddle/pir/include/core/dll_decl.h"

namespace pir {

class Pass;

IR_API std::unique_ptr<Pass> CreateCPUSpecialOpsBf16Pass();

} // namespace pir
1 change: 1 addition & 0 deletions paddle/fluid/pir/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ USE_PIR_PASS(cpu_bfloat16_placement_pass);
USE_PIR_PASS(cpu_bfloat16_type_placement_pass);
USE_PIR_PASS(cpu_bfloat16_pass);
USE_PIR_PASS(cpu_bf16_quantize_squash_pass);
USE_PIR_PASS(cpu_special_ops_bf16_pass);
#endif

#ifdef PADDLE_WITH_XPU
Expand Down
47 changes: 47 additions & 0 deletions test/ir/pir/fused_pass/onednn/test_cpu_bfloat16_pir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,5 +1121,52 @@ def test_check_output(self):
self.check_pass_correct()


class TestConcatBfloatQuantizePass(PassTest):
def is_program_valid(self, program=None):
return True

def build_ir_program(self):
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.pir.core.program_guard(main_prog, start_prog):
x = paddle.static.data(
name='x', shape=[5, 5, 5, 5], dtype='float32'
)
y = paddle.static.data(
name='y', shape=[5, 5, 5, 5], dtype='float32'
)
z = paddle.static.data(
name='z', shape=[5, 5, 5, 5], dtype='float32'
)
out = paddle.concat((x, y, z))
out = paddle.assign(out)
self.pass_attr_list = [
{'onednn_placement_pass': {}},
{'cpu_special_ops_bf16_pass': {}},
]
self.feeds = {
"x": np.random.random((5, 5, 5, 5)).astype("float32"),
"y": np.random.random((5, 5, 5, 5)).astype("float32"),
"z": np.random.random((5, 5, 5, 5)).astype("float32"),
}
self.fetch_list = [out]
self.valid_op_map = {
"onednn_op.concat": 1,
"onednn_op.dequantize": 1,
"onednn_op.quantize": 3,
}
return [main_prog, start_prog]

def sample_program(self):
yield self.build_ir_program(), False

def setUp(self):
self.places.append(paddle.CPUPlace())

def test_check_output(self):
self.check_pass_correct(rtol=1e-02, atol=1e-02)


if __name__ == "__main__":
unittest.main()

0 comments on commit 1e9f2b3

Please sign in to comment.