diff --git a/compiler/luci/service/src/Nodes/CircleReshape.cpp b/compiler/luci/service/src/Nodes/CircleReshape.cpp index 778f8d45762..d30337193eb 100644 --- a/compiler/luci/service/src/Nodes/CircleReshape.cpp +++ b/compiler/luci/service/src/Nodes/CircleReshape.cpp @@ -138,8 +138,14 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) uint32_t input_element_count = 1; uint32_t output_element_count = 1; uint32_t unknown_dim_index = UINT32_MAX; + bool should_infer = true; for (uint32_t i = 0; i < input_shape.rank(); ++i) - input_element_count *= (input_shape.dim(i).known() ? input_shape.dim(i).value() : 1); + { + if (input_shape.dim(i).known()) + input_element_count *= input_shape.dim(i).value(); + else + should_infer = false; + } for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index) { const uint32_t dim_value = output_shape.dim(dim_index).value(); @@ -153,7 +159,7 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) output_element_count *= dim_value; } } - if (unknown_dim_index != UINT32_MAX) + if (unknown_dim_index != UINT32_MAX && should_infer) { output_shape.dim(unknown_dim_index) = input_element_count / output_element_count; } diff --git a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp index a6ae6735500..65382ce1f8e 100644 --- a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp +++ b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp @@ -102,3 +102,35 @@ TEST(ShapeRuleTest, reshape_by_input_const_dynamic) ASSERT_EQ(6, output_shape.dim(0).value()); ASSERT_EQ(4, output_shape.dim(1).value()); } + +TEST(ShapeRuleTest, reshape_should_infer) +{ + auto g = loco::make_graph(); + auto node_reshape = g->nodes()->create(); + auto tensor_input = g->nodes()->create(); + auto shape_by_input = g->nodes()->create(); + + tensor_input->dtype(loco::DataType::S32); + tensor_input->shape({0, 3, 4}); + tensor_input->dim(0).unset(); + tensor_input->shape_status(luci::ShapeStatus::VALID); + + shape_by_input->dtype(loco::DataType::S32); + shape_by_input->size(2); + shape_by_input->at(0) = -1; + shape_by_input->at(1) = 4; + shape_by_input->shape_status(luci::ShapeStatus::VALID); + + node_reshape->tensor(tensor_input); + node_reshape->shape(shape_by_input); + + loco::TensorShape output_shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape)); + + ASSERT_EQ(2, output_shape.rank()); + ASSERT_FALSE(output_shape.dim(0).known()); + ASSERT_TRUE(output_shape.dim(1).known()); + ASSERT_EQ(4, output_shape.dim(1).value()); +}