Skip to content

Commit

Permalink
[luci/service] Fix dynamic shape inference for reshape operation
Browse files Browse the repository at this point in the history
This commit prevents shape inference when the reshape tensor's dimensions are not fully known.

ONE-DCO-1.0-Signed-off-by: Jongwon Yang <[email protected]>
  • Loading branch information
jongwonyang committed Sep 24, 2024
1 parent 82a7d90 commit 54a1bcc
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
10 changes: 8 additions & 2 deletions compiler/luci/service/src/Nodes/CircleReshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,14 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
uint32_t input_element_count = 1;
uint32_t output_element_count = 1;
uint32_t unknown_dim_index = UINT32_MAX;
bool is_static_shape = true;
for (uint32_t i = 0; i < input_shape.rank(); ++i)
input_element_count *= (input_shape.dim(i).known() ? input_shape.dim(i).value() : 1);
{
if (input_shape.dim(i).known())
input_element_count *= input_shape.dim(i).value();
else
is_static_shape = false;
}
for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index)
{
const uint32_t dim_value = output_shape.dim(dim_index).value();
Expand All @@ -153,7 +159,7 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
output_element_count *= dim_value;
}
}
if (unknown_dim_index != UINT32_MAX)
if (unknown_dim_index != UINT32_MAX && is_static_shape)
{
output_shape.dim(unknown_dim_index) = input_element_count / output_element_count;
}
Expand Down
32 changes: 32 additions & 0 deletions compiler/luci/service/src/Nodes/CircleReshape.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,35 @@ TEST(ShapeRuleTest, reshape_by_input_const_dynamic)
ASSERT_EQ(6, output_shape.dim(0).value());
ASSERT_EQ(4, output_shape.dim(1).value());
}

TEST(ShapeRuleTest, reshape_should_infer)
{
auto g = loco::make_graph();
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
auto tensor_input = g->nodes()->create<luci::CircleInput>();
auto shape_by_input = g->nodes()->create<luci::CircleConst>();

tensor_input->dtype(loco::DataType::S32);
tensor_input->shape({0, 3, 4});
tensor_input->dim(0).unset();
tensor_input->shape_status(luci::ShapeStatus::VALID);

shape_by_input->dtype(loco::DataType::S32);
shape_by_input->size<loco::DataType::S32>(2);
shape_by_input->at<loco::DataType::S32>(0) = -1;
shape_by_input->at<loco::DataType::S32>(1) = 4;
shape_by_input->shape_status(luci::ShapeStatus::VALID);

node_reshape->tensor(tensor_input);
node_reshape->shape(shape_by_input);

loco::TensorShape output_shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape));

ASSERT_EQ(2, output_shape.rank());
ASSERT_FALSE(output_shape.dim(0).known());
ASSERT_TRUE(output_shape.dim(1).known());
ASSERT_EQ(4, output_shape.dim(1).value());
}

0 comments on commit 54a1bcc

Please sign in to comment.