Skip to content

Commit

Permalink
[luci] Removed beta(bias) from RmsNorm (#14207)
Browse files Browse the repository at this point in the history
This commit removes beta(bias) from RmsNorm in luci.

ONE-DCO-1.0-Signed-off-by: Seockho Kim [email protected]
  • Loading branch information
seockho-kim authored Oct 14, 2024
1 parent 7fd5a43 commit ccf3da1
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 21 deletions.
3 changes: 1 addition & 2 deletions compiler/luci/import/src/Nodes/CircleRmsNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace luci
bool CircleRmsNormGraphBuilder::validate(const ValidateArgs &args) const
{
// TODO check dtypes
return GraphBuilder::validate(args, 3);
return GraphBuilder::validate(args, 2);
}

CircleNode *CircleRmsNormGraphBuilder::build_node(const circle::OperatorT &op,
Expand All @@ -36,7 +36,6 @@ CircleNode *CircleRmsNormGraphBuilder::build_node(const circle::OperatorT &op,
auto *node = graph->nodes()->create<CircleRmsNorm>();
node->input(inputs.at(0));
node->gamma(inputs.at(1));
node->beta(inputs.at(2));

const auto *options = op.builtin_options.AsRmsNormOptions();
node->epsilon(options->epsilon);
Expand Down
5 changes: 1 addition & 4 deletions compiler/luci/lang/include/luci/IR/Nodes/CircleRmsNorm.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace luci
/**
* @brief RMS_NORM in Circle
*/
class CircleRmsNorm final : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::RMS_NORM>>
class CircleRmsNorm final : public FixedArityNode<2, CircleNodeImpl<CircleOpcode::RMS_NORM>>
{
public:
loco::Node *input(void) const { return at(0)->node(); }
Expand All @@ -37,9 +37,6 @@ class CircleRmsNorm final : public FixedArityNode<3, CircleNodeImpl<CircleOpcode
loco::Node *gamma(void) const { return at(1)->node(); }
void gamma(loco::Node *node) { at(1)->node(node); }

loco::Node *beta(void) const { return at(2)->node(); }
void beta(loco::Node *node) { at(2)->node(node); }

public:
float epsilon() const { return _epsilon; }
void epsilon(float epsilon) { _epsilon = epsilon; }
Expand Down
9 changes: 2 additions & 7 deletions compiler/luci/lang/src/Nodes/CircleRmsNorm.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ TEST(CircleRmsNormTest, constructor)

ASSERT_EQ(nullptr, rms_norm.input());
ASSERT_EQ(nullptr, rms_norm.gamma());
ASSERT_EQ(nullptr, rms_norm.beta());
ASSERT_FLOAT_EQ(rms_norm.epsilon(), 1e-06);
}

Expand All @@ -41,25 +40,21 @@ TEST(CircleRmsNormTest, input_NEG)

rms_norm.input(&node);
rms_norm.gamma(&node);
rms_norm.beta(&node);
ASSERT_NE(nullptr, rms_norm.input());
ASSERT_NE(nullptr, rms_norm.gamma());
ASSERT_NE(nullptr, rms_norm.beta());

rms_norm.input(nullptr);
rms_norm.gamma(nullptr);
rms_norm.beta(nullptr);
ASSERT_EQ(nullptr, rms_norm.input());
ASSERT_EQ(nullptr, rms_norm.gamma());
ASSERT_EQ(nullptr, rms_norm.beta());
}

TEST(CircleRmsNormTest, arity_NEG)
{
luci::CircleRmsNorm rms_norm;

ASSERT_NO_THROW(rms_norm.arg(2));
ASSERT_THROW(rms_norm.arg(3), std::out_of_range);
ASSERT_NO_THROW(rms_norm.arg(1));
ASSERT_THROW(rms_norm.arg(2), std::out_of_range);
}

TEST(CircleRmsNormTest, visit_mutable_NEG)
Expand Down
2 changes: 1 addition & 1 deletion compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,7 @@ std::vector<std::string> CircleReverseV2SummaryBuilder::get_input_names(const lu

std::vector<std::string> CircleRmsNormSummaryBuilder::get_input_names(const luci::CircleNode *)
{
return {"input", "gamma", "beta"};
return {"input", "gamma"};
}

void CircleRmsNormSummaryBuilder::build_attributes(const luci::CircleNode *node,
Expand Down
2 changes: 0 additions & 2 deletions compiler/luci/partition/src/Nodes/CircleRmsNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@ void connect(luci::ConnectNode *cn, const luci::CircleRmsNorm *node)

luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input());
luci::CircleNode *gamma = loco::must_cast<luci::CircleNode *>(node->gamma());
luci::CircleNode *beta = loco::must_cast<luci::CircleNode *>(node->beta());

cloned->input(cn->find_clone(input));
cloned->gamma(cn->find_clone(gamma));
cloned->beta(cn->find_clone(beta));
}

} // namespace
Expand Down
8 changes: 3 additions & 5 deletions compiler/luci/partition/src/Nodes/CircleRmsNorm.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,19 @@ class NodeGraphlet : public NodeGraphletT<luci::CircleRmsNorm>
void init(loco::Graph *g) override { NodeGraphletT<luci::CircleRmsNorm>::init(g); }
};

class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet
class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet
{
public:
TestNodeGraph() = default;

public:
void init(const ShapeU32 shape)
{
TestIsOGraph<3>::init({shape, shape, shape}, shape);
TestIsOGraph<2>::init({shape, shape}, shape);
NodeGraphlet::init(g());

node()->input(input(0));
node()->gamma(input(1));
node()->beta(input(2));

output()->from(node());
}
Expand All @@ -73,10 +72,9 @@ TEST(ConnectNodeTest, connect_RmsNorm)

cth.clone_connect(node, clone);

ASSERT_EQ(3, clone->arity());
ASSERT_EQ(2, clone->arity());
ASSERT_EQ(cth.inputs(0), clone->arg(0));
ASSERT_EQ(cth.inputs(1), clone->arg(1));
ASSERT_EQ(cth.inputs(2), clone->arg(2));
}

TEST(ConnectNodeTest, connect_RmsNorm_NEG)
Expand Down

0 comments on commit ccf3da1

Please sign in to comment.