From d90833a5bb61a19b564764e7cb56d6a70a210644 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 16 Oct 2024 14:51:57 -0700 Subject: [PATCH] [Executorch] Handle broadcast semantics for last dim This diff add support to handle element wise mul op when broadcast is across last dim Differential Revision: [D64156863](https://our.internmc.facebook.com/intern/diff/D64156863/) [ghstack-poisoned] --- kernels/optimized/cpu/binary_ops.h | 32 ++++- kernels/optimized/cpu/op_mul.cpp | 167 ++++++++++++++++-------- kernels/optimized/cpu/targets.bzl | 1 + kernels/optimized/vec/functional_base.h | 39 ++++++ kernels/test/op_mul_test.cpp | 67 ++++++++++ 5 files changed, 244 insertions(+), 62 deletions(-) diff --git a/kernels/optimized/cpu/binary_ops.h b/kernels/optimized/cpu/binary_ops.h index d02153ea44..ce19a8fa9d 100644 --- a/kernels/optimized/cpu/binary_ops.h +++ b/kernels/optimized/cpu/binary_ops.h @@ -43,17 +43,27 @@ enum class ElementwiseOptimizedPath { kBroadcast2dBy1dReverseArguments, kBroadcastNdByNd, kBroadcastNdByNdReverseArguments, + kBroadcastLastDim, + kBroadcastLastDimReverseArguments, }; namespace internal { -// Find the single broadcast dimension if it exists. -// This path aims to handle broadcast of the following form -// A = [a1, a2,., 1, .., an] -// B = [b1, b2,., bm, .., bn] -// OR -// A = [a1, a2,., am, .., an] -// B = [b1, b2,., 1, .., bn] +/* + Given two tensors, this function returns the broadcast dim if it exists. + Returns 0 if no broadcast dim is found. + Else negative index is used to indicate broadcast dim + e.g. if size = [a, b, c, 1, e, f] then broadcast dim is -3 + + This path aims to handle broadcast of the following form + A = [a1, a2,., 1, .., an] + B = [b1, b2,., bm, .., bn] + OR + A = [a1, a2,., am, .., an] + B = [b1, b2,., 1, .., bn] + Note that this way of determining broadcast dim also works + when broadcast dim is the last dim. +*/ int32_t inline get_broadcast_dim(const Tensor& lhs, const Tensor& rhs) { auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes()); auto lhs_end = lhs.sizes().end(); @@ -125,6 +135,14 @@ inline ElementwiseOptimizedPath select_broadcast_optimized_path( } else { return ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments; } + } else if (broadcast_dim == -1) { + if (std::count_if(lhs_begin, lhs_end, [](Tensor::SizesType x) { + return x == 1; + }) == 1) { + return ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments; + } else { + return ElementwiseOptimizedPath::kBroadcastLastDim; + } } return ElementwiseOptimizedPath::kNone; } diff --git a/kernels/optimized/cpu/op_mul.cpp b/kernels/optimized/cpu/op_mul.cpp index ea95f16108..8de5ab8f74 100644 --- a/kernels/optimized/cpu/op_mul.cpp +++ b/kernels/optimized/cpu/op_mul.cpp @@ -11,6 +11,7 @@ #include #include #include +#include // IWYU pragma: export #include #include @@ -66,6 +67,116 @@ template < typename CTYPE_OUT> struct MulInner : public ReportCanCastBug {}; + +Tensor& handle_last_dim_broadcast( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + Tensor& out, + const ElementwiseOptimizedPath selected_optimized_path) { + ScalarType out_type = out.scalar_type(); + const Tensor* lhs; + const Tensor* rhs; + if (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) { + lhs = &b; + rhs = &a; + } else { + lhs = &a; + rhs = &b; + } + auto error = resize_tensor(out, lhs->sizes()); + ET_KERNEL_CHECK_MSG( + ctx, + error == Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor."); + const size_t outer_size = getLeadingDims(out, out.dim() - 1); + const auto broadcast_size = out.size(out.dim() - 1); + ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { + using Vec = executorch::vec::Vectorized; + executorch::vec::broadcasting_map_broadcast_last_dim( + [](Vec x, Vec y) { return x * y; }, + out.mutable_data_ptr(), + lhs->const_data_ptr(), + rhs->const_data_ptr(), + outer_size, + broadcast_size); + }); + return out; +} + +Tensor& handle_broadcast_mul( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + Tensor& out, + const ElementwiseOptimizedPath selected_optimized_path) { + if ((selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastLastDim) || + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) { + return handle_last_dim_broadcast(ctx, a, b, out, selected_optimized_path); + } + + ScalarType out_type = out.scalar_type(); + const Tensor* lhs; + const Tensor* rhs; + if ((selected_optimized_path == + ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) || + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { + lhs = &b; + rhs = &a; + } else { + // Catch failure to update logic when adding new broadcasting possibility. + ET_DCHECK( + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcast2dBy1d) || + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNd)); + lhs = &a; + rhs = &b; + } + auto error = resize_tensor(out, lhs->sizes()); + ET_KERNEL_CHECK_MSG( + ctx, + error == Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor."); + int64_t outer_size = 1; + int64_t broadcast_size; + int64_t inner_size; + if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) || + (selected_optimized_path == + ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { + int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs); + int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim; + int32_t broadcast_dim_rhs = rhs->dim() + broadcast_dim; + auto normalized_tensor_size_lhs = + get_normalized_tensor_size(*lhs, broadcast_dim_lhs); + outer_size = normalized_tensor_size_lhs[0]; + broadcast_size = normalized_tensor_size_lhs[1]; + inner_size = normalized_tensor_size_lhs[2]; + } else { + broadcast_size = lhs->sizes()[lhs->dim() - 2]; + inner_size = lhs->sizes()[lhs->dim() - 1]; + } + ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { + using Vec = executorch::vec::Vectorized; + executorch::vec::broadcasting_map_3d_and_unsqueezed_3d( + [](Vec x, Vec y) { return x * y; }, + out.mutable_data_ptr(), + lhs->const_data_ptr(), + rhs->const_data_ptr(), + outer_size, + broadcast_size, + inner_size); + }); + return out; +} } // namespace Tensor& opt_mul_out( @@ -128,61 +239,7 @@ Tensor& opt_mul_out( out.numel()); }); } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { - const Tensor* lhs; - const Tensor* rhs; - if ((selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) || - (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { - lhs = &b; - rhs = &a; - } else { - // Catch failure to update logic when adding new broadcasting possibility. - ET_DCHECK( - (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcast2dBy1d) || - (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcastNdByNd)); - lhs = &a; - rhs = &b; - } - auto error = resize_tensor(out, lhs->sizes()); - ET_KERNEL_CHECK_MSG( - ctx, - error == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); - int64_t outer_size = 1; - int64_t broadcast_size; - int64_t inner_size; - if ((selected_optimized_path == - ElementwiseOptimizedPath::kBroadcastNdByNd) || - (selected_optimized_path == - ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { - int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs); - int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim; - int32_t broadcast_dim_rhs = rhs->dim() + broadcast_dim; - auto normalized_tensor_size_lhs = - get_normalized_tensor_size(*lhs, broadcast_dim_lhs); - outer_size = normalized_tensor_size_lhs[0]; - broadcast_size = normalized_tensor_size_lhs[1]; - inner_size = normalized_tensor_size_lhs[2]; - } else { - broadcast_size = lhs->sizes()[lhs->dim() - 2]; - inner_size = lhs->sizes()[lhs->dim() - 1]; - } - ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { - using Vec = executorch::vec::Vectorized; - executorch::vec::broadcasting_map_3d_and_unsqueezed_3d( - [](Vec x, Vec y) { return x * y; }, - out.mutable_data_ptr(), - lhs->const_data_ptr(), - rhs->const_data_ptr(), - outer_size, - broadcast_size, - inner_size); - }); + return handle_broadcast_mul(ctx, a, b, out, selected_optimized_path); } else { ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true); diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index 488d2af7fa..77a270cc45 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -72,6 +72,7 @@ _OPTIMIZED_ATEN_OPS = ( ":binary_ops", "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/runtime/core/exec_aten/util:tensor_util", ], ), op_target( diff --git a/kernels/optimized/vec/functional_base.h b/kernels/optimized/vec/functional_base.h index 6f9fcd8e79..f05c2becc5 100644 --- a/kernels/optimized/vec/functional_base.h +++ b/kernels/optimized/vec/functional_base.h @@ -380,5 +380,44 @@ inline void broadcasting_map_2d_by_1d( broadcasting_map_3d_and_unsqueezed_3d(vec_fun, output_data, input_data, input_data2, 1, size, size2); } +/* +Following function is used to implement broadcasting binary operation on two tensors +where lhs tensor is treated to be of shape [outer_size, broadcast_size] and +rhs tensor is treated to be of shape [outer_size, 1] +Any two N dimensional tensors can be mapped to this formula +when lhs size = [lhs0, lhs1, ..., lhsN-1] and rhs size = [rhs0, rhs1, ..., 1] +by viewing the two tensors as +lhs size = [lsh0 * lsh1 * ... * lshN-2, lhsN-1] +rhs size = [rsh0 * rsh1 * ... * rshN-2, 1] +*/ +template +inline void broadcasting_map_broadcast_last_dim( + const Op& vec_fun, + scalar_t* output_data, + const scalar_t* lhs, + const scalar_t* rhs, + int64_t outer_size, + int64_t broadcast_size) { + using Vec = vec::Vectorized; + int64_t outer_stride_lhs = broadcast_size; + int64_t outer_stride_rhs = 1; + for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) { + const scalar_t* lhs_outer = lhs + outer_idx * outer_stride_lhs; + scalar_t* output_data_row = output_data + outer_idx * outer_stride_lhs; + int64_t inner_idx = 0; + Vec data_vec2 = Vec(rhs[outer_idx]); + for (; inner_idx < broadcast_size - (broadcast_size % Vec::size()); inner_idx += Vec::size()) { + Vec data_vec = Vec::loadu(lhs_outer + inner_idx); + Vec output_vec = vec_fun(data_vec, data_vec2); + output_vec.store(output_data_row + inner_idx); + } + if (broadcast_size - inner_idx > 0) { + Vec data_vec = Vec::loadu(lhs_outer + inner_idx, broadcast_size - inner_idx); + Vec output_vec = vec_fun(data_vec, data_vec2); + output_vec.store(output_data_row + inner_idx, broadcast_size - inner_idx); + } + } +} + } // namespace vec } // namespace executorch diff --git a/kernels/test/op_mul_test.cpp b/kernels/test/op_mul_test.cpp index 6430bfbd8f..2f40e05fc9 100644 --- a/kernels/test/op_mul_test.cpp +++ b/kernels/test/op_mul_test.cpp @@ -220,6 +220,61 @@ class OpMulOutTest : public OperatorTest { EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected); } + template + void test_broadcast_last_dim() { + TensorFactory tf_a; + + Tensor a = + tf_a.make({4, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Tensor b = tf_a.make({4, 1}, /*data=*/{2, 3, 4, 5}); + + // Destination for output of mul. + Tensor out = tf_a.zeros({4, 3}); + Tensor expected = tf_a.make( + {4, 3}, /*data=*/{2, 4, 6, 12, 15, 18, 28, 32, 36, 50, 55, 60}); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected); + EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected); + + a = + tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + b = tf_a.make({2, 2, 1}, /*data=*/{2, 3, 4, 5}); + + // Destination for output of mul. + out = tf_a.zeros({2, 2, 3}); + expected = tf_a.make( + {2, 2, 3}, /*data=*/{2, 4, 6, 12, 15, 18, 28, 32, 36, 50, 55, 60}); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected); + EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected); + + a = tf_a.make( + {2, 2, 3, 5}, + /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60}); + b = tf_a.make( + {2, 2, 3, 1}, + /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + + // Destination for output of mul. + out = tf_a.zeros({2, 2, 3, 5}); + expected = tf_a.make( + {2, 2, 3, 5}, + /*data=*/{1, 2, 3, 4, 5, 12, 14, 16, 18, 20, 33, 36, + 39, 42, 45, 64, 68, 72, 76, 80, 105, 110, 115, 120, + 125, 156, 162, 168, 174, 180, 217, 224, 231, 238, 245, 288, + 296, 304, 312, 320, 369, 378, 387, 396, 405, 460, 470, 480, + 490, 500, 561, 572, 583, 594, 605, 672, 684, 696, 708, 720}); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected); + EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected); + } + template void test_broadcast_b2a() { TensorFactory tf_a; @@ -392,6 +447,18 @@ TEST_F(OpMulOutTest, BroadcastNDTest) { test_broadcast_4D(); test_broadcast_4D(); test_broadcast_4D(); + + // Test broadcasting on the last dimension + test_broadcast_last_dim(); + test_broadcast_last_dim(); + test_broadcast_last_dim(); +} + +TEST_F(OpMulOutTest, BroadcastLastDimTest) { + // Test broadcasting on the last dimension + test_broadcast_last_dim(); + test_broadcast_last_dim(); + test_broadcast_last_dim(); } // Broadcast tensor a and b's size to a new size c.