diff --git a/src/common/snippets/include/snippets/pass/tokenization.hpp b/src/common/snippets/include/snippets/pass/tokenization.hpp index 3d71ad281728f9..25c93a4b82f63a 100644 --- a/src/common/snippets/include/snippets/pass/tokenization.hpp +++ b/src/common/snippets/include/snippets/pass/tokenization.hpp @@ -61,16 +61,17 @@ class SnippetsTokenization : public ov::pass::ModelPass { * @ingroup snippets */ struct Config { - Config(size_t minimal_concurrency = 1, bool split_m_dimension = true, bool enable_transpose = true) - : minimal_concurrency(minimal_concurrency), split_m_dimension(split_m_dimension), mha_token_enable_transpose(enable_transpose) {} + Config(size_t minimal_concurrency = 1, bool split_m_dimension = true, bool enable_transpose_on_output = true) + : minimal_concurrency(minimal_concurrency), split_m_dimension(split_m_dimension), + mha_token_enable_transpose_on_output(enable_transpose_on_output) {} size_t minimal_concurrency = 1; // True if "SplitDimensionM" optimization is enabled. Otherwise, it's disabled. bool split_m_dimension = true; - // False if all Transposes aren't tokenized in MHA Tokenization. - // Otherwise, they may be fused into Subgraph if possible - // TODO [106921]: Remove please when the ticket 106921 is implemented - bool mha_token_enable_transpose = true; + // False if Transpose on output isn't tokenized in MHA Tokenization. + // Otherwise, it may be fused into Subgraph if possible + // TODO [111813]: Remove please when the ticket 111813 is implemented + bool mha_token_enable_transpose_on_output = true; }; OPENVINO_RTTI("SnippetsTokenization", "0"); diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index c64511f1fd8641..a1eead792fc4cc 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -430,10 +430,26 @@ void snippets::op::Subgraph::align_element_types(const BlockedShapeVector& outpu for (size_t i = 0; i < outputShapes.size(); i++) { const auto needed_out_type = std::get<2>(outputShapes[i]); if (body_results[i]->get_input_element_type(0) != needed_out_type) { - const auto convert = std::make_shared( - body_results[i]->get_input_node_shared_ptr(0), needed_out_type); - body_results[i]->set_argument(0, convert); - body_results[i]->validate_and_infer_types(); + auto parent_output = body_results[i]->get_input_source_output(0); + std::shared_ptr consumer = body_results[i]; + + // Snippets supports Transpose only after Parameter or before Result nodes + // So we have to insert Convert before Transpose (if there is) on Subgraph outputs + const auto transpose = ov::as_type_ptr(parent_output.get_node_shared_ptr()); + if (transpose) { + OPENVINO_ASSERT(parent_output.get_target_inputs().size() == 1, + "If Result has Transpose on input, this Result must be single consumer of the Transpose"); + parent_output = transpose->get_input_source_output(0); + consumer = transpose; + } + + const auto convert = std::make_shared(parent_output, needed_out_type); + ov::copy_runtime_info(parent_output.get_node_shared_ptr(), convert); + + consumer->set_argument(0, convert); + consumer->validate_and_infer_types(); + if (consumer != body_results[i]) + body_results[i]->validate_and_infer_types(); } } @@ -442,23 +458,37 @@ void snippets::op::Subgraph::align_element_types(const BlockedShapeVector& outpu for (size_t i = 0; i < inputShapes.size(); ++i) { const auto needed_in_type = std::get<2>(inputShapes[i]); const auto& parameter = parameters[i]; - if (parameter->get_element_type() != needed_in_type) { - const auto parameter_output = parameter->output(0); - const auto convert = std::make_shared( - parameter_output, - parameter_output.get_element_type()); - ov::copy_runtime_info(parameter, convert); - - for (const auto input : parameter_output.get_target_inputs()) { + const auto original_type = parameter->get_element_type(); + if (original_type != needed_in_type) { + parameter->set_element_type(needed_in_type); + parameter->validate_and_infer_types(); + + auto parent_output = parameter->output(0); + auto consumer_inputs = parent_output.get_target_inputs(); + + // Snippets supports Transpose only after Parameter or before Result nodes + // So we have to insert Convert after Transpose (if there is) on Subgraph inputs + if (std::any_of(consumer_inputs.cbegin(), consumer_inputs.cend(), + [](const ov::Input& input) { return ov::is_type(input.get_node()); })) { + OPENVINO_ASSERT(consumer_inputs.size() == 1, + "If Parameter has Transpose on output, this Transpose must be single consumer of the Parameter"); + const auto transpose = consumer_inputs.begin()->get_node()->shared_from_this(); + transpose->validate_and_infer_types(); + + parent_output = transpose; + consumer_inputs = parent_output.get_target_inputs(); + } + + const auto convert = std::make_shared(parent_output, original_type); + ov::copy_runtime_info(parent_output.get_node_shared_ptr(), convert); + + for (const auto input : consumer_inputs) { const auto& input_node = input.get_node(); if (input_node == convert.get()) { continue; } input_node->set_argument(input.get_index(), convert->output(0)); } - - parameter->set_element_type(needed_in_type); - parameter->validate_and_infer_types(); } } } diff --git a/src/common/snippets/src/pass/mha_tokenization.cpp b/src/common/snippets/src/pass/mha_tokenization.cpp index f714271627cc48..ae2e4dd360e908 100644 --- a/src/common/snippets/src/pass/mha_tokenization.cpp +++ b/src/common/snippets/src/pass/mha_tokenization.cpp @@ -191,7 +191,7 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken MATCHER_SCOPE(TokenizeMHASnippets); auto m_matmul0 = std::make_shared(ov::pass::pattern::any_input(ov::pass::pattern::has_static_shape()), - ov::pass::pattern::any_input(ov::pass::pattern::has_static_shape())); + ov::pass::pattern::any_input(ov::pass::pattern::has_static_shape())); register_matcher(std::make_shared(m_matmul0, matcher_name), [=](ov::pass::pattern::Matcher &m) { @@ -388,14 +388,9 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken } }; - auto get_transpose = [config](const std::shared_ptr& node) -> std::shared_ptr { - return config.mha_token_enable_transpose ? ov::as_type_ptr(node) - : nullptr; - }; - - const auto transpose1 = get_transpose(parent); - const auto transpose0 = get_transpose(matmul0->get_input_node_shared_ptr(0)); - const auto transpose2 = get_transpose(matmul1->get_input_node_shared_ptr(1)); + const auto transpose1 = ov::as_type_ptr(parent); + const auto transpose0 = ov::as_type_ptr(matmul0->get_input_node_shared_ptr(0)); + const auto transpose2 = ov::as_type_ptr(matmul1->get_input_node_shared_ptr(1)); tokenize_transpose(transpose1, is_transposed_b_0, {0, 2, 3, 1}, ordered_ops.begin()); tokenize_transpose(transpose0, matmul0->get_transpose_a(), {0, 2, 1, 3}, ordered_ops.begin()); tokenize_transpose(transpose2, matmul1->get_transpose_b(), {0, 2, 1, 3}, ordered_ops.end()); @@ -431,7 +426,7 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken // // Transpose3 if (!are_ops_after_matmul1) { - auto transpose3 = get_transpose(child); + auto transpose3 = config.mha_token_enable_transpose_on_output ? ov::as_type_ptr(child) : nullptr; if (is_valid_transpose(transpose3, {0, 2, 1, 3}) && transpose3->get_input_element_type(0) == matmul1_out_type) { // To avoid Convert between MatMul1 and Transpose3 ordered_ops.push_back(transpose3); diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 3e300bc70778f0..5ef5f913d77f23 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -546,12 +546,13 @@ void Transformations::MainSnippets(void) { return; ov::snippets::pass::SnippetsTokenization::Config tokenization_config; - // At the moment Snippets supports Transposes in MHA pattern only in FP32 case since - // - ConvertSaturation[BF16->FP32] will be inserted after Parameters and before Transposes in canonicalization stage - // - ConvertSaturation[FP32->BF16] will be inserted after Transposes and before Brgemm in precision propagation stage - // Because of that Transposes won't be fused into Brgemm - // TODO [111813]: Need to update this pipeline to avoid Converts between Transposes and Brgemm on inputs - tokenization_config.mha_token_enable_transpose = (inferencePrecision == ov::element::f32); + // [111813]: At the moment Snippets supports Transpose on output of MHA pattern only if it is an one node between MatMul and Result. + // However there may be Convert [f32->bf16] before Result since: + // - bf16 Brgemm has f32 output; + // - CPU Node Subgraph requires bf16 on output when inference precision is bf16. + // To avoid sitations when Transpose is not alone node between MatMul and Result, + // Plugin disables Transpose tokenization on output + tokenization_config.mha_token_enable_transpose_on_output = (inferencePrecision == ov::element::f32); tokenization_config.minimal_concurrency = parallel_get_num_threads(); // The optimization "SplitDimensionM" depends on target machine (thread count). // To avoid uncontrolled behavior in tests, we disabled the optimization when there is Config::SnippetsMode::IgnoreCallback diff --git a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/mha.cpp b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/mha.cpp index 5c7fb58ca85654..a4936ac4ee3950 100644 --- a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/mha.cpp @@ -290,15 +290,14 @@ INSTANTIATE_TEST_SUITE_P(smoke_MHA, MHATest, ::testing::Values(ov::test::utils::DEVICE_CPU)), MHATest::getTestCaseName); -// Snippets doesn't support Transpose tokenization when inference_precision = bf16 INSTANTIATE_TEST_SUITE_P(smoke_MHA_BF16, MHATest, ::testing::Combine( ::testing::ValuesIn(static_shapes_to_test_representation(inputShapes)), ::testing::Values(std::vector{ ElementType::bf16, ElementType::bf16, ElementType::bf16, ElementType::bf16 }), ::testing::ValuesIn(matMulIn0Precisions), ::testing::ValuesIn(patternTypes), - ::testing::Values(ExpectedNodes{{"Subgraph", 2}, // MHA + Decomposed Transpose - {"Transpose", 3}}), + ::testing::Values(ExpectedNodes{{"Subgraph", 1}, + {"Transpose", 1}}), // Plugin disables tokenization of Transpose on output ::testing::Values(ov::test::utils::DEVICE_CPU)), MHATest::getTestCaseName);