From 07a40e0146e920a8c6e8b660189d7dc35cdb1762 Mon Sep 17 00:00:00 2001 From: Junqiu Lei Date: Wed, 31 Jul 2024 19:57:19 -0700 Subject: [PATCH] refactor jni train index Signed-off-by: Junqiu Lei --- jni/include/faiss_index_service.h | 19 ++++-- jni/include/faiss_wrapper.h | 3 + jni/src/faiss_index_service.cpp | 61 +++++++++++++++++++ jni/src/faiss_methods.cpp | 3 + jni/src/faiss_wrapper.cpp | 40 ++++++++++++ .../org_opensearch_knn_jni_FaissService.cpp | 10 ++- 6 files changed, 128 insertions(+), 8 deletions(-) diff --git a/jni/include/faiss_index_service.h b/jni/include/faiss_index_service.h index 76e8cfdb97..5d4ec66506 100644 --- a/jni/include/faiss_index_service.h +++ b/jni/include/faiss_index_service.h @@ -88,6 +88,10 @@ class IndexService { std::unordered_map parameters, std::vector templateIndexData); + virtual std::vector trainIndex(JNIUtilInterface* jniUtil, JNIEnv* env, faiss::MetricType metric, std::string& indexDescription, int dimension, int numVectors, float* trainingVectors, std::unordered_map& parameters); + + virtual void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x); + virtual ~IndexService() = default; protected: std::unique_ptr faissMethods; @@ -101,7 +105,7 @@ class BinaryIndexService : public IndexService { public: //TODO Remove dependency on JNIUtilInterface and JNIEnv //TODO Reduce the number of parameters - BinaryIndexService(std::unique_ptr faissMethods); + explicit BinaryIndexService(std::unique_ptr faissMethods); /** * Create binary index @@ -118,7 +122,7 @@ class BinaryIndexService : public IndexService { * @param indexPath path to write index * @param parameters parameters to be applied to faiss index */ - virtual void createIndex( + void createIndex( knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, faiss::MetricType metric, @@ -145,7 +149,7 @@ class BinaryIndexService : public IndexService { * @param parameters parameters to be applied to faiss index * @param templateIndexData vector containing the template index data */ - virtual void createIndexFromTemplate( + void createIndexFromTemplate( knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, int dim, @@ -154,9 +158,14 @@ class BinaryIndexService : public IndexService { std::vector ids, std::string indexPath, std::unordered_map parameters, - std::vector templateIndexData); + std::vector templateIndexData) override; + + void InternalTrainIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x); + + std::vector trainIndex(JNIUtilInterface* jniUtil, JNIEnv* env, faiss::MetricType metric, std::string& indexDescription, int dimension, int numVectors, float* trainingVectors, std::unordered_map& parameters) override; + - virtual ~BinaryIndexService() = default; + ~BinaryIndexService() override = default; }; } diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 44882210d0..34dc62b1c4 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -103,6 +103,9 @@ namespace knn_jni { jbyteArray TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension, jlong trainVectorsPointerJ); + jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension, + jlong trainVectorsPointerJ, IndexService* indexService); + /* * Perform a range search with filter against the index located in memory at indexPointerJ. * diff --git a/jni/src/faiss_index_service.cpp b/jni/src/faiss_index_service.cpp index 9432db6033..c683a56354 100644 --- a/jni/src/faiss_index_service.cpp +++ b/jni/src/faiss_index_service.cpp @@ -138,6 +138,37 @@ void IndexService::createIndexFromTemplate( faissMethods->writeIndex(idMap.get(), indexPath.c_str()); } +void IndexService::InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) { + if (auto * indexIvf = dynamic_cast(index)) { + if (indexIvf->quantizer_trains_alone == 2) { + InternalTrainIndex(indexIvf->quantizer, n, x); + } + indexIvf->make_direct_map(); + } + + if (!index->is_trained) { + index->train(n, x); + } +} + +std::vector IndexService::trainIndex(JNIUtilInterface* jniUtil, JNIEnv* env, faiss::MetricType metric, std::string& indexDescription, int dimension, int numVectors, float* trainingVectors, std::unordered_map& parameters) { + // Create faiss index + std::unique_ptr index(faissMethods->indexFactory(dimension, indexDescription.c_str(), metric)); + + // Train index if needed + if (!index->is_trained) { + InternalTrainIndex(index.get(), numVectors, trainingVectors); + } + + // Write index to a vector + faiss::VectorIOWriter vectorIoWriter; + faiss::write_index(index.get(), &vectorIoWriter); + + return std::vector(vectorIoWriter.data.begin(), vectorIoWriter.data.end()); +} + + + BinaryIndexService::BinaryIndexService(std::unique_ptr faissMethods) : IndexService(std::move(faissMethods)) {} void BinaryIndexService::createIndex( @@ -223,5 +254,35 @@ void BinaryIndexService::createIndexFromTemplate( faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str()); } +void BinaryIndexService::InternalTrainIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x) { + if (auto * indexIvf = dynamic_cast(index)) { + if (!indexIvf->is_trained) { + indexIvf->train(n, reinterpret_cast(x)); + } + } + if (!index->is_trained) { + index->train(n, reinterpret_cast(x)); + } +} + +std::vector BinaryIndexService::trainIndex(JNIUtilInterface* jniUtil, JNIEnv* env, faiss::MetricType metric, std::string& indexDescription, int dimension, int numVectors, float* trainingVectors, std::unordered_map& parameters) { + // Convert Java parameters to C++ parameters + std::unique_ptr indexWriter; + indexWriter.reset(faiss::index_binary_factory(dimension, indexDescription.c_str())); + + // Train the index if it is not already trained + if (!indexWriter->is_trained) { + InternalTrainIndex(indexWriter.get(), numVectors, trainingVectors); + } + + // Serialize the trained index to a byte array + faiss::VectorIOWriter vectorIoWriter; + faiss::write_index_binary(indexWriter.get(), &vectorIoWriter); + + // Convert the serialized data to a std::vector + std::vector trainedIndexData(vectorIoWriter.data.begin(), vectorIoWriter.data.end()); + + return trainedIndexData; +} } // namespace faiss_wrapper } // namespace knn_jni diff --git a/jni/src/faiss_methods.cpp b/jni/src/faiss_methods.cpp index abc70d4605..0c0924e74a 100644 --- a/jni/src/faiss_methods.cpp +++ b/jni/src/faiss_methods.cpp @@ -32,12 +32,15 @@ faiss::IndexIDMapTemplate* FaissMethods::indexBinaryIdMap(fa void FaissMethods::writeIndex(const faiss::Index* idx, const char* fname) { faiss::write_index(idx, fname); } + void FaissMethods::writeIndexBinary(const faiss::IndexBinary* idx, const char* fname) { faiss::write_index_binary(idx, fname); } + faiss::Index* FaissMethods::readIndex(faiss::IOReader* f, int io_flags) { return faiss::read_index(f, io_flags); } + faiss::IndexBinary* FaissMethods::readIndexBinary(faiss::IOReader* f, int io_flags) { return faiss::read_index_binary(f, io_flags); } diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 3878e4f6c4..54024678f1 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -555,6 +555,46 @@ void knn_jni::faiss_wrapper::InitLibrary() { // omp_set_num_threads(1); } +jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, + jint dimensionJ, jlong trainVectorsPointerJ, IndexService* indexService) { + // First, we need to build the index + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } + + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + + jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); + std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); + faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp); + + // Create faiss index + jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); + std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + omp_set_num_threads(threadCount); + } + + // Add extra parameters that can't be configured with the index factory + std::unordered_map subParametersCpp; + if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { + subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersCpp[knn_jni::PARAMETERS]); + } + + // Train index using IndexService + auto *trainingVectorsPointerCpp = reinterpret_cast*>(trainVectorsPointerJ); + int numVectors = trainingVectorsPointerCpp->size() / (int) dimensionJ; + std::vector trainedIndexData = indexService->trainIndex(jniUtil, env, metric, indexDescriptionCpp, dimensionJ, numVectors, trainingVectorsPointerCpp->data(), subParametersCpp); + + // Now that indexWriter is trained, we just load the bytes into an array and return + jbyteArray ret = jniUtil->NewByteArray(env, trainedIndexData.size()); + jniUtil->SetByteArrayRegion(env, ret, 0, trainedIndexData.size(), reinterpret_cast(trainedIndexData.data())); + return ret; +} + jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimensionJ, jlong trainVectorsPointerJ) { // First, we need to build the index diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index d0df5adac9..36b91b74e8 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -86,7 +86,7 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT try { std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); - CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ, &indexService); + knn_jni::faiss_wrapper::CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ, &indexService); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } @@ -232,7 +232,9 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex jlong trainVectorsPointerJ) { try { - return knn_jni::faiss_wrapper::TrainIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ); + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + return knn_jni::faiss_wrapper::TrainIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ, &indexService); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } @@ -245,7 +247,9 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinar jlong trainVectorsPointerJ) { try { - return knn_jni::faiss_wrapper::TrainBinaryIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ); + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::BinaryIndexService indexService(std::move(faissMethods)); + return knn_jni::faiss_wrapper::TrainIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ, &indexService); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); }