diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 3d396610e..e8fb4de20 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -41,9 +41,6 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, // Train an index with data provided void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x); -// Create the SearchParams based on the Index Type -std::unique_ptr buildSearchParams(const faiss::IndexIDMap *indexReader, faiss::IDSelector* idSelector); - // Helps to choose the right FilterIdsSelectorType for Faiss FilterIdsSelectorType getIdSelectorType(const int* filterIds, int filterIdsLength); @@ -249,9 +246,26 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter buildFilterIdsBitMap(filteredIdsArray, filterIdsLength, bitmap.data()); idSelector.reset(new faiss::IDSelectorBitmap(filterIdsLength, bitmap.data())); } - std::unique_ptr searchParameters = buildSearchParams(indexReader, idSelector.get()); + faiss::SearchParameters *searchParameters; + faiss::SearchParametersHNSW hnswParams; + faiss::SearchParametersIVF ivfParams; + auto hnswReader = dynamic_cast(indexReader->index); + if(hnswReader) { + // Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default + // value of ef_search = 16 which will then be used. + hnswParams.efSearch = hnswReader->hnsw.efSearch; + hnswParams.sel = idSelector.get(); + searchParameters = &hnswParams; + } else { + auto ivfReader = dynamic_cast(indexReader->index); + auto ivfFlatReader = dynamic_cast(indexReader->index); + if(ivfReader || ivfFlatReader) { + ivfParams.sel = idSelector.get(); + searchParameters = &ivfParams; + } + } try { - indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data(), searchParameters.get()); + indexReader->search(1, rawQueryvector, kJ, dis.data(), ids.data(), searchParameters); } catch (...) { jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); jniUtil->ReleaseIntArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); @@ -475,33 +489,3 @@ void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bi bitsetVector[bitsetArrayIndex] = bitsetVector[bitsetArrayIndex] | (1 << (value & 7)); } } - -/** - * Based on the type of the index reader we need to return the SearchParameters. The way we do this by dynamically - * casting the IndexReader. - * @param indexReader - * @param idSelector - * @return SearchParameters - */ -std::unique_ptr buildSearchParams(const faiss::IndexIDMap *indexReader, faiss::IDSelector* idSelector) { - auto hnswReader = dynamic_cast(indexReader->index); - if(hnswReader) { - // we need to make this variable unique_ptr so that the scope can be shared with caller function. - std::unique_ptr hnswParams(new faiss::SearchParametersHNSW); - // Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default - // value of ef_search = 16 which will then be used. - hnswParams->efSearch = hnswReader->hnsw.efSearch; - hnswParams->sel = idSelector; - return hnswParams; - } - - auto ivfReader = dynamic_cast(indexReader->index); - auto ivfFlatReader = dynamic_cast(indexReader->index); - if(ivfReader || ivfFlatReader) { - // we need to make this variable unique_ptr so that the scope can be shared with caller function. - std::unique_ptr ivfParams(new faiss::SearchParametersIVF); - ivfParams->sel = idSelector; - return ivfParams; - } - throw std::runtime_error("Invalid Index Type supported for Filtered Search on Faiss"); -}