Skip to content

Commit

Permalink
Test and fix optimizers LayerNormFusion, BiasSoftmaxFusion, Transpose…
Browse files Browse the repository at this point in the history
… for opset 18 (microsoft#14542)

### Description

Due to the changes introduced in opset 18 on Reduce operators (axes is
an input and not an attribute), the following optimizers are not
catching the pattern they are supposed to optimize. This PR addresses
that.

* layer_norm_fusion.cc: the optimizer was not detecting the pattern it
was suppose to optimize
* bias_softmax_fusion.cc: the optimizer was not detecting the pattern it
was suppose to optimize
* transpose_optimizer.cc: the optimizer was not optimize Reduce
operators other than ReduceSum

### Motivation and Context
Better performance.

---------

Signed-off-by: xadupre <[email protected]>
  • Loading branch information
xadupre authored Feb 8, 2023
1 parent cfda876 commit 30ec8b0
Show file tree
Hide file tree
Showing 10 changed files with 605 additions and 228 deletions.
1 change: 1 addition & 0 deletions onnxruntime/core/optimizer/bias_softmax_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ bool TrySelectInputAndBiasWithAlignment(Node& add_node, Node& softmax_node, Node
new_axis = (int)HandleNegativeAxis(axis, rank);

// The axis attribute for Softmax in OpSet-11 and OpSet-13 are different.
// Details in function documentatin.
if (is_since_opset_13 && new_axis != rank - 1) return false;

int singlebatch_rank = rank - new_axis;
Expand Down
35 changes: 28 additions & 7 deletions onnxruntime/core/optimizer/layer_norm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "core/optimizer/layer_norm_fusion.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/utils.h"
#include "core/optimizer/transpose_optimizer/optimizer_api.h"
#include "float.h"
#include <deque>

Expand All @@ -16,12 +17,17 @@ static constexpr std::array<std::string_view, 3> supported_data_types{"tensor(fl
// Default epsilon
static constexpr float DEFAULT_LAYERNORM_EPSILON = 1e-5f;

static bool IsSupportedDataType(const Node& node) {
static bool IsSupportedDataType(const Node& node, int first_n_inputs=-1) {
int input_index = 0;
for (const auto& input_arg : node.InputDefs()) {
if (first_n_inputs != -1 && input_index >= first_n_inputs) {
return true;
}
if (std::find(supported_data_types.begin(), supported_data_types.end(),
*(input_arg->Type())) == supported_data_types.end()) {
return false;
}
++input_index;
}
return true;
}
Expand Down Expand Up @@ -99,11 +105,11 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
Node& reduce_mean_node = *p_reduce_mean;
ORT_RETURN_IF_ERROR(Recurse(reduce_mean_node, modified, graph_level, logger));

if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13, 18}) ||
!graph_utils::IsSupportedProvider(reduce_mean_node, GetCompatibleExecutionProviders()) ||
(reduce_mean_node.GetOutputEdgesCount() != 1 && reduce_mean_node.GetOutputEdgesCount() != 2) ||
graph.NodeProducesGraphOutput(reduce_mean_node) ||
!IsSupportedDataType(reduce_mean_node)) {
!IsSupportedDataType(reduce_mean_node, 1)) {
continue;
}
nodes_to_remove.push_back(reduce_mean_node);
Expand Down Expand Up @@ -263,10 +269,10 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}
Node& reduce_mean2_node = *graph.GetNode(p_reduce_mean2->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1, 11, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1, 11, 13, 18}) ||
reduce_mean2_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, reduce_mean2_node, 1) ||
!IsSupportedDataType(reduce_mean2_node) ||
!IsSupportedDataType(reduce_mean2_node, 1) ||
reduce_mean2_node.GetInputEdgesCount() == 0) {
continue;
}
Expand Down Expand Up @@ -333,8 +339,16 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// get axes attributes
const onnxruntime::NodeAttributes& attributes = reduce_mean_node.GetAttributes();
std::vector<int64_t> axes_values;
// TODO: modify this codes when opset >= 18 (axes is an input).
if (attributes.find("axes") != attributes.end()) {
axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
} else if (reduce_mean_node.InputDefs().size() == 2) {
auto axes = reduce_mean_node.InputDefs()[1];
auto axes_const = graph.GetConstantInitializer(axes->Name(), true);
if (axes_const != nullptr) {
Initializer initializer{*axes_const, graph.ModelPath()};
axes_values.insert(axes_values.end(), initializer.DataAsSpan<int64_t>().begin(), initializer.DataAsSpan<int64_t>().end());
}
}

// Get the inputs for the new LayerNormalization node.
Expand Down Expand Up @@ -485,9 +499,9 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
continue;
}
Node& reduce_mean_node = *graph.GetNode(p_reduce_mean->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13, 18}) ||
reduce_mean_node.GetExecutionProviderType() != pow_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, reduce_mean_node, 1) || !IsSupportedDataType(reduce_mean_node) ||
!optimizer_utils::CheckOutputEdges(graph, reduce_mean_node, 1) || !IsSupportedDataType(reduce_mean_node, 1) ||
reduce_mean_node.GetInputEdgesCount() == 0) {
continue;
}
Expand Down Expand Up @@ -585,6 +599,13 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
std::vector<int64_t> axes_values;
if (attributes.find("axes") != attributes.end()) {
axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
} else if (reduce_mean_node.InputDefs().size() == 2) {
auto axes = reduce_mean_node.InputDefs()[1];
auto axes_const = graph.GetConstantInitializer(axes->Name(), true);
if (axes_const != nullptr && axes_const->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
Initializer initializer{*axes_const, graph.ModelPath()};
axes_values.insert(axes_values.end(), initializer.DataAsSpan<int64_t>().begin(), initializer.DataAsSpan<int64_t>().end());
}
}

// Get the inputs for the new LayerNormalization node.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,7 @@ static bool HandlePad(HandlerArgs& args) {

constexpr HandlerInfo pad_handler = {&FirstInput, &HandlePad};

static bool HandleReduceOp(HandlerArgs& args) {
static bool HandleReduceOpWithArg(HandlerArgs& args) {
int64_t keepdims = args.node.GetAttributeIntDefault("keepdims", 1);

std::optional<std::vector<int64_t>> axes = args.node.GetAttributeInts("axes");
Expand Down Expand Up @@ -1078,11 +1078,11 @@ static bool HandleReduceOp(HandlerArgs& args) {
return true;
}

constexpr HandlerInfo reduce_op_handler = {&FirstInput, &HandleReduceOp};

static bool HandleReduceSum(HandlerArgs& args) {
if (args.ctx.opset < 13) {
return HandleReduceOp(args);
static bool HandleReduceOps(HandlerArgs& args) {
if ((args.node.OpType() == "ReduceSum" && args.ctx.opset < 13) ||
// or all other reduce operators since opset 18
(args.node.OpType() != "ReduceSum" && args.ctx.opset < 18)) {
return HandleReduceOpWithArg(args);
}

bool keepdims = args.node.GetAttributeIntDefault("keepdims", 1) != 0;
Expand Down Expand Up @@ -1147,7 +1147,7 @@ static bool HandleReduceSum(HandlerArgs& args) {
return true;
}

constexpr HandlerInfo reduce_sum_handler = {&FirstInput, &HandleReduceSum};
constexpr HandlerInfo reduce_op_handler = {&FirstInput, &HandleReduceOps};

static bool HandleSqueeze(HandlerArgs& args) {
std::vector<int64_t> new_axes;
Expand Down Expand Up @@ -1709,7 +1709,7 @@ static const std::unordered_map<std::string_view, const HandlerInfo&> handler_ma
#if !defined(USE_CUDA) && !defined(USE_ROCM)
{"Resize", resize_handler},
#endif
{"ReduceSum", reduce_sum_handler},
{"ReduceSum", reduce_op_handler},

{"ReduceLogSum", reduce_op_handler},
{"ReduceLogSumExp", reduce_op_handler},
Expand Down
31 changes: 21 additions & 10 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ namespace onnxruntime {
namespace test {

#define MODEL_FOLDER ORT_TSTR("testdata/transform/")

TEST_F(GraphTransformationTests, IdentityElimination) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "abs-id-max.onnx";
std::shared_ptr<Model> model;
Expand Down Expand Up @@ -4390,11 +4389,12 @@ TEST_F(GraphTransformationTests, ReshapeFusionOpsetTest) {
return Status::OK();
};

const std::vector<int> opsets{11, 12, 13, 14, 15, 15};
const std::vector<int> opsets{11, 12, 13, 14, 15, 18};
bool shape_test_for_opset15 = false;

for (auto& opset_version : opsets) {
for (auto& opset : opsets) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto opset_version = builder.DomainToVersionMap().find(kOnnxDomain)->second;
auto* input_arg0 = builder.MakeInput<float>({{batch_size, seq_lenth, hidden_size}});
auto* input_arg1 = builder.MakeInput<float>({{hidden_size}});
auto* scalar_int_0 = builder.MakeInitializer<int64_t>({}, {0});
Expand All @@ -4414,7 +4414,7 @@ TEST_F(GraphTransformationTests, ReshapeFusionOpsetTest) {
auto* out = builder.MakeOutput();

builder.AddNode("Add", {input_arg0, input_arg1}, {add_out});
if (opset_version == 15) {
if (opset_version >= 15) {
if (shape_test_for_opset15) {
auto& shape_1 = builder.AddNode("Shape", {add_out}, {shape_out});
shape_1.AddAttribute("start", (int64_t)1);
Expand Down Expand Up @@ -4442,11 +4442,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionOpsetTest) {
};

std::unique_ptr<GraphTransformer> transformer = std::make_unique<ReshapeFusion>();
if (opset_version == 15 && shape_test_for_opset15) {
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
if (opset >= 15 && shape_test_for_opset15) {
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, pre_graph_checker));
} else {
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
}
}
Expand Down Expand Up @@ -4610,13 +4610,24 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_5) {
auto* cast_out_2 = builder.MakeIntermediate();
auto* mul_out = builder.MakeIntermediate();
auto* add_out_2 = builder.MakeOutput();
auto opset = builder.DomainToVersionMap().find(kOnnxDomain)->second;
onnxruntime::NodeArg* axes = nullptr;

builder.AddNode("ReduceMean", {data_arg}, {reduce_mean_out_1}).AddAttribute("axes", std::vector<int64_t>{-1});
if (opset >= 18) {
axes = builder.MakeInitializer<int64_t>({1}, {-1});
builder.AddNode("ReduceMean", {data_arg, axes}, {reduce_mean_out_1});
} else {
builder.AddNode("ReduceMean", {data_arg}, {reduce_mean_out_1}).AddAttribute("axes", std::vector<int64_t>{-1});
}
builder.AddNode("Sub", {data_arg, reduce_mean_out_1}, {sub_out});
builder.AddNode("Cast", {sub_out}, {cast_out_1})
.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT));
builder.AddNode("Pow", {cast_out_1, pow_initializer}, {pow_out});
builder.AddNode("ReduceMean", {pow_out}, {reduce_mean_out_2}).AddAttribute("axes", std::vector<int64_t>{-1});
if (opset >= 18) {
builder.AddNode("ReduceMean", {pow_out, axes}, {reduce_mean_out_2});
} else {
builder.AddNode("ReduceMean", {pow_out}, {reduce_mean_out_2}).AddAttribute("axes", std::vector<int64_t>{-1});
}
builder.AddNode("Add", {reduce_mean_out_2, add_initializer}, {add_out_1});
builder.AddNode("Sqrt", {add_out_1}, {sqrt_out});
builder.AddNode("Div", {cast_out_1, sqrt_out}, {div_out});
Expand Down Expand Up @@ -4652,7 +4663,7 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_5) {
};

std::unique_ptr<GraphTransformer> transformer = std::make_unique<LayerNormFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1,
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1,
1, pre_graph_checker, post_graph_checker));
}

Expand Down
67 changes: 53 additions & 14 deletions onnxruntime/test/optimizer/graph_transform_test_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,31 @@
namespace onnxruntime {
namespace test {

void TransformerTester(const std::function<void(ModelTestBuilder& helper)>& build_test_case,
const std::function<void(InferenceSessionWrapper& session)>& check_transformed_graph,
TransformerLevel baseline_level,
TransformerLevel target_level,
const std::vector<int64_t>& opset_versions,
double per_sample_tolerance,
double relative_per_sample_tolerance,
std::unique_ptr<GraphTransformer> transformer,
const std::function<void(SessionOptions&)>& add_session_options,
const InlinedHashSet<std::string>& disabled_optimizers) {
ASSERT_TRUE(transformer == nullptr);
for (auto opset_version : opset_versions) {
TransformerTester(build_test_case,
check_transformed_graph,
baseline_level,
target_level,
opset_version,
per_sample_tolerance,
relative_per_sample_tolerance,
nullptr,
add_session_options,
disabled_optimizers);
}
}

void TransformerTester(const std::function<void(ModelTestBuilder& helper)>& build_test_case,
const std::function<void(InferenceSessionWrapper& session)>& check_transformed_graph,
TransformerLevel baseline_level,
Expand Down Expand Up @@ -101,22 +126,36 @@ Status TestGraphTransformer(const std::function<void(ModelTestBuilder& helper)>&
const logging::Logger& logger, std::unique_ptr<GraphTransformer> transformer,
TransformerLevel level, unsigned steps, const std::function<Status(Graph&)>& pre_graph_checker,
const std::function<Status(Graph&)>& post_graph_checker) {
// Build the model for this test.
std::unordered_map<std::string, int> domain_to_version;
domain_to_version[kOnnxDomain] = opset_version;
domain_to_version[kMSDomain] = 1;
Model model("TransformerTester", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version, {}, logger);
Graph& graph = model.MainGraph();
ModelTestBuilder helper(graph);
build_test_case(helper);
helper.SetGraphOutputs();
ORT_RETURN_IF_ERROR(graph.Resolve());
ORT_RETURN_IF_ERROR(pre_graph_checker(graph));
const std::vector<int64_t> opset_versions{opset_version};
return TestGraphTransformer(build_test_case, opset_versions, logger, std::move(transformer),
level, steps, pre_graph_checker, post_graph_checker);
}

Status TestGraphTransformer(const std::function<void(ModelTestBuilder& helper)>& build_test_case,
const std::vector<int64_t>& opset_versions,
const logging::Logger& logger, std::unique_ptr<GraphTransformer> transformer,
TransformerLevel level, unsigned steps, const std::function<Status(Graph&)>& pre_graph_checker,
const std::function<Status(Graph&)>& post_graph_checker) {
onnxruntime::GraphTransformerManager graph_transformation_mgr{steps};
ORT_RETURN_IF_ERROR(graph_transformation_mgr.Register(std::move(transformer), level));
ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, level, logger));
ORT_RETURN_IF_ERROR(post_graph_checker(graph));

for (auto opset : opset_versions) {
// Build the model for this test.
std::unordered_map<std::string, int> domain_to_version;
domain_to_version[kOnnxDomain] = opset;
domain_to_version[kMSDomain] = 1;
Model model("TransformerTester", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version, {}, logger);
Graph& graph = model.MainGraph();
ModelTestBuilder helper(graph);
build_test_case(helper);
helper.SetGraphOutputs();
ORT_RETURN_IF_ERROR(graph.Resolve());
ORT_RETURN_IF_ERROR(pre_graph_checker(graph));
ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, level, logger));
ORT_RETURN_IF_ERROR(post_graph_checker(graph));
}

return Status::OK();
}

Expand Down
33 changes: 33 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ class ModelTestBuilder {
ModelTestBuilder(Graph& graph) : graph_(graph) {
}

const std::unordered_map<std::string, int>& DomainToVersionMap() const noexcept {
return graph_.DomainToVersionMap();
}

template <typename T>
NodeArg* MakeInput(const std::vector<int64_t>& shape, const std::vector<T>& data) {
ONNX_NAMESPACE::TypeProto type_proto;
Expand Down Expand Up @@ -356,6 +360,17 @@ void TransformerTester(const std::function<void(ModelTestBuilder& helper)>& buil
const std::function<void(SessionOptions&)>& add_session_options = {},
const InlinedHashSet<std::string>& disabled_optimizers = {});

void TransformerTester(const std::function<void(ModelTestBuilder& helper)>& build_test_case,
const std::function<void(InferenceSessionWrapper& session)>& check_transformed_graph,
TransformerLevel baseline_level,
TransformerLevel target_level,
const std::vector<int64_t>& opset_versions,
double per_sample_tolerance = 0.0,
double relative_per_sample_tolerance = 0.0,
std::unique_ptr<GraphTransformer> transformer = nullptr, // must be null in this case.
const std::function<void(SessionOptions&)>& add_session_options = {},
const InlinedHashSet<std::string>& disabled_optimizers = {});

/**
* @brief Apply a GraphTransformer to a graph, and run graph checkers before and after applying the transformer.
*
Expand All @@ -372,5 +387,23 @@ Status TestGraphTransformer(const std::function<void(ModelTestBuilder& helper)>&
const logging::Logger& logger, std::unique_ptr<GraphTransformer> transformer,
TransformerLevel level, unsigned steps, const std::function<Status(Graph&)>& pre_graph_checker,
const std::function<Status(Graph&)>& post_graph_checker);

/**
* @brief Apply a GraphTransformer to a graph, and run graph checkers before and after applying the transformer.
*
* @param build_test_case The function to build a graph for testing
* @param opset_versions A graph is created and tested for every opset in this set
* @param logger The logger
* @param transformer The GraphTransformer to be applied
* @param level The transformer level on which the transformer will be applied
* @param steps The step count of the GraphTransformerManager
* @param pre_graph_checker The graph checker function before applying the transformer
* @param post_graph_checker The graph checker function after applying the transformer
*/
Status TestGraphTransformer(const std::function<void(ModelTestBuilder& helper)>& build_test_case,
const std::vector<int64_t>& opset_versions,
const logging::Logger& logger, std::unique_ptr<GraphTransformer> transformer,
TransformerLevel level, unsigned steps, const std::function<Status(Graph&)>& pre_graph_checker,
const std::function<Status(Graph&)>& post_graph_checker);
} // namespace test
} // namespace onnxruntime
Loading

0 comments on commit 30ec8b0

Please sign in to comment.