Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Executorch] Handle broadcast semantics for last dim #6309

Open
wants to merge 1 commit into
base: gh/kimishpatel/121/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions kernels/optimized/cpu/binary_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,27 @@ enum class ElementwiseOptimizedPath {
kBroadcast2dBy1dReverseArguments,
kBroadcastNdByNd,
kBroadcastNdByNdReverseArguments,
kBroadcastLastDim,
kBroadcastLastDimReverseArguments,
};

namespace internal {

// Find the single broadcast dimension if it exists.
// This path aims to handle broadcast of the following form
// A = [a1, a2,., 1, .., an]
// B = [b1, b2,., bm, .., bn]
// OR
// A = [a1, a2,., am, .., an]
// B = [b1, b2,., 1, .., bn]
/*
Given two tensors, this function returns the broadcast dim if it exists.
Returns 0 if no broadcast dim is found.
Else negative index is used to indicate broadcast dim
e.g. if size = [a, b, c, 1, e, f] then broadcast dim is -3

This path aims to handle broadcast of the following form
A = [a1, a2,., 1, .., an]
B = [b1, b2,., bm, .., bn]
OR
A = [a1, a2,., am, .., an]
B = [b1, b2,., 1, .., bn]
Note that this way of determining broadcast dim also works
when broadcast dim is the last dim.
*/
int32_t inline get_broadcast_dim(const Tensor& lhs, const Tensor& rhs) {
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
auto lhs_end = lhs.sizes().end();
Expand Down Expand Up @@ -125,6 +135,14 @@ inline ElementwiseOptimizedPath select_broadcast_optimized_path(
} else {
return ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments;
}
} else if (broadcast_dim == -1) {
if (std::count_if(lhs_begin, lhs_end, [](Tensor::SizesType x) {
return x == 1;
}) == 1) {
return ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments;
} else {
return ElementwiseOptimizedPath::kBroadcastLastDim;
}
}
return ElementwiseOptimizedPath::kNone;
}
Expand Down
167 changes: 112 additions & 55 deletions kernels/optimized/cpu/op_mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <executorch/kernels/optimized/vec/vec.h>
#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h> // IWYU pragma: export
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/platform/assert.h>

Expand Down Expand Up @@ -66,6 +67,116 @@ template <
typename CTYPE_OUT>
struct MulInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
: public ReportCanCastBug {};

Tensor& handle_last_dim_broadcast(
KernelRuntimeContext& ctx,
const Tensor& a,
const Tensor& b,
Tensor& out,
const ElementwiseOptimizedPath selected_optimized_path) {
ScalarType out_type = out.scalar_type();
const Tensor* lhs;
const Tensor* rhs;
if (selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) {
lhs = &b;
rhs = &a;
} else {
lhs = &a;
rhs = &b;
}
auto error = resize_tensor(out, lhs->sizes());
ET_KERNEL_CHECK_MSG(
ctx,
error == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");
const size_t outer_size = getLeadingDims(out, out.dim() - 1);
const auto broadcast_size = out.size(out.dim() - 1);
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
using Vec = executorch::vec::Vectorized<CTYPE>;
executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>(
[](Vec x, Vec y) { return x * y; },
out.mutable_data_ptr<CTYPE>(),
lhs->const_data_ptr<CTYPE>(),
rhs->const_data_ptr<CTYPE>(),
outer_size,
broadcast_size);
});
return out;
}

Tensor& handle_broadcast_mul(
KernelRuntimeContext& ctx,
const Tensor& a,
const Tensor& b,
Tensor& out,
const ElementwiseOptimizedPath selected_optimized_path) {
if ((selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastLastDim) ||
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) {
return handle_last_dim_broadcast(ctx, a, b, out, selected_optimized_path);
}

ScalarType out_type = out.scalar_type();
const Tensor* lhs;
const Tensor* rhs;
if ((selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
lhs = &b;
rhs = &a;
} else {
// Catch failure to update logic when adding new broadcasting possibility.
ET_DCHECK(
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNd));
lhs = &a;
rhs = &b;
}
auto error = resize_tensor(out, lhs->sizes());
ET_KERNEL_CHECK_MSG(
ctx,
error == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");
int64_t outer_size = 1;
int64_t broadcast_size;
int64_t inner_size;
if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) ||
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
int32_t broadcast_dim_rhs = rhs->dim() + broadcast_dim;
auto normalized_tensor_size_lhs =
get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
outer_size = normalized_tensor_size_lhs[0];
broadcast_size = normalized_tensor_size_lhs[1];
inner_size = normalized_tensor_size_lhs[2];
} else {
broadcast_size = lhs->sizes()[lhs->dim() - 2];
inner_size = lhs->sizes()[lhs->dim() - 1];
}
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
using Vec = executorch::vec::Vectorized<CTYPE>;
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>(
[](Vec x, Vec y) { return x * y; },
out.mutable_data_ptr<CTYPE>(),
lhs->const_data_ptr<CTYPE>(),
rhs->const_data_ptr<CTYPE>(),
outer_size,
broadcast_size,
inner_size);
});
return out;
}
} // namespace

Tensor& opt_mul_out(
Expand Down Expand Up @@ -128,61 +239,7 @@ Tensor& opt_mul_out(
out.numel());
});
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
const Tensor* lhs;
const Tensor* rhs;
if ((selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) ||
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
lhs = &b;
rhs = &a;
} else {
// Catch failure to update logic when adding new broadcasting possibility.
ET_DCHECK(
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1d) ||
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNd));
lhs = &a;
rhs = &b;
}
auto error = resize_tensor(out, lhs->sizes());
ET_KERNEL_CHECK_MSG(
ctx,
error == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");
int64_t outer_size = 1;
int64_t broadcast_size;
int64_t inner_size;
if ((selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNd) ||
(selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) {
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs);
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim;
int32_t broadcast_dim_rhs = rhs->dim() + broadcast_dim;
auto normalized_tensor_size_lhs =
get_normalized_tensor_size(*lhs, broadcast_dim_lhs);
outer_size = normalized_tensor_size_lhs[0];
broadcast_size = normalized_tensor_size_lhs[1];
inner_size = normalized_tensor_size_lhs[2];
} else {
broadcast_size = lhs->sizes()[lhs->dim() - 2];
inner_size = lhs->sizes()[lhs->dim() - 1];
}
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
using Vec = executorch::vec::Vectorized<CTYPE>;
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>(
[](Vec x, Vec y) { return x * y; },
out.mutable_data_ptr<CTYPE>(),
lhs->const_data_ptr<CTYPE>(),
rhs->const_data_ptr<CTYPE>(),
outer_size,
broadcast_size,
inner_size);
});
return handle_broadcast_mul(ctx, a, b, out, selected_optimized_path);
} else {
ScalarType common_type =
promoteTypes(a_type, b_type, /*half_to_float*/ true);
Expand Down
1 change: 1 addition & 0 deletions kernels/optimized/cpu/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ _OPTIMIZED_ATEN_OPS = (
":binary_ops",
"//executorch/kernels/portable/cpu:scalar_utils",
"//executorch/kernels/portable/cpu/util:broadcast_util",
"//executorch/runtime/core/exec_aten/util:tensor_util",
],
),
op_target(
Expand Down
39 changes: 39 additions & 0 deletions kernels/optimized/vec/functional_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,5 +380,44 @@ inline void broadcasting_map_2d_by_1d(
broadcasting_map_3d_and_unsqueezed_3d(vec_fun, output_data, input_data, input_data2, 1, size, size2);
}

/*
Following function is used to implement broadcasting binary operation on two tensors
where lhs tensor is treated to be of shape [outer_size, broadcast_size] and
rhs tensor is treated to be of shape [outer_size, 1]
Any two N dimensional tensors can be mapped to this formula
when lhs size = [lhs0, lhs1, ..., lhsN-1] and rhs size = [rhs0, rhs1, ..., 1]
by viewing the two tensors as
lhs size = [lsh0 * lsh1 * ... * lshN-2, lhsN-1]
rhs size = [rsh0 * rsh1 * ... * rshN-2, 1]
*/
template <typename scalar_t, typename Op>
inline void broadcasting_map_broadcast_last_dim(
const Op& vec_fun,
scalar_t* output_data,
const scalar_t* lhs,
const scalar_t* rhs,
int64_t outer_size,
int64_t broadcast_size) {
using Vec = vec::Vectorized<scalar_t>;
int64_t outer_stride_lhs = broadcast_size;
int64_t outer_stride_rhs = 1;
for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
const scalar_t* lhs_outer = lhs + outer_idx * outer_stride_lhs;
scalar_t* output_data_row = output_data + outer_idx * outer_stride_lhs;
int64_t inner_idx = 0;
Vec data_vec2 = Vec(rhs[outer_idx]);
for (; inner_idx < broadcast_size - (broadcast_size % Vec::size()); inner_idx += Vec::size()) {
Vec data_vec = Vec::loadu(lhs_outer + inner_idx);
Vec output_vec = vec_fun(data_vec, data_vec2);
output_vec.store(output_data_row + inner_idx);
}
if (broadcast_size - inner_idx > 0) {
Vec data_vec = Vec::loadu(lhs_outer + inner_idx, broadcast_size - inner_idx);
Vec output_vec = vec_fun(data_vec, data_vec2);
output_vec.store(output_data_row + inner_idx, broadcast_size - inner_idx);
}
}
}

} // namespace vec
} // namespace executorch
67 changes: 67 additions & 0 deletions kernels/test/op_mul_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,61 @@ class OpMulOutTest : public OperatorTest {
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);
}

template <ScalarType DTYPE>
void test_broadcast_last_dim() {
TensorFactory<DTYPE> tf_a;

Tensor a =
tf_a.make({4, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
Tensor b = tf_a.make({4, 1}, /*data=*/{2, 3, 4, 5});

// Destination for output of mul.
Tensor out = tf_a.zeros({4, 3});
Tensor expected = tf_a.make(
{4, 3}, /*data=*/{2, 4, 6, 12, 15, 18, 28, 32, 36, 50, 55, 60});

// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected);
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);

a =
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
b = tf_a.make({2, 2, 1}, /*data=*/{2, 3, 4, 5});

// Destination for output of mul.
out = tf_a.zeros({2, 2, 3});
expected = tf_a.make(
{2, 2, 3}, /*data=*/{2, 4, 6, 12, 15, 18, 28, 32, 36, 50, 55, 60});

// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected);
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);

a = tf_a.make(
{2, 2, 3, 5},
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60});
b = tf_a.make(
{2, 2, 3, 1},
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});

// Destination for output of mul.
out = tf_a.zeros({2, 2, 3, 5});
expected = tf_a.make(
{2, 2, 3, 5},
/*data=*/{1, 2, 3, 4, 5, 12, 14, 16, 18, 20, 33, 36,
39, 42, 45, 64, 68, 72, 76, 80, 105, 110, 115, 120,
125, 156, 162, 168, 174, 180, 217, 224, 231, 238, 245, 288,
296, 304, 312, 320, 369, 378, 387, 396, 405, 460, 470, 480,
490, 500, 561, 572, 583, 594, 605, 672, 684, 696, 708, 720});

// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected);
EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected);
}

template <ScalarType DTYPE>
void test_broadcast_b2a() {
TensorFactory<DTYPE> tf_a;
Expand Down Expand Up @@ -392,6 +447,18 @@ TEST_F(OpMulOutTest, BroadcastNDTest) {
test_broadcast_4D<ScalarType::Float>();
test_broadcast_4D<ScalarType::Half>();
test_broadcast_4D<ScalarType::BFloat16>();

// Test broadcasting on the last dimension
test_broadcast_last_dim<ScalarType::Float>();
test_broadcast_last_dim<ScalarType::Half>();
test_broadcast_last_dim<ScalarType::BFloat16>();
}

TEST_F(OpMulOutTest, BroadcastLastDimTest) {
// Test broadcasting on the last dimension
test_broadcast_last_dim<ScalarType::Float>();
test_broadcast_last_dim<ScalarType::Half>();
test_broadcast_last_dim<ScalarType::BFloat16>();
}

// Broadcast tensor a and b's size to a new size c.
Expand Down
Loading