Skip to content

Commit

Permalink
[Snippets][CPU] Added SNIPPETS_REGISTER_PASS_RELATIVE and SNIPPETS_RE…
Browse files Browse the repository at this point in the history
…GISTER_PASS_ABSOLUTE
  • Loading branch information
a-sidorova committed Dec 26, 2023
1 parent eec3e80 commit c7e5aa6
Showing 1 changed file with 25 additions and 22 deletions.
47 changes: 25 additions & 22 deletions src/plugins/intel_cpu/src/nodes/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,27 +329,30 @@ void Snippet::initOptimalPrimitiveDescriptor() {
#if defined(OPENVINO_ARCH_X86_64)
using PassPosition = ov::snippets::pass::PassPosition;
using Place = PassPosition::Place;
# define SNIPPETS_REGISTER_PASS(PASS_POS, PASS, ...) \
backend_passes.emplace_back(PASS_POS, std::make_shared<PASS>(__VA_ARGS__))
# define SNIPPETS_REGISTER_PASS_ABSOLUTE(PASS_PLACE, PASS, ...) \
backend_passes.emplace_back(PassPosition(PASS_PLACE), std::make_shared<PASS>(__VA_ARGS__))
# define SNIPPETS_REGISTER_PASS_RELATIVE(PASS_PLACE, TARGET_PASS, PASS, ...) \
backend_passes.emplace_back(PassPosition(PASS_PLACE, TARGET_PASS::get_type_info_static()), std::make_shared<PASS>(__VA_ARGS__))
#else
# define SNIPPETS_REGISTER_PASS(PASS_POS, PASS, ...)
# define SNIPPETS_REGISTER_PASS_ABSOLUTE(PASS_PLACE, PASS, ...)
# define SNIPPETS_REGISTER_PASS_RELATIVE(PASS_PLACE, TARGET_PASS, PASS, ...)
#endif // OPENVINO_ARCH_X86_64

SNIPPETS_REGISTER_PASS(PassPosition(Place::PipelineStart), ConvertToSwishCPU);
SNIPPETS_REGISTER_PASS_ABSOLUTE(Place::PipelineStart, ConvertToSwishCPU);
if (context->getConfig().inferencePrecision == ov::element::bf16 && snippetAttrs.snippet->has_domain_sensitive_ops()) {
// enforce BF16 precisions to supported operations
// MatMul has to be decomposed to Brgemm operations before enforcement
// Note, MatMul decomposition will be run later again for case if BF16 enforcement is not happened
SNIPPETS_REGISTER_PASS(PassPosition(Place::PipelineStart), ov::snippets::pass::MatMulToBrgemm);
SNIPPETS_REGISTER_PASS(PassPosition(Place::After, ov::snippets::pass::MatMulToBrgemm::get_type_info_static()),
pass::EnforcePrecision, element::f32, element::bf16);
SNIPPETS_REGISTER_PASS_ABSOLUTE(Place::PipelineStart, ov::snippets::pass::MatMulToBrgemm);
SNIPPETS_REGISTER_PASS_RELATIVE(Place::After, ov::snippets::pass::MatMulToBrgemm,
pass::EnforcePrecision, element::f32, element::bf16);
}
SNIPPETS_REGISTER_PASS(PassPosition(Place::Before, ov::snippets::pass::PropagatePrecision::get_type_info_static()),
ov::intel_cpu::pass::BrgemmToBrgemmCPU);
SNIPPETS_REGISTER_PASS(PassPosition(Place::After, ov::intel_cpu::pass::BrgemmToBrgemmCPU::get_type_info_static()),
ov::intel_cpu::pass::SetBrgemmCPUBlockingParams);
SNIPPETS_REGISTER_PASS(PassPosition(Place::PipelineEnd), ov::intel_cpu::pass::RemoveConverts);
SNIPPETS_REGISTER_PASS(PassPosition(Place::PipelineEnd), ov::intel_cpu::pass::MulAddToFMA);
SNIPPETS_REGISTER_PASS_RELATIVE(Place::Before, ov::snippets::pass::PropagatePrecision,
ov::intel_cpu::pass::BrgemmToBrgemmCPU);
SNIPPETS_REGISTER_PASS_RELATIVE(Place::After, ov::intel_cpu::pass::BrgemmToBrgemmCPU,
ov::intel_cpu::pass::SetBrgemmCPUBlockingParams);
SNIPPETS_REGISTER_PASS_ABSOLUTE(Place::PipelineEnd, ov::intel_cpu::pass::RemoveConverts);
SNIPPETS_REGISTER_PASS_ABSOLUTE(Place::PipelineEnd, ov::intel_cpu::pass::MulAddToFMA);

#undef SNIPPETS_REGISTER_PASS

Expand Down Expand Up @@ -596,18 +599,18 @@ void Snippet::SnippetJitExecutor::generate(const jit_snippets_compile_args* jcp)
#if defined(OPENVINO_ARCH_X86_64)
using PassPosition = ov::snippets::pass::PassPosition;
using Place = PassPosition::Place;
# define SNIPPETS_REGISTER_PASS(PASS_POS, PASS, ...) \
backend_passes.emplace_back(PASS_POS, std::make_shared<PASS>(__VA_ARGS__))
# define SNIPPETS_REGISTER_PASS_RELATIVE(PASS_PLACE, TARGET_PASS, PASS, ...) \
backend_passes.emplace_back(PassPosition(PASS_PLACE, TARGET_PASS::get_type_info_static()), std::make_shared<PASS>(__VA_ARGS__))
#else
# define SNIPPETS_REGISTER_PASS(PASS_POS, PASS, ...)
# define SNIPPETS_REGISTER_PASS_RELATIVE(PASS_PLACE, TARGET_PASS, PASS, ...)
#endif // OPENVINO_ARCH_X86_64

SNIPPETS_REGISTER_PASS(PassPosition(Place::After, ov::snippets::lowered::pass::MarkLoops::get_type_info_static()),
ov::intel_cpu::pass::BrgemmBlocking);
SNIPPETS_REGISTER_PASS(PassPosition(Place::After, ov::snippets::lowered::pass::InsertLoops::get_type_info_static()),
ov::intel_cpu::pass::FuseLoadStoreConvert);
SNIPPETS_REGISTER_PASS(PassPosition(Place::After, ov::intel_cpu::pass::FuseLoadStoreConvert::get_type_info_static()),
ov::intel_cpu::pass::SetBrgemmCopyBBuffersShape);
SNIPPETS_REGISTER_PASS_RELATIVE(Place::After, ov::snippets::lowered::pass::MarkLoops,
ov::intel_cpu::pass::BrgemmBlocking);
SNIPPETS_REGISTER_PASS_RELATIVE(Place::After, ov::snippets::lowered::pass::InsertLoops,
ov::intel_cpu::pass::FuseLoadStoreConvert);
SNIPPETS_REGISTER_PASS_RELATIVE(Place::After, ov::intel_cpu::pass::FuseLoadStoreConvert,
ov::intel_cpu::pass::SetBrgemmCopyBBuffersShape);

schedule = snippetAttrs.snippet->generate_from_linear_ir(std::make_shared<ov::snippets::lowered::pass::PassConfig>(),
backend_passes,
Expand Down

0 comments on commit c7e5aa6

Please sign in to comment.