Skip to content

Commit

Permalink
[PT FE] Support aten::randint and aten::index_put_ on mask
Browse files Browse the repository at this point in the history
  • Loading branch information
mvafin committed Aug 11, 2023
1 parent 45a6063 commit 1656c55
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 14 deletions.
34 changes: 28 additions & 6 deletions src/frontends/pytorch/src/op/rand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,6 @@ OutputVector translate_randn(const NodeContext& context) {
sizes = concat_list_construct(sizes);
}
sizes = context.mark_node(std::make_shared<v0::Convert>(sizes, element::i32));
auto low = context.mark_node(v0::Constant::create(element::f32, Shape{1}, {0}));
auto high = context.mark_node(v0::Constant::create(element::f32, Shape{1}, {1}));
auto dtype = element::f32;
size_t out_id = 1;
if (context.get_input_size() == 3) {
Expand Down Expand Up @@ -202,8 +200,6 @@ OutputVector translate_randn(const NodeContext& context) {
if (std::dynamic_pointer_cast<v0::Constant>(
context.get_input_from_visible_context(dtype_id).get_node_shared_ptr())) {
dtype = convert_dtype(context.const_input<int64_t>(dtype_id));
low = context.mark_node(std::make_shared<v0::Convert>(low, dtype));
high = context.mark_node(std::make_shared<v0::Convert>(low, dtype));
} else if (const auto& fw_node =
cast_fw_node(context.get_input(static_cast<int>(dtype_id)).get_node_shared_ptr(),
"prim::dtype")) {
Expand All @@ -228,8 +224,6 @@ OutputVector translate_randn_like(const NodeContext& context) {
num_inputs_check(context, 3, 6);
auto inp_tensor = context.get_input(0);
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(inp_tensor, element::i32));
auto low = context.mark_node(v0::Constant::create(element::f32, Shape{1}, {0}));
auto high = context.mark_node(v0::Constant::create(element::f32, Shape{1}, {1}));
auto dtype = element::f32;
if (context.get_input_size() == 3) {
auto res = make_random_normal(context, sizes, dtype);
Expand Down Expand Up @@ -259,6 +253,34 @@ OutputVector translate_randn_like(const NodeContext& context) {
return res;
};

OutputVector translate_randint(const NodeContext& context) {
// aten::randint.low(int low, int high, SymInt[] size, *, ScalarType? dtype=4, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
num_inputs_check(context, 7, 7);
auto low = context.get_input(0);
auto high = context.get_input(1);
auto sizes = context.get_input(2);
auto dtype = element::i64;
bool dtype_applied = true;
Output<Node> convert_like_out;
if (!context.input_is_none(3)) {
if (std::dynamic_pointer_cast<v0::Constant>(context.get_input_from_visible_context(3).get_node_shared_ptr())) {
dtype = convert_dtype(context.const_input<int64_t>(3));
} else if (const auto& fw_node = cast_fw_node(context.get_input(static_cast<int>(3)).get_node_shared_ptr(), "prim::dtype")) {
convert_like_out = fw_node->input_value(0);
dtype_applied = false;
} else {
FRONT_END_OP_CONVERSION_CHECK(false, "Couldn't get dtype input");
}
}
low = context.mark_node(std::make_shared<v0::Convert>(low, dtype));
high = context.mark_node(std::make_shared<v0::Convert>(high, dtype));
auto res = context.mark_node(std::make_shared<v8::RandomUniform>(sizes, low, high, dtype));
if (!dtype_applied) {
res = context.mark_node(std::make_shared<v1::ConvertLike>(res, convert_like_out));
}
return {res};
};

} // namespace op
} // namespace pytorch
} // namespace frontend
Expand Down
2 changes: 2 additions & 0 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ OP_CONVERTER(translate_quantized_mul);
OP_CONVERTER(translate_range_length);
OP_CONVERTER(translate_rand);
OP_CONVERTER(translate_randn);
OP_CONVERTER(translate_randint);
OP_CONVERTER(translate_rand_like);
OP_CONVERTER(translate_randn_like);
OP_CONVERTER(translate_reciprocal);
Expand Down Expand Up @@ -378,6 +379,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::quantize_per_tensor", op::translate_quantize_per_tensor},
{"aten::rand", op::translate_rand},
{"aten::randn", op::translate_randn},
{"aten::randint", op::translate_randint},
{"aten::rand_like", op::translate_rand_like},
{"aten::randn_like", op::translate_randn_like},
{"aten::reciprocal", op::translate_reciprocal},
Expand Down
34 changes: 27 additions & 7 deletions src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
#include "openvino/op/convert_like.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/mod.hpp"
#include "openvino/op/non_zero.hpp"
#include "openvino/op/scatter_nd_update.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/split.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/pass/pattern/matcher.hpp"
Expand Down Expand Up @@ -123,14 +125,32 @@ AtenIndexPutReplacer::AtenIndexPutReplacer() {
index = rg.make<v0::Concat>(indices_list, -1);
} else {
index = indices_inputs[0];
// change negative indices to positive indices
auto dim_0 = (rg.make<v8::Gather>(input_shape, const_0, const_0));
auto dim_0_correct_type = (rg.make<v1::ConvertLike>(dim_0, index));
index = rg.make<v1::Add>(index, dim_0_correct_type);
index = rg.make<v1::Mod>(index, dim_0_correct_type);
auto index_dtype = index.get_element_type();
// Do we need to also check u8?
if (index_dtype == element::boolean) {
auto nonzero = rg.make<v3::NonZero>(index, element::i32);
auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0});
index = rg.make<v1::Transpose>(nonzero, input_order);
broadcast_index_shape = rg.make<v3::ShapeOf>(index, element::i32);
auto start_0 = v0::Constant::create(element::i32, Shape{1}, {0});
auto end_neg_1 = v0::Constant::create(element::i32, Shape{1}, {-1});
auto values_shape = rg.make<v8::Slice>(broadcast_index_shape, start_0, end_neg_1, const_1);
values = rg.make<v3::Broadcast>(values, values_shape);
values = rg.make<v1::ConvertLike>(values, input);
auto result = rg.make<v3::ScatterNDUpdate>(input, index, values);
copy_runtime_info_and_name(index_op, rg.get(), rt_copy_from);
replace_node(index_op, result);
return true;
} else {
// change negative indices to positive indices
auto dim_0 = (rg.make<v8::Gather>(input_shape, const_0, const_0));
auto dim_0_correct_type = (rg.make<v1::ConvertLike>(dim_0, index));
index = rg.make<v1::Add>(index, dim_0_correct_type);
index = rg.make<v1::Mod>(index, dim_0_correct_type);

broadcast_index_shape = rg.make<v3::ShapeOf>(index, element::i32);
index = rg.make<v0::Unsqueeze>(index, const_neg_1);
broadcast_index_shape = rg.make<v3::ShapeOf>(index, element::i32);
index = rg.make<v0::Unsqueeze>(index, const_neg_1);
}
}

auto sub_data_shape = rg.make<v8::Slice>(input_shape, const_indices_list_len, const_max_int, const_1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "openvino/op/equal.hpp"
#include "openvino/op/interpolate.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/random_uniform.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/roll.hpp"
#include "openvino/op/select.hpp"
Expand Down Expand Up @@ -60,6 +61,8 @@ ListConstructReplacer::ListConstructReplacer() {
auto interpolate_mul_op = pattern::wrap_type<v1::Multiply>({interpolate_convert_op, pattern::any_input()});
auto interpolate_op =
pattern::wrap_type<v11::Interpolate>({pattern::any_input(), interpolate_mul_op, pattern::any_input()});
// aten::randint case
auto rand_op = pattern::wrap_type<v8::RandomUniform>({list, pattern::any_input(), pattern::any_input()});
auto lc_pattern = std::make_shared<pattern::op::Or>(OutputVector{reshape_op,
roll_op,
broadcast_op,
Expand All @@ -70,7 +73,8 @@ ListConstructReplacer::ListConstructReplacer() {
tile_op,
transpose_op,
vsplit_op,
interpolate_op});
interpolate_op,
rand_op});

ov::matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto& pattern_map = m.get_pattern_value_map();
Expand Down
26 changes: 26 additions & 0 deletions tests/layer_tests/pytorch_tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,29 @@ def test_index_range(self, input_shape, idx, ie_device, precision, ir_version):
def test_index_range_free_dims(self, input_shape, idx, ie_device, precision, ir_version):
self._test(*self.create_model2(), ie_device, precision, ir_version, kwargs_to_prepare_input={
"input_shape": input_shape, "idx": idx}, trace_model=True, dynamic_shapes=False)

class TestIndexMask(PytorchLayerTest):
def _prepare_input(self, input_shape):
import numpy as np
return (np.random.randn(*input_shape).astype(np.float32),)

def create_model(self):
import torch

class aten_index_mask(torch.nn.Module):
def forward(self, x):
return x[x > 0]

ref_net = None

return aten_index_mask(), ref_net, "aten::index"

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize(("input_shape"), ((1, 1),
[2, 3],
[7, 8, 9],
[2, 2, 3, 4]))
def test_index_mask(self, input_shape, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={
"input_shape": input_shape}, trace_model=True)
19 changes: 19 additions & 0 deletions tests/layer_tests/pytorch_tests/test_index_put_.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,22 @@ def test_nonzero_index_put_(self, ie_device, precision, ir_version, input_data,
self.indices_0 = indices[0]
self.indices_1 = indices[1]
self._test(*self.create_model(accumulate), ie_device, precision, ir_version, trace_model=True)

class TestMask_IndexPut(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(100, 5).astype(np.float32),np.random.randn(100, 5).astype(np.float32))

def create_model(self):
class aten_index_put_mask(torch.nn.Module):
def forward(self, x, y):
x[x < 0] = y[x < 0]
return x

ref_net = None

return aten_index_put_mask(), ref_net, "aten::index_put_"

@pytest.mark.nightly
@pytest.mark.precommit
def test_nonzero_index_put_(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True)

0 comments on commit 1656c55

Please sign in to comment.