diff --git a/src/main/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadata.java b/src/main/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadata.java index 047a98e7a9..e17c962022 100644 --- a/src/main/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadata.java +++ b/src/main/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadata.java @@ -96,12 +96,13 @@ public TypeInformation getReturnType(Method method) { return returnType; } + @Override public Class getReturnedDomainClass(Method method) { TypeInformation returnType = getReturnType(method); + returnType = ReactiveWrapperConverters.unwrapWrapperTypes(returnType); - return QueryExecutionConverters.unwrapWrapperTypes(ReactiveWrapperConverters.unwrapWrapperTypes(returnType)) - .getType(); + return QueryExecutionConverters.unwrapWrapperTypes(returnType, getDomainTypeInformation()).getType(); } public Class getRepositoryInterface() { diff --git a/src/main/java/org/springframework/data/repository/query/QueryMethod.java b/src/main/java/org/springframework/data/repository/query/QueryMethod.java index 1d12f533d8..05fc4439fe 100644 --- a/src/main/java/org/springframework/data/repository/query/QueryMethod.java +++ b/src/main/java/org/springframework/data/repository/query/QueryMethod.java @@ -24,16 +24,17 @@ import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; -import org.springframework.data.domain.Window; import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.Slice; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Window; import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.repository.core.EntityMetadata; import org.springframework.data.repository.core.RepositoryMetadata; import org.springframework.data.repository.util.QueryExecutionConverters; import org.springframework.data.repository.util.ReactiveWrapperConverters; import org.springframework.data.util.Lazy; +import org.springframework.data.util.NullableWrapperConverters; import org.springframework.data.util.ReactiveWrappers; import org.springframework.data.util.TypeInformation; import org.springframework.util.Assert; @@ -296,7 +297,15 @@ private boolean calculateIsCollectionQuery() { return false; } - Class returnType = metadata.getReturnType(method).getType(); + TypeInformation returnTypeInformation = metadata.getReturnType(method); + + // Check against simple wrapper types first + if (metadata.getDomainTypeInformation() + .isAssignableFrom(NullableWrapperConverters.unwrapActualType(returnTypeInformation))) { + return false; + } + + Class returnType = returnTypeInformation.getType(); if (QueryExecutionConverters.supports(returnType) && !QueryExecutionConverters.isSingleValue(returnType)) { return true; diff --git a/src/main/java/org/springframework/data/repository/util/QueryExecutionConverters.java b/src/main/java/org/springframework/data/repository/util/QueryExecutionConverters.java index b900480c21..b6beff82c6 100644 --- a/src/main/java/org/springframework/data/repository/util/QueryExecutionConverters.java +++ b/src/main/java/org/springframework/data/repository/util/QueryExecutionConverters.java @@ -36,8 +36,8 @@ import org.springframework.core.convert.support.ConfigurableConversionService; import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.data.domain.Page; -import org.springframework.data.domain.Window; import org.springframework.data.domain.Slice; +import org.springframework.data.domain.Window; import org.springframework.data.geo.GeoResults; import org.springframework.data.util.CustomCollections; import org.springframework.data.util.NullableWrapper; @@ -85,6 +85,7 @@ public abstract class QueryExecutionConverters { private static final Set> ALLOWED_PAGEABLE_TYPES = new HashSet<>(); private static final Map, ExecutionAdapter> EXECUTION_ADAPTER = new HashMap<>(); private static final Map, Boolean> supportsCache = new ConcurrentReferenceHashMap<>(); + private static final TypeInformation VOID_INFORMATION = TypeInformation.of(Void.class); static { @@ -235,15 +236,21 @@ public static Object unwrap(@Nullable Object source) { } /** - * Recursively unwraps well known wrapper types from the given {@link TypeInformation}. + * Recursively unwraps well known wrapper types from the given {@link TypeInformation} but aborts at the given + * reference type. * * @param type must not be {@literal null}. + * @param reference must not be {@literal null}. * @return will never be {@literal null}. */ - public static TypeInformation unwrapWrapperTypes(TypeInformation type) { + public static TypeInformation unwrapWrapperTypes(TypeInformation type, TypeInformation reference) { Assert.notNull(type, "type must not be null"); + if (reference.isAssignableFrom(type)) { + return type; + } + Class rawType = type.getType(); boolean needToUnwrap = type.isCollectionLike() // @@ -253,7 +260,17 @@ public static TypeInformation unwrapWrapperTypes(TypeInformation type) { || supports(rawType) // || Stream.class.isAssignableFrom(rawType); - return needToUnwrap ? unwrapWrapperTypes(type.getRequiredComponentType()) : type; + return needToUnwrap ? unwrapWrapperTypes(type.getRequiredComponentType(), reference) : type; + } + + /** + * Recursively unwraps well known wrapper types from the given {@link TypeInformation}. + * + * @param type must not be {@literal null}. + * @return will never be {@literal null}. + */ + public static TypeInformation unwrapWrapperTypes(TypeInformation type) { + return unwrapWrapperTypes(type, VOID_INFORMATION); } /** diff --git a/src/test/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadataUnitTests.java b/src/test/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadataUnitTests.java index 234e872c22..5b45a5a09d 100755 --- a/src/test/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadataUnitTests.java +++ b/src/test/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadataUnitTests.java @@ -21,14 +21,19 @@ import java.lang.reflect.Method; import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.stream.Stream; +import org.junit.jupiter.api.DynamicTest; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestFactory; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; import org.springframework.data.querydsl.User; import org.springframework.data.repository.PagingAndSortingRepository; import org.springframework.data.repository.Repository; import org.springframework.data.repository.core.RepositoryMetadata; +import org.springframework.data.util.Streamable; /** * Unit tests for {@link AbstractRepositoryMetadata}. @@ -111,6 +116,25 @@ void doesNotUnwrapCustomTypeImplementingIterable() throws Exception { assertThat(metadata.getReturnedDomainClass(method)).isEqualTo(Container.class); } + @TestFactory // GH-2869 + Stream detectsReturnTypesForStreamableAggregates() throws Exception { + + var metadata = AbstractRepositoryMetadata.getMetadata(StreamableAggregateRepository.class); + var methods = Stream.of( + Map.entry("findBy", StreamableAggregate.class), + Map.entry("findSubTypeBy", StreamableAggregateSubType.class), + Map.entry("findAllBy", StreamableAggregate.class), + Map.entry("findOptional", StreamableAggregate.class)); + + return DynamicTest.stream(methods, // + it -> it.getKey() + "'s returned domain class is " + it.getValue(), // + it -> { + + var method = StreamableAggregateRepository.class.getMethod(it.getKey()); + assertThat(metadata.getReturnedDomainClass(method)).isEqualTo(it.getValue()); + }); + } + interface UserRepository extends Repository { User findSingle(); @@ -155,4 +179,20 @@ interface ContainerRepository extends Repository { interface CompletePageableAndSortingRepository extends PagingAndSortingRepository {} + // GH-2869 + + static abstract class StreamableAggregate implements Streamable {} + + interface StreamableAggregateRepository extends Repository { + + StreamableAggregate findBy(); + + StreamableAggregateSubType findSubTypeBy(); + + Streamable findAllBy(); + + Optional findOptional(); + } + + static abstract class StreamableAggregateSubType extends StreamableAggregate {} } diff --git a/src/test/java/org/springframework/data/repository/query/QueryMethodUnitTests.java b/src/test/java/org/springframework/data/repository/query/QueryMethodUnitTests.java index baa9cf6f60..ec1d0ed349 100755 --- a/src/test/java/org/springframework/data/repository/query/QueryMethodUnitTests.java +++ b/src/test/java/org/springframework/data/repository/query/QueryMethodUnitTests.java @@ -24,12 +24,16 @@ import java.io.Serializable; import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; import java.util.stream.Stream; import org.eclipse.collections.api.list.ImmutableList; +import org.junit.jupiter.api.DynamicTest; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestFactory; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.ScrollPosition; @@ -41,6 +45,7 @@ import org.springframework.data.repository.core.RepositoryMetadata; import org.springframework.data.repository.core.support.AbstractRepositoryMetadata; import org.springframework.data.repository.core.support.DefaultRepositoryMetadata; +import org.springframework.data.util.Streamable; /** * Unit tests for {@link QueryMethod}. @@ -302,6 +307,28 @@ void considersEclipseCollectionCollectionQuery() throws Exception { assertThat(queryMethod.isCollectionQuery()).isTrue(); } + @TestFactory // GH-2869 + Stream doesNotConsiderQueryMethodReturningAggregateImplementingStreamableACollectionQuery() + throws Exception { + + var metadata = AbstractRepositoryMetadata.getMetadata(StreamableAggregateRepository.class); + var stream = Stream.of( + Map.entry("findBy", false), + Map.entry("findSubTypeBy", false), + Map.entry("findAllBy", true), + Map.entry("findOptionalBy", false)); + + return DynamicTest.stream(stream, // + it -> it.getKey() + " considered collection query -> " + it.getValue(), // + it -> { + + var method = StreamableAggregateRepository.class.getMethod(it.getKey()); + var queryMethod = new QueryMethod(method, metadata, factory); + + assertThat(queryMethod.isCollectionQuery()).isEqualTo(it.getValue()); + }); + } + interface SampleRepository extends Repository { String pagingMethodWithInvalidReturnType(Pageable pageable); @@ -379,4 +406,21 @@ abstract class Container implements Iterable {} interface ContainerRepository extends Repository { Container someMethod(); } + + // GH-2869 + + static abstract class StreamableAggregate implements Streamable {} + + interface StreamableAggregateRepository extends Repository { + + StreamableAggregate findBy(); + + StreamableAggregateSubType findSubTypeBy(); + + Optional findOptionalBy(); + + Streamable findAllBy(); + } + + static abstract class StreamableAggregateSubType extends StreamableAggregate {} }