Skip to content

Commit

Permalink
[Snippets] Added support of repacking tail of input_0 of BrgemmCPU
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Nov 8, 2024
1 parent 2b6673e commit 00f1e7b
Show file tree
Hide file tree
Showing 24 changed files with 569 additions and 165 deletions.
44 changes: 44 additions & 0 deletions src/common/snippets/include/snippets/op/set_scalar.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/op.hpp"
#include "snippets/op/memory_access.hpp"
#include "snippets/shape_inference/shape_inference.hpp"

namespace ov {
namespace snippets {
namespace op {

/**
* @interface SetScalar
* @brief Sets passed 1-byte value to the memory pointer by provided `value_offset`
* @ingroup snippets
*/
class SetScalar : public modifier::MemoryAccess, public ov::op::Op {
public:
OPENVINO_OP("SetScalar", "SnippetsOpset");
SetScalar(uint8_t value = 0x0, size_t value_offset = 0lu, size_t ma_offset = 0lu, ov::element::Type element_type = ov::element::u8);

uint8_t get_value() const { return m_value; }
size_t get_value_offset() const { return m_value_offset; }

void set_value(const uint8_t value) { m_value = value; }

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;

protected:
void validate_memory_access_params() const;

uint8_t m_value = 0x0;
size_t m_value_offset = 0;
ov::element::Type m_element_type = ov::element::u8;
};

} // namespace op
} // namespace snippets
} // namespace ov
1 change: 1 addition & 0 deletions src/common/snippets/include/snippets/snippets_isa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "op/reshape.hpp"
#include "op/nop.hpp"
#include "op/scalar.hpp"
#include "op/set_scalar.hpp"
#include "op/powerstatic.hpp"
#include "op/store.hpp"
#include "op/loop.hpp"
Expand Down
1 change: 1 addition & 0 deletions src/common/snippets/src/generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ RegType Generator::get_op_out_reg_type(const ov::Output<Node>& out) const {
std::dynamic_pointer_cast<op::Buffer>(op) ||
std::dynamic_pointer_cast<op::RankNormalization>(op) ||
std::dynamic_pointer_cast<op::Reshape>(op) ||
std::dynamic_pointer_cast<snippets::op::SetScalar>(op) ||
std::dynamic_pointer_cast<snippets::op::Store>(op)
#ifdef SNIPPETS_DEBUG_CAPS
|| std::dynamic_pointer_cast<op::PerfCountBeginBase>(op)
Expand Down
2 changes: 2 additions & 0 deletions src/common/snippets/src/op/memory_access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ void MemoryAccess::set_output_port_descriptor(const PortDescriptor& desc, const

const MemoryAccess::PortDescriptor& MemoryAccess::get_input_port_descriptor(const size_t i) const {
const auto it = m_input_ports.find(i);
if (it == m_input_ports.end())
std::cout << std::endl;
OPENVINO_ASSERT(it != m_input_ports.end(), "Index of input port descriptor should be less than count of input ports");
return it->second;
}
Expand Down
53 changes: 53 additions & 0 deletions src/common/snippets/src/op/set_scalar.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "snippets/op/set_scalar.hpp"

#include "snippets/utils/utils.hpp"
#include "snippets/itt.hpp"


namespace ov {
namespace snippets {
namespace op {

SetScalar::SetScalar(uint8_t value, size_t value_offset, size_t ma_offset, ov::element::Type element_type)
: MemoryAccess(std::set<size_t>{}, std::set<size_t>{0}), Op(), m_value(value), m_value_offset(value_offset), m_element_type(element_type) {
set_output_port_descriptor({1, ma_offset}, 0);
set_output_size(1);
constructor_validate_and_infer_types();
}

void SetScalar::validate_memory_access_params() const {
// SetScalar has memory access port only on output
const auto input_ma_ports = get_memory_access_input_ports();
const auto output_ma_ports = get_memory_access_output_ports();
OPENVINO_ASSERT(input_ma_ports.size() == 0, "SetScalar node shouldn't have memory access input port");
OPENVINO_ASSERT(output_ma_ports.size() == 1 && is_memory_access_output_port(0), "SetScalar node must have one memory access output port");
}

void SetScalar::validate_and_infer_types() {
INTERNAL_OP_SCOPE(SetScalar_validate_and_infer_types);
OPENVINO_ASSERT(!utils::is_dynamic_value(m_value_offset), "Value offset must be static");
validate_memory_access_params();
set_output_type(0, m_element_type, ov::Shape{1});
}

bool SetScalar::visit_attributes(AttributeVisitor& visitor) {
INTERNAL_OP_SCOPE(SetScalar_visit_attributes);
auto i32_value = static_cast<uint32_t>(m_value);
visitor.on_attribute("value", i32_value);
visitor.on_attribute("value_offset", m_value_offset);
return MemoryAccess::visit_attributes(visitor);
}

std::shared_ptr<Node> SetScalar::clone_with_new_inputs(const OutputVector& new_args) const {
INTERNAL_OP_SCOPE(SetScalar_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<SetScalar>(m_value, m_value_offset, get_output_offset(0), m_element_type);
}

}// namespace op
}// namespace snippets
}// namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ const IShapeInferSnippetsFactory::TRegistry IShapeInferSnippetsFactory::registry
SHAPE_INFER_PREDEFINED(op::LoopBegin, SingleElementShapeInfer),
SHAPE_INFER_PREDEFINED(op::Scalar, SingleElementShapeInfer),
SHAPE_INFER_PREDEFINED(op::VectorBuffer, SingleElementShapeInfer),
SHAPE_INFER_PREDEFINED(op::SetScalar, SingleElementShapeInfer),
SHAPE_INFER_PREDEFINED(op::LoopEnd, EmptyShapeInfer),
#ifdef SNIPPETS_DEBUG_CAPS
SHAPE_INFER_PREDEFINED(op::PerfCountBegin, EmptyShapeInfer),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,9 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho
jitters[ov::op::v4::HSwish::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_hswish_emitter);
jitters[ov::op::v0::Gelu::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_gelu_v0_emitter);
jitters[ov::op::v7::Gelu::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_gelu_v7_emitter);

jitters[snippets::op::Fill::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_fill_emitter);
jitters[snippets::op::SetScalar::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_set_emitter);

jitters[snippets::op::HorizonMax::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_horizon_emitter);
jitters[snippets::op::HorizonSum::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_horizon_emitter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

#include "jit_snippets_emitters.hpp"

#include "emitters/snippets/jit_snippets_call_args.hpp"
#include "utils.hpp"

using namespace Xbyak;
using namespace dnnl::impl;
using namespace dnnl::impl::cpu::x64;
Expand Down Expand Up @@ -105,5 +108,30 @@ void jit_scalar_emitter::emit_isa(const std::vector<size_t> &in, const std::vect
}


jit_set_emitter::jit_set_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa,
const ov::snippets::lowered::ExpressionPtr& expr)
: jit_emitter(h, isa, ov::element::f32, emitter_in_out_map::vec_to_vec) {
const auto setter = ov::as_type_ptr<snippets::op::SetScalar>(expr->get_node());

m_value = setter->get_value();
m_value_offset = setter->get_value_offset();
m_buffer_offset = setter->get_output_offset(0);
m_buffer_cluster_id = utils::get_buffer_cluster_id(expr->get_output_port(0));
}

void jit_set_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
auto reg = Xbyak::Reg64(out.front());

if (ov::snippets::utils::is_dynamic_value(m_buffer_offset)) {
const auto aux_reg = Xbyak::Reg64(aux_gpr_idxs.front());
const auto& offset = h->ptr[abi_param1 + GET_OFF(buffer_offsets) + m_buffer_cluster_id * sizeof(size_t)];
h->mov(aux_reg, reg);
h->add(aux_reg, offset);
h->mov(h->byte[aux_reg + m_value_offset], m_value);
} else {
h->mov(h->byte[reg + m_buffer_offset + m_value_offset], m_value);
}
}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class jit_scalar_emitter : public jit_emitter {

size_t get_inputs_num() const override {return 0;}
size_t aux_gprs_count() const override {return 1;}

static int32_t read_value(const ov::snippets::lowered::ExpressionPtr& expr);

private:
Expand All @@ -68,5 +69,22 @@ class jit_scalar_emitter : public jit_emitter {
void emit_isa(const std::vector<size_t> &in, const std::vector<size_t> &out) const;
};

class jit_set_emitter : public jit_emitter {
public:
jit_set_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa,
const ov::snippets::lowered::ExpressionPtr& expr);

size_t get_inputs_num() const override { return 0; }
size_t aux_gprs_count() const override { return ov::snippets::utils::is_dynamic_value(m_buffer_offset) ? 1 : 0; }

private:
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;

size_t m_buffer_offset = 0;
size_t m_buffer_cluster_id = 0;
uint8_t m_value = 0x0;
size_t m_value_offset = 0;
};

} // namespace intel_cpu
} // namespace ov
Loading

0 comments on commit 00f1e7b

Please sign in to comment.