diff --git a/src/common/snippets/include/snippets/lowered/buffer_manager.hpp b/src/common/snippets/include/snippets/lowered/buffer_manager.hpp index b62806e8a30c88..a6fe680833c835 100644 --- a/src/common/snippets/include/snippets/lowered/buffer_manager.hpp +++ b/src/common/snippets/include/snippets/lowered/buffer_manager.hpp @@ -24,13 +24,11 @@ class BufferManager { static int64_t allocate(const lowered::LinearIR& linear_ir); private: - using BufferSystem = std::vector; using BufferCluster = std::set; using BufferClusters = std::vector; - static BufferSystem create_buffer_system(const lowered::LinearIR& linear_ir); - static size_t init_default(const BufferSystem& buffer_system); - static BufferClusters init_clusters(const BufferSystem& buffer_system); + static size_t init_default(const lowered::LinearIR& linear_ir); + static BufferClusters init_clusters(const lowered::LinearIR& linear_ir); static std::vector init_boxes(const BufferClusters& buffer_clusters); static void set_buffer_offset(const ExpressionPtr& buffer_expr, const size_t offset); diff --git a/src/common/snippets/include/snippets/lowered/linear_ir.hpp b/src/common/snippets/include/snippets/lowered/linear_ir.hpp index 511894a030eeb3..9f02426f03b034 100644 --- a/src/common/snippets/include/snippets/lowered/linear_ir.hpp +++ b/src/common/snippets/include/snippets/lowered/linear_ir.hpp @@ -95,7 +95,7 @@ class LinearIR { iterator find_after(iterator it, const ExpressionPtr& target) const; void init_emitters(const std::shared_ptr& target); - void serialize(const std::string& xml, const std::string& bin); + void serialize(const std::string& xml, const std::string& bin) const; class LoopManager; using LoopManagerPtr = std::shared_ptr; diff --git a/src/common/snippets/src/lowered/buffer_manager.cpp b/src/common/snippets/src/lowered/buffer_manager.cpp index 6cc025837fc48e..d8d1de95e8c44b 100644 --- a/src/common/snippets/src/lowered/buffer_manager.cpp +++ b/src/common/snippets/src/lowered/buffer_manager.cpp @@ -19,11 +19,14 @@ namespace lowered { int64_t BufferManager::allocate(const lowered::LinearIR& linear_ir) { OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::BufferManager::allocate") - const auto buffer_system = create_buffer_system(linear_ir); - auto scratchpad_size = init_default(buffer_system); + int64_t order = 0; + for (const auto& expr : linear_ir) { + ov::snippets::pass::SetTopologicalOrder(expr->get_node(), order++); + } + auto scratchpad_size = init_default(linear_ir); if (m_enable_optimizations) { - const auto buffer_clusters = init_clusters(buffer_system); + const auto buffer_clusters = init_clusters(linear_ir); const auto boxes = init_boxes(buffer_clusters); MemorySolver staticMemSolver(boxes); @@ -40,44 +43,121 @@ int64_t BufferManager::allocate(const lowered::LinearIR& linear_ir) { return scratchpad_size; } -BufferManager::BufferSystem BufferManager::create_buffer_system(const lowered::LinearIR& linear_ir) { - int64_t order = 0; - BufferSystem system; +size_t BufferManager::init_default(const lowered::LinearIR& linear_ir) { + size_t buffer_id = 0; + size_t buffer_offset = 0; for (const auto& expr : linear_ir) { const auto op = expr->get_node(); if (const auto buffer = ov::as_type_ptr(op)) { - system.push_back(expr); + const auto byte_size = buffer->get_byte_size(); + set_buffer_offset(expr, buffer_offset); + buffer->set_id(buffer_id); + buffer_offset += byte_size; + buffer_id++; } - ov::snippets::pass::SetTopologicalOrder(op, order++); } - return system; + return buffer_offset; } -size_t BufferManager::init_default(const BufferSystem& buffer_system) { - size_t buffer_id = 0; - size_t buffer_offset = 0; - for (const auto& buffer_expr : buffer_system) { - const auto node = buffer_expr->get_node(); - const auto buffer = ov::as_type_ptr(node); - if (!buffer) - continue; +BufferManager::BufferClusters BufferManager::init_clusters(const lowered::LinearIR& linear_ir) { + BufferClusters buffer_clusters; + auto find_cluster = [&buffer_clusters](const ExpressionPtr& target) { + for (auto it = buffer_clusters.begin(); it != buffer_clusters.end(); ++it) { + if (it->count(target) > 0) + return it; + } + return buffer_clusters.end(); + }; + auto create_cluster = [&buffer_clusters, &find_cluster](const ExpressionPtr& target_expr, const ExpressionPtr& node_expr) { + const auto buffer = ov::as_type_ptr(target_expr->get_node()); + // Buffer must be explicitly source for the target LoopEnd expr or MemoryAccess op (there aren't other loop between them) + if (buffer && target_expr->get_loop_ids() == node_expr->get_loop_ids()) { + const auto cluster_it = find_cluster(target_expr); + // If Buffer is missed in clusters, create new cluster with the single Buffer node inside + if (cluster_it == buffer_clusters.cend()) { + buffer_clusters.push_back(BufferCluster{target_expr}); + } + return true; + } + return false; + }; - const auto byte_size = buffer->get_byte_size(); - set_buffer_offset(buffer_expr, buffer_offset); - buffer->set_id(buffer_id); + for (const auto& expr : linear_ir) { + const auto op = expr->get_node(); + if (const auto loop_end = ov::as_type_ptr(op)) { + const auto ptr_increments = loop_end->get_ptr_increments(); + const auto final_offsets = loop_end->get_finalization_offsets(); + const auto in_count = loop_end->get_input_num(); + const auto out_count = loop_end->get_output_num(); + const auto connectors = expr->get_input_port_connectors(); - buffer_offset += byte_size; - buffer_id++; + std::unordered_map> input_buffers; + for (size_t i = 0; i < in_count; ++i) { + const auto source_expr = connectors[i]->get_source().get_expr(); + const auto is_buffer = create_cluster(source_expr, expr); + if (is_buffer) { + // Save as input Buffer + const auto ret = input_buffers.insert(std::make_pair(source_expr, std::set{ i })).second; + if (!ret) + input_buffers[source_expr].insert(i); + } + } + for (size_t i = in_count; i < in_count + out_count; ++i) { + for (const auto& consumer : connectors[i]->get_consumers()) { + auto consumer_expr = consumer.get_expr(); + const auto buffer = ov::as_type_ptr(consumer_expr->get_node()); + // Buffer must be explicitly source for the target LoopEnd expr (there aren't other loop between them) + if (buffer && consumer_expr->get_loop_ids() == expr->get_loop_ids()) { + bool has_been_added = false; + for (const auto& input_buffer : input_buffers) { + const auto& input_buffer_expr = input_buffer.first; + const auto input_buffer_node = ov::as_type_ptr(input_buffer_expr->get_node()); + const auto& input_buffer_idxs = input_buffer.second; + for (const auto& input_buffer_idx : input_buffer_idxs) { + if (input_buffer_node->get_byte_size() == buffer->get_byte_size() && + input_buffer_expr->get_output_port_descriptor(0)->get_layout() == consumer.get_descriptor_ptr()->get_layout() && + ptr_increments[input_buffer_idx] == ptr_increments[i] && + final_offsets[input_buffer_idx] == final_offsets[i]) { + auto cluster_it = find_cluster(input_buffer_expr); + OPENVINO_ASSERT(cluster_it != buffer_clusters.cend(), "Buffer on inputs of Loop must be already saved in clusters"); + // Add to the existing cluster + auto res = cluster_it->insert(consumer_expr); + OPENVINO_ASSERT(res.second, "Buffer has not been saved in cluster"); + // Remove input buffer because we have already use its memory + input_buffers.erase(input_buffer_expr); + has_been_added = res.second; + break; + } + } + if (has_been_added) break; + } + if (!has_been_added) { + buffer_clusters.push_back(BufferCluster{consumer_expr}); + } + } + } + } + continue; + } + // TODO: Some full MemoryAccess ops can have inplace inputs and outputs in general. + // Need to add mechanism of inplace ports using MemoryAccess::PortDescriptor::inplace + if (const auto ma = ov::as_type_ptr(op)) { + if (ma->is_full_memory_access_op()) { + const auto target_loop_ids = expr->get_loop_ids(); + for (const auto& input : expr->get_input_port_connectors()) { + const auto source_expr = input->get_source().get_expr(); + create_cluster(source_expr, expr); + } + for (const auto& output : expr->get_output_port_connectors()) { + for (const auto& consumer : output->get_consumers()) { + const auto consumer_expr = consumer.get_expr(); + create_cluster(consumer_expr, expr); + } + } + } + } } - return buffer_offset; -} -BufferManager::BufferClusters BufferManager::init_clusters(const BufferSystem& buffer_system) { - // TODO: Add support of inplace - BufferClusters buffer_clusters; - for (const auto& buffer_expr : buffer_system) { - buffer_clusters.push_back(BufferCluster{buffer_expr}); - } return buffer_clusters; } diff --git a/src/common/snippets/src/lowered/linear_ir.cpp b/src/common/snippets/src/lowered/linear_ir.cpp index 6246ddef8838a4..f22f0efbc9b0c5 100644 --- a/src/common/snippets/src/lowered/linear_ir.cpp +++ b/src/common/snippets/src/lowered/linear_ir.cpp @@ -66,7 +66,7 @@ ov::NodeVector LinearIR::get_ordered_ops(const std::shared_ptr& m) { return ov::topological_sort(nodes); } -void LinearIR::serialize(const std::string& xml, const std::string& bin) { +void LinearIR::serialize(const std::string& xml, const std::string& bin) const { auto first_node = std::make_shared(element::f32, Shape{}); first_node->set_friendly_name("Start"); first_node->get_rt_info()["execTimeMcs"] = 0;