Skip to content

Commit

Permalink
fixes for knn
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexandr-Solovev committed Sep 30, 2024
1 parent 8acbfac commit 9ec31b8
Showing 1 changed file with 8 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,18 +128,21 @@ Status KNNClassificationPredictKernel<algorithmFpType, defaultDense, cpu>::compu
typedef kdtree_knn_classification::internal::Stack<SearchNode<algorithmFpType>, cpu> SearchStack;
typedef daal::services::internal::MaxVal<algorithmFpType> MaxVal;
typedef daal::internal::MathInst<algorithmFpType, cpu> Math;

size_t k;
size_t nClasses;
VoteWeights voteWeights = voteUniform;
DAAL_UINT64 resultsToEvaluate = classifier::computeClassLabels;
const auto par3 = dynamic_cast<const kdtree_knn_classification::interface3::Parameter *>(par);

const auto par3 = dynamic_cast<const kdtree_knn_classification::interface3::Parameter *>(par);
if (par3)
{
k = par3->k;
voteWeights = par3->voteWeights;
resultsToEvaluate = par3->resultsToEvaluate;
nClasses = par3->nClasses;
}

if (par3 == NULL) return Status(ErrorNullParameterNotSupported);

const Model * const model = static_cast<const Model *>(m);
Expand All @@ -151,6 +154,7 @@ Status KNNClassificationPredictKernel<algorithmFpType, defaultDense, cpu>::compu
{
labels = model->impl()->getLabels().get();
}

const NumericTable * const modelIndices = model->impl()->getIndices().get();

size_t iSize = 1;
Expand Down Expand Up @@ -263,15 +267,14 @@ Status KNNClassificationPredictKernel<algorithmFpType, defaultDense, cpu>::compu
});
status = safeStat.detach();
DAAL_CHECK_SAFE_STATUS()
localTLS.reduce([=](Local * ptr) -> void {
localTLS.reduce([&](Local * ptr) -> void {
if (ptr)
{
ptr->stack.clear();
ptr->heap.clear();
service_scalable_free<Local, cpu>(ptr);
}
});

return status;
}

Expand Down Expand Up @@ -460,7 +463,9 @@ services::Status KNNClassificationPredictKernel<algorithmFpType, defaultDense, c
{
distancesPtr[i] = heap[i].distance;
}

Math::vSqrt(heapSize, distancesPtr, distancesPtr);

for (size_t i = heapSize; i < nDistances; ++i)
{
distancesPtr[i] = -1;
Expand Down

0 comments on commit 9ec31b8

Please sign in to comment.