Skip to content

Commit

Permalink
Ensure that native C++ index object doesn't get destroyed while Java …
Browse files Browse the repository at this point in the history
…threads are using it. (#36)

* Ensure that native C++ index object doesn't get destroyed while Java threads are still using it.

* Add test for native concurrency.

* Fix markDeleted/unmarkDeleted.

* Add ASAN for testing on Linux.

* Allow overriding CXX.

* Only run with ASan on x86.

* Fix flakey test.

* Fix ASan and formatting.

* Add LD_PRELOAD for ASan on Linux.

* Er, this way instead

* No clang on Linux.

* Ignore leaks and SEGV from the JVM.

* Update all.yml

* Disable ASan for Java.

* Downgrade dockcross version again.

* Always use g++ on Linux.

* Fix glibc version check again.
  • Loading branch information
psobot authored Mar 19, 2024
1 parent 0c2196e commit 6ffbf68
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 44 deletions.
42 changes: 37 additions & 5 deletions .github/workflows/all.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,37 @@ jobs:
- name: Run Tests
run: mvn --batch-mode test

run-java-tests-with-address-sanitizer:
if: false # Disabled due to ASan log spam in output
continue-on-error: true
name: Test with Java ${{ matrix.java-version }} + Address Sanitizer
runs-on: ubuntu-latest
defaults:
run:
working-directory: java
strategy:
matrix:
java-version: ['11', '17']
steps:
- uses: actions/checkout@v3
- name: Set up JDK ${{ matrix.java-version }}
uses: actions/setup-java@v4
with:
java-version: ${{ matrix.java-version }}
distribution: 'corretto'
gpg-private-key: ${{ secrets.SONATYPE_GPG_PRIVATE_KEY }}
gpg-passphrase: MAVEN_GPG_PASSPHRASE
- name: Compile
env:
USE_ASAN: "1"
run: make target/classes/linux-x64/libvoyager.so
- name: Run Tests
run: |
ASAN_OPTIONS=detect_leaks=0
LD_PRELOAD=$(clang -print-file-name=libclang_rt.asan-x86_64.so) \
mvn --batch-mode test
run-python-tests:
runs-on: ${{ matrix.os }}
continue-on-error: true
Expand Down Expand Up @@ -319,13 +350,14 @@ jobs:
file combined-jar/mac-aarch64/libvoyager.dylib | grep arm64
echo "Checking the Linux JARs to ensure a sufficiently old required GLIBC and GLIBCXX version in each..."
echo "Checking linux-aarch64/libvoyager.so for GLIBC 2.27..."
echo "Checking linux-aarch64/libvoyager.so (should require at most GLIBC 2.27...)"
objdump -T combined-jar/linux-aarch64/libvoyager.so | grep GLIBC | sed 's/.*GLIBC_\([.0-9]*\).*/\1/g' | sort -Vu | tail -n 1 | grep 2\.27 || (echo "Expected linux-aarch64 binary to require at most glibc 2.27. Ensure the dockcross container has not been updated." && false)
echo "Checking linux-aarch64/libvoyager.so for GLIBCXX 3.4.22..."
echo "Checking linux-aarch64/libvoyager.so (should require at most GLIBCXX 3.4.22...)"
objdump -T combined-jar/linux-aarch64/libvoyager.so | grep GLIBCXX | sed 's/.*GLIBCXX_\([.0-9]*\).*/\1/g' | sort -Vu | tail -n 1 | grep 3\.4\.22 || (echo "Expected linux-aarch64 binary to require at most glibc++ 3.4.22. Ensure the dockcross container has not been updated." && false)
echo "Checking linux-x64/libvoyager.so for GLIBC 2.29..."
objdump -T combined-jar/linux-x64/libvoyager.so | grep GLIBC | sed 's/.*GLIBC_\([.0-9]*\).*/\1/g' | sort -Vu | tail -n 1 | grep 2\.29 || (echo "Expected linux-x64 binary to require at most glibc 2.29. Ensure the dockcross container has not been updated." && false)
echo "Checking linux-x64/libvoyager.so for GLIBCXX 3.4.22..."
echo "Checking linux-x64/libvoyager.so (should require at most GLIBC 2.27...)"
objdump -T combined-jar/linux-x64/libvoyager.so | grep GLIBC | sed 's/.*GLIBC_\([.0-9]*\).*/\1/g' | sort -Vu | tail -n 1 | grep 2\.27 || (echo "Expected linux-x64 binary to require at most glibc 2.28. Ensure the dockcross container has not been updated." && false)
echo "Checking linux-x64/libvoyager.so (should require at most GLIBCXX 3.4.22...)"
objdump -T combined-jar/linux-x64/libvoyager.so | grep GLIBCXX | sed 's/.*GLIBCXX_\([.0-9]*\).*/\1/g' | sort -Vu | tail -n 1 | grep 3\.4\.22 || (echo "Expected linux-x64 binary to require at most glibc++ 3.4.22. Ensure the dockcross container has not been updated." && false)
zip -r $(ls java-* | grep jar | tail -n 1) combined-jar/*
Expand Down
7 changes: 6 additions & 1 deletion cpp/TypedIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,12 @@ class TypedIndex : public Index {

size_t getNumElements() const { return algorithmImpl->cur_element_count; }

int getEF() const { return defaultEF; }
int getEF() const {
if (algorithmImpl)
return algorithmImpl->ef_;
else
return defaultEF;
}

int getNumThreads() { return numThreadsDefault; }

Expand Down
31 changes: 19 additions & 12 deletions java/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,41 +10,48 @@ WIN_SOBJ := voyager.dll
CPP_SRC_DIR := ../cpp/

SOURCE := com_spotify_voyager_jni_Index.cpp
HEADERS := $(SOURCE:.cpp=.h)
HEADERS := $(SOURCE:.cpp=.h) $(wildcard ../cpp/*.h)
JAVA_SRC := ./src/main/java

ifeq ($(UNAME_S),Linux)
CXX := g++
JAVA_INC := $(JAVA_HOME)/include $(JAVA_HOME)/include/linux
ALL_OBJS := target/classes/linux-x64/$(LINUX_SOBJ) target/classes/linux-aarch64/$(LINUX_SOBJ)
CXXFLAGS := -I. -lc -shared -std=c++17 -I ./include $(addprefix -I,$(JAVA_INC)) -I $(CPP_SRC_DIR) -fPIC -O3
PREBUILD_COMMAND := sudo apt-get update && sudo apt-get install -y ca-certificates-java openjdk-11-jre-headless; { curl https://dlcdn.apache.org/maven/maven-3/3.9.2/binaries/apache-maven-3.9.2-bin.tar.gz | sudo tar -xvzf - -C /opt ; }; export M2_HOME=/opt/apache-maven-3.9.2 && export PATH=\"${M2_HOME}/bin:${PATH}\"
else ifeq ($(UNAME_S),Darwin)
CC := clang++
CXX := clang++
JAVA_INC := $(JAVA_HOME)/include $(JAVA_HOME)/include/darwin
ALL_OBJS := target/classes/mac-x64/$(MAC_SOBJ) target/classes/mac-aarch64/$(MAC_SOBJ)
CXXFLAGS := -I. -lc++ -shared -std=c++17 -I ./include $(addprefix -I,$(JAVA_INC)) -I $(CPP_SRC_DIR) -fPIC -O3
else ifdef OS # Windows:
CC := cl.exe
CXX := cl.exe
JAVA_INC := $(JAVA_HOME)/include $(JAVA_HOME)/include/win32
ALL_OBJS := target/classes/win-x64/$(WIN_SOBJ)
CXXFLAGS := /I. /EHsc /DLL /LD /std:c++17 /I.\include $(addprefix -I,$(JAVA_INC)) /I $(CPP_SRC_DIR) /O2
endif

all: $(ALL_OBJS)
# Done!
# Allow linking against ASan if USE_ASAN is set to 1:
ifeq (${USE_ASAN},1)
CXXFLAGS := $(CXXFLAGS) -fsanitize=address
else
endif

check-java-home:
ifndef JAVA_HOME
$(error JAVA_HOME is undefined)
else
endif

classpath.txt: check-java-home
all: $(ALL_OBJS)
# Done!

classpath.txt:
mvn dependency:build-classpath | grep '\.jar[;:]' | tr ';' ':' > $@

%.h: %.cpp classpath.txt
# For each Java JNI .cpp file, we need a corresponding header to be generated.
$(eval JAVA_FILE := $(shell python3 -c 'print("$<".split(".")[0].replace("_", "/") + ".java")'))
javac -cp ./target/classes:$(shell cat classpath.txt) -h . -sourcepath $(JAVA_SRC) $(JAVA_SRC)/$(JAVA_FILE)
javac -cp ./target/classes:$(shell cat classpath.txt) -h . -Xlint:deprecation -sourcepath $(JAVA_SRC) $(JAVA_SRC)/$(JAVA_FILE)

target/classes/linux-x64/$(LINUX_SOBJ): classpath.txt $(HEADERS)
mkdir -p target/classes/linux-x64/
Expand All @@ -53,7 +60,7 @@ target/classes/linux-x64/$(LINUX_SOBJ): classpath.txt $(HEADERS)
cp -r $(addsuffix /*,$(JAVA_INC)) linux-build/include
# Why use Dockcross if we're already building on x86?
# We need to use an older version of GLIBC and GLIBCXX to ensure wide compatibility on Linux.
docker run -v $(shell realpath ../):/work dockcross/linux-x64:20230202-378e837 bash -x -c 'cd /work/java && $(PREBUILD_COMMAND) && $$CXX $(CXXFLAGS) -I linux-build/include -o $@ $(SOURCE)'
docker run -v $(shell realpath ../):/work dockcross/linux-x64:20210625-78b96c7 bash -x -c 'cd /work/java && $(PREBUILD_COMMAND) && $$CXX $(CXXFLAGS) -I linux-build/include -o $@ $(SOURCE)'

target/classes/linux-aarch64/$(LINUX_SOBJ): classpath.txt $(HEADERS)
mkdir -p target/classes/linux-aarch64/
Expand All @@ -64,15 +71,15 @@ target/classes/linux-aarch64/$(LINUX_SOBJ): classpath.txt $(HEADERS)

target/classes/mac-x64/$(MAC_SOBJ): classpath.txt $(HEADERS)
mkdir -p target/classes/mac-x64/
$(CC) $(CXXFLAGS) -o $@ $(SOURCE) -arch x86_64
$(CXX) $(CXXFLAGS) -o $@ $(SOURCE) -arch x86_64

target/classes/mac-aarch64/$(MAC_SOBJ): classpath.txt $(HEADERS)
mkdir -p target/classes/mac-aarch64/
$(CC) $(CXXFLAGS) -o $@ $(SOURCE) -arch arm64
$(CXX) $(CXXFLAGS) -o $@ $(SOURCE) -arch arm64

target/classes/win-x64/$(WIN_SOBJ): classpath.txt $(HEADERS)
mkdir -p target/classes/win-x64/
$(CC) $(CXXFLAGS) $(SOURCE)
$(CXX) $(CXXFLAGS) $(SOURCE)
cp com_spotify_voyager_jni_Index.dll $@

test: $(ALL_OBJS)
Expand Down
75 changes: 50 additions & 25 deletions java/com_spotify_voyager_jni_Index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,51 @@ jfieldID getHandleFieldID(JNIEnv *env, jobject obj) {
}

template <typename T>
T *getHandle(JNIEnv *env, jobject obj, bool allow_missing = false) {
std::shared_ptr<T> getHandle(JNIEnv *env, jobject obj,
bool allow_missing = false) {
env->MonitorEnter(obj);
jlong handle = env->GetLongField(obj, getHandleFieldID(env, obj));
T *pointer = reinterpret_cast<T *>(handle);
env->MonitorExit(obj);

// Yes, we're storing a pointer to a shared pointer on a Java object.
// A bit strange, but totally okay to ensure that we still get shared_ptr
// semantics while storing a single Long value.
std::shared_ptr<T> *pointer = reinterpret_cast<std::shared_ptr<T> *>(handle);

if (!allow_missing && !pointer) {
throw std::runtime_error("Native JNI object not found.");
throw std::runtime_error(
"This Voyager index has been closed and can no longer be used.");
}

return pointer;
// Return a copy of this shared pointer, thereby ensuring that it remains
// alive until the shared_ptr goes out of scope.
return *pointer;
}

template <typename T> void setHandle(JNIEnv *env, jobject obj, T *t) {
std::shared_ptr<T> *sharedPointerForJava = new std::shared_ptr<T>(t);
env->MonitorEnter(obj);
env->SetLongField(obj, getHandleFieldID(env, obj),
reinterpret_cast<jlong>(t));
reinterpret_cast<jlong>(sharedPointerForJava));
env->MonitorExit(obj);
}

template <typename T> void deleteHandle(JNIEnv *env, jobject obj) {
env->MonitorEnter(obj);
jlong handle = env->GetLongField(obj, getHandleFieldID(env, obj));
env->MonitorExit(obj);

if (handle == 0)
return;

std::shared_ptr<T> *pointer = reinterpret_cast<std::shared_ptr<T> *>(handle);
// Note: This _may_ trigger the destructor of T, but if any other threads have
// thread-local copies of this shared_ptr, the destructor will be triggered
// on the last thread that has one of these shared_ptrs.
env->MonitorEnter(obj);
delete pointer;
env->SetLongField(obj, getHandleFieldID(env, obj), 0);
env->MonitorExit(obj);
}

std::string toString(JNIEnv *env, jstring js) {
Expand Down Expand Up @@ -315,7 +346,7 @@ void Java_com_spotify_voyager_jni_Index_nativeConstructor(
void Java_com_spotify_voyager_jni_Index_addItem___3F(JNIEnv *env, jobject self,
jfloatArray vector) {
try {
Index *index = getHandle<Index>(env, self);
std::shared_ptr<Index> index = getHandle<Index>(env, self);
index->addItem(toStdVector(env, vector), {});
} catch (std::exception const &e) {
if (!env->ExceptionCheck()) {
Expand All @@ -328,7 +359,7 @@ void Java_com_spotify_voyager_jni_Index_addItem___3FJ(JNIEnv *env, jobject self,
jfloatArray vector,
jlong id) {
try {
Index *index = getHandle<Index>(env, self);
std::shared_ptr<Index> index = getHandle<Index>(env, self);
index->addItem(toStdVector(env, vector), {id});
} catch (std::exception const &e) {
if (!env->ExceptionCheck()) {
Expand All @@ -342,7 +373,7 @@ void Java_com_spotify_voyager_jni_Index_addItems___3_3FI(JNIEnv *env,
jobjectArray vectors,
jint numThreads) {
try {
Index *index = getHandle<Index>(env, self);
std::shared_ptr<Index> index = getHandle<Index>(env, self);
index->addItems(toNDArray(env, vectors), {}, numThreads);
} catch (std::exception const &e) {
if (!env->ExceptionCheck()) {
Expand All @@ -355,7 +386,7 @@ void Java_com_spotify_voyager_jni_Index_addItems___3_3F_3JI(
JNIEnv *env, jobject self, jobjectArray vectors, jlongArray ids,
jint numThreads) {
try {
Index *index = getHandle<Index>(env, self);
std::shared_ptr<Index> index = getHandle<Index>(env, self);
index->addItems(toNDArray(env, vectors), toUnsignedStdVector(env, ids),
numThreads);
} catch (std::exception const &e) {
Expand All @@ -374,7 +405,7 @@ jobject Java_com_spotify_voyager_jni_Index_query___3FIJ(JNIEnv *env,
jint numNeighbors,
jlong queryEf) {
try {
Index *index = getHandle<Index>(env, self);
std::shared_ptr<Index> index = getHandle<Index>(env, self);

std::tuple<std::vector<hnswlib::labeltype>, std::vector<float>>
queryResults =
Expand Down Expand Up @@ -421,7 +452,7 @@ jobjectArray Java_com_spotify_voyager_jni_Index_query___3_3FIIJ(
JNIEnv *env, jobject self, jobjectArray queryVectors, jint numNeighbors,
jint numThreads, jlong queryEf) {
try {
Index *index = getHandle<Index>(env, self);
std::shared_ptr<Index> index = getHandle<Index>(env, self);

int numQueries = env->GetArrayLength(queryVectors);

Expand Down Expand Up @@ -564,7 +595,7 @@ jfloatArray Java_com_spotify_voyager_jni_Index_getVector(JNIEnv *env,
jobject self,
jlong id) {
try {
Index *index = getHandle<Index>(env, self);
std::shared_ptr<Index> index = getHandle<Index>(env, self);
return toFloatArray(env, index->getVector(id));
} catch (std::exception const &e) {
if (!env->ExceptionCheck()) {
Expand All @@ -578,8 +609,7 @@ jobjectArray Java_com_spotify_voyager_jni_Index_getVectors(JNIEnv *env,
jobject self,
jlongArray ids) {
try {
Index *index = getHandle<Index>(env, self);

std::shared_ptr<Index> index = getHandle<Index>(env, self);
NDArray<float, 2> vectors =
index->getVectors(toUnsignedStdVector(env, ids));

Expand Down Expand Up @@ -610,7 +640,7 @@ jobjectArray Java_com_spotify_voyager_jni_Index_getVectors(JNIEnv *env,
jlongArray Java_com_spotify_voyager_jni_Index_getIDs(JNIEnv *env,
jobject self) {
try {
Index *index = getHandle<Index>(env, self);
std::shared_ptr<Index> index = getHandle<Index>(env, self);

std::vector<hnswlib::labeltype> ids = index->getIDs();

Expand Down Expand Up @@ -662,8 +692,7 @@ jint Java_com_spotify_voyager_jni_Index_getEf(JNIEnv *env, jobject self) {
void Java_com_spotify_voyager_jni_Index_markDeleted(JNIEnv *env, jobject self,
jlong label) {
try {
Index *index = getHandle<Index>(env, self);
index->markDeleted(label);
getHandle<Index>(env, self)->markDeleted(label);
} catch (std::exception const &e) {
if (!env->ExceptionCheck()) {
env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what());
Expand All @@ -674,8 +703,7 @@ void Java_com_spotify_voyager_jni_Index_markDeleted(JNIEnv *env, jobject self,
void Java_com_spotify_voyager_jni_Index_unmarkDeleted(JNIEnv *env, jobject self,
jlong label) {
try {
Index *index = getHandle<Index>(env, self);
index->unmarkDeleted(label);
getHandle<Index>(env, self)->unmarkDeleted(label);
} catch (std::exception const &e) {
if (!env->ExceptionCheck()) {
env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what());
Expand All @@ -691,7 +719,7 @@ void Java_com_spotify_voyager_jni_Index_unmarkDeleted(JNIEnv *env, jobject self,
void Java_com_spotify_voyager_jni_Index_saveIndex__Ljava_lang_String_2(
JNIEnv *env, jobject self, jstring filename) {
try {
Index *index = getHandle<Index>(env, self);
std::shared_ptr<Index> index = getHandle<Index>(env, self);
index->saveIndex(toString(env, filename));
} catch (std::exception const &e) {
if (!env->ExceptionCheck()) {
Expand All @@ -703,7 +731,7 @@ void Java_com_spotify_voyager_jni_Index_saveIndex__Ljava_lang_String_2(
void Java_com_spotify_voyager_jni_Index_saveIndex__Ljava_io_OutputStream_2(
JNIEnv *env, jobject self, jobject outputStream) {
try {
Index *index = getHandle<Index>(env, self);
std::shared_ptr<Index> index = getHandle<Index>(env, self);
index->saveIndex(std::make_shared<JavaOutputStream>(env, outputStream));
} catch (std::exception const &e) {
if (!env->ExceptionCheck()) {
Expand Down Expand Up @@ -903,10 +931,7 @@ void Java_com_spotify_voyager_jni_Index_nativeLoadFromInputStream(
void Java_com_spotify_voyager_jni_Index_nativeDestructor(JNIEnv *env,
jobject self) {
try {
if (Index *index = getHandle<Index>(env, self, true)) {
delete index;
setHandle<Index>(env, self, nullptr);
}
deleteHandle<Index>(env, self);
} catch (std::exception const &e) {
if (!env->ExceptionCheck()) {
env->ThrowNew(env->FindClass("java/lang/RuntimeException"), e.what());
Expand Down
Loading

0 comments on commit 6ffbf68

Please sign in to comment.