Skip to content

Commit

Permalink
[Java] support bytes string serialization for jdk8 (#1039)
Browse files Browse the repository at this point in the history
* use field type instead of jdk version for string value extract/populate

* fix naming for jdk11

* fix string builder serialization for jdk8 with byte[] string

* rename writeJDK11String to writeBytesString
  • Loading branch information
chaokunyang authored Oct 30, 2023
1 parent 1a75cbe commit 355aaf8
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public Object createJDK11StringByUnsafe() {

// @Benchmark
public Object createJDK8StringByMethodHandle() {
return StringSerializer.newJava11StringByZeroCopy(coder, strBytes);
return StringSerializer.newBytesStringZeroCopy(coder, strBytes);
}

// @Benchmark
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public Object createJDK8StringByUnsafe() {

// @Benchmark
public Object createJDK8StringByMethodHandle() {
return StringSerializer.newJava8StringByZeroCopy(strData);
return StringSerializer.newCharsStringZeroCopy(strData);
}

// @Benchmark
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,8 @@ private static synchronized Tuple2<ToIntFunction, Function> getBuilderFunc() {
(ToIntFunction<CharSequence>) makeGetterFunction(getCoderMethod, int.class);
builderCache = Tuple2.of(getCoder, getValue);
} catch (NoSuchMethodException e) {
throw new RuntimeException(e);
builderCache = Tuple2.of(null, getValue);
}

} else {
builderCache = Tuple2.of(null, getValue);
}
Expand Down Expand Up @@ -220,7 +219,7 @@ public short getXtypeId() {

@Override
public void write(MemoryBuffer buffer, T value) {
if (Platform.JAVA_VERSION > 8) {
if (getCoder != null) {
int coder = getCoder.applyAsInt(value);
byte[] v = (byte[]) getValue.apply(value);
buffer.writeByte(coder);
Expand All @@ -236,9 +235,9 @@ public void write(MemoryBuffer buffer, T value) {
} else {
char[] v = (char[]) getValue.apply(value);
if (StringSerializer.isLatin(v)) {
stringSerializer.writeJDK8Latin(buffer, v, value.length());
stringSerializer.writeCharsLatin(buffer, v, value.length());
} else {
stringSerializer.writeJDK8UTF16(buffer, v, value.length());
stringSerializer.writeCharsUTF16(buffer, v, value.length());
}
}
}
Expand Down
112 changes: 57 additions & 55 deletions java/fury-core/src/main/java/io/fury/serializer/StringSerializer.java
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,15 @@ public void writeString(MemoryBuffer buffer, String value) {
public Expression writeStringExpr(Expression strSerializer, Expression buffer, Expression str) {
if (isJava) {
if (STRING_VALUE_FIELD_IS_BYTES) {
return new StaticInvoke(StringSerializer.class, "writeJDK11String", buffer, str);
return new StaticInvoke(StringSerializer.class, "writeBytesString", buffer, str);
} else {
if (!STRING_VALUE_FIELD_IS_CHARS) {
throw new UnsupportedOperationException();
}
if (compressString) {
return new Invoke(strSerializer, "writeJava8StringCompressed", buffer, str);
return new Invoke(strSerializer, "writeCharsStringCompressed", buffer, str);
} else {
return new Invoke(strSerializer, "writeJava8StringUncompressed", buffer, str);
return new Invoke(strSerializer, "writeCharsStringUncompressed", buffer, str);
}
}
} else {
Expand All @@ -153,17 +153,17 @@ public Expression writeStringExpr(Expression strSerializer, Expression buffer, E
}

// Invoked by jit
public void writeJava8StringCompressed(MemoryBuffer buffer, String value) {
public void writeCharsStringCompressed(MemoryBuffer buffer, String value) {
final char[] chars = (char[]) Platform.getObject(value, STRING_VALUE_FIELD_OFFSET);
if (isLatin(chars)) {
writeJDK8Latin(buffer, chars, chars.length);
writeCharsLatin(buffer, chars, chars.length);
} else {
writeJDK8UTF16(buffer, chars, chars.length);
writeCharsUTF16(buffer, chars, chars.length);
}
}

// Invoked by jit
public void writeJava8StringUncompressed(MemoryBuffer buffer, String value) {
public void writeCharsStringUncompressed(MemoryBuffer buffer, String value) {
int numBytes = MathUtils.doubleExact(value.length());
final char[] chars = (char[]) Platform.getObject(value, STRING_VALUE_FIELD_OFFSET);
buffer.writePrimitiveArrayWithSizeEmbedded(chars, Platform.CHAR_ARRAY_OFFSET, numBytes);
Expand All @@ -183,19 +183,19 @@ public Expression readStringExpr(Expression strSerializer, Expression buffer) {
// Expression coder = inlineInvoke(buffer, "readByte", BYTE_TYPE);
// Expression value = inlineInvoke(buffer, "readBytesWithSizeEmbedded", BINARY_TYPE);
// return new StaticInvoke(
// StringSerializer.class, "newJava11StringByZeroCopy", STRING_TYPE, coder, value);
return new Invoke(strSerializer, "readJava11String", STRING_TYPE, buffer);
// StringSerializer.class, "newBytesStringZeroCopy", STRING_TYPE, coder, value);
return new Invoke(strSerializer, "readBytesString", STRING_TYPE, buffer);
} else {
if (!STRING_VALUE_FIELD_IS_CHARS) {
throw new UnsupportedOperationException();
}
if (compressString) {
return new Invoke(strSerializer, "readJava8CompressedString", STRING_TYPE, buffer);
return new Invoke(strSerializer, "readCompressedCharsString", STRING_TYPE, buffer);
} else {
Expression chars =
new Invoke(buffer, "readCharsWithSizeEmbedded", PRIMITIVE_CHAR_ARRAY_TYPE);
return new StaticInvoke(
StringSerializer.class, "newJava8StringByZeroCopy", STRING_TYPE, chars);
StringSerializer.class, "newCharsStringZeroCopy", STRING_TYPE, chars);
}
}
} else {
Expand All @@ -204,7 +204,7 @@ public Expression readStringExpr(Expression strSerializer, Expression buffer) {
}

// Invoked by jit.
public String readJava11String(MemoryBuffer buffer) {
public String readBytesString(MemoryBuffer buffer) {
byte[] heapMemory = buffer.getHeapMemory();
if (heapMemory != null) {
final int targetIndex = buffer.unsafeHeapReaderIndex();
Expand Down Expand Up @@ -240,25 +240,25 @@ public String readJava11String(MemoryBuffer buffer) {
final byte[] bytes = new byte[numBytes];
System.arraycopy(heapMemory, arrIndex, bytes, 0, numBytes);
buffer.increaseReaderIndexUnsafe(arrIndex - targetIndex + numBytes);
return newJava11StringByZeroCopy(coder, bytes);
return newBytesStringZeroCopy(coder, bytes);
} else {
byte coder = buffer.readByte();
final int numBytes = buffer.readPositiveVarInt();
byte[] bytes = buffer.readBytes(numBytes);
if (coder == UTF8) {
return new String(bytes, 0, numBytes, StandardCharsets.UTF_8);
}
return newJava11StringByZeroCopy(coder, bytes);
return newBytesStringZeroCopy(coder, bytes);
}
}

// Invoked by jit
public String readJava8CompressedString(MemoryBuffer buffer) {
public String readCompressedCharsString(MemoryBuffer buffer) {
byte coder = buffer.readByte();
if (coder == LATIN1) {
return newJava8StringByZeroCopy(readLatinChars(buffer));
return newCharsStringZeroCopy(readLatinChars(buffer));
} else {
return newJava8StringByZeroCopy(readUTF16Chars(buffer, coder));
return newCharsStringZeroCopy(readUTF16Chars(buffer, coder));
}
}

Expand All @@ -281,17 +281,17 @@ private byte[] getByteArray(int numElements) {
// Invoked by fury JIT
public void writeJavaString(MemoryBuffer buffer, String value) {
if (STRING_VALUE_FIELD_IS_BYTES) {
writeJDK11String(buffer, value);
writeBytesString(buffer, value);
} else {
if (!STRING_VALUE_FIELD_IS_CHARS) {
throw new UnsupportedOperationException();
}
final char[] chars = (char[]) Platform.getObject(value, STRING_VALUE_FIELD_OFFSET);
if (compressString) {
if (isLatin(chars)) {
writeJDK8Latin(buffer, chars, chars.length);
writeCharsLatin(buffer, chars, chars.length);
} else {
writeJDK8UTF16(buffer, chars, chars.length);
writeCharsUTF16(buffer, chars, chars.length);
}
} else {
int numBytes = MathUtils.doubleExact(value.length());
Expand Down Expand Up @@ -329,30 +329,30 @@ public static boolean isLatin(char[] chars) {
// Invoked by fury JIT
public String readJavaString(MemoryBuffer buffer) {
if (STRING_VALUE_FIELD_IS_BYTES) {
return readJava11String(buffer);
return readBytesString(buffer);
} else {
if (!STRING_VALUE_FIELD_IS_CHARS) {
throw new UnsupportedOperationException();
}
if (compressString) {
byte coder = buffer.readByte();
if (coder == LATIN1) {
return newJava8StringByZeroCopy(readLatinChars(buffer));
return newCharsStringZeroCopy(readLatinChars(buffer));
} else if (coder == UTF16) {
return newJava8StringByZeroCopy(readUTF16Chars(buffer, coder));
return newCharsStringZeroCopy(readUTF16Chars(buffer, coder));
} else {
if (coder != UTF8) {
throw new UnsupportedOperationException("Unsupported encoding: " + coder);
}
return readUTF8String(buffer);
}
} else {
return newJava8StringByZeroCopy(buffer.readCharsWithSizeEmbedded());
return newCharsStringZeroCopy(buffer.readCharsWithSizeEmbedded());
}
}
}

public static void writeJDK11String(MemoryBuffer buffer, String value) {
public static void writeBytesString(MemoryBuffer buffer, String value) {
byte[] bytes = (byte[]) Platform.getObject(value, STRING_VALUE_FIELD_OFFSET);
byte coder = Platform.getByte(value, STRING_CODER_FIELD_OFFSET);
int bytesLen = bytes.length;
Expand Down Expand Up @@ -382,7 +382,7 @@ public static void writeJDK11String(MemoryBuffer buffer, String value) {
buffer.unsafeWriterIndex(writerIndex);
}

public void writeJDK8Latin(MemoryBuffer buffer, char[] chars, final int strLen) {
public void writeCharsLatin(MemoryBuffer buffer, char[] chars, final int strLen) {
int writerIndex = buffer.writerIndex();
// The `ensure` ensure next operations are safe without bound checks,
// and inner heap buffer doesn't change.
Expand Down Expand Up @@ -412,7 +412,7 @@ public void writeJDK8Latin(MemoryBuffer buffer, char[] chars, final int strLen)
}
}

public void writeJDK8UTF16(MemoryBuffer buffer, char[] chars, int strLen) {
public void writeCharsUTF16(MemoryBuffer buffer, char[] chars, int strLen) {
int numBytes = MathUtils.doubleExact(strLen);
if (Platform.IS_LITTLE_ENDIAN) {
buffer.writeByte(UTF16);
Expand Down Expand Up @@ -515,20 +515,21 @@ private char[] readUTF16Chars(MemoryBuffer buffer, byte coder) {

private static final MethodHandles.Lookup STRING_LOOK_UP =
_JDKAccess._trustedLookup(String.class);
private static final BiFunction<char[], Boolean, String> JAVA8_STRING_ZERO_COPY_CTR =
getJava8StringZeroCopyCtr();
private static final BiFunction<byte[], Byte, String> JAVA11_STRING_ZERO_COPY_CTR =
getJava11StringZeroCopyCtr();
private static final Function<byte[], String> JAVA11_LATIN_STRING_ZERO_COPY_CTR =
getJava11LatinStringZeroCopyCtr();

public static String newJava8StringByZeroCopy(char[] data) {
if (Platform.JAVA_VERSION != 8) {
private static final BiFunction<char[], Boolean, String> CHARS_STRING_ZERO_COPY_CTR =
getCharsStringZeroCopyCtr();
private static final BiFunction<byte[], Byte, String> BYTES_STRING_ZERO_COPY_CTR =
getBytesStringZeroCopyCtr();
private static final Function<byte[], String> LATIN_BYTES_STRING_ZERO_COPY_CTR =
getLatinBytesStringZeroCopyCtr();

public static String newCharsStringZeroCopy(char[] data) {
if (!STRING_VALUE_FIELD_IS_CHARS) {
throw new IllegalStateException(
String.format("Current java version is %s", Platform.JAVA_VERSION));
String.format(
"String value isn't char[], current java %s isn't supported", Platform.JAVA_VERSION));
}
try {
if (JAVA8_STRING_ZERO_COPY_CTR == null) {
if (CHARS_STRING_ZERO_COPY_CTR == null) {
// 1. As documented in `Subsequent Modification of final Fields` in
// https://docs.oracle.com/javase/specs/jls/se8/html/jls-17.html#d5e34106
// Maybe we can use `UNSAFE.putObject` to update String field to avoid reflection overhead.
Expand All @@ -543,7 +544,7 @@ public static String newJava8StringByZeroCopy(char[] data) {
return str;
} else {
// 25% faster than unsafe put field, only 10% slower than `new String(str)`
return JAVA8_STRING_ZERO_COPY_CTR.apply(data, Boolean.TRUE);
return CHARS_STRING_ZERO_COPY_CTR.apply(data, Boolean.TRUE);
}
} catch (Throwable e) {
throw new RuntimeException(e);
Expand All @@ -552,36 +553,37 @@ public static String newJava8StringByZeroCopy(char[] data) {

// coder param first to make inline call args
// `(buffer.readByte(), buffer.readBytesWithSizeEmbedded())` work.
public static String newJava11StringByZeroCopy(byte coder, byte[] data) {
if (Platform.JAVA_VERSION < 9) {
public static String newBytesStringZeroCopy(byte coder, byte[] data) {
if (!STRING_VALUE_FIELD_IS_BYTES) {
throw new IllegalStateException(
String.format("Current java version is %s", Platform.JAVA_VERSION));
String.format(
"String value isn't byte[], current java %s isn't supported", Platform.JAVA_VERSION));
}
if (coder == LATIN1) {
// 700% faster than unsafe put field in java11, only 10% slower than `new String(str)` for
// string length 230.
// 50% faster than unsafe put field in java11 for string length 10.
if (JAVA11_LATIN_STRING_ZERO_COPY_CTR != null) {
return JAVA11_LATIN_STRING_ZERO_COPY_CTR.apply(data);
if (LATIN_BYTES_STRING_ZERO_COPY_CTR != null) {
return LATIN_BYTES_STRING_ZERO_COPY_CTR.apply(data);
} else {
// JDK17 removed newStringLatin1
return JAVA11_STRING_ZERO_COPY_CTR.apply(data, LATIN1_BOXED);
return BYTES_STRING_ZERO_COPY_CTR.apply(data, LATIN1_BOXED);
}
} else if (coder == UTF16) {
// avoid byte box cost.
return JAVA11_STRING_ZERO_COPY_CTR.apply(data, UTF16_BOXED);
return BYTES_STRING_ZERO_COPY_CTR.apply(data, UTF16_BOXED);
} else {
// 700% faster than unsafe put field in java11, only 10% slower than `new String(str)` for
// string length 230.
// 50% faster than unsafe put field in java11 for string length 10.
// `invokeExact` must pass exact params with exact types:
// `(Object) data, coder` will throw WrongMethodTypeException
return JAVA11_STRING_ZERO_COPY_CTR.apply(data, coder);
return BYTES_STRING_ZERO_COPY_CTR.apply(data, coder);
}
}

private static BiFunction<char[], Boolean, String> getJava8StringZeroCopyCtr() {
if (Platform.JAVA_VERSION > 8) {
private static BiFunction<char[], Boolean, String> getCharsStringZeroCopyCtr() {
if (!STRING_VALUE_FIELD_IS_CHARS) {
return null;
}
MethodHandle handle = getJavaStringZeroCopyCtrHandle();
Expand All @@ -604,8 +606,8 @@ private static BiFunction<char[], Boolean, String> getJava8StringZeroCopyCtr() {
}
}

private static BiFunction<byte[], Byte, String> getJava11StringZeroCopyCtr() {
if (Platform.JAVA_VERSION < 9) {
private static BiFunction<byte[], Byte, String> getBytesStringZeroCopyCtr() {
if (!STRING_VALUE_FIELD_IS_BYTES) {
return null;
}
MethodHandle handle = getJavaStringZeroCopyCtrHandle();
Expand All @@ -630,15 +632,15 @@ private static BiFunction<byte[], Byte, String> getJava11StringZeroCopyCtr() {
}
}

private static Function<byte[], String> getJava11LatinStringZeroCopyCtr() {
if (Platform.JAVA_VERSION < 9) {
private static Function<byte[], String> getLatinBytesStringZeroCopyCtr() {
if (!STRING_VALUE_FIELD_IS_BYTES) {
return null;
}
if (STRING_LOOK_UP == null) {
return null;
}
try {
Class clazz = Class.forName("java.lang.StringCoding");
Class<?> clazz = Class.forName("java.lang.StringCoding");
MethodHandles.Lookup caller = STRING_LOOK_UP.in(clazz);
// JDK17 removed this method.
MethodHandle handle =
Expand All @@ -657,7 +659,7 @@ private static MethodHandle getJavaStringZeroCopyCtrHandle() {
return null;
}
try {
if (Platform.JAVA_VERSION == 8) {
if (STRING_VALUE_FIELD_IS_CHARS) {
return STRING_LOOK_UP.findConstructor(
String.class, MethodType.methodType(void.class, char[].class, boolean.class));
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package io.fury.serializer;

import static io.fury.serializer.StringSerializer.newJava11StringByZeroCopy;
import static io.fury.serializer.StringSerializer.newBytesStringZeroCopy;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;
Expand Down Expand Up @@ -83,7 +83,7 @@ private static String readJavaStringZeroCopy(MemoryBuffer buffer) {
if (STRING_VALUE_FIELD_IS_BYTES) {
return readJDK11String(buffer);
} else if (STRING_VALUE_FIELD_IS_CHARS) {
return StringSerializer.newJava8StringByZeroCopy(buffer.readCharsWithSizeEmbedded());
return StringSerializer.newCharsStringZeroCopy(buffer.readCharsWithSizeEmbedded());
}
return null;
} catch (Exception e) {
Expand All @@ -94,7 +94,7 @@ private static String readJavaStringZeroCopy(MemoryBuffer buffer) {
static String readJDK11String(MemoryBuffer buffer) {
byte coder = buffer.readByte();
byte[] value = buffer.readBytesWithSizeEmbedded();
return newJava11StringByZeroCopy(coder, value);
return newBytesStringZeroCopy(coder, value);
}

private static boolean writeJavaStringZeroCopy(MemoryBuffer buffer, String value) {
Expand All @@ -108,7 +108,7 @@ private static boolean writeJavaStringZeroCopy(MemoryBuffer buffer, String value
valueIsCharsField.setAccessible(true);
boolean STRING_VALUE_FIELD_IS_CHARS = (Boolean) valueIsCharsField.get(null);
if (STRING_VALUE_FIELD_IS_BYTES) {
StringSerializer.writeJDK11String(buffer, value);
StringSerializer.writeBytesString(buffer, value);
} else if (STRING_VALUE_FIELD_IS_CHARS) {
writeJDK8String(buffer, value);
} else {
Expand Down

0 comments on commit 355aaf8

Please sign in to comment.