Skip to content

Commit

Permalink
[Java] Support jdk record metashare mode (#855)
Browse files Browse the repository at this point in the history
* add record info

* refactor record utils

* support meta share for record and support record default value
  • Loading branch information
chaokunyang authored Aug 7, 2023
1 parent 8394dd2 commit 6343585
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static io.fury.collection.Maps.ofHashMap;

import io.fury.Fury;
import io.fury.resolver.MetaContext;
import io.fury.serializer.CompatibleMode;
import io.fury.test.bean.Struct;
import io.fury.util.RecordComponent;
Expand Down Expand Up @@ -105,4 +106,49 @@ public void testRecordCompatible() throws Throwable {
// test compatible
Assert.assertEquals(fury2.deserialize(bytes1), record2);
}

@Test
public void testRecordMetaShare() throws Throwable {
String code1 =
"import java.util.*;"
+ "public record TestRecord(int f1, String f2, List<String> f3, char f4, Map<String, Integer> f5) {}";
Class<?> cls1 = Struct.createStructClass("TestRecord", code1);
Object record1 =
RecordUtils.getRecordConstructor(cls1)
.f1
.invoke(1, "abc", ofArrayList("a", "b"), 'a', ofHashMap("a", 1));
Fury fury1 =
Fury.builder()
.requireClassRegistration(false)
.withCodegen(false)
.withCompatibleMode(CompatibleMode.COMPATIBLE)
.withMetaContextShare(true)
.withClassLoader(cls1.getClassLoader())
.build();
String code2 =
"import java.util.*;"
+ "public record TestRecord(String f2, char f4, Map<String, Integer> f5) {}";
Class<?> cls2 = Struct.createStructClass("TestRecord", code2);
Object record2 =
RecordUtils.getRecordConstructor(cls2).f1.invoke("abc", 'a', ofHashMap("a", 1));
Fury fury2 =
Fury.builder()
.requireClassRegistration(false)
.withCodegen(false)
.withCompatibleMode(CompatibleMode.COMPATIBLE)
.withMetaContextShare(true)
.withClassLoader(cls2.getClassLoader())
.build();
MetaContext metaContext1 = new MetaContext();
MetaContext metaContext2 = new MetaContext();
fury1.getSerializationContext().setMetaContext(metaContext1);
byte[] bytes1 = fury1.serialize(record1);
fury2.getSerializationContext().setMetaContext(metaContext2);
Object o21 = fury2.deserialize(bytes1);
fury2.getSerializationContext().setMetaContext(metaContext2);
byte[] bytes2 = fury2.serialize(o21);
fury1.getSerializationContext().setMetaContext(metaContext1);
Object o12 = fury1.deserialize(bytes2);
System.out.println(o12);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.fury.resolver.RefResolver;
import io.fury.util.FieldAccessor;
import io.fury.util.Platform;
import io.fury.util.RecordInfo;
import io.fury.util.RecordUtils;
import io.fury.util.ReflectionUtils;
import java.lang.invoke.MethodHandle;
Expand Down Expand Up @@ -52,8 +53,7 @@ public final class CompatibleSerializer<T> extends CompatibleSerializerBase<T> {
private final boolean isRecord;
private final MethodHandle constructor;
private final boolean compressNumber;
private final int[] recordComponentsIndex;
private final Object[] recordComponents;
private final RecordInfo recordInfo;

public CompatibleSerializer(Fury fury, Class<T> cls) {
super(fury, cls);
Expand All @@ -62,20 +62,19 @@ public CompatibleSerializer(Fury fury, Class<T> cls) {
// Use `setSerializerIfAbsent` to avoid overwriting existing serializer for class when used
// as data serializer.
classResolver.setSerializerIfAbsent(cls, this);
fieldResolver = classResolver.getFieldResolver(cls);
isRecord = RecordUtils.isRecord(type);
if (isRecord) {
constructor = RecordUtils.getRecordConstructor(type).f1;
List<String> fieldNames =
fieldResolver.getAllFieldsList().stream()
.map(FieldResolver.FieldInfo::getName)
.collect(Collectors.toList());
recordInfo = new RecordInfo(cls, fieldNames);
} else {
this.constructor = ReflectionUtils.getExecutableNoArgConstructorHandle(type);
recordInfo = null;
}
fieldResolver = classResolver.getFieldResolver(cls);
List<String> fieldNames =
fieldResolver.getAllFieldsList().stream()
.map(FieldResolver.FieldInfo::getName)
.collect(Collectors.toList());
recordComponentsIndex = Serializers.buildRecordComponentMapping(cls, fieldNames);
assert recordComponentsIndex != null;
recordComponents = new Object[recordComponentsIndex.length];
compressNumber = fury.compressNumber();
}

Expand All @@ -85,8 +84,7 @@ public CompatibleSerializer(Fury fury, Class<T> cls, FieldResolver fieldResolver
this.classResolver = fury.getClassResolver();
isRecord = RecordUtils.isRecord(type);
Preconditions.checkArgument(!isRecord, cls);
recordComponentsIndex = null;
recordComponents = null;
recordInfo = null;
this.constructor = null;
this.fieldResolver = fieldResolver;
compressNumber = fury.compressNumber();
Expand Down Expand Up @@ -300,11 +298,11 @@ public T read(MemoryBuffer buffer) {
if (isRecord) {
Object[] fieldValues = new Object[fieldResolver.getNumFields()];
readFields(buffer, fieldValues);
Serializers.remapping(recordComponentsIndex, fieldValues, recordComponents);
RecordUtils.remapping(recordInfo, fieldValues);
assert constructor != null;
try {
T t = (T) constructor.invokeWithArguments(recordComponents);
Arrays.fill(recordComponents, null);
T t = (T) constructor.invokeWithArguments(recordInfo.getRecordComponents());
Arrays.fill(recordInfo.getRecordComponents(), null);
return t;
} catch (Throwable e) {
Platform.throwException(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package io.fury.serializer;

import static io.fury.serializer.ObjectSerializer.getSortedDescriptors;

import com.google.common.base.Preconditions;
import io.fury.Fury;
import io.fury.builder.MetaSharedCodecBuilder;
Expand All @@ -31,16 +33,19 @@
import io.fury.type.Generics;
import io.fury.util.FieldAccessor;
import io.fury.util.Platform;
import io.fury.util.RecordInfo;
import io.fury.util.RecordUtils;
import io.fury.util.ReflectionUtils;
import java.lang.invoke.MethodHandle;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.stream.Collectors;

/**
* A meta-shared compatible deserializer builder based on {@link ClassDef}. This serializer will
Expand Down Expand Up @@ -73,6 +78,7 @@ public class MetaSharedSerializer<T> extends Serializer<T> {
private final ObjectSerializer.GenericTypeField[] containerFields;
private final boolean isRecord;
private final MethodHandle constructor;
private final RecordInfo recordInfo;
private Serializer<T> serializer;
private final ClassInfoCache classInfoCache;

Expand Down Expand Up @@ -102,6 +108,15 @@ public MetaSharedSerializer(Fury fury, Class<T> type, ClassDef classDef) {
otherFields = infos.f1;
containerFields = infos.f2;
classInfoCache = fury.getClassResolver().nilClassInfoCache();
if (isRecord) {
List<String> fieldNames =
getSortedDescriptors(descriptorGrouper).stream()
.map(Descriptor::getName)
.collect(Collectors.toList());
recordInfo = new RecordInfo(type, fieldNames);
} else {
recordInfo = null;
}
}

@Override
Expand All @@ -120,8 +135,11 @@ public T read(MemoryBuffer buffer) {
Object[] fieldValues =
new Object[finalFields.length + otherFields.length + containerFields.length];
readFields(buffer, fieldValues);
RecordUtils.remapping(recordInfo, fieldValues);
try {
return (T) constructor.invoke(fieldValues);
T t = (T) constructor.invokeWithArguments(recordInfo.getRecordComponents());
Arrays.fill(recordInfo.getRecordComponents(), null);
return t;
} catch (Throwable e) {
Platform.throwException(e);
}
Expand Down Expand Up @@ -213,21 +231,18 @@ private void readFields(MemoryBuffer buffer, Object[] fields) {
fury, refResolver, classResolver, fieldInfo, isFinal, buffer);
}
}
fields[counter++] = null;
}
}
for (ObjectSerializer.GenericTypeField fieldInfo : otherFields) {
Object fieldValue = ObjectSerializer.readOtherFieldValue(fury, fieldInfo, buffer);
if (fieldInfo.fieldAccessor != null) {
fields[counter++] = fieldValue;
}
fields[counter++] = fieldValue;
}
Generics generics = fury.getGenerics();
for (ObjectSerializer.GenericTypeField fieldInfo : containerFields) {
Object fieldValue =
ObjectSerializer.readContainerFieldValue(fury, generics, fieldInfo, buffer);
if (fieldInfo.fieldAccessor != null) {
fields[counter++] = fieldValue;
}
fields[counter++] = fieldValue;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import io.fury.type.Generics;
import io.fury.util.FieldAccessor;
import io.fury.util.Platform;
import io.fury.util.RecordInfo;
import io.fury.util.RecordUtils;
import io.fury.util.ReflectionUtils;
import java.lang.invoke.MethodHandle;
Expand Down Expand Up @@ -66,8 +67,7 @@ public final class ObjectSerializer<T> extends Serializer<T> {
private final RefResolver refResolver;
private final ClassResolver classResolver;
private final boolean isRecord;
private final Object[] recordComponents;
private final int[] recordComponentsIndex;
private final RecordInfo recordInfo;
private final FinalTypeField[] finalFields;
/**
* Whether write class def for non-inner final types.
Expand Down Expand Up @@ -104,13 +104,10 @@ public ObjectSerializer(Fury fury, Class<T> cls, boolean resolveParent) {
getSortedDescriptors(descriptorGrouper).stream()
.map(Descriptor::getName)
.collect(Collectors.toList());
recordComponentsIndex = Serializers.buildRecordComponentMapping(cls, fieldNames);
assert recordComponentsIndex != null;
recordComponents = new Object[recordComponentsIndex.length];
recordInfo = new RecordInfo(cls, fieldNames);
} else {
this.constructor = ReflectionUtils.getExecutableNoArgConstructorHandle(cls);
recordComponentsIndex = null;
recordComponents = null;
recordInfo = null;
}
if (fury.checkClassVersion()) {
classVersionHash = computeVersionHash(descriptors);
Expand Down Expand Up @@ -313,10 +310,10 @@ static void writeContainerFieldValue(
public T read(MemoryBuffer buffer) {
if (isRecord) {
Object[] fields = readFields(buffer);
Serializers.remapping(recordComponentsIndex, fields, recordComponents);
RecordUtils.remapping(recordInfo, fields);
try {
T obj = (T) constructor.invokeWithArguments(recordComponents);
Arrays.fill(recordComponents, null);
T obj = (T) constructor.invokeWithArguments(recordInfo.getRecordComponents());
Arrays.fill(recordInfo.getRecordComponents(), null);
return obj;
} catch (Throwable e) {
Platform.throwException(e);
Expand Down
32 changes: 0 additions & 32 deletions java/fury-core/src/main/java/io/fury/serializer/Serializers.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import io.fury.resolver.ClassResolver;
import io.fury.type.Type;
import io.fury.util.Platform;
import io.fury.util.RecordComponent;
import io.fury.util.RecordUtils;
import io.fury.util.Utils;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
Expand All @@ -33,10 +31,7 @@
import java.net.URI;
import java.nio.charset.Charset;
import java.util.Currency;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
Expand Down Expand Up @@ -139,33 +134,6 @@ public static Object readPrimitiveValue(Fury fury, MemoryBuffer buffer, short cl
}
}

static int[] buildRecordComponentMapping(Class<?> cls, List<String> fields) {
Map<String, Integer> fieldOrderIndex = new HashMap<>(fields.size());
int counter = 0;
for (String fieldName : fields) {
fieldOrderIndex.put(fieldName, counter++);
}
RecordComponent[] components = RecordUtils.getRecordComponents(cls);
if (components == null) {
return null;
}
int[] mapping = new int[components.length];
for (int i = 0; i < mapping.length; i++) {
RecordComponent component = components[i];
Integer index = fieldOrderIndex.get(component.getName());
mapping[i] = index;
}
return mapping;
}

public static void remapping(
int[] recordComponentsIndex, Object[] fields, Object[] recordComponents) {
for (int i = 0; i < recordComponentsIndex.length; i++) {
int index = recordComponentsIndex[i];
recordComponents[i] = fields[index];
}
}

public abstract static class CrossLanguageCompatibleSerializer<T> extends Serializer<T> {
private final short typeId;

Expand Down
49 changes: 49 additions & 0 deletions java/fury-core/src/main/java/io/fury/util/RecordInfo.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright 2023 The Fury Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.fury.util;

import java.util.List;

/**
* Record build information.
*
* @author chaokunyang
*/
public class RecordInfo {
private final int[] recordComponentsIndex;
private final Object[] recordComponentsDefaultValues;
private final Object[] recordComponents;

public RecordInfo(Class<?> cls, List<String> fieldNames) {
recordComponentsDefaultValues = RecordUtils.buildRecordComponentDefaultValues(cls);
recordComponentsIndex = RecordUtils.buildRecordComponentMapping(cls, fieldNames);
assert recordComponentsIndex != null;
recordComponents = new Object[recordComponentsIndex.length];
}

public int[] getRecordComponentsIndex() {
return recordComponentsIndex;
}

public Object[] getRecordComponentsDefaultValues() {
return recordComponentsDefaultValues;
}

public Object[] getRecordComponents() {
return recordComponents;
}
}
Loading

0 comments on commit 6343585

Please sign in to comment.