Skip to content

Commit

Permalink
Online interfaces for LOM (#2483)
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexandr-Solovev authored Oct 2, 2023
1 parent d5c9935 commit b5e996a
Show file tree
Hide file tree
Showing 33 changed files with 1,966 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace oneapi
{
namespace internal
{
template class LowOrderMomentsOnlineKernelOneAPI<DAAL_FPTYPE, defaultDense>;
template class DAAL_EXPORT LowOrderMomentsOnlineKernelOneAPI<DAAL_FPTYPE, defaultDense>;
template class LowOrderMomentsOnlineKernelOneAPI<DAAL_FPTYPE, singlePassDense>;
template class LowOrderMomentsOnlineKernelOneAPI<DAAL_FPTYPE, sumDense>;
template class LowOrderMomentsOnlineKernelOneAPI<DAAL_FPTYPE, fastCSR>;
Expand Down
2 changes: 2 additions & 0 deletions cpp/oneapi/dal/algo/basic_statistics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@
#pragma once

#include "oneapi/dal/algo/basic_statistics/compute.hpp"
#include "oneapi/dal/algo/basic_statistics/partial_compute.hpp"
#include "oneapi/dal/algo/basic_statistics/finalize_compute.hpp"
1 change: 1 addition & 0 deletions cpp/oneapi/dal/algo/basic_statistics/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dal_module(
dal_deps = [
"@onedal//cpp/oneapi/dal:core",
"@onedal//cpp/oneapi/dal/backend/primitives:common",
"@onedal//cpp/oneapi/dal/backend/primitives:reduction",
],
extra_deps = [
"@onedal//cpp/daal/src/algorithms/low_order_moments:kernel",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace oneapi::dal::basic_statistics::backend {

namespace daal_lom = daal::algorithms::low_order_moments;
namespace interop = dal::backend::interop;
namespace bk = dal::backend;

using task_t = task::compute;
using descriptor_t = detail::descriptor_base<task_t>;
Expand Down Expand Up @@ -100,4 +101,32 @@ inline auto get_result(const descriptor_t& desc, const daal_lom::Result& daal_re
return res;
}

template <typename Float>
inline array<Float> copy_immutable(const array<Float>&& inp) {
if (inp.has_mutable_data()) {
return inp;
}
else {
const auto count = inp.get_count();
auto res = array<Float>::empty(count);
bk::copy(res.get_mutable_data(), inp.get_data(), count);
return res;
}
}

template <typename Float, typename Result, typename Input, typename Parameter>
inline void alloc_result(Result& result, const Input* input, const Parameter* params, int method) {
const auto status = result.template allocate<Float>(input, params, method);
interop::status_to_exception(status);
}

template <typename Float, typename Result, typename Input, typename Parameter>
inline void initialize_result(Result& result,
const Input* input,
const Parameter* params,
int method) {
const auto status = result.template initialize<Float>(input, params, method);
interop::status_to_exception(status);
}

} // namespace oneapi::dal::basic_statistics::backend
Original file line number Diff line number Diff line change
Expand Up @@ -64,31 +64,6 @@ std::int64_t propose_block_size(std::int64_t row_count, std::int64_t col_count)
return std::max<idx_t>(std::min<idx_t>(row_count, idx_t(1024l)), block_of_rows_size);
}

template <typename Float>
array<Float> copy_immutable(const array<Float>&& inp) {
if (inp.has_mutable_data()) {
return inp;
}
else {
const auto count = inp.get_count();
auto res = array<Float>::empty(count);
bk::copy(res.get_mutable_data(), inp.get_data(), count);
return res;
}
}

template <typename Float, typename Result, typename Input, typename Parameter>
void alloc_result(Result& result, const Input* input, const Parameter* params, int method) {
const auto status = result.template allocate<Float>(input, params, method);
interop::status_to_exception(status);
}

template <typename Float, typename Result, typename Input, typename Parameter>
void initialize_result(Result& result, const Input* input, const Parameter* params, int method) {
const auto status = result.template initialize<Float>(input, params, method);
interop::status_to_exception(status);
}

template <typename Float>
result_t call_daal_kernel_with_weights(const context_cpu& ctx,
const descriptor_t& desc,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*******************************************************************************
* Copyright 2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#pragma once

#include "oneapi/dal/algo/basic_statistics/compute_types.hpp"
#include "oneapi/dal/backend/dispatcher.hpp"

namespace oneapi::dal::basic_statistics::backend {

template <typename Float, typename Method, typename Task>
struct finalize_compute_kernel_cpu {
compute_result<Task> operator()(const dal::backend::context_cpu& ctx,
const detail::descriptor_base<Task>& params,
const partial_compute_result<Task>& input) const;
};

} // namespace oneapi::dal::basic_statistics::backend
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*******************************************************************************
* Copyright 2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include <algorithm>

#include "oneapi/dal/algo/basic_statistics/backend/cpu/finalize_compute_kernel.hpp"
#include "oneapi/dal/algo/basic_statistics/backend/basic_statistics_interop.hpp"

#include "oneapi/dal/backend/interop/common.hpp"
#include "oneapi/dal/backend/interop/error_converter.hpp"
#include "oneapi/dal/backend/interop/table_conversion.hpp"

#include "oneapi/dal/table/row_accessor.hpp"

#include <daal/src/algorithms/low_order_moments/moments_online.h>
#include <daal/src/algorithms/low_order_moments/low_order_moments_kernel.h>

namespace oneapi::dal::basic_statistics::backend {

using dal::backend::context_cpu;
using method_t = method::dense;
using task_t = task::compute;
using input_t = compute_input<task_t>;
using result_t = compute_result<task_t>;
using descriptor_t = detail::descriptor_base<task_t>;

namespace daal_lom = daal::algorithms::low_order_moments;
namespace interop = dal::backend::interop;
namespace bk = dal::backend;

template <typename Float, daal::CpuType Cpu>
using daal_lom_online_kernel_t =
daal_lom::internal::LowOrderMomentsOnlineKernel<Float, daal_lom::defaultDense, Cpu>;

template <typename Float, typename Task>
static compute_result<Task> call_daal_kernel_finalize_compute(
const context_cpu& ctx,
const descriptor_t& desc,
const partial_compute_result<Task>& input) {
const auto result_ids = daal_lom::estimatesAll;
const auto daal_parameter = daal_lom::Parameter(result_ids);

auto column_count = input.get_partial_min().get_column_count();

auto daal_partial_obs = interop::copy_to_daal_homogen_table<Float>(input.get_partial_n_rows());
auto daal_partial_min = interop::copy_to_daal_homogen_table<Float>(input.get_partial_min());
auto daal_partial_max = interop::copy_to_daal_homogen_table<Float>(input.get_partial_max());
auto daal_partial_sums = interop::copy_to_daal_homogen_table<Float>(input.get_partial_sum());
auto daal_partial_sum_squares =
interop::copy_to_daal_homogen_table<Float>(input.get_partial_sum_squares());
auto daal_partial_sum_squares_centered =
interop::copy_to_daal_homogen_table<Float>(input.get_partial_sum_squares_centered());

auto daal_means = interop::allocate_daal_homogen_table<Float>(1, column_count);
auto daal_rawt = interop::allocate_daal_homogen_table<Float>(1, column_count);

auto daal_variance = interop::allocate_daal_homogen_table<Float>(1, column_count);
auto daal_stdev = interop::allocate_daal_homogen_table<Float>(1, column_count);
auto daal_variation = interop::allocate_daal_homogen_table<Float>(1, column_count);
{
interop::status_to_exception(
interop::call_daal_kernel_finalize_compute<Float, daal_lom_online_kernel_t>(
ctx,
daal_partial_obs.get(),
daal_partial_sums.get(),
daal_partial_sum_squares.get(),
daal_partial_sum_squares_centered.get(),
daal_means.get(),
daal_rawt.get(),
daal_variance.get(),
daal_stdev.get(),
daal_variation.get(),
&daal_parameter));
}

compute_result<Task> res;
const auto res_op = desc.get_result_options();
res.set_result_options(desc.get_result_options());

if (res_op.test(result_options::min)) {
res.set_min(interop::convert_from_daal_homogen_table<Float>(daal_partial_min));
}
if (res_op.test(result_options::max)) {
res.set_max(interop::convert_from_daal_homogen_table<Float>(daal_partial_max));
}
if (res_op.test(result_options::sum)) {
res.set_sum(interop::convert_from_daal_homogen_table<Float>(daal_partial_sums));
}
if (res_op.test(result_options::sum_squares)) {
res.set_sum_squares(
interop::convert_from_daal_homogen_table<Float>(daal_partial_sum_squares));
}
if (res_op.test(result_options::sum_squares_centered)) {
res.set_sum_squares_centered(
interop::convert_from_daal_homogen_table<Float>(daal_partial_sum_squares_centered));
}
if (res_op.test(result_options::mean)) {
res.set_mean(interop::convert_from_daal_homogen_table<Float>(daal_means));
}
if (res_op.test(result_options::second_order_raw_moment)) {
res.set_second_order_raw_moment(interop::convert_from_daal_homogen_table<Float>(daal_rawt));
}
if (res_op.test(result_options::variance)) {
res.set_variance(interop::convert_from_daal_homogen_table<Float>(daal_variance));
}
if (res_op.test(result_options::standard_deviation)) {
res.set_standard_deviation(interop::convert_from_daal_homogen_table<Float>(daal_stdev));
}
if (res_op.test(result_options::variation)) {
res.set_variation(interop::convert_from_daal_homogen_table<Float>(daal_variation));
}

return res;
}

template <typename Float, typename Task>
static compute_result<Task> finalize_compute(const context_cpu& ctx,
const descriptor_t& desc,
const partial_compute_result<Task>& input) {
return call_daal_kernel_finalize_compute<Float, Task>(ctx, desc, input);
}

template <typename Float>
struct finalize_compute_kernel_cpu<Float, method_t, task_t> {
compute_result<task::compute> operator()(
const context_cpu& ctx,
const descriptor_t& desc,
const partial_compute_result<task::compute>& input) const {
return finalize_compute<Float, task::compute>(ctx, desc, input);
}
};

template struct finalize_compute_kernel_cpu<float, method_t, task_t>;
template struct finalize_compute_kernel_cpu<double, method_t, task_t>;

} // namespace oneapi::dal::basic_statistics::backend
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*******************************************************************************
* Copyright 2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#pragma once

#include "oneapi/dal/algo/basic_statistics/compute_types.hpp"
#include "oneapi/dal/backend/dispatcher.hpp"

namespace oneapi::dal::basic_statistics::backend {

template <typename Float, typename Method, typename Task>
struct partial_compute_kernel_cpu {
partial_compute_result<Task> operator()(const dal::backend::context_cpu& ctx,
const detail::descriptor_base<Task>& params,
const partial_compute_input<Task>& input) const;
};

} // namespace oneapi::dal::basic_statistics::backend
Loading

0 comments on commit b5e996a

Please sign in to comment.