Skip to content

Commit

Permalink
[luci/service] Support CircleConst as reshape's shape
Browse files Browse the repository at this point in the history
This commit enhances robustness in handling CircleConst as reshape's
shape.

ONE-DCO-1.0-Signed-off-by: Jongwon Yang <[email protected]>
  • Loading branch information
jongwonyang committed Sep 19, 2024
1 parent 9028169 commit ae2e783
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
16 changes: 13 additions & 3 deletions compiler/luci/service/src/Nodes/CircleReshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleReshape *node)
namespace sinf
{

/**
* @note CircleReshape always has two inputs: `tensor` and `shape`.
* The `shape` can be CircleConst, CircleOutputDummy, or CircleNode.
* - If the `shape` is CircleConst, the shape is inferred from the constant.
* - Else, the shape is inferred from the node iteself.
* - TODO support CircleOutputDummy and CircleNode
*/
loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
{
LOGGER(l);
Expand All @@ -77,8 +84,7 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)

// Only support node's shape() is CircleConst with S32
// TODO support other node with other types
auto const_shape_node = dynamic_cast<luci::CircleConst *>(node->shape());
if (const_shape_node != nullptr)
if (auto const_shape_node = dynamic_cast<luci::CircleConst *>(node->shape()))
{
LUCI_ASSERT(const_shape_node->dtype() == S32, "Only support int32 CircleConst");

Expand All @@ -87,6 +93,10 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis)
{
shape_by_input.dim(axis) = const_shape_node->at<S32>(axis);
if (const_shape_node->at<S32>(axis) < 0)
{
shape_by_input.dim(axis).unset();
}
}
}
else
Expand Down Expand Up @@ -139,7 +149,7 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index)
{
const uint32_t dim_value = output_shape.dim(dim_index).value();
if (static_cast<int>(dim_value) == -1)
if (output_shape.dim(dim_index).known() == false)
{
LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension");
unknown_dim_index = dim_index;
Expand Down
4 changes: 2 additions & 2 deletions compiler/luci/service/src/Nodes/CircleReshape.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ TEST(CloneNodeTest, clone_Reshape)
ASSERT_EQ(node_reshape->newShape()->dim(1), cloned_reshape->newShape()->dim(1));
}

TEST(ShapeRuleTest, reshape_by_input_const_static)
TEST(ShapeRuleTest, reshape_by_const_static)
{
auto g = loco::make_graph();
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
Expand Down Expand Up @@ -71,7 +71,7 @@ TEST(ShapeRuleTest, reshape_by_input_const_static)
ASSERT_EQ(4, output_shape.dim(1).value());
}

TEST(ShapeRuleTest, reshape_by_input_const_dynamic)
TEST(ShapeRuleTest, reshape_by_const_dynamic)
{
auto g = loco::make_graph();
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
Expand Down

0 comments on commit ae2e783

Please sign in to comment.