Skip to content

Commit

Permalink
[PT FE] Remove opset usage
Browse files Browse the repository at this point in the history
  • Loading branch information
mvafin committed Aug 22, 2024
1 parent 486fd0b commit d135a25
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 266 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,21 @@ class NodeContext : public frontend::NodeContext {
// TODO: int due to base class uses it, but naturally it should be size_t for PT
Output<Node> get_input(int index) const override {
size_t index_ = static_cast<size_t>(index);
FRONT_END_GENERAL_CHECK(!m_decoder->input_is_none(index_), "Input doesn't exist with index: ", index);
FRONT_END_GENERAL_CHECK(!m_decoder->input_is_none(index_),
"Input doesn't exist with index: ",
index,
" for operation ",
get_op_type());
auto input = m_decoder_inputs.at(index);
if (input == 0) {
// Case when input can be inlined (possible only for fx decoder)
if (m_decoder->is_input_inlined(index_)) {
auto inlined_input = m_decoder->inlined_input(index_);
FRONT_END_GENERAL_CHECK(inlined_input.size() == 1, "Incorrect inlined input with index:", index);
FRONT_END_GENERAL_CHECK(inlined_input.size() == 1,
"Incorrect inlined input with index: ",
index,
" for operation ",
get_op_type());
return inlined_input[0];
}
}
Expand Down
6 changes: 5 additions & 1 deletion src/frontends/pytorch/src/node_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ OutputVector NodeContext::inputs() const {
// Case when input can be inlined (possible only for fx decoder)
if (m_decoder->is_input_inlined(i)) {
auto inlined_input = m_decoder->inlined_input(i);
FRONT_END_GENERAL_CHECK(inlined_input.size() == 1, "Incorrect inlined input with index:", i);
FRONT_END_GENERAL_CHECK(inlined_input.size() == 1,
"Incorrect inlined input with index: ",
i,
" for operation ",
get_op_type());
res.push_back(inlined_input[0]);
continue;
}
Expand Down
7 changes: 4 additions & 3 deletions src/frontends/pytorch/src/op/log_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
//

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/log_softmax.hpp"
#include "utils.hpp"

namespace ov {
Expand All @@ -30,11 +31,11 @@ OutputVector translate_log_softmax_common(const NodeContext& context, bool is_fx
const auto target_dtype_i64 = context.const_input<int64_t>(2);
const auto target_dtype = convert_dtype(target_dtype_i64);
if (elem_type != target_dtype) {
input = context.mark_node(std::make_shared<opset10::Convert>(input, target_dtype));
input = context.mark_node(std::make_shared<v0::Convert>(input, target_dtype));
}
}

const auto log_softmax = context.mark_node(std::make_shared<opset10::LogSoftmax>(input, dim));
const auto log_softmax = context.mark_node(std::make_shared<v5::LogSoftmax>(input, dim));
return {log_softmax};
};

Expand Down
Loading

0 comments on commit d135a25

Please sign in to comment.