Skip to content

Commit

Permalink
[luci/svc] Fix Reshape shape inference (#14234)
Browse files Browse the repository at this point in the history
This will fix Reshape shape inference to refer newShape attribute.

ONE-DCO-1.0-Signed-off-by: SaeHie Park <[email protected]>
  • Loading branch information
seanshpark authored Oct 17, 2024
1 parent 5803423 commit 190b538
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 6 deletions.
35 changes: 29 additions & 6 deletions compiler/luci/service/src/Nodes/CircleReshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,35 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
}
else
{
auto shape_node = loco::must_cast<luci::CircleNode *>(node->shape());
assert(shape_node->rank() == 1);
// shape_node tensor values will provide new shape, like [2, 3, 4]
auto num_elements = shape_node->dim(0).value(); // above example will give 3
shape_by_input.rank(num_elements);
is_static_shape = false;
// NOTE assumption is that `shape` and `newShape` having same value.
// for non-existing `shape`, we can use `newShape` if it's valid
auto new_shape = node->newShape();
auto rank = new_shape->rank();
auto shape_dummy = dynamic_cast<luci::CircleOutputDummy *>(node->shape());
if (shape_dummy && rank > 0)
{
is_static_shape = true;
shape_by_input.rank(rank);
for (uint32_t i = 0; i < rank; ++i)
{
if (new_shape->dim(i) > 0)
shape_by_input.dim(i) = static_cast<uint32_t>(new_shape->dim(i));
else
{
is_static_shape = false;
shape_by_input.dim(i).unset();
}
}
}
else
{
auto shape_node = loco::must_cast<luci::CircleNode *>(node->shape());
assert(shape_node->rank() == 1);
// shape_node tensor values will provide new shape, like [2, 3, 4]
auto num_elements = shape_node->dim(0).value(); // above example will give 3
shape_by_input.rank(num_elements);
is_static_shape = false;
}
}
}

Expand Down
35 changes: 35 additions & 0 deletions compiler/luci/service/src/Nodes/CircleReshape.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,38 @@ TEST(ShapeRuleTest, reshape_by_input_node)
ASSERT_FALSE(output_shape.dim(0).known());
ASSERT_FALSE(output_shape.dim(1).known());
}

TEST(ShapeRuleTest, reshape_by_newShape)
{
auto g = loco::make_graph();
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
auto tensor_input = g->nodes()->create<luci::CircleInput>();
auto shape_dummy = g->nodes()->create<luci::CircleOutputDummy>();

tensor_input->dtype(loco::DataType::S32);
tensor_input->shape({2, 3, 4});
tensor_input->shape_status(luci::ShapeStatus::VALID);

shape_dummy->dtype(loco::DataType::S32);
shape_dummy->shape({});
shape_dummy->shape_status(luci::ShapeStatus::VALID);

node_reshape->tensor(tensor_input);
node_reshape->shape(shape_dummy);

// reshape to {2, 12}
node_reshape->newShape()->rank(2);
node_reshape->newShape()->dim(0) = 2;
node_reshape->newShape()->dim(1) = 12;

loco::TensorShape output_shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape));

ASSERT_EQ(2, output_shape.rank());
ASSERT_TRUE(output_shape.dim(0).known());
ASSERT_TRUE(output_shape.dim(1).known());
ASSERT_EQ(2, output_shape.dim(0).value());
ASSERT_EQ(12, output_shape.dim(1).value());
}

0 comments on commit 190b538

Please sign in to comment.