From ccf3da1897511a11caa07481e5eced8db707f772 Mon Sep 17 00:00:00 2001 From: seockho-kim Date: Mon, 14 Oct 2024 10:40:35 +0900 Subject: [PATCH] [luci] Removed beta(bias) from RmsNorm (#14207) This commit removes beta(bias) from RmsNorm in luci. ONE-DCO-1.0-Signed-off-by: Seockho Kim seockho.kim@samsung.com --- compiler/luci/import/src/Nodes/CircleRmsNorm.cpp | 3 +-- compiler/luci/lang/include/luci/IR/Nodes/CircleRmsNorm.h | 5 +---- compiler/luci/lang/src/Nodes/CircleRmsNorm.test.cpp | 9 ++------- compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp | 2 +- compiler/luci/partition/src/Nodes/CircleRmsNorm.cpp | 2 -- compiler/luci/partition/src/Nodes/CircleRmsNorm.test.cpp | 8 +++----- 6 files changed, 8 insertions(+), 21 deletions(-) diff --git a/compiler/luci/import/src/Nodes/CircleRmsNorm.cpp b/compiler/luci/import/src/Nodes/CircleRmsNorm.cpp index 28fef764a65..ad0dc601bcb 100644 --- a/compiler/luci/import/src/Nodes/CircleRmsNorm.cpp +++ b/compiler/luci/import/src/Nodes/CircleRmsNorm.cpp @@ -26,7 +26,7 @@ namespace luci bool CircleRmsNormGraphBuilder::validate(const ValidateArgs &args) const { // TODO check dtypes - return GraphBuilder::validate(args, 3); + return GraphBuilder::validate(args, 2); } CircleNode *CircleRmsNormGraphBuilder::build_node(const circle::OperatorT &op, @@ -36,7 +36,6 @@ CircleNode *CircleRmsNormGraphBuilder::build_node(const circle::OperatorT &op, auto *node = graph->nodes()->create(); node->input(inputs.at(0)); node->gamma(inputs.at(1)); - node->beta(inputs.at(2)); const auto *options = op.builtin_options.AsRmsNormOptions(); node->epsilon(options->epsilon); diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleRmsNorm.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleRmsNorm.h index 3395c13a4f6..d073ddec643 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleRmsNorm.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleRmsNorm.h @@ -28,7 +28,7 @@ namespace luci /** * @brief RMS_NORM in Circle */ -class CircleRmsNorm final : public FixedArityNode<3, CircleNodeImpl> +class CircleRmsNorm final : public FixedArityNode<2, CircleNodeImpl> { public: loco::Node *input(void) const { return at(0)->node(); } @@ -37,9 +37,6 @@ class CircleRmsNorm final : public FixedArityNode<3, CircleNodeImplnode(); } void gamma(loco::Node *node) { at(1)->node(node); } - loco::Node *beta(void) const { return at(2)->node(); } - void beta(loco::Node *node) { at(2)->node(node); } - public: float epsilon() const { return _epsilon; } void epsilon(float epsilon) { _epsilon = epsilon; } diff --git a/compiler/luci/lang/src/Nodes/CircleRmsNorm.test.cpp b/compiler/luci/lang/src/Nodes/CircleRmsNorm.test.cpp index 5c705b345e0..de61d6b535a 100644 --- a/compiler/luci/lang/src/Nodes/CircleRmsNorm.test.cpp +++ b/compiler/luci/lang/src/Nodes/CircleRmsNorm.test.cpp @@ -30,7 +30,6 @@ TEST(CircleRmsNormTest, constructor) ASSERT_EQ(nullptr, rms_norm.input()); ASSERT_EQ(nullptr, rms_norm.gamma()); - ASSERT_EQ(nullptr, rms_norm.beta()); ASSERT_FLOAT_EQ(rms_norm.epsilon(), 1e-06); } @@ -41,25 +40,21 @@ TEST(CircleRmsNormTest, input_NEG) rms_norm.input(&node); rms_norm.gamma(&node); - rms_norm.beta(&node); ASSERT_NE(nullptr, rms_norm.input()); ASSERT_NE(nullptr, rms_norm.gamma()); - ASSERT_NE(nullptr, rms_norm.beta()); rms_norm.input(nullptr); rms_norm.gamma(nullptr); - rms_norm.beta(nullptr); ASSERT_EQ(nullptr, rms_norm.input()); ASSERT_EQ(nullptr, rms_norm.gamma()); - ASSERT_EQ(nullptr, rms_norm.beta()); } TEST(CircleRmsNormTest, arity_NEG) { luci::CircleRmsNorm rms_norm; - ASSERT_NO_THROW(rms_norm.arg(2)); - ASSERT_THROW(rms_norm.arg(3), std::out_of_range); + ASSERT_NO_THROW(rms_norm.arg(1)); + ASSERT_THROW(rms_norm.arg(2), std::out_of_range); } TEST(CircleRmsNormTest, visit_mutable_NEG) diff --git a/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp b/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp index e30219690e8..f60fecc9f10 100644 --- a/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp +++ b/compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp @@ -905,7 +905,7 @@ std::vector CircleReverseV2SummaryBuilder::get_input_names(const lu std::vector CircleRmsNormSummaryBuilder::get_input_names(const luci::CircleNode *) { - return {"input", "gamma", "beta"}; + return {"input", "gamma"}; } void CircleRmsNormSummaryBuilder::build_attributes(const luci::CircleNode *node, diff --git a/compiler/luci/partition/src/Nodes/CircleRmsNorm.cpp b/compiler/luci/partition/src/Nodes/CircleRmsNorm.cpp index fa7f58af357..b086559ad40 100644 --- a/compiler/luci/partition/src/Nodes/CircleRmsNorm.cpp +++ b/compiler/luci/partition/src/Nodes/CircleRmsNorm.cpp @@ -25,11 +25,9 @@ void connect(luci::ConnectNode *cn, const luci::CircleRmsNorm *node) luci::CircleNode *input = loco::must_cast(node->input()); luci::CircleNode *gamma = loco::must_cast(node->gamma()); - luci::CircleNode *beta = loco::must_cast(node->beta()); cloned->input(cn->find_clone(input)); cloned->gamma(cn->find_clone(gamma)); - cloned->beta(cn->find_clone(beta)); } } // namespace diff --git a/compiler/luci/partition/src/Nodes/CircleRmsNorm.test.cpp b/compiler/luci/partition/src/Nodes/CircleRmsNorm.test.cpp index 625e66c2a14..d48edc0bc63 100644 --- a/compiler/luci/partition/src/Nodes/CircleRmsNorm.test.cpp +++ b/compiler/luci/partition/src/Nodes/CircleRmsNorm.test.cpp @@ -36,7 +36,7 @@ class NodeGraphlet : public NodeGraphletT void init(loco::Graph *g) override { NodeGraphletT::init(g); } }; -class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet { public: TestNodeGraph() = default; @@ -44,12 +44,11 @@ class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet public: void init(const ShapeU32 shape) { - TestIsOGraph<3>::init({shape, shape, shape}, shape); + TestIsOGraph<2>::init({shape, shape}, shape); NodeGraphlet::init(g()); node()->input(input(0)); node()->gamma(input(1)); - node()->beta(input(2)); output()->from(node()); } @@ -73,10 +72,9 @@ TEST(ConnectNodeTest, connect_RmsNorm) cth.clone_connect(node, clone); - ASSERT_EQ(3, clone->arity()); + ASSERT_EQ(2, clone->arity()); ASSERT_EQ(cth.inputs(0), clone->arg(0)); ASSERT_EQ(cth.inputs(1), clone->arg(1)); - ASSERT_EQ(cth.inputs(2), clone->arg(2)); } TEST(ConnectNodeTest, connect_RmsNorm_NEG)