From 8123cfc06007d7a04f5fa0af8f384430adadf9ee Mon Sep 17 00:00:00 2001 From: Junqiu Lei Date: Fri, 26 Jul 2024 15:12:11 -0700 Subject: [PATCH] refactor jni create template index Signed-off-by: Junqiu Lei --- jni/include/faiss_index_service.h | 58 +++++++- jni/include/faiss_methods.h | 2 + jni/include/faiss_wrapper.h | 8 +- jni/src/faiss_index_service.cpp | 65 ++++++++- jni/src/faiss_methods.cpp | 7 +- jni/src/faiss_wrapper.cpp | 108 ++------------- .../org_opensearch_knn_jni_FaissService.cpp | 8 +- jni/tests/faiss_index_service_test.cpp | 130 +++++++++++++++++- jni/tests/faiss_wrapper_test.cpp | 74 +++++++++- jni/tests/mocks/faiss_index_service_mock.h | 16 +++ jni/tests/mocks/faiss_methods_mock.h | 8 +- 11 files changed, 366 insertions(+), 118 deletions(-) diff --git a/jni/include/faiss_index_service.h b/jni/include/faiss_index_service.h index 59f15fda9c..76e8cfdb97 100644 --- a/jni/include/faiss_index_service.h +++ b/jni/include/faiss_index_service.h @@ -19,11 +19,13 @@ #include "jni_util.h" #include "faiss_methods.h" #include +#include +#include +#include namespace knn_jni { namespace faiss_wrapper { - /** * A class to provide operations on index * This class should evolve to have only cpp object but not jni object @@ -61,13 +63,38 @@ class IndexService { std::vector ids, std::string indexPath, std::unordered_map parameters); + + /** + * Create index from template + * + * @param jniUtil jni util + * @param env jni environment + * @param dim dimension of vectors + * @param numIds number of vectors + * @param vectorsAddress memory address which is holding vector data + * @param ids a list of document ids for corresponding vectors + * @param indexPath path to write index + * @param parameters parameters to be applied to faiss index + * @param templateIndexData vector containing the template index data + */ + virtual void createIndexFromTemplate( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + int dim, + int numIds, + int64_t vectorsAddress, + std::vector ids, + std::string indexPath, + std::unordered_map parameters, + std::vector templateIndexData); + virtual ~IndexService() = default; protected: std::unique_ptr faissMethods; }; /** - * A class to provide operations on index + * A class to provide operations on binary index * This class should evolve to have only cpp object but not jni object */ class BinaryIndexService : public IndexService { @@ -75,6 +102,7 @@ class BinaryIndexService : public IndexService { //TODO Remove dependency on JNIUtilInterface and JNIEnv //TODO Reduce the number of parameters BinaryIndexService(std::unique_ptr faissMethods); + /** * Create binary index * @@ -103,11 +131,35 @@ class BinaryIndexService : public IndexService { std::string indexPath, std::unordered_map parameters ) override; + + /** + * Create binary index from template + * + * @param jniUtil jni util + * @param env jni environment + * @param dim dimension of vectors + * @param numIds number of vectors + * @param vectorsAddress memory address which is holding vector data + * @param ids a list of document ids for corresponding vectors + * @param indexPath path to write index + * @param parameters parameters to be applied to faiss index + * @param templateIndexData vector containing the template index data + */ + virtual void createIndexFromTemplate( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + int dim, + int numIds, + int64_t vectorsAddress, + std::vector ids, + std::string indexPath, + std::unordered_map parameters, + std::vector templateIndexData); + virtual ~BinaryIndexService() = default; }; } } - #endif //OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H diff --git a/jni/include/faiss_methods.h b/jni/include/faiss_methods.h index 38d8d756a7..c7c8959406 100644 --- a/jni/include/faiss_methods.h +++ b/jni/include/faiss_methods.h @@ -32,6 +32,8 @@ class FaissMethods { virtual faiss::IndexIDMapTemplate* indexBinaryIdMap(faiss::IndexBinary* index); virtual void writeIndex(const faiss::Index* idx, const char* fname); virtual void writeIndexBinary(const faiss::IndexBinary* idx, const char* fname); + virtual faiss::Index* readIndex(faiss::IOReader* f, int io_flags); + virtual faiss::IndexBinary* readIndexBinary(faiss::IOReader* f, int io_flags); virtual ~FaissMethods() = default; }; diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 5ad0dedc4d..44882210d0 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -27,13 +27,7 @@ namespace knn_jni { // based off of the template index passed in. The index is serialized to indexPathJ. void CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, - jobject parametersJ); - - // Create an index with ids and vectors. Instead of creating a new index, this function creates the index - // based off of the template index passed in. The index is serialized to indexPathJ. - void CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, - jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, - jobject parametersJ); + jobject parametersJ, IndexService* indexService); // Load an index from indexPathJ into memory. // diff --git a/jni/src/faiss_index_service.cpp b/jni/src/faiss_index_service.cpp index 8c5ba36af2..9432db6033 100644 --- a/jni/src/faiss_index_service.cpp +++ b/jni/src/faiss_index_service.cpp @@ -23,6 +23,7 @@ #include #include #include +#include namespace knn_jni { namespace faiss_wrapper { @@ -106,6 +107,37 @@ void IndexService::createIndex( faissMethods->writeIndex(idMap.get(), indexPath.c_str()); } +void IndexService::createIndexFromTemplate( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + int dim, + int numIds, + int64_t vectorsAddress, + std::vector ids, + std::string indexPath, + std::unordered_map parameters, + std::vector templateIndexData + ) { + faiss::VectorIOReader vectorIoReader; + vectorIoReader.data = templateIndexData; + + std::unique_ptr indexWriter(faissMethods->readIndex(&vectorIoReader, 0)); + + auto *inputVectors = reinterpret_cast*>(vectorsAddress); + int numVectors = (int) (inputVectors->size() / (uint64_t) dim); + if (numIds != numVectors) { + throw std::runtime_error("Number of vectors or IDs does not match expected values"); + } + + // Add extra parameters that cant be configured with the index factory + SetExtraParameters(jniUtil, env, parameters, indexWriter.get()); + + std::unique_ptr idMap(faissMethods->indexIdMap(indexWriter.get())); + idMap->add_with_ids(numVectors, inputVectors->data(), ids.data()); + + faissMethods->writeIndex(idMap.get(), indexPath.c_str()); +} + BinaryIndexService::BinaryIndexService(std::unique_ptr faissMethods) : IndexService(std::move(faissMethods)) {} void BinaryIndexService::createIndex( @@ -160,5 +192,36 @@ void BinaryIndexService::createIndex( faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str()); } +void BinaryIndexService::createIndexFromTemplate( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + int dim, + int numIds, + int64_t vectorsAddress, + std::vector ids, + std::string indexPath, + std::unordered_map parameters, + std::vector templateIndexData + ) { + faiss::VectorIOReader vectorIoReader; + vectorIoReader.data = templateIndexData; + + std::unique_ptr indexWriter(faissMethods->readIndexBinary(&vectorIoReader, 0)); + + auto *inputVectors = reinterpret_cast*>(vectorsAddress); + int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8)); + if (numIds != numVectors) { + throw std::runtime_error("Number of vectors or IDs does not match expected values"); + } + + // Add extra parameters that cant be configured with the index factory + SetExtraParameters(jniUtil, env, parameters, indexWriter.get()); + + std::unique_ptr idMap(faissMethods->indexBinaryIdMap(indexWriter.get())); + idMap->add_with_ids(numVectors, inputVectors->data(), ids.data()); + + faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str()); +} + } // namespace faiss_wrapper -} // namesapce knn_jni +} // namespace knn_jni diff --git a/jni/src/faiss_methods.cpp b/jni/src/faiss_methods.cpp index 05c8f459ae..abc70d4605 100644 --- a/jni/src/faiss_methods.cpp +++ b/jni/src/faiss_methods.cpp @@ -35,6 +35,11 @@ void FaissMethods::writeIndex(const faiss::Index* idx, const char* fname) { void FaissMethods::writeIndexBinary(const faiss::IndexBinary* idx, const char* fname) { faiss::write_index_binary(idx, fname); } - +faiss::Index* FaissMethods::readIndex(faiss::IOReader* f, int io_flags) { + return faiss::read_index(f, io_flags); +} +faiss::IndexBinary* FaissMethods::readIndexBinary(faiss::IOReader* f, int io_flags) { + return faiss::read_index_binary(f, io_flags); +} } // namespace faiss_wrapper } // namesapce knn_jni diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 1d44374148..3878e4f6c4 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -33,7 +33,7 @@ #include // Defines type of IDSelector -enum FilterIdsSelectorType{ +enum FilterIdsSelectorType { BITMAP = 0, BATCH = 1, }; namespace faiss { @@ -76,7 +76,7 @@ void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const // Converts the int FilterIds to Faiss ids type array. void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds); -// Concerts the FilterIds to BitMap +// Converts the FilterIds to BitMap void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bitsetVector); std::unique_ptr buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector* bitmap); @@ -161,7 +161,7 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, - jbyteArray templateIndexJ, jobject parametersJ) { + jbyteArray templateIndexJ, jobject parametersJ, IndexService* indexService) { if (idsJ == nullptr) { throw std::runtime_error("IDs cannot be null"); } @@ -192,108 +192,22 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * // Read data set // Read vectors from memory address - auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); int dim = (int)dimJ; - int numVectors = (int) (inputVectors->size() / (uint64_t) dim); int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); - if (numIds != numVectors) { - throw std::runtime_error("Number of IDs does not match number of vectors"); - } // Get vector of bytes from jbytearray int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ); jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr); - - faiss::VectorIOReader vectorIoReader; - for (int i = 0; i < indexBytesCount; i++) { - vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]); - } + std::vector templateIndexData(indexBytesJ, indexBytesJ + indexBytesCount); jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT); - // Create faiss index - std::unique_ptr indexWriter; - indexWriter.reset(faiss::read_index(&vectorIoReader, 0)); - - auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); - faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get()); - idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data()); - // Releasing the vectorsAddressJ memory as that is not required once we have created the index. - // This is not the ideal approach, please refer this gh issue for long term solution: - // https://github.com/opensearch-project/k-NN/issues/1600 - delete inputVectors; - // Write the index to disk - std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); - faiss::write_index(&idMap, indexPathCpp.c_str()); -} - -void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, - jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, - jbyteArray templateIndexJ, jobject parametersJ) { - if (idsJ == nullptr) { - throw std::runtime_error("IDs cannot be null"); - } - - if (vectorsAddressJ <= 0) { - throw std::runtime_error("VectorsAddress cannot be less than 0"); - } - - if(dimJ <= 0) { - throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); - } - - if (indexPathJ == nullptr) { - throw std::runtime_error("Index path cannot be null"); - } - - if (templateIndexJ == nullptr) { - throw std::runtime_error("Template index cannot be null"); - } - - // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread - auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); - if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { - auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); - omp_set_num_threads(threadCount); - } - jniUtil->DeleteLocalRef(env, parametersJ); - - // Read data set - // Read vectors from memory address - auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); - int dim = (int)dimJ; - if (dim % 8 != 0) { - throw std::runtime_error("Dimensions should be multiply of 8"); - } - int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8)); - int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); - if (numIds != numVectors) { - throw std::runtime_error("Number of IDs does not match number of vectors"); - } - - // Get vector of bytes from jbytearray - int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ); - jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr); - - faiss::VectorIOReader vectorIoReader; - for (int i = 0; i < indexBytesCount; i++) { - vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]); - } - jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT); + // Convert ids + auto ids = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); + int64_t vectorsAddress = (int64_t)vectorsAddressJ; + std::string indexPathCpp = jniUtil->ConvertJavaStringToCppString(env, indexPathJ); - // Create faiss index - std::unique_ptr indexWriter; - indexWriter.reset(faiss::read_index_binary(&vectorIoReader, 0)); - - auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); - faiss::IndexBinaryIDMap idMap = faiss::IndexBinaryIDMap(indexWriter.get()); - idMap.add_with_ids(numVectors, reinterpret_cast(inputVectors->data()), idVector.data()); - // Releasing the vectorsAddressJ memory as that is not required once we have created the index. - // This is not the ideal approach, please refer this gh issue for long term solution: - // https://github.com/opensearch-project/k-NN/issues/1600 - delete inputVectors; - // Write the index to disk - std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); - faiss::write_index_binary(&idMap, indexPathCpp.c_str()); + // Create index using IndexService + indexService->createIndexFromTemplate(jniUtil, env, dim, numIds, vectorsAddress, ids, indexPathCpp, parametersCpp, templateIndexData); } jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) { @@ -674,7 +588,7 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti omp_set_num_threads(threadCount); } - // Add extra parameters that cant be configured with the index factory + // Add extra parameters that can't be configured with the index factory if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS]; auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ); diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 2394e2951f..d0df5adac9 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -84,7 +84,9 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT jobject parametersJ) { try { - knn_jni::faiss_wrapper::CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ); + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ, &indexService); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } @@ -99,7 +101,9 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryInde jobject parametersJ) { try { - knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ); + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); + CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ, &binaryIndexService); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } diff --git a/jni/tests/faiss_index_service_test.cpp b/jni/tests/faiss_index_service_test.cpp index f876edced1..cb382ca1dd 100644 --- a/jni/tests/faiss_index_service_test.cpp +++ b/jni/tests/faiss_index_service_test.cpp @@ -15,6 +15,8 @@ #include "mocks/faiss_index_mock.h" #include "test_util.h" #include +#include + #include "gmock/gmock.h" #include "gtest/gtest.h" #include "commons.h" @@ -131,4 +133,130 @@ TEST(CreateBinaryIndexTest, BasicAssertions) { ids, indexPath, parametersMap); -} \ No newline at end of file +} + +std::vector createSampleTemplateIndexData(int dim) { + // Create a sample FAISS index + faiss::Index* index = faiss::index_factory(dim, "Flat", faiss::METRIC_L2); + faiss::VectorIOWriter vectorIoWriter; + faiss::write_index(index, &vectorIoWriter); + delete index; + + // Copy the data from the VectorIOWriter to a vector + std::vector templateIndexData(vectorIoWriter.data.begin(), vectorIoWriter.data.end()); + return templateIndexData; +} + +std::vector createSampleBinaryTemplateIndexData(int dim) { + // Create a sample FAISS binary index + faiss::IndexBinary* index = faiss::index_binary_factory(dim, "BIVF4096,Flat"); + faiss::VectorIOWriter vectorIoWriter; + faiss::write_index_binary(index, &vectorIoWriter); + delete index; + + // Copy the data from the VectorIOWriter to a vector + std::vector templateIndexData(vectorIoWriter.data.begin(), vectorIoWriter.data.end()); + return templateIndexData; +} + +TEST(CreateIndexFromTemplateTest, BasicAssertions) { + // Define the data + faiss::idx_t numIds = 100; + std::vector ids; + std::vector vectors; + int dim = 2; + vectors.reserve(dim * numIds); + for (int64_t i = 0; i < numIds; ++i) { + ids.push_back(i); + for (int j = 0; j < dim; ++j) { + vectors.push_back(test_util::RandomFloat(-500.0, 500.0)); + } + } + + std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); + std::vector templateIndexData = createSampleTemplateIndexData(dim); + std::unordered_map parametersMap; + + // Set up jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + // Setup faiss method mock + MockIndex* mockIndex = new MockIndex(); + EXPECT_CALL(*mockIndex, add(numIds, vectors.data())) + .Times(1); + + faiss::IndexIDMap* indexIdMap = new faiss::IndexIDMap(mockIndex); + std::unique_ptr mockFaissMethods(new MockFaissMethods()); + EXPECT_CALL(*mockFaissMethods, readIndex(_, 0)) + .WillOnce(Return(indexIdMap->index)); + EXPECT_CALL(*mockFaissMethods, indexIdMap(indexIdMap->index)) + .WillOnce(Return(indexIdMap)); + EXPECT_CALL(*mockFaissMethods, writeIndex(indexIdMap, ::testing::StrEq(indexPath.c_str()))) + .Times(1); + + // Create the index + knn_jni::faiss_wrapper::IndexService indexService(std::move(mockFaissMethods)); + indexService.createIndexFromTemplate( + &mockJNIUtil, + jniEnv, + dim, + numIds, + (int64_t) &vectors, + ids, + indexPath, + parametersMap, + templateIndexData); +} + +TEST(CreateBinaryIndexFromTemplateTest, BasicAssertions) { + // Define the data + faiss::idx_t numIds = 200; + std::vector ids; + std::vector vectors; + int dim = 128; + vectors.reserve(numIds); + for (int64_t i = 0; i < numIds; ++i) { + ids.push_back(i); + for (int j = 0; j < dim / 8; ++j) { + vectors.push_back(test_util::RandomInt(0, 255)); + } + } + + std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); + std::vector templateIndexData = createSampleBinaryTemplateIndexData(dim); + std::unordered_map parametersMap; + + // Set up jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + // Setup faiss method mock + // This object is handled by unique_ptr inside indexService.createIndexFromTemplate() + MockIndexBinary* mockIndex = new MockIndexBinary(); + EXPECT_CALL(*mockIndex, add(numIds, vectors.data())) + .Times(1); + + faiss::IndexBinaryIDMap* indexIdMap = new faiss::IndexBinaryIDMap(mockIndex); + std::unique_ptr mockFaissMethods(new MockFaissMethods()); + + EXPECT_CALL(*mockFaissMethods, readIndexBinary(_, 0)) + .WillOnce(Return(mockIndex)); + EXPECT_CALL(*mockFaissMethods, indexBinaryIdMap(mockIndex)) + .WillOnce(Return(indexIdMap)); + EXPECT_CALL(*mockFaissMethods, writeIndexBinary(indexIdMap, ::testing::StrEq(indexPath.c_str()))) + .Times(1); + + // Create the index + knn_jni::faiss_wrapper::BinaryIndexService indexService(std::move(mockFaissMethods)); + indexService.createIndexFromTemplate( + &mockJNIUtil, + jniEnv, + dim, + numIds, + (int64_t) &vectors, + ids, + indexPath, + parametersMap, + templateIndexData); +} diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 5ae4438377..d47e43ac6a 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -100,7 +100,6 @@ TEST(FaissCreateBinaryIndexTest, BasicAssertions) { JNIEnv *jniEnv = nullptr; NiceMock mockJNIUtil; - // Create the index std::unique_ptr faissMethods(new FaissMethods()); NiceMock mockIndexService(std::move(faissMethods)); EXPECT_CALL(mockIndexService, createIndex(_, _, faiss::METRIC_L2, indexDescription, dim, (int)numIds, 0, (int64_t)&vectors, ids, indexPath, subParametersMap)) @@ -129,10 +128,11 @@ TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) { std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); faiss::MetricType metricType = faiss::METRIC_L2; - std::string method = "HNSW32,Flat"; + std::string method = "IVF32,Flat"; std::unique_ptr createdIndex( test_util::FaissCreateIndex(dim, method, metricType)); + createdIndex->train(numIds, vectors->data()); auto vectorIoWriter = test_util::FaissGetSerializedIndex(createdIndex.get()); // Setup jni @@ -148,15 +148,83 @@ TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) { std::unordered_map parametersMap; parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType; + // Create the index + std::unique_ptr faissMethods(new FaissMethods()); + IndexService indexService(std::move(faissMethods)); knn_jni::faiss_wrapper::CreateIndexFromTemplate( &mockJNIUtil, jniEnv, reinterpret_cast(&ids), (jlong)vectors, dim, (jstring)&indexPath, reinterpret_cast(&(vectorIoWriter.data)), - (jobject) ¶metersMap + (jobject) ¶metersMap, + &indexService ); // Make sure index can be loaded std::unique_ptr index(test_util::FaissLoadIndex(indexPath)); + auto indexIDMap = dynamic_cast(index.get()); + + // Assert that the index is of the correct type + ASSERT_NE(indexIDMap, nullptr); + ASSERT_NE(indexIDMap->index, nullptr); + + // Clean up + std::remove(indexPath.c_str()); +} + +TEST(FaissCreateBinaryIndexFromTemplateTest, BasicAssertions) { + // Define the data + faiss::idx_t numIds = 100; + std::vector ids; + std::vector vectors; + int dim = 128; + vectors.reserve(numIds); + for (int64_t i = 0; i < numIds; ++i) { + ids.push_back(i); + for (int j = 0; j < dim / 8; ++j) { + vectors.push_back(test_util::RandomInt(0, 255)); + } + } + + std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); + std::string spaceType = knn_jni::HAMMING; + std::string indexDescription = "BIVF32,Flat"; + + std::unordered_map parametersMap; + parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType; + parametersMap[knn_jni::INDEX_DESCRIPTION] = (jobject) &indexDescription; + std::unordered_map subParametersMap; + parametersMap[knn_jni::PARAMETERS] = (jobject)&subParametersMap; + + // Setup jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + EXPECT_CALL(mockJNIUtil, GetJavaObjectArrayLength(jniEnv, reinterpret_cast(&vectors))) + .WillRepeatedly(Return(vectors.size())); + + std::unique_ptr faissMethods(new FaissMethods()); + knn_jni::faiss_wrapper::BinaryIndexService indexService(std::move(faissMethods)); + + std::unique_ptr createdIndex( + test_util::FaissCreateBinaryIndex(dim, indexDescription)); + createdIndex->train(numIds, vectors.data()); + auto vectorIoWriter = test_util::FaissGetSerializedBinaryIndex(createdIndex.get()); + + knn_jni::faiss_wrapper::CreateIndexFromTemplate( + &mockJNIUtil, jniEnv, reinterpret_cast(&ids), + (jlong)&vectors, dim, (jstring)&indexPath, + reinterpret_cast(&(vectorIoWriter.data)), + (jobject)¶metersMap, + &indexService + ); + + // Make sure index can be loaded + std::unique_ptr index(test_util::FaissLoadBinaryIndex(indexPath)); + auto indexIDMap = dynamic_cast(index.get()); + + // Assert that the index is of the correct type + ASSERT_NE(indexIDMap, nullptr); + ASSERT_NE(indexIDMap->index, nullptr); // Clean up std::remove(indexPath.c_str()); diff --git a/jni/tests/mocks/faiss_index_service_mock.h b/jni/tests/mocks/faiss_index_service_mock.h index 7af08c82ea..e86596b06a 100644 --- a/jni/tests/mocks/faiss_index_service_mock.h +++ b/jni/tests/mocks/faiss_index_service_mock.h @@ -39,6 +39,22 @@ class MockIndexService : public IndexService { StringToJObjectMap parameters ), (override)); + + MOCK_METHOD( + void, + createIndexFromTemplate, + ( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + int dim, + int numIds, + int64_t vectorsAddress, + std::vector ids, + std::string indexPath, + StringToJObjectMap parameters, + std::vector templateIndexData + ), + (override)); }; #endif // OPENSEARCH_KNN_FAISS_INDEX_SERVICE_MOCK_H \ No newline at end of file diff --git a/jni/tests/mocks/faiss_methods_mock.h b/jni/tests/mocks/faiss_methods_mock.h index 64a23b8951..69e1f0bc20 100644 --- a/jni/tests/mocks/faiss_methods_mock.h +++ b/jni/tests/mocks/faiss_methods_mock.h @@ -1,5 +1,5 @@ /* - * SPDX-License-Identifier: Apache-2.0 +* SPDX-License-Identifier: Apache-2.0 * * The OpenSearch Contributors require contributions made to * this file be licensed under the Apache-2.0 license or a @@ -9,8 +9,8 @@ * GitHub history for details. */ - #ifndef OPENSEARCH_KNN_FAISS_METHODS_MOCK_H - #define OPENSEARCH_KNN_FAISS_METHODS_MOCK_H +#ifndef OPENSEARCH_KNN_FAISS_METHODS_MOCK_H +#define OPENSEARCH_KNN_FAISS_METHODS_MOCK_H #include "faiss_methods.h" #include @@ -23,6 +23,8 @@ class MockFaissMethods : public knn_jni::faiss_wrapper::FaissMethods { MOCK_METHOD(faiss::IndexIDMapTemplate*, indexBinaryIdMap, (faiss::IndexBinary* index), (override)); MOCK_METHOD(void, writeIndex, (const faiss::Index* idx, const char* fname), (override)); MOCK_METHOD(void, writeIndexBinary, (const faiss::IndexBinary* idx, const char* fname), (override)); + MOCK_METHOD(faiss::Index*, readIndex, (faiss::IOReader* f, int io_flags), (override)); + MOCK_METHOD(faiss::IndexBinary*, readIndexBinary, (faiss::IOReader* f, int io_flags), (override)); }; #endif // OPENSEARCH_KNN_FAISS_METHODS_MOCK_H \ No newline at end of file