Skip to content

Commit

Permalink
MLIR Reduce Op (elixir-nx#1339)
Browse files Browse the repository at this point in the history
Co-authored-by: Paulo Valente <[email protected]>
  • Loading branch information
seanmor5 and polvalente authored Nov 18, 2023
1 parent 07447be commit 7d13411
Show file tree
Hide file tree
Showing 10 changed files with 230 additions and 23 deletions.
3 changes: 2 additions & 1 deletion exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ static ErlNifFunc exla_funcs[] = {
{"mlir_less_equal", 3, mlir_less_equal},
{"mlir_greater", 3, mlir_greater},
{"mlir_greater_equal", 3, mlir_greater_equal},
{"mlir_build", 2, mlir_build},
{"mlir_build", 3, mlir_build},
{"dump_mlir_module", 1, dump_mlir_module},
{"mlir_get_shape", 1, mlir_get_shape},
{"mlir_convert", 3, mlir_convert},
Expand Down Expand Up @@ -719,6 +719,7 @@ static ErlNifFunc exla_funcs[] = {
{"mlir_create_token", 1, mlir_create_token},
{"mlir_triangular_solve", 6, mlir_triangular_solve},
{"mlir_dynamic_update_slice", 4, mlir_dynamic_update_slice},
{"mlir_reduce", 5, mlir_reduce},
// XlaBuilder
{"new_builder", 1, new_builder},
{"create_sub_builder", 2, create_sub_builder},
Expand Down
32 changes: 30 additions & 2 deletions exla/c_src/exla/mlir/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,28 @@ mlir::Value MLIRFunction::ScatterOp(mlir::Value target, mlir::Value indices, mli
return scatter_op.getResult(0);
}

std::vector<mlir::Value> MLIRFunction::ReduceOp(
MLIRFunction * reducer,
std::vector<mlir::Value> init_values,
std::vector<mlir::Value> inputs,
std::vector<int64_t> dimensions
) {
auto builder = module_->builder();
builder->setInsertionPointToEnd(&func_->getBody().back());

mlir::ValueRange init_values_range(init_values);
mlir::ValueRange inputs_range(inputs);
mlir::DenseIntElementsAttr dimensions_attr = Int64ToDenseIntElementsAttr(builder, dimensions);

mlir::mhlo::ReduceOp reduce_op = builder->create<mlir::mhlo::ReduceOp>(builder->getUnknownLoc(), inputs_range, init_values_range, dimensions_attr);
mlir::Region &reduceBody = reduce_op.getRegion();
mlir::Region &funcBody = reducer->function()->getBody();
reduceBody.getBlocks().splice(reduceBody.end(), funcBody.getBlocks());

mlir::Operation::result_range results = reduce_op.getResults();
return std::vector<mlir::Value>(results.begin(), results.end());
}

mlir::Value MLIRFunction::SelectAndScatterOp(
mlir::Value target,
mlir::Value source,
Expand Down Expand Up @@ -918,9 +940,15 @@ ERL_NIF_TERM MLIRFunction::ConstantOp(mlir::Type type, ErlNifEnv *env, ERL_NIF_T
return exla::nif::error(env, "invalid type received");
}

void MLIRFunction::Build(mlir::Value root) {
void MLIRFunction::Build(mlir::Value root, bool use_mhlo_return) {
module_->builder()->setInsertionPointToEnd(&func_->getBody().back());
auto op = module_->builder()->create<mlir::func::ReturnOp>(module_->builder()->getUnknownLoc(), root);

if (use_mhlo_return) {
module_->builder()->create<mlir::mhlo::ReturnOp>(module_->builder()->getUnknownLoc(), root);
} else {
module_->builder()->create<mlir::func::ReturnOp>(module_->builder()->getUnknownLoc(), root);
}

return;
}

Expand Down
5 changes: 4 additions & 1 deletion exla/c_src/exla/mlir/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,16 @@ class MLIRFunction {
mlir::Value CreateTokenOp();
mlir::Value TriangularSolveOp(mlir::Value a, mlir::Value b, bool left_side, bool lower, bool transpose_a);
mlir::Value DynamicUpdateSliceOp(mlir::Value operand, mlir::Value update, std::vector<mlir::Value> start_indices);
std::vector<mlir::Value> ReduceOp(MLIRFunction * function, std::vector<mlir::Value> init_values, std::vector<mlir::Value> inputs, std::vector<int64_t> dimensions);
ERL_NIF_TERM ConstantOp(mlir::Type type, ErlNifEnv *env, ERL_NIF_TERM value_ptr, std::vector<int64_t> dims = {});
int get_mlir_type(ErlNifEnv *env, ERL_NIF_TERM term, mlir::Type *type);

void Build(mlir::Value root);
void Build(mlir::Value root, bool use_mhlo_return);

llvm::MutableArrayRef<mlir::BlockArgument> get_arguments() { return func_->getBody().front().getArguments(); }

mlir::func::FuncOp * function() { return func_.get(); }

private:
std::shared_ptr<MLIRModule> module_;
std::unique_ptr<mlir::func::FuncOp> func_;
Expand Down
52 changes: 50 additions & 2 deletions exla/c_src/exla/mlir/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -666,21 +666,25 @@ ERL_NIF_TERM mlir_select(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
}

ERL_NIF_TERM mlir_build(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 2) {
if (argc != 3) {
return exla::nif::error(env, "Bad argument count.");
}

exla::MLIRFunction** function;
mlir::Value* root;
bool use_mhlo_return;

if (!exla::nif::get<exla::MLIRFunction*>(env, argv[0], function)) {
return exla::nif::error(env, "Unable to get function.");
}
if (!exla::nif::get<mlir::Value>(env, argv[1], root)) {
return exla::nif::error(env, "Unable to get root.");
}
if (!exla::nif::get(env, argv[2], &use_mhlo_return)) {
return exla::nif::error(env, "Unable to get return");
}

(*function)->Build(*root);
(*function)->Build(*root, use_mhlo_return);

return exla::nif::ok(env);
}
Expand Down Expand Up @@ -751,6 +755,50 @@ ERL_NIF_TERM mlir_sort(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
return exla::nif::ok(env, list);
}

ERL_NIF_TERM mlir_reduce(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 5) {
return exla::nif::error(env, "Bad argument count.");
}

exla::MLIRFunction** function;
exla::MLIRFunction** reducer;
std::vector<mlir::Value> init_values;
std::vector<mlir::Value> inputs;
std::vector<exla::int64> dimensions;

if (!exla::nif::get<exla::MLIRFunction*>(env, argv[0], function)) {
return exla::nif::error(env, "Unable to get function.");
}
if (!exla::nif::get<exla::MLIRFunction*>(env, argv[1], reducer)) {
return exla::nif::error(env, "Unable to get reducer.");
}
if (!exla::nif::get_list(env, argv[2], init_values)) {
return exla::nif::error(env, "Unable to get init_values.");
}
if (!exla::nif::get_list(env, argv[3], inputs)) {
return exla::nif::error(env, "Unable to get inputs.");
}
if (!exla::nif::get_tuple(env, argv[4], dimensions)) {
return exla::nif::error(env, "Unable to get dimensions.");
}

std::vector<mlir::Value> res = (*function)->ReduceOp(*reducer, init_values, inputs, dimensions);

size_t n = res.size();

std::vector<ERL_NIF_TERM> nif_terms;
nif_terms.reserve(n);

for (size_t i = 0; i < n; i++) {
nif_terms[i] = exla::nif::make<mlir::Value>(env, res[i]);
}

auto data = nif_terms.data();
auto list = enif_make_list_from_array(env, &data[0], n);

return exla::nif::ok(env, list);
}

ERL_NIF_TERM mlir_bitcast_convert(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 4) {
return exla::nif::error(env, "Bad argument count.");
Expand Down
1 change: 1 addition & 0 deletions exla/c_src/exla/mlir/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,4 @@ ERL_NIF_TERM mlir_fft(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]);
ERL_NIF_TERM mlir_create_token(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]);
ERL_NIF_TERM mlir_triangular_solve(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]);
ERL_NIF_TERM mlir_dynamic_update_slice(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]);
ERL_NIF_TERM mlir_reduce(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]);
26 changes: 18 additions & 8 deletions exla/lib/exla/builder.ex
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,22 @@ defmodule EXLA.Builder do
@enforce_keys [:ref]
defstruct [:ref, :parent, :name]

def new(name, _inputs, _outputs, :xla) do
def new(name, inputs, outputs, type, sub? \\ false)

def new(name, _inputs, _outputs, :xla, _sub?) do
new(name)
end

def new(_name, inputs, outputs, :mlir) do
def new(_name, inputs, outputs, :mlir, sub?) do
# TO-DO (mlir): check if using the function name makes sense
arg_shapes = Enum.map(inputs, fn {_, %Shape{} = s} -> s end)

return_shape =
[outputs] |> Nx.Defn.Composite.flatten_list() |> List.to_tuple() |> exla_shape()
if sub? do
exla_shape(outputs)
else
[outputs] |> Nx.Defn.Composite.flatten_list() |> List.to_tuple() |> exla_shape()
end

module = M.new()
M.create_function(module, "main", arg_shapes, return_shape)
Expand All @@ -34,8 +40,8 @@ defmodule EXLA.Builder do
|> EXLA.Shape.make_tuple_shape()
end

defp exla_shape(%Nx.Tensor{} = t) do
EXLA.Shape.make_shape(t.type, t.shape)
defp exla_shape(%{shape: shape, type: type}) do
EXLA.Shape.make_shape(type, shape)
end

defp new(name) when is_binary(name) do
Expand All @@ -48,15 +54,19 @@ defmodule EXLA.Builder do
%__MODULE__{ref: ref, parent: builder, name: name}
end

def build(%Op{} = root) do
def build(root, use_mhlo_return? \\ false)

def build(%Op{} = root, _) do
shape = EXLA.Op.get_shape(root)
{:ok, ref} = EXLA.NIF.build(root.builder, root.ref)
%Computation{ref: ref, output_shape: shape}
end

def build(%EXLA.MLIR.Value{function: function, ref: root_ref}) do
def build(%EXLA.MLIR.Value{function: function, ref: root_ref}, use_mhlo_return?) do
%EXLA.MLIR.Function{ref: function_ref} = function
:ok = EXLA.NIF.mlir_build(function_ref, root_ref)
return_int = if use_mhlo_return?, do: 1, else: 0

:ok = EXLA.NIF.mlir_build(function_ref, root_ref, return_int)
function
end
end
71 changes: 63 additions & 8 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,23 @@ defmodule EXLA.Defn do
to_aggregate(:min, type, shape, arg, max_number, opts, state)
end

defp to_operator(
:reduce,
[%Value{} = arg, %Value{} = acc, opts, fun],
%{type: type, shape: shape},
_state
) do
arg = to_type(arg, type)
keep_axes = opts[:keep_axes]
[result] = Value.reduce(fun, [to_type(acc, type)], [arg], reduce_axes(arg, opts[:axes]))

if keep_axes do
Value.reshape(result, shape)
else
result
end
end

defp to_operator(:reduce, [arg, acc, opts, fun], %{type: type, shape: shape}, _state) do
arg = to_type(arg, type)
keep_axes = opts[:keep_axes]
Expand Down Expand Up @@ -1253,8 +1270,8 @@ defmodule EXLA.Defn do
init_value = to_type(init_value, type)

args = [%{type: type, shape: {}}, %{type: type, shape: {}}]
select_fn = op_computation(:greater, args, state)
scatter_fn = op_computation(:add, args, state)
select_fn = op_computation(:greater, args, :unused, state)
scatter_fn = op_computation(:add, args, :unused, state)

EXLA.Op.select_and_scatter(
arg,
Expand Down Expand Up @@ -1307,8 +1324,8 @@ defmodule EXLA.Defn do

args = [%{type: type, shape: {}}, %{type: type, shape: {}}]

select_fn = op_computation(:less, args, state)
scatter_fn = op_computation(:add, args, state)
select_fn = op_computation(:less, args, :unused, state)
scatter_fn = op_computation(:add, args, :unused, state)

EXLA.Op.select_and_scatter(
arg,
Expand All @@ -1333,7 +1350,7 @@ defmodule EXLA.Defn do
state
) do
args = [%{type: type, shape: {}}, %{type: type, shape: {}}]
scatter_fn = op_computation(:add, args, state)
scatter_fn = op_computation(:add, args, :unused, state)

scatter(scatter_fn, tensors, out)
end
Expand Down Expand Up @@ -1575,6 +1592,7 @@ defmodule EXLA.Defn do
end

args = [%{type: ans.type, shape: {}}, %{type: ans.type, shape: {}}]

comp = sort_computation(op, ans.type, args, state)
EXLA.Op.sort(tensor, comp, dimension, opts[:stable] == true)
end
Expand Down Expand Up @@ -1800,7 +1818,22 @@ defmodule EXLA.Defn do
EXLA.Builder.build(op)
end

defp op_computation(op, args, state, prepare_args \\ & &1) do
defp op_computation(op, args, out, state, prepare_args \\ & &1)

defp op_computation(op, args, out, %{builder: %EXLA.MLIR.Function{}}, prepare_args) do
arg_shapes =
Enum.with_index(args, fn arg, i ->
{"p#{i}", computation_arg_shape(arg)}
end)

function = EXLA.Builder.new(Atom.to_string(op), arg_shapes, struct(Nx.Tensor, out), :mlir)

args = EXLA.MLIR.Function.get_arguments(function)

EXLA.Builder.build(apply(Value, op, prepare_args.(args)), true)
end

defp op_computation(op, args, _out, state, prepare_args) do
subbuilder = subbuilder(state.builder, Atom.to_string(op))

args =
Expand Down Expand Up @@ -1981,6 +2014,28 @@ defmodule EXLA.Defn do

## Aggregation

defp to_aggregate(op, type, shape, %Value{} = arg, initial, opts, state) do
arg = to_type(arg, type)

acc =
case initial do
%Value{} = initial -> initial
initial when is_number(initial) -> Value.constant_r0(state.builder, initial, type)
end

args = [%{type: type, shape: {}}, %{type: type, shape: {}}]
comp = op_computation(op, args, %{shape: shape, type: type}, state, &Enum.reverse/1)

keep_axes = opts[:keep_axes]
[result] = Value.reduce(comp, [acc], [arg], reduce_axes(arg, opts[:axes]))

if keep_axes do
Value.reshape(result, shape)
else
result
end
end

defp to_aggregate(op, type, shape, arg, initial, opts, state) do
arg = to_type(arg, type)

Expand All @@ -1995,7 +2050,7 @@ defmodule EXLA.Defn do
# returns :nan but :infinity + :nan returns :infinity.
# So we want to keep the current value as first argument
# to preserve such properties.
comp = op_computation(op, args, state, &Enum.reverse/1)
comp = op_computation(op, args, :unused, state, &Enum.reverse/1)

keep_axes = opts[:keep_axes]
result = EXLA.Op.reduce(arg, acc, comp, reduce_axes(arg, opts[:axes]))
Expand Down Expand Up @@ -2024,7 +2079,7 @@ defmodule EXLA.Defn do
# returns :nan but :infinity + :nan returns :infinity.
# So we want to keep the current value as first argument
# to preserve such properties.
comp = op_computation(op, args, state, &Enum.reverse/1)
comp = op_computation(op, args, :unused, state, &Enum.reverse/1)

strides = opts[:strides]
padding = opts[:padding]
Expand Down
23 changes: 23 additions & 0 deletions exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ defmodule EXLA.MLIR.Value do
end

def constant_r0(%Function{} = func, value, type) do
value =
if Nx.Type.float?(type) and not is_float(value) do
value * 1.0
else
value
end

ref =
EXLA.NIF.mlir_constant_r0(func.ref, value, EXLA.Shape.dtype_to_charlist(type)) |> unwrap!()

Expand Down Expand Up @@ -453,6 +460,22 @@ defmodule EXLA.MLIR.Value do
%{operand | ref: ref}
end

def reduce(
%Function{ref: reducer},
[%Value{function: func} | _] = init_values,
[%Value{function: func} | _] = inputs,
dimensions
) do
init_value_refs = Enum.map(init_values, & &1.ref)
input_refs = Enum.map(inputs, & &1.ref)

refs =
EXLA.NIF.mlir_reduce(func.ref, reducer, init_value_refs, input_refs, dimensions)
|> unwrap!()

Enum.map(refs, &%Value{ref: &1, function: func})
end

defp unwrap!({:ok, value}), do: value
defp unwrap!(other), do: raise("#{inspect(other)}")
end
Loading

0 comments on commit 7d13411

Please sign in to comment.