From 9ec31b81f7469dc1a99597f3a793caec29d63ee4 Mon Sep 17 00:00:00 2001 From: Alexandr-Solovev Date: Mon, 30 Sep 2024 09:34:00 -0700 Subject: [PATCH] fixes for knn --- ..._classification_predict_dense_default_batch_impl.i | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/cpp/daal/src/algorithms/k_nearest_neighbors/kdtree_knn_classification_predict_dense_default_batch_impl.i b/cpp/daal/src/algorithms/k_nearest_neighbors/kdtree_knn_classification_predict_dense_default_batch_impl.i index 297f6a0179e..891d52e6b17 100644 --- a/cpp/daal/src/algorithms/k_nearest_neighbors/kdtree_knn_classification_predict_dense_default_batch_impl.i +++ b/cpp/daal/src/algorithms/k_nearest_neighbors/kdtree_knn_classification_predict_dense_default_batch_impl.i @@ -128,11 +128,13 @@ Status KNNClassificationPredictKernel::compu typedef kdtree_knn_classification::internal::Stack, cpu> SearchStack; typedef daal::services::internal::MaxVal MaxVal; typedef daal::internal::MathInst Math; + size_t k; size_t nClasses; VoteWeights voteWeights = voteUniform; DAAL_UINT64 resultsToEvaluate = classifier::computeClassLabels; - const auto par3 = dynamic_cast(par); + + const auto par3 = dynamic_cast(par); if (par3) { k = par3->k; @@ -140,6 +142,7 @@ Status KNNClassificationPredictKernel::compu resultsToEvaluate = par3->resultsToEvaluate; nClasses = par3->nClasses; } + if (par3 == NULL) return Status(ErrorNullParameterNotSupported); const Model * const model = static_cast(m); @@ -151,6 +154,7 @@ Status KNNClassificationPredictKernel::compu { labels = model->impl()->getLabels().get(); } + const NumericTable * const modelIndices = model->impl()->getIndices().get(); size_t iSize = 1; @@ -263,7 +267,7 @@ Status KNNClassificationPredictKernel::compu }); status = safeStat.detach(); DAAL_CHECK_SAFE_STATUS() - localTLS.reduce([=](Local * ptr) -> void { + localTLS.reduce([&](Local * ptr) -> void { if (ptr) { ptr->stack.clear(); @@ -271,7 +275,6 @@ Status KNNClassificationPredictKernel::compu service_scalable_free(ptr); } }); - return status; } @@ -460,7 +463,9 @@ services::Status KNNClassificationPredictKernel