diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SerializableComparator.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SerializableComparator.java index c66fbb7d7497..16304633c993 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SerializableComparator.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SerializableComparator.java @@ -19,10 +19,29 @@ import java.io.Serializable; import java.util.Comparator; +import java.util.Objects; +import java.util.function.Function; /** * A {@code Comparator} that is also {@code Serializable}. * * @param type of values being compared */ -public interface SerializableComparator extends Comparator, Serializable {} +public interface SerializableComparator extends Comparator, Serializable { + /** + * Analogous to {@link Comparator#comparing(Function)}, except that it takes in a {@link + * SerializableFunction} as the key extractor and returns a {@link SerializableComparator}. + * + * @param keyExtractor the function used to extract the {@link java.lang.Comparable} sort key + * @return A {@link SerializableComparator} that compares by an extracted key + * @param the type of element to be compared + * @param the type of the {@code Comparable} sort key + * @see Comparator#comparing(Function) + */ + static > SerializableComparator comparing( + SerializableFunction keyExtractor) { + Objects.requireNonNull(keyExtractor); + return (SerializableComparator) + (c1, c2) -> keyExtractor.apply(c1).compareTo(keyExtractor.apply(c2)); + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SerializableComparatorTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SerializableComparatorTest.java new file mode 100644 index 000000000000..09583ec44f28 --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SerializableComparatorTest.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.transforms; + +import java.io.Serializable; +import java.util.function.Function; +import org.apache.beam.sdk.util.SerializableUtils; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link SerializableComparator}. */ +@RunWith(JUnit4.class) +public class SerializableComparatorTest { + + /** + * Tests if the {@link SerializableComparator} returned by {@link + * SerializableComparator#comparing(SerializableFunction)} using {@link + * SerializableUtils#ensureSerializable(Serializable)}. + */ + @Test + public void testSerializable() { + SerializableFunction fn = Integer::parseInt; + + SerializableComparator cmp = SerializableComparator.comparing(fn); + SerializableUtils.ensureSerializable(cmp); + } + + /** + * Tests if {@link SerializableComparator#comparing(Function)} throws a {@link + * java.lang.NullPointerException} if null is passed to it. + */ + @Test(expected = NullPointerException.class) + public void testIfNPEThrownForNullFunction() { + SerializableComparator.comparing(null); + } + + /** Tests the basic comparison function of the {@link SerializableComparator} returned. */ + @Test + public void testBasicComparison() { + SerializableFunction fn = Integer::parseInt; + SerializableComparator cmp = SerializableComparator.comparing(fn); + + Assert.assertTrue(cmp.compare("1", "10") < 0); + Assert.assertTrue(cmp.compare("9", "6") > 0); + } +}