From 7d134117fae16720b82912e4e55b08c64de87514 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Sat, 18 Nov 2023 13:41:19 -0500 Subject: [PATCH] MLIR Reduce Op (#1339) Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com> --- exla/c_src/exla/exla.cc | 3 +- exla/c_src/exla/mlir/builder.cc | 32 ++++++++++- exla/c_src/exla/mlir/builder.h | 5 +- exla/c_src/exla/mlir/ops.cc | 52 +++++++++++++++++- exla/c_src/exla/mlir/ops.h | 1 + exla/lib/exla/builder.ex | 26 ++++++--- exla/lib/exla/defn.ex | 71 ++++++++++++++++++++++--- exla/lib/exla/mlir/value.ex | 23 ++++++++ exla/lib/exla/nif.ex | 5 +- exla/test/exla/mlir/executable_test.exs | 35 ++++++++++++ 10 files changed, 230 insertions(+), 23 deletions(-) diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index 1fb1a5b0cf..fa903df80b 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -648,7 +648,7 @@ static ErlNifFunc exla_funcs[] = { {"mlir_less_equal", 3, mlir_less_equal}, {"mlir_greater", 3, mlir_greater}, {"mlir_greater_equal", 3, mlir_greater_equal}, - {"mlir_build", 2, mlir_build}, + {"mlir_build", 3, mlir_build}, {"dump_mlir_module", 1, dump_mlir_module}, {"mlir_get_shape", 1, mlir_get_shape}, {"mlir_convert", 3, mlir_convert}, @@ -719,6 +719,7 @@ static ErlNifFunc exla_funcs[] = { {"mlir_create_token", 1, mlir_create_token}, {"mlir_triangular_solve", 6, mlir_triangular_solve}, {"mlir_dynamic_update_slice", 4, mlir_dynamic_update_slice}, + {"mlir_reduce", 5, mlir_reduce}, // XlaBuilder {"new_builder", 1, new_builder}, {"create_sub_builder", 2, create_sub_builder}, diff --git a/exla/c_src/exla/mlir/builder.cc b/exla/c_src/exla/mlir/builder.cc index aa7d3b0ba6..e97ad2511c 100644 --- a/exla/c_src/exla/mlir/builder.cc +++ b/exla/c_src/exla/mlir/builder.cc @@ -762,6 +762,28 @@ mlir::Value MLIRFunction::ScatterOp(mlir::Value target, mlir::Value indices, mli return scatter_op.getResult(0); } +std::vector MLIRFunction::ReduceOp( + MLIRFunction * reducer, + std::vector init_values, + std::vector inputs, + std::vector dimensions +) { + auto builder = module_->builder(); + builder->setInsertionPointToEnd(&func_->getBody().back()); + + mlir::ValueRange init_values_range(init_values); + mlir::ValueRange inputs_range(inputs); + mlir::DenseIntElementsAttr dimensions_attr = Int64ToDenseIntElementsAttr(builder, dimensions); + + mlir::mhlo::ReduceOp reduce_op = builder->create(builder->getUnknownLoc(), inputs_range, init_values_range, dimensions_attr); + mlir::Region &reduceBody = reduce_op.getRegion(); + mlir::Region &funcBody = reducer->function()->getBody(); + reduceBody.getBlocks().splice(reduceBody.end(), funcBody.getBlocks()); + + mlir::Operation::result_range results = reduce_op.getResults(); + return std::vector(results.begin(), results.end()); +} + mlir::Value MLIRFunction::SelectAndScatterOp( mlir::Value target, mlir::Value source, @@ -918,9 +940,15 @@ ERL_NIF_TERM MLIRFunction::ConstantOp(mlir::Type type, ErlNifEnv *env, ERL_NIF_T return exla::nif::error(env, "invalid type received"); } -void MLIRFunction::Build(mlir::Value root) { +void MLIRFunction::Build(mlir::Value root, bool use_mhlo_return) { module_->builder()->setInsertionPointToEnd(&func_->getBody().back()); - auto op = module_->builder()->create(module_->builder()->getUnknownLoc(), root); + + if (use_mhlo_return) { + module_->builder()->create(module_->builder()->getUnknownLoc(), root); + } else { + module_->builder()->create(module_->builder()->getUnknownLoc(), root); + } + return; } diff --git a/exla/c_src/exla/mlir/builder.h b/exla/c_src/exla/mlir/builder.h index 33ab2e170a..0a82488d70 100644 --- a/exla/c_src/exla/mlir/builder.h +++ b/exla/c_src/exla/mlir/builder.h @@ -104,13 +104,16 @@ class MLIRFunction { mlir::Value CreateTokenOp(); mlir::Value TriangularSolveOp(mlir::Value a, mlir::Value b, bool left_side, bool lower, bool transpose_a); mlir::Value DynamicUpdateSliceOp(mlir::Value operand, mlir::Value update, std::vector start_indices); + std::vector ReduceOp(MLIRFunction * function, std::vector init_values, std::vector inputs, std::vector dimensions); ERL_NIF_TERM ConstantOp(mlir::Type type, ErlNifEnv *env, ERL_NIF_TERM value_ptr, std::vector dims = {}); int get_mlir_type(ErlNifEnv *env, ERL_NIF_TERM term, mlir::Type *type); - void Build(mlir::Value root); + void Build(mlir::Value root, bool use_mhlo_return); llvm::MutableArrayRef get_arguments() { return func_->getBody().front().getArguments(); } + mlir::func::FuncOp * function() { return func_.get(); } + private: std::shared_ptr module_; std::unique_ptr func_; diff --git a/exla/c_src/exla/mlir/ops.cc b/exla/c_src/exla/mlir/ops.cc index 0a8a173943..8e963cace1 100644 --- a/exla/c_src/exla/mlir/ops.cc +++ b/exla/c_src/exla/mlir/ops.cc @@ -666,12 +666,13 @@ ERL_NIF_TERM mlir_select(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { } ERL_NIF_TERM mlir_build(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 2) { + if (argc != 3) { return exla::nif::error(env, "Bad argument count."); } exla::MLIRFunction** function; mlir::Value* root; + bool use_mhlo_return; if (!exla::nif::get(env, argv[0], function)) { return exla::nif::error(env, "Unable to get function."); @@ -679,8 +680,11 @@ ERL_NIF_TERM mlir_build(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { if (!exla::nif::get(env, argv[1], root)) { return exla::nif::error(env, "Unable to get root."); } + if (!exla::nif::get(env, argv[2], &use_mhlo_return)) { + return exla::nif::error(env, "Unable to get return"); + } - (*function)->Build(*root); + (*function)->Build(*root, use_mhlo_return); return exla::nif::ok(env); } @@ -751,6 +755,50 @@ ERL_NIF_TERM mlir_sort(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { return exla::nif::ok(env, list); } +ERL_NIF_TERM mlir_reduce(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { + if (argc != 5) { + return exla::nif::error(env, "Bad argument count."); + } + + exla::MLIRFunction** function; + exla::MLIRFunction** reducer; + std::vector init_values; + std::vector inputs; + std::vector dimensions; + + if (!exla::nif::get(env, argv[0], function)) { + return exla::nif::error(env, "Unable to get function."); + } + if (!exla::nif::get(env, argv[1], reducer)) { + return exla::nif::error(env, "Unable to get reducer."); + } + if (!exla::nif::get_list(env, argv[2], init_values)) { + return exla::nif::error(env, "Unable to get init_values."); + } + if (!exla::nif::get_list(env, argv[3], inputs)) { + return exla::nif::error(env, "Unable to get inputs."); + } + if (!exla::nif::get_tuple(env, argv[4], dimensions)) { + return exla::nif::error(env, "Unable to get dimensions."); + } + + std::vector res = (*function)->ReduceOp(*reducer, init_values, inputs, dimensions); + + size_t n = res.size(); + + std::vector nif_terms; + nif_terms.reserve(n); + + for (size_t i = 0; i < n; i++) { + nif_terms[i] = exla::nif::make(env, res[i]); + } + + auto data = nif_terms.data(); + auto list = enif_make_list_from_array(env, &data[0], n); + + return exla::nif::ok(env, list); +} + ERL_NIF_TERM mlir_bitcast_convert(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { if (argc != 4) { return exla::nif::error(env, "Bad argument count."); diff --git a/exla/c_src/exla/mlir/ops.h b/exla/c_src/exla/mlir/ops.h index 8cab3fc95c..6b146487ae 100644 --- a/exla/c_src/exla/mlir/ops.h +++ b/exla/c_src/exla/mlir/ops.h @@ -100,3 +100,4 @@ ERL_NIF_TERM mlir_fft(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); ERL_NIF_TERM mlir_create_token(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); ERL_NIF_TERM mlir_triangular_solve(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); ERL_NIF_TERM mlir_dynamic_update_slice(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); +ERL_NIF_TERM mlir_reduce(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); diff --git a/exla/lib/exla/builder.ex b/exla/lib/exla/builder.ex index 162a6ed663..98f6627b53 100644 --- a/exla/lib/exla/builder.ex +++ b/exla/lib/exla/builder.ex @@ -12,16 +12,22 @@ defmodule EXLA.Builder do @enforce_keys [:ref] defstruct [:ref, :parent, :name] - def new(name, _inputs, _outputs, :xla) do + def new(name, inputs, outputs, type, sub? \\ false) + + def new(name, _inputs, _outputs, :xla, _sub?) do new(name) end - def new(_name, inputs, outputs, :mlir) do + def new(_name, inputs, outputs, :mlir, sub?) do # TO-DO (mlir): check if using the function name makes sense arg_shapes = Enum.map(inputs, fn {_, %Shape{} = s} -> s end) return_shape = - [outputs] |> Nx.Defn.Composite.flatten_list() |> List.to_tuple() |> exla_shape() + if sub? do + exla_shape(outputs) + else + [outputs] |> Nx.Defn.Composite.flatten_list() |> List.to_tuple() |> exla_shape() + end module = M.new() M.create_function(module, "main", arg_shapes, return_shape) @@ -34,8 +40,8 @@ defmodule EXLA.Builder do |> EXLA.Shape.make_tuple_shape() end - defp exla_shape(%Nx.Tensor{} = t) do - EXLA.Shape.make_shape(t.type, t.shape) + defp exla_shape(%{shape: shape, type: type}) do + EXLA.Shape.make_shape(type, shape) end defp new(name) when is_binary(name) do @@ -48,15 +54,19 @@ defmodule EXLA.Builder do %__MODULE__{ref: ref, parent: builder, name: name} end - def build(%Op{} = root) do + def build(root, use_mhlo_return? \\ false) + + def build(%Op{} = root, _) do shape = EXLA.Op.get_shape(root) {:ok, ref} = EXLA.NIF.build(root.builder, root.ref) %Computation{ref: ref, output_shape: shape} end - def build(%EXLA.MLIR.Value{function: function, ref: root_ref}) do + def build(%EXLA.MLIR.Value{function: function, ref: root_ref}, use_mhlo_return?) do %EXLA.MLIR.Function{ref: function_ref} = function - :ok = EXLA.NIF.mlir_build(function_ref, root_ref) + return_int = if use_mhlo_return?, do: 1, else: 0 + + :ok = EXLA.NIF.mlir_build(function_ref, root_ref, return_int) function end end diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index de59934faa..2abb6aa414 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -1163,6 +1163,23 @@ defmodule EXLA.Defn do to_aggregate(:min, type, shape, arg, max_number, opts, state) end + defp to_operator( + :reduce, + [%Value{} = arg, %Value{} = acc, opts, fun], + %{type: type, shape: shape}, + _state + ) do + arg = to_type(arg, type) + keep_axes = opts[:keep_axes] + [result] = Value.reduce(fun, [to_type(acc, type)], [arg], reduce_axes(arg, opts[:axes])) + + if keep_axes do + Value.reshape(result, shape) + else + result + end + end + defp to_operator(:reduce, [arg, acc, opts, fun], %{type: type, shape: shape}, _state) do arg = to_type(arg, type) keep_axes = opts[:keep_axes] @@ -1253,8 +1270,8 @@ defmodule EXLA.Defn do init_value = to_type(init_value, type) args = [%{type: type, shape: {}}, %{type: type, shape: {}}] - select_fn = op_computation(:greater, args, state) - scatter_fn = op_computation(:add, args, state) + select_fn = op_computation(:greater, args, :unused, state) + scatter_fn = op_computation(:add, args, :unused, state) EXLA.Op.select_and_scatter( arg, @@ -1307,8 +1324,8 @@ defmodule EXLA.Defn do args = [%{type: type, shape: {}}, %{type: type, shape: {}}] - select_fn = op_computation(:less, args, state) - scatter_fn = op_computation(:add, args, state) + select_fn = op_computation(:less, args, :unused, state) + scatter_fn = op_computation(:add, args, :unused, state) EXLA.Op.select_and_scatter( arg, @@ -1333,7 +1350,7 @@ defmodule EXLA.Defn do state ) do args = [%{type: type, shape: {}}, %{type: type, shape: {}}] - scatter_fn = op_computation(:add, args, state) + scatter_fn = op_computation(:add, args, :unused, state) scatter(scatter_fn, tensors, out) end @@ -1575,6 +1592,7 @@ defmodule EXLA.Defn do end args = [%{type: ans.type, shape: {}}, %{type: ans.type, shape: {}}] + comp = sort_computation(op, ans.type, args, state) EXLA.Op.sort(tensor, comp, dimension, opts[:stable] == true) end @@ -1800,7 +1818,22 @@ defmodule EXLA.Defn do EXLA.Builder.build(op) end - defp op_computation(op, args, state, prepare_args \\ & &1) do + defp op_computation(op, args, out, state, prepare_args \\ & &1) + + defp op_computation(op, args, out, %{builder: %EXLA.MLIR.Function{}}, prepare_args) do + arg_shapes = + Enum.with_index(args, fn arg, i -> + {"p#{i}", computation_arg_shape(arg)} + end) + + function = EXLA.Builder.new(Atom.to_string(op), arg_shapes, struct(Nx.Tensor, out), :mlir) + + args = EXLA.MLIR.Function.get_arguments(function) + + EXLA.Builder.build(apply(Value, op, prepare_args.(args)), true) + end + + defp op_computation(op, args, _out, state, prepare_args) do subbuilder = subbuilder(state.builder, Atom.to_string(op)) args = @@ -1981,6 +2014,28 @@ defmodule EXLA.Defn do ## Aggregation + defp to_aggregate(op, type, shape, %Value{} = arg, initial, opts, state) do + arg = to_type(arg, type) + + acc = + case initial do + %Value{} = initial -> initial + initial when is_number(initial) -> Value.constant_r0(state.builder, initial, type) + end + + args = [%{type: type, shape: {}}, %{type: type, shape: {}}] + comp = op_computation(op, args, %{shape: shape, type: type}, state, &Enum.reverse/1) + + keep_axes = opts[:keep_axes] + [result] = Value.reduce(comp, [acc], [arg], reduce_axes(arg, opts[:axes])) + + if keep_axes do + Value.reshape(result, shape) + else + result + end + end + defp to_aggregate(op, type, shape, arg, initial, opts, state) do arg = to_type(arg, type) @@ -1995,7 +2050,7 @@ defmodule EXLA.Defn do # returns :nan but :infinity + :nan returns :infinity. # So we want to keep the current value as first argument # to preserve such properties. - comp = op_computation(op, args, state, &Enum.reverse/1) + comp = op_computation(op, args, :unused, state, &Enum.reverse/1) keep_axes = opts[:keep_axes] result = EXLA.Op.reduce(arg, acc, comp, reduce_axes(arg, opts[:axes])) @@ -2024,7 +2079,7 @@ defmodule EXLA.Defn do # returns :nan but :infinity + :nan returns :infinity. # So we want to keep the current value as first argument # to preserve such properties. - comp = op_computation(op, args, state, &Enum.reverse/1) + comp = op_computation(op, args, :unused, state, &Enum.reverse/1) strides = opts[:strides] padding = opts[:padding] diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 3e1add09fb..d686b572c2 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -163,6 +163,13 @@ defmodule EXLA.MLIR.Value do end def constant_r0(%Function{} = func, value, type) do + value = + if Nx.Type.float?(type) and not is_float(value) do + value * 1.0 + else + value + end + ref = EXLA.NIF.mlir_constant_r0(func.ref, value, EXLA.Shape.dtype_to_charlist(type)) |> unwrap!() @@ -453,6 +460,22 @@ defmodule EXLA.MLIR.Value do %{operand | ref: ref} end + def reduce( + %Function{ref: reducer}, + [%Value{function: func} | _] = init_values, + [%Value{function: func} | _] = inputs, + dimensions + ) do + init_value_refs = Enum.map(init_values, & &1.ref) + input_refs = Enum.map(inputs, & &1.ref) + + refs = + EXLA.NIF.mlir_reduce(func.ref, reducer, init_value_refs, input_refs, dimensions) + |> unwrap!() + + Enum.map(refs, &%Value{ref: &1, function: func}) + end + defp unwrap!({:ok, value}), do: value defp unwrap!(other), do: raise("#{inspect(other)}") end diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index 297e743773..907fcde601 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -46,7 +46,10 @@ defmodule EXLA.NIF do def mlir_get_tuple_element(_function, _tuple, _index), do: :erlang.nif_error(:undef) def mlir_pad(_function, _tensor, _pad, _low, _high, _mid), do: :erlang.nif_error(:undef) - def mlir_build(_function, _root), do: :erlang.nif_error(:undef) + def mlir_reduce(_function, _reducer, _init_values, _inputs, _dimensions), + do: :erlang.nif_error(:undef) + + def mlir_build(_function, _root, _return?), do: :erlang.nif_error(:undef) def mlir_compile( _client, diff --git a/exla/test/exla/mlir/executable_test.exs b/exla/test/exla/mlir/executable_test.exs index fa543019d8..0fd5359b74 100644 --- a/exla/test/exla/mlir/executable_test.exs +++ b/exla/test/exla/mlir/executable_test.exs @@ -814,6 +814,41 @@ defmodule EXLA.MLIR.ExecutableTest do end end + describe "reduce" do + test "sum defaults" do + tensor = Nx.tensor([1, 2, 3, 4.0]) + + function = &Nx.sum/1 + + result_nx = Nx.Defn.jit_apply(function, [tensor], compiler: Nx.Defn.Evaluator) + result_mlir = Nx.Defn.jit_apply(function, [tensor]) + + assert_equal(result_nx, result_mlir) + end + + test "sum custom axes" do + tensor = Nx.tensor([[[1, 2, 3.0], [4, 5, 6]]]) + + function = &Nx.sum(&1, axes: [0, 2]) + + result_nx = Nx.Defn.jit_apply(function, [tensor], compiler: Nx.Defn.Evaluator) + result_mlir = Nx.Defn.jit_apply(function, [tensor]) + + assert_equal(result_nx, result_mlir) + end + + test "sum keep axes" do + tensor = Nx.tensor([[[1, 2, 3.0], [4, 5, 6]]]) + + function = &Nx.sum(&1, axes: [0, 2], keep_axes: true) + + result_nx = Nx.Defn.jit_apply(function, [tensor], compiler: Nx.Defn.Evaluator) + result_mlir = Nx.Defn.jit_apply(function, [tensor]) + + assert_equal(result_nx, result_mlir) + end + end + describe "triangular_solve" do test "supports options" do a = Nx.tensor([[1, 1, 1], [0, 1, 1], [0, 0, 1]])