Skip to content

Commit

Permalink
refactor jni create template index
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Aug 1, 2024
1 parent ec6451c commit 8123cfc
Show file tree
Hide file tree
Showing 11 changed files with 366 additions and 118 deletions.
58 changes: 55 additions & 3 deletions jni/include/faiss_index_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
#include "jni_util.h"
#include "faiss_methods.h"
#include <memory>
#include <vector>
#include <unordered_map>
#include <string>

namespace knn_jni {
namespace faiss_wrapper {


/**
* A class to provide operations on index
* This class should evolve to have only cpp object but not jni object
Expand Down Expand Up @@ -61,20 +63,46 @@ class IndexService {
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters);

/**
* Create index from template
*
* @param jniUtil jni util
* @param env jni environment
* @param dim dimension of vectors
* @param numIds number of vectors
* @param vectorsAddress memory address which is holding vector data
* @param ids a list of document ids for corresponding vectors
* @param indexPath path to write index
* @param parameters parameters to be applied to faiss index
* @param templateIndexData vector containing the template index data
*/
virtual void createIndexFromTemplate(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
int dim,
int numIds,
int64_t vectorsAddress,
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters,
std::vector<uint8_t> templateIndexData);

virtual ~IndexService() = default;
protected:
std::unique_ptr<FaissMethods> faissMethods;
};

/**
* A class to provide operations on index
* A class to provide operations on binary index
* This class should evolve to have only cpp object but not jni object
*/
class BinaryIndexService : public IndexService {
public:
//TODO Remove dependency on JNIUtilInterface and JNIEnv
//TODO Reduce the number of parameters
BinaryIndexService(std::unique_ptr<FaissMethods> faissMethods);

/**
* Create binary index
*
Expand Down Expand Up @@ -103,11 +131,35 @@ class BinaryIndexService : public IndexService {
std::string indexPath,
std::unordered_map<std::string, jobject> parameters
) override;

/**
* Create binary index from template
*
* @param jniUtil jni util
* @param env jni environment
* @param dim dimension of vectors
* @param numIds number of vectors
* @param vectorsAddress memory address which is holding vector data
* @param ids a list of document ids for corresponding vectors
* @param indexPath path to write index
* @param parameters parameters to be applied to faiss index
* @param templateIndexData vector containing the template index data
*/
virtual void createIndexFromTemplate(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
int dim,
int numIds,
int64_t vectorsAddress,
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters,
std::vector<uint8_t> templateIndexData);

virtual ~BinaryIndexService() = default;
};

}
}


#endif //OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H
2 changes: 2 additions & 0 deletions jni/include/faiss_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class FaissMethods {
virtual faiss::IndexIDMapTemplate<faiss::IndexBinary>* indexBinaryIdMap(faiss::IndexBinary* index);
virtual void writeIndex(const faiss::Index* idx, const char* fname);
virtual void writeIndexBinary(const faiss::IndexBinary* idx, const char* fname);
virtual faiss::Index* readIndex(faiss::IOReader* f, int io_flags);
virtual faiss::IndexBinary* readIndexBinary(faiss::IOReader* f, int io_flags);
virtual ~FaissMethods() = default;
};

Expand Down
8 changes: 1 addition & 7 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,7 @@ namespace knn_jni {
// based off of the template index passed in. The index is serialized to indexPathJ.
void CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
jobject parametersJ);

// Create an index with ids and vectors. Instead of creating a new index, this function creates the index
// based off of the template index passed in. The index is serialized to indexPathJ.
void CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ,
jobject parametersJ);
jobject parametersJ, IndexService* indexService);

// Load an index from indexPathJ into memory.
//
Expand Down
65 changes: 64 additions & 1 deletion jni/src/faiss_index_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <vector>
#include <memory>
#include <type_traits>
#include <faiss/impl/io.h>

namespace knn_jni {
namespace faiss_wrapper {
Expand Down Expand Up @@ -106,6 +107,37 @@ void IndexService::createIndex(
faissMethods->writeIndex(idMap.get(), indexPath.c_str());
}

void IndexService::createIndexFromTemplate(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
int dim,
int numIds,
int64_t vectorsAddress,
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters,
std::vector<uint8_t> templateIndexData
) {
faiss::VectorIOReader vectorIoReader;
vectorIoReader.data = templateIndexData;

std::unique_ptr<faiss::Index> indexWriter(faissMethods->readIndex(&vectorIoReader, 0));

auto *inputVectors = reinterpret_cast<std::vector<float>*>(vectorsAddress);
int numVectors = (int) (inputVectors->size() / (uint64_t) dim);
if (numIds != numVectors) {
throw std::runtime_error("Number of vectors or IDs does not match expected values");
}

// Add extra parameters that cant be configured with the index factory
SetExtraParameters<faiss::Index, faiss::IndexIVF, faiss::IndexHNSW>(jniUtil, env, parameters, indexWriter.get());

std::unique_ptr<faiss::IndexIDMap> idMap(faissMethods->indexIdMap(indexWriter.get()));
idMap->add_with_ids(numVectors, inputVectors->data(), ids.data());

faissMethods->writeIndex(idMap.get(), indexPath.c_str());
}

BinaryIndexService::BinaryIndexService(std::unique_ptr<FaissMethods> faissMethods) : IndexService(std::move(faissMethods)) {}

void BinaryIndexService::createIndex(
Expand Down Expand Up @@ -160,5 +192,36 @@ void BinaryIndexService::createIndex(
faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str());
}

void BinaryIndexService::createIndexFromTemplate(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
int dim,
int numIds,
int64_t vectorsAddress,
std::vector<int64_t> ids,
std::string indexPath,
std::unordered_map<std::string, jobject> parameters,
std::vector<uint8_t> templateIndexData
) {
faiss::VectorIOReader vectorIoReader;
vectorIoReader.data = templateIndexData;

std::unique_ptr<faiss::IndexBinary> indexWriter(faissMethods->readIndexBinary(&vectorIoReader, 0));

auto *inputVectors = reinterpret_cast<std::vector<uint8_t>*>(vectorsAddress);
int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8));
if (numIds != numVectors) {
throw std::runtime_error("Number of vectors or IDs does not match expected values");
}

// Add extra parameters that cant be configured with the index factory
SetExtraParameters<faiss::IndexBinary, faiss::IndexBinaryIVF, faiss::IndexBinaryHNSW>(jniUtil, env, parameters, indexWriter.get());

std::unique_ptr<faiss::IndexBinaryIDMap> idMap(faissMethods->indexBinaryIdMap(indexWriter.get()));
idMap->add_with_ids(numVectors, inputVectors->data(), ids.data());

faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str());
}

} // namespace faiss_wrapper
} // namesapce knn_jni
} // namespace knn_jni
7 changes: 6 additions & 1 deletion jni/src/faiss_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ void FaissMethods::writeIndex(const faiss::Index* idx, const char* fname) {
void FaissMethods::writeIndexBinary(const faiss::IndexBinary* idx, const char* fname) {
faiss::write_index_binary(idx, fname);
}

faiss::Index* FaissMethods::readIndex(faiss::IOReader* f, int io_flags) {
return faiss::read_index(f, io_flags);
}
faiss::IndexBinary* FaissMethods::readIndexBinary(faiss::IOReader* f, int io_flags) {
return faiss::read_index_binary(f, io_flags);
}
} // namespace faiss_wrapper
} // namesapce knn_jni
108 changes: 11 additions & 97 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include <vector>

// Defines type of IDSelector
enum FilterIdsSelectorType{
enum FilterIdsSelectorType {
BITMAP = 0, BATCH = 1,
};
namespace faiss {
Expand Down Expand Up @@ -76,7 +76,7 @@ void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const
// Converts the int FilterIds to Faiss ids type array.
void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds);

// Concerts the FilterIds to BitMap
// Converts the FilterIds to BitMap
void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bitsetVector);

std::unique_ptr<faiss::IDGrouperBitmap> buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector<uint64_t>* bitmap);
Expand Down Expand Up @@ -161,7 +161,7 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN

void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ,
jbyteArray templateIndexJ, jobject parametersJ) {
jbyteArray templateIndexJ, jobject parametersJ, IndexService* indexService) {
if (idsJ == nullptr) {
throw std::runtime_error("IDs cannot be null");
}
Expand Down Expand Up @@ -192,108 +192,22 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface *

// Read data set
// Read vectors from memory address
auto *inputVectors = reinterpret_cast<std::vector<float>*>(vectorsAddressJ);
int dim = (int)dimJ;
int numVectors = (int) (inputVectors->size() / (uint64_t) dim);
int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}

// Get vector of bytes from jbytearray
int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ);
jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr);

faiss::VectorIOReader vectorIoReader;
for (int i = 0; i < indexBytesCount; i++) {
vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]);
}
std::vector<uint8_t> templateIndexData(indexBytesJ, indexBytesJ + indexBytesCount);
jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT);

// Create faiss index
std::unique_ptr<faiss::Index> indexWriter;
indexWriter.reset(faiss::read_index(&vectorIoReader, 0));

auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get());
idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data());
// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
// This is not the ideal approach, please refer this gh issue for long term solution:
// https://github.com/opensearch-project/k-NN/issues/1600
delete inputVectors;
// Write the index to disk
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
faiss::write_index(&idMap, indexPathCpp.c_str());
}

void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ, jstring indexPathJ,
jbyteArray templateIndexJ, jobject parametersJ) {
if (idsJ == nullptr) {
throw std::runtime_error("IDs cannot be null");
}

if (vectorsAddressJ <= 0) {
throw std::runtime_error("VectorsAddress cannot be less than 0");
}

if(dimJ <= 0) {
throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0");
}

if (indexPathJ == nullptr) {
throw std::runtime_error("Index path cannot be null");
}

if (templateIndexJ == nullptr) {
throw std::runtime_error("Template index cannot be null");
}

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);
if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
omp_set_num_threads(threadCount);
}
jniUtil->DeleteLocalRef(env, parametersJ);

// Read data set
// Read vectors from memory address
auto *inputVectors = reinterpret_cast<std::vector<uint8_t>*>(vectorsAddressJ);
int dim = (int)dimJ;
if (dim % 8 != 0) {
throw std::runtime_error("Dimensions should be multiply of 8");
}
int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8));
int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}

// Get vector of bytes from jbytearray
int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ);
jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr);

faiss::VectorIOReader vectorIoReader;
for (int i = 0; i < indexBytesCount; i++) {
vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]);
}
jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT);
// Convert ids
auto ids = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
int64_t vectorsAddress = (int64_t)vectorsAddressJ;
std::string indexPathCpp = jniUtil->ConvertJavaStringToCppString(env, indexPathJ);

// Create faiss index
std::unique_ptr<faiss::IndexBinary> indexWriter;
indexWriter.reset(faiss::read_index_binary(&vectorIoReader, 0));

auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
faiss::IndexBinaryIDMap idMap = faiss::IndexBinaryIDMap(indexWriter.get());
idMap.add_with_ids(numVectors, reinterpret_cast<const uint8_t*>(inputVectors->data()), idVector.data());
// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
// This is not the ideal approach, please refer this gh issue for long term solution:
// https://github.com/opensearch-project/k-NN/issues/1600
delete inputVectors;
// Write the index to disk
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
faiss::write_index_binary(&idMap, indexPathCpp.c_str());
// Create index using IndexService
indexService->createIndexFromTemplate(jniUtil, env, dim, numIds, vectorsAddress, ids, indexPathCpp, parametersCpp, templateIndexData);
}

jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) {
Expand Down Expand Up @@ -674,7 +588,7 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti
omp_set_num_threads(threadCount);
}

// Add extra parameters that cant be configured with the index factory
// Add extra parameters that can't be configured with the index factory
if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) {
jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS];
auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ);
Expand Down
8 changes: 6 additions & 2 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT
jobject parametersJ)
{
try {
knn_jni::faiss_wrapper::CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ);
std::unique_ptr<knn_jni::faiss_wrapper::FaissMethods> faissMethods(new knn_jni::faiss_wrapper::FaissMethods());
knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods));
CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ, &indexService);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand All @@ -99,7 +101,9 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryInde
jobject parametersJ)
{
try {
knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ);
std::unique_ptr<knn_jni::faiss_wrapper::FaissMethods> faissMethods(new knn_jni::faiss_wrapper::FaissMethods());
knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods));
CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ, &binaryIndexService);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand Down
Loading

0 comments on commit 8123cfc

Please sign in to comment.