diff --git a/compiler/luci/service/src/Nodes/CircleReshape.cpp b/compiler/luci/service/src/Nodes/CircleReshape.cpp index 553e1eabd5d..025ad4912a1 100644 --- a/compiler/luci/service/src/Nodes/CircleReshape.cpp +++ b/compiler/luci/service/src/Nodes/CircleReshape.cpp @@ -117,12 +117,35 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) } else { - auto shape_node = loco::must_cast(node->shape()); - assert(shape_node->rank() == 1); - // shape_node tensor values will provide new shape, like [2, 3, 4] - auto num_elements = shape_node->dim(0).value(); // above example will give 3 - shape_by_input.rank(num_elements); - is_static_shape = false; + // NOTE assumption is that `shape` and `newShape` having same value. + // for non-existing `shape`, we can use `newShape` if it's valid + auto new_shape = node->newShape(); + auto rank = new_shape->rank(); + auto shape_dummy = dynamic_cast(node->shape()); + if (shape_dummy && rank > 0) + { + is_static_shape = true; + shape_by_input.rank(rank); + for (uint32_t i = 0; i < rank; ++i) + { + if (new_shape->dim(i) > 0) + shape_by_input.dim(i) = static_cast(new_shape->dim(i)); + else + { + is_static_shape = false; + shape_by_input.dim(i).unset(); + } + } + } + else + { + auto shape_node = loco::must_cast(node->shape()); + assert(shape_node->rank() == 1); + // shape_node tensor values will provide new shape, like [2, 3, 4] + auto num_elements = shape_node->dim(0).value(); // above example will give 3 + shape_by_input.rank(num_elements); + is_static_shape = false; + } } } diff --git a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp index 4bb13edc2f9..653cb690d18 100644 --- a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp +++ b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp @@ -162,3 +162,38 @@ TEST(ShapeRuleTest, reshape_by_input_node) ASSERT_FALSE(output_shape.dim(0).known()); ASSERT_FALSE(output_shape.dim(1).known()); } + +TEST(ShapeRuleTest, reshape_by_newShape) +{ + auto g = loco::make_graph(); + auto node_reshape = g->nodes()->create(); + auto tensor_input = g->nodes()->create(); + auto shape_dummy = g->nodes()->create(); + + tensor_input->dtype(loco::DataType::S32); + tensor_input->shape({2, 3, 4}); + tensor_input->shape_status(luci::ShapeStatus::VALID); + + shape_dummy->dtype(loco::DataType::S32); + shape_dummy->shape({}); + shape_dummy->shape_status(luci::ShapeStatus::VALID); + + node_reshape->tensor(tensor_input); + node_reshape->shape(shape_dummy); + + // reshape to {2, 12} + node_reshape->newShape()->rank(2); + node_reshape->newShape()->dim(0) = 2; + node_reshape->newShape()->dim(1) = 12; + + loco::TensorShape output_shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape)); + + ASSERT_EQ(2, output_shape.rank()); + ASSERT_TRUE(output_shape.dim(0).known()); + ASSERT_TRUE(output_shape.dim(1).known()); + ASSERT_EQ(2, output_shape.dim(0).value()); + ASSERT_EQ(12, output_shape.dim(1).value()); +}