Skip to content

Commit

Permalink
Patch up function metadata for arguments nullability (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
weijiii authored May 17, 2023
1 parent 9f2c9a8 commit d2d62d2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import com.linkedin.transport.test.spi.StdTester;
import java.util.List;
import java.util.Map;
import org.testng.Assert;
import org.testng.annotations.Test;


Expand All @@ -32,16 +31,9 @@ public void testFileLookup() {
tester.check(functionCall("file_lookup", null, 1), null, "boolean");
}

@Test
@Test(expectedExceptions = NullPointerException.class)
public void testFileLookupFailNull() {
try {
StdTester tester = getTester();
// in case of Trino, the execution of a query with UDF to check a null value in a file
// does not result in a NullPointerException, but returns a null value
tester.check(functionCall("file_lookup", resource("file_lookup_function/sample"), null), null, "boolean");
} catch (NullPointerException ex) {
// in case of Hive and Spark, the execution of a query with UDF to check a null value in a file results in a NullPointerException
Assert.assertFalse(isTrinoTest());
}
StdTester tester = getTester();
tester.check(functionCall("file_lookup", resource("file_lookup_function/sample"), null), null, "boolean");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import java.util.stream.IntStream;
import org.apache.commons.lang3.ClassUtils;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.linkedin.transport.trino.StdUDFUtils.quoteReservedKeywords;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.*;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;
Expand All @@ -76,6 +77,7 @@ public StdUdfWrapper(StdUDF stdUDF) {
.nullable()
.nondeterministic()
.description(((TopLevelStdUDF) stdUDF).getFunctionDescription())
.argumentNullability(getArgumentNullabilityObjects(stdUDF.getNullableArguments()))
.signature(Signature.builder()
.name(((TopLevelStdUDF) stdUDF).getFunctionName())
.typeVariableConstraints(getTypeVariableConstraintsForStdUdf(stdUDF))
Expand Down Expand Up @@ -174,6 +176,12 @@ private MethodHandle getMethodHandle(StdUDF stdUDF, BoundSignature boundSignatur
outputType instanceof IntegerType, requiredFilesNextRefreshTime);
}

private List<Boolean> getArgumentNullabilityObjects(boolean[] argumentsNullability) {
return IntStream.range(0, argumentsNullability.length)
.mapToObj(idx -> argumentsNullability[idx] ? Boolean.TRUE : Boolean.FALSE)
.collect(toImmutableList());
}

private List<InvocationConvention.InvocationArgumentConvention> getNullConventionForArguments(
boolean[] nullableArguments) {
return IntStream.range(0, nullableArguments.length)
Expand Down

0 comments on commit d2d62d2

Please sign in to comment.