Skip to content

Commit

Permalink
Fix query execution mode detection for aggregate types that implement…
Browse files Browse the repository at this point in the history
… Streamable.

We now short-circuit the QueryMethod.isCollectionQuery() algorithm in case we find the concrete domain type or any subclass of it.

Fixes #2869.
  • Loading branch information
odrotbohm committed Jun 30, 2023
1 parent 05dd7ae commit ca9f9bf
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,13 @@ public TypeInformation<?> getReturnType(Method method) {
return returnType;
}

@Override
public Class<?> getReturnedDomainClass(Method method) {

TypeInformation<?> returnType = getReturnType(method);
returnType = ReactiveWrapperConverters.unwrapWrapperTypes(returnType);

return QueryExecutionConverters.unwrapWrapperTypes(ReactiveWrapperConverters.unwrapWrapperTypes(returnType))
.getType();
return QueryExecutionConverters.unwrapWrapperTypes(returnType, getDomainTypeInformation()).getType();
}

public Class<?> getRepositoryInterface() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@

import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Window;
import org.springframework.data.domain.ScrollPosition;
import org.springframework.data.domain.Slice;
import org.springframework.data.domain.Sort;
import org.springframework.data.domain.Window;
import org.springframework.data.projection.ProjectionFactory;
import org.springframework.data.repository.core.EntityMetadata;
import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.data.repository.util.QueryExecutionConverters;
import org.springframework.data.repository.util.ReactiveWrapperConverters;
import org.springframework.data.util.Lazy;
import org.springframework.data.util.NullableWrapperConverters;
import org.springframework.data.util.ReactiveWrappers;
import org.springframework.data.util.TypeInformation;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -296,7 +297,15 @@ private boolean calculateIsCollectionQuery() {
return false;
}

Class<?> returnType = metadata.getReturnType(method).getType();
TypeInformation<?> returnTypeInformation = metadata.getReturnType(method);

// Check against simple wrapper types first
if (metadata.getDomainTypeInformation()
.isAssignableFrom(NullableWrapperConverters.unwrapActualType(returnTypeInformation))) {
return false;
}

Class<?> returnType = returnTypeInformation.getType();

if (QueryExecutionConverters.supports(returnType) && !QueryExecutionConverters.isSingleValue(returnType)) {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
import org.springframework.core.convert.support.ConfigurableConversionService;
import org.springframework.core.convert.support.DefaultConversionService;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Window;
import org.springframework.data.domain.Slice;
import org.springframework.data.domain.Window;
import org.springframework.data.geo.GeoResults;
import org.springframework.data.util.CustomCollections;
import org.springframework.data.util.NullableWrapper;
Expand Down Expand Up @@ -85,6 +85,7 @@ public abstract class QueryExecutionConverters {
private static final Set<Class<?>> ALLOWED_PAGEABLE_TYPES = new HashSet<>();
private static final Map<Class<?>, ExecutionAdapter> EXECUTION_ADAPTER = new HashMap<>();
private static final Map<Class<?>, Boolean> supportsCache = new ConcurrentReferenceHashMap<>();
private static final TypeInformation<Void> VOID_INFORMATION = TypeInformation.of(Void.class);

static {

Expand Down Expand Up @@ -235,15 +236,21 @@ public static Object unwrap(@Nullable Object source) {
}

/**
* Recursively unwraps well known wrapper types from the given {@link TypeInformation}.
* Recursively unwraps well known wrapper types from the given {@link TypeInformation} but aborts at the given
* reference type.
*
* @param type must not be {@literal null}.
* @param reference must not be {@literal null}.
* @return will never be {@literal null}.
*/
public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type) {
public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type, TypeInformation<?> reference) {

Assert.notNull(type, "type must not be null");

if (reference.isAssignableFrom(type)) {
return type;
}

Class<?> rawType = type.getType();

boolean needToUnwrap = type.isCollectionLike() //
Expand All @@ -253,7 +260,17 @@ public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type) {
|| supports(rawType) //
|| Stream.class.isAssignableFrom(rawType);

return needToUnwrap ? unwrapWrapperTypes(type.getRequiredComponentType()) : type;
return needToUnwrap ? unwrapWrapperTypes(type.getRequiredComponentType(), reference) : type;
}

/**
* Recursively unwraps well known wrapper types from the given {@link TypeInformation}.
*
* @param type must not be {@literal null}.
* @return will never be {@literal null}.
*/
public static TypeInformation<?> unwrapWrapperTypes(TypeInformation<?> type) {
return unwrapWrapperTypes(type, VOID_INFORMATION);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,19 @@
import java.lang.reflect.Method;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;

import org.junit.jupiter.api.DynamicTest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestFactory;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.querydsl.User;
import org.springframework.data.repository.PagingAndSortingRepository;
import org.springframework.data.repository.Repository;
import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.data.util.Streamable;

/**
* Unit tests for {@link AbstractRepositoryMetadata}.
Expand Down Expand Up @@ -111,6 +116,25 @@ void doesNotUnwrapCustomTypeImplementingIterable() throws Exception {
assertThat(metadata.getReturnedDomainClass(method)).isEqualTo(Container.class);
}

@TestFactory // GH-2869
Stream<DynamicTest> detectsReturnTypesForStreamableAggregates() throws Exception {

var metadata = AbstractRepositoryMetadata.getMetadata(StreamableAggregateRepository.class);
var methods = Stream.of(
Map.entry("findBy", StreamableAggregate.class),
Map.entry("findSubTypeBy", StreamableAggregateSubType.class),
Map.entry("findAllBy", StreamableAggregate.class),
Map.entry("findOptional", StreamableAggregate.class));

return DynamicTest.stream(methods, //
it -> it.getKey() + "'s returned domain class is " + it.getValue(), //
it -> {

var method = StreamableAggregateRepository.class.getMethod(it.getKey());
assertThat(metadata.getReturnedDomainClass(method)).isEqualTo(it.getValue());
});
}

interface UserRepository extends Repository<User, Long> {

User findSingle();
Expand Down Expand Up @@ -155,4 +179,20 @@ interface ContainerRepository extends Repository<Container, Long> {

interface CompletePageableAndSortingRepository extends PagingAndSortingRepository<Container, Long> {}

// GH-2869

static abstract class StreamableAggregate implements Streamable<Object> {}

interface StreamableAggregateRepository extends Repository<StreamableAggregate, Object> {

StreamableAggregate findBy();

StreamableAggregateSubType findSubTypeBy();

Streamable<StreamableAggregate> findAllBy();

Optional<StreamableAggregate> findOptional();
}

static abstract class StreamableAggregateSubType extends StreamableAggregate {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@

import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import java.util.stream.Stream;

import org.eclipse.collections.api.list.ImmutableList;
import org.junit.jupiter.api.DynamicTest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestFactory;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.ScrollPosition;
Expand All @@ -41,6 +45,7 @@
import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.data.repository.core.support.AbstractRepositoryMetadata;
import org.springframework.data.repository.core.support.DefaultRepositoryMetadata;
import org.springframework.data.util.Streamable;

/**
* Unit tests for {@link QueryMethod}.
Expand Down Expand Up @@ -302,6 +307,28 @@ void considersEclipseCollectionCollectionQuery() throws Exception {
assertThat(queryMethod.isCollectionQuery()).isTrue();
}

@TestFactory // GH-2869
Stream<DynamicTest> doesNotConsiderQueryMethodReturningAggregateImplementingStreamableACollectionQuery()
throws Exception {

var metadata = AbstractRepositoryMetadata.getMetadata(StreamableAggregateRepository.class);
var stream = Stream.of(
Map.entry("findBy", false),
Map.entry("findSubTypeBy", false),
Map.entry("findAllBy", true),
Map.entry("findOptionalBy", false));

return DynamicTest.stream(stream, //
it -> it.getKey() + " considered collection query -> " + it.getValue(), //
it -> {

var method = StreamableAggregateRepository.class.getMethod(it.getKey());
var queryMethod = new QueryMethod(method, metadata, factory);

assertThat(queryMethod.isCollectionQuery()).isEqualTo(it.getValue());
});
}

interface SampleRepository extends Repository<User, Serializable> {

String pagingMethodWithInvalidReturnType(Pageable pageable);
Expand Down Expand Up @@ -379,4 +406,21 @@ abstract class Container implements Iterable<Element> {}
interface ContainerRepository extends Repository<Container, Long> {
Container someMethod();
}

// GH-2869

static abstract class StreamableAggregate implements Streamable<Object> {}

interface StreamableAggregateRepository extends Repository<StreamableAggregate, Object> {

StreamableAggregate findBy();

StreamableAggregateSubType findSubTypeBy();

Optional<StreamableAggregate> findOptionalBy();

Streamable<StreamableAggregate> findAllBy();
}

static abstract class StreamableAggregateSubType extends StreamableAggregate {}
}

0 comments on commit ca9f9bf

Please sign in to comment.