Skip to content

Commit

Permalink
[Snippets][CPU] Enable Transpose tokenization only on inputs in bf16 …
Browse files Browse the repository at this point in the history
…cases
  • Loading branch information
a-sidorova committed Jul 27, 2023
1 parent b9c32cb commit faadf8f
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 40 deletions.
13 changes: 7 additions & 6 deletions src/common/snippets/include/snippets/pass/tokenization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,17 @@ class SnippetsTokenization : public ov::pass::ModelPass {
* @ingroup snippets
*/
struct Config {
Config(size_t minimal_concurrency = 1, bool split_m_dimension = true, bool enable_transpose = true)
: minimal_concurrency(minimal_concurrency), split_m_dimension(split_m_dimension), mha_token_enable_transpose(enable_transpose) {}
Config(size_t minimal_concurrency = 1, bool split_m_dimension = true, bool enable_transpose_on_output = true)
: minimal_concurrency(minimal_concurrency), split_m_dimension(split_m_dimension),
mha_token_enable_transpose_on_output(enable_transpose_on_output) {}

size_t minimal_concurrency = 1;
// True if "SplitDimensionM" optimization is enabled. Otherwise, it's disabled.
bool split_m_dimension = true;
// False if all Transposes aren't tokenized in MHA Tokenization.
// Otherwise, they may be fused into Subgraph if possible
// TODO [106921]: Remove please when the ticket 106921 is implemented
bool mha_token_enable_transpose = true;
// False if Transpose on output isn't tokenized in MHA Tokenization.
// Otherwise, it may be fused into Subgraph if possible
// TODO [111813]: Remove please when the ticket 111813 is implemented
bool mha_token_enable_transpose_on_output = true;
};

OPENVINO_RTTI("SnippetsTokenization", "0");
Expand Down
60 changes: 45 additions & 15 deletions src/common/snippets/src/op/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,10 +430,26 @@ void snippets::op::Subgraph::align_element_types(const BlockedShapeVector& outpu
for (size_t i = 0; i < outputShapes.size(); i++) {
const auto needed_out_type = std::get<2>(outputShapes[i]);
if (body_results[i]->get_input_element_type(0) != needed_out_type) {
const auto convert = std::make_shared<ov::snippets::op::ConvertSaturation>(
body_results[i]->get_input_node_shared_ptr(0), needed_out_type);
body_results[i]->set_argument(0, convert);
body_results[i]->validate_and_infer_types();
auto parent_output = body_results[i]->get_input_source_output(0);
std::shared_ptr<ov::Node> consumer = body_results[i];

// Snippets supports Transpose only after Parameter or before Result nodes
// So we have to insert Convert before Transpose (if there is) on Subgraph outputs
const auto transpose = ov::as_type_ptr<ov::op::v1::Transpose>(parent_output.get_node_shared_ptr());
if (transpose) {
OPENVINO_ASSERT(parent_output.get_target_inputs().size() == 1,
"If Result has Transpose on input, this Result must be single consumer of the Transpose");
parent_output = transpose->get_input_source_output(0);
consumer = transpose;
}

const auto convert = std::make_shared<ov::snippets::op::ConvertSaturation>(parent_output, needed_out_type);
ov::copy_runtime_info(parent_output.get_node_shared_ptr(), convert);

consumer->set_argument(0, convert);
consumer->validate_and_infer_types();
if (consumer != body_results[i])
body_results[i]->validate_and_infer_types();
}
}

Expand All @@ -442,23 +458,37 @@ void snippets::op::Subgraph::align_element_types(const BlockedShapeVector& outpu
for (size_t i = 0; i < inputShapes.size(); ++i) {
const auto needed_in_type = std::get<2>(inputShapes[i]);
const auto& parameter = parameters[i];
if (parameter->get_element_type() != needed_in_type) {
const auto parameter_output = parameter->output(0);
const auto convert = std::make_shared<ov::snippets::op::ConvertSaturation>(
parameter_output,
parameter_output.get_element_type());
ov::copy_runtime_info(parameter, convert);

for (const auto input : parameter_output.get_target_inputs()) {
const auto original_type = parameter->get_element_type();
if (original_type != needed_in_type) {
parameter->set_element_type(needed_in_type);
parameter->validate_and_infer_types();

auto parent_output = parameter->output(0);
auto consumer_inputs = parent_output.get_target_inputs();

// Snippets supports Transpose only after Parameter or before Result nodes
// So we have to insert Convert after Transpose (if there is) on Subgraph inputs
if (std::any_of(consumer_inputs.cbegin(), consumer_inputs.cend(),
[](const ov::Input<ov::Node>& input) { return ov::is_type<ov::op::v1::Transpose>(input.get_node()); })) {
OPENVINO_ASSERT(consumer_inputs.size() == 1,
"If Parameter has Transpose on output, this Transpose must be single consumer of the Parameter");
const auto transpose = consumer_inputs.begin()->get_node()->shared_from_this();
transpose->validate_and_infer_types();

parent_output = transpose;
consumer_inputs = parent_output.get_target_inputs();
}

const auto convert = std::make_shared<ov::snippets::op::ConvertSaturation>(parent_output, original_type);
ov::copy_runtime_info(parent_output.get_node_shared_ptr(), convert);

for (const auto input : consumer_inputs) {
const auto& input_node = input.get_node();
if (input_node == convert.get()) {
continue;
}
input_node->set_argument(input.get_index(), convert->output(0));
}

parameter->set_element_type(needed_in_type);
parameter->validate_and_infer_types();
}
}
}
Expand Down
15 changes: 5 additions & 10 deletions src/common/snippets/src/pass/mha_tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken
MATCHER_SCOPE(TokenizeMHASnippets);

auto m_matmul0 = std::make_shared<ov::opset1::MatMul>(ov::pass::pattern::any_input(ov::pass::pattern::has_static_shape()),
ov::pass::pattern::any_input(ov::pass::pattern::has_static_shape()));
ov::pass::pattern::any_input(ov::pass::pattern::has_static_shape()));

register_matcher(std::make_shared<ov::pass::pattern::Matcher>(m_matmul0, matcher_name),
[=](ov::pass::pattern::Matcher &m) {
Expand Down Expand Up @@ -388,14 +388,9 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken
}
};

auto get_transpose = [config](const std::shared_ptr<ov::Node>& node) -> std::shared_ptr<ov::opset1::Transpose> {
return config.mha_token_enable_transpose ? ov::as_type_ptr<ov::opset1::Transpose>(node)
: nullptr;
};

const auto transpose1 = get_transpose(parent);
const auto transpose0 = get_transpose(matmul0->get_input_node_shared_ptr(0));
const auto transpose2 = get_transpose(matmul1->get_input_node_shared_ptr(1));
const auto transpose1 = ov::as_type_ptr<ov::opset1::Transpose>(parent);
const auto transpose0 = ov::as_type_ptr<ov::opset1::Transpose>(matmul0->get_input_node_shared_ptr(0));
const auto transpose2 = ov::as_type_ptr<ov::opset1::Transpose>(matmul1->get_input_node_shared_ptr(1));
tokenize_transpose(transpose1, is_transposed_b_0, {0, 2, 3, 1}, ordered_ops.begin());
tokenize_transpose(transpose0, matmul0->get_transpose_a(), {0, 2, 1, 3}, ordered_ops.begin());
tokenize_transpose(transpose2, matmul1->get_transpose_b(), {0, 2, 1, 3}, ordered_ops.end());
Expand Down Expand Up @@ -431,7 +426,7 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken
// <Supported ops>
// Transpose3
if (!are_ops_after_matmul1) {
auto transpose3 = get_transpose(child);
auto transpose3 = config.mha_token_enable_transpose_on_output ? ov::as_type_ptr<ov::opset1::Transpose>(child) : nullptr;
if (is_valid_transpose(transpose3, {0, 2, 1, 3}) &&
transpose3->get_input_element_type(0) == matmul1_out_type) { // To avoid Convert between MatMul1 and Transpose3
ordered_ops.push_back(transpose3);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -546,12 +546,13 @@ void Transformations::MainSnippets(void) {
return;

ov::snippets::pass::SnippetsTokenization::Config tokenization_config;
// At the moment Snippets supports Transposes in MHA pattern only in FP32 case since
// - ConvertSaturation[BF16->FP32] will be inserted after Parameters and before Transposes in canonicalization stage
// - ConvertSaturation[FP32->BF16] will be inserted after Transposes and before Brgemm in precision propagation stage
// Because of that Transposes won't be fused into Brgemm
// TODO [111813]: Need to update this pipeline to avoid Converts between Transposes and Brgemm on inputs
tokenization_config.mha_token_enable_transpose = (inferencePrecision == ov::element::f32);
// [111813]: At the moment Snippets supports Transpose on output of MHA pattern only if it is an one node between MatMul and Result.
// However there may be Convert [f32->bf16] before Result since:
// - bf16 Brgemm has f32 output;
// - CPU Node Subgraph requires bf16 on output when inference precision is bf16.
// To avoid sitations when Transpose is not alone node between MatMul and Result,
// Plugin disables Transpose tokenization on output
tokenization_config.mha_token_enable_transpose_on_output = (inferencePrecision == ov::element::f32);
tokenization_config.minimal_concurrency = parallel_get_num_threads();
// The optimization "SplitDimensionM" depends on target machine (thread count).
// To avoid uncontrolled behavior in tests, we disabled the optimization when there is Config::SnippetsMode::IgnoreCallback
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,15 +290,14 @@ INSTANTIATE_TEST_SUITE_P(smoke_MHA, MHATest,
::testing::Values(ov::test::utils::DEVICE_CPU)),
MHATest::getTestCaseName);

// Snippets doesn't support Transpose tokenization when inference_precision = bf16
INSTANTIATE_TEST_SUITE_P(smoke_MHA_BF16, MHATest,
::testing::Combine(
::testing::ValuesIn(static_shapes_to_test_representation(inputShapes)),
::testing::Values(std::vector<ElementType>{ ElementType::bf16, ElementType::bf16, ElementType::bf16, ElementType::bf16 }),
::testing::ValuesIn(matMulIn0Precisions),
::testing::ValuesIn(patternTypes),
::testing::Values(ExpectedNodes{{"Subgraph", 2}, // MHA + Decomposed Transpose
{"Transpose", 3}}),
::testing::Values(ExpectedNodes{{"Subgraph", 1},
{"Transpose", 1}}), // Plugin disables tokenization of Transpose on output
::testing::Values(ov::test::utils::DEVICE_CPU)),
MHATest::getTestCaseName);

Expand Down

0 comments on commit faadf8f

Please sign in to comment.