Skip to content

Commit

Permalink
Hibernate: Set the order using the Hibernate API instead of modifying…
Browse files Browse the repository at this point in the history
… the query (#3156)
  • Loading branch information
dstepanov authored Oct 1, 2024
1 parent 2943b03 commit 2a9d77e
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,6 @@
import io.micronaut.data.runtime.operations.internal.query.DefaultBindableParametersStoredQuery;
import io.micronaut.data.runtime.query.PreparedQueryDecorator;
import io.micronaut.data.runtime.query.StoredQueryDecorator;
import org.hibernate.graph.AttributeNode;
import org.hibernate.graph.Graph;
import org.hibernate.graph.RootGraph;
import org.hibernate.graph.SubGraph;

import jakarta.persistence.FlushModeType;
import jakarta.persistence.Tuple;
import jakarta.persistence.TupleElement;
Expand All @@ -63,6 +58,12 @@
import jakarta.persistence.criteria.Order;
import jakarta.persistence.criteria.Path;
import jakarta.persistence.criteria.Root;
import jakarta.validation.constraints.NotNull;
import org.hibernate.graph.AttributeNode;
import org.hibernate.graph.Graph;
import org.hibernate.graph.RootGraph;
import org.hibernate.graph.SubGraph;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -262,6 +263,15 @@ public Map<String, Object> getQueryHints(@NonNull StoredQuery<?, ?> storedQuery)
*/
protected abstract void setOffset(P query, int offset);

/**
* Sets the order.
*
* @param query The query
* @param orders The orders
* @since 4.10
*/
protected abstract void setOrder(P query, List<org.hibernate.query.Order<?>> orders);

/**
* Gets an entity graph.
*
Expand Down Expand Up @@ -322,7 +332,7 @@ public Map<String, Object> getQueryHints(@NonNull StoredQuery<?, ?> storedQuery)
*/
protected <R> void collectFindOne(S session, PreparedQuery<?, R> preparedQuery, ResultCollector<R> collector) {
String query = preparedQuery.getQuery();
collectResults(session, query, preparedQuery, collector);
collectResults(session, query, preparedQuery, preparedQuery.getPageable(), collector);
}

/**
Expand All @@ -340,30 +350,38 @@ protected <R> void collectFindAll(S session, PreparedQuery<?, R> preparedQuery,
if (pageable.getMode() != Mode.OFFSET) {
throw new UnsupportedOperationException("Pageable mode " + pageable.getMode() + " is not supported by hibernate operations");
}
Sort sort = pageable.getSort();
if (sort.isSorted()) {
queryStr += QUERY_BUILDER.buildOrderBy(queryStr, getEntity(preparedQuery.getRootEntity()), AnnotationMetadata.EMPTY_METADATA, sort,
preparedQuery.isNative()).getQuery();
if (preparedQuery.isNative()) {
// Native queries don't support setting the order
Sort sort = pageable.getSort();
if (sort.isSorted()) {
queryStr += QUERY_BUILDER.buildOrderBy(queryStr, getEntity(preparedQuery.getRootEntity()), AnnotationMetadata.EMPTY_METADATA, sort,
preparedQuery.isNative()).getQuery();
}
pageable = pageable.withoutSort();
}
}
collectResults(session, queryStr, preparedQuery, collector);
collectResults(session, queryStr, preparedQuery, pageable, collector);
}

private <T, R> void collectResults(S session, String queryStr, PreparedQuery<T, R> preparedQuery, ResultCollector<R> resultCollector) {
private <T, R> void collectResults(S session,
String queryStr,
PreparedQuery<T, R> preparedQuery,
Pageable pageable,
ResultCollector<R> resultCollector) {
if (preparedQuery.isDtoProjection()) {
P q;
if (preparedQuery.isNative()) {
q = createNativeQuery(session, queryStr, Tuple.class);
} else if (queryStr.toLowerCase(Locale.ENGLISH).startsWith("select new ")) {
@SuppressWarnings("unchecked") Class<R> wrapperType = (Class<R>) ReflectionUtils.getWrapperType(preparedQuery.getResultType());
P query = createQuery(session, queryStr, wrapperType);
bindPreparedQuery(query, preparedQuery, session);
bindPreparedQuery(query, preparedQuery, pageable, session);
resultCollector.collect(query);
return;
} else {
q = createQuery(session, queryStr, Tuple.class);
}
bindPreparedQuery(q, preparedQuery, session);
bindPreparedQuery(q, preparedQuery, pageable, session);
resultCollector.collectTuple(q, tuple -> {
Set<String> properties = tuple.getElements().stream().map(TupleElement::getAlias).collect(Collectors.toCollection(() -> new TreeSet<>(String.CASE_INSENSITIVE_ORDER)));
return (new BeanIntrospectionMapper<Tuple, R>() {
Expand All @@ -389,7 +407,7 @@ public ConversionService getConversionService() {
Class<T> rootEntity = preparedQuery.getRootEntity();
if (wrapperType != rootEntity) {
P nativeQuery = createNativeQuery(session, queryStr, Tuple.class);
bindPreparedQuery(nativeQuery, preparedQuery, session);
bindPreparedQuery(nativeQuery, preparedQuery, pageable, session);
resultCollector.collectTuple(nativeQuery, tuple -> {
Object o = tuple.get(0);
if (wrapperType.isInstance(o)) {
Expand All @@ -404,7 +422,7 @@ public ConversionService getConversionService() {
} else {
q = createQuery(session, queryStr, wrapperType);
}
bindPreparedQuery(q, preparedQuery, session);
bindPreparedQuery(q, preparedQuery, pageable, session);
resultCollector.collect(q);
}
}
Expand All @@ -420,34 +438,29 @@ public ConversionService getConversionService() {
*/
protected <T, R> void bindParameters(Q q, @NonNull PreparedQuery<T, R> preparedQuery, boolean bindNamed) {
BindableParametersPreparedQuery<T, R> bindableParametersPreparedQuery = getBindableParametersPreparedQuery(preparedQuery);
BindableParametersStoredQuery.Binder binder = createBinder(q, preparedQuery, preparedQuery.getArguments(), bindNamed);
BindableParametersStoredQuery.Binder binder = createBinder(q, preparedQuery.getArguments(), bindNamed);
bindableParametersPreparedQuery.bindParameters(binder);
}

/**
* Bind parameters into query.
*
* @param <T> The entity type
* @param <R> The result type
* @param q The query
* @param storedQuery The stored query
* @param invocationContext The invocationContext
* @param bindNamed If parameter should be bind by the name
* @param entity The entity
* @param <T> The entity type
* @param <R> The result type
*/
protected <T, R> void bindParameters(Q q, @NonNull StoredQuery<T, R> storedQuery,
InvocationContext<?, ?> invocationContext,
boolean bindNamed,
T entity) {
BindableParametersStoredQuery<T, R> bindableParametersPreparedQuery = (BindableParametersStoredQuery<T, R>) storedQuery;
BindableParametersStoredQuery.Binder binder = createBinder(q, storedQuery, invocationContext.getArguments(), bindNamed);
BindableParametersStoredQuery.Binder binder = createBinder(q, invocationContext.getArguments(), true);
bindableParametersPreparedQuery.bindParameters(binder, invocationContext, entity, null);
}

private <T, R> BindableParametersStoredQuery.Binder createBinder(Q q,
StoredQuery<T, R> storedQuery,
Argument<?>[] arguments,
boolean bindNamed) {
private BindableParametersStoredQuery.Binder createBinder(Q q, Argument<?>[] arguments, boolean bindNamed) {
return new BindableParametersStoredQuery.Binder() {

int index = 1;
Expand Down Expand Up @@ -515,9 +528,9 @@ public void bindMany(QueryParameterBinding binding, Collection<Object> values) {
};
}

private <T, R> void bindPreparedQuery(P q, @NonNull PreparedQuery<T, R> preparedQuery, S currentSession) {
private <T, R> void bindPreparedQuery(P q, @NonNull PreparedQuery<T, R> preparedQuery, Pageable pageable, S currentSession) {
bindParameters(q, preparedQuery, true);
bindPageable(q, preparedQuery.getPageable());
bindPageable(q, pageable, preparedQuery.getRootEntity());
bindQueryHints(q, preparedQuery, currentSession);
}

Expand Down Expand Up @@ -599,7 +612,7 @@ protected final FlushModeType getFlushModeType(AnnotationMetadata annotationMeta
return annotationMetadata.getAnnotationValuesByType(QueryHint.class).stream().filter(av -> FlushModeType.class.getName().equals(av.stringValue("name").orElse(null))).map(av -> av.enumValue("value", FlushModeType.class)).findFirst().orElse(Optional.empty()).orElse(null);
}

private void bindPageable(P q, @NonNull Pageable pageable) {
private void bindPageable(P q, @NonNull Pageable pageable, @NotNull Class<?> entityClass) {
if (pageable == Pageable.UNPAGED) {
// no pagination
return;
Expand All @@ -616,6 +629,19 @@ private void bindPageable(P q, @NonNull Pageable pageable) {
if (offset > 0) {
setOffset(q, (int) offset);
}
Sort sort = pageable.getSort();
if (sort.isSorted()) {
List<Sort.Order> orderBy = sort.getOrderBy();
List<org.hibernate.query.Order<?>> orders = new ArrayList<>(orderBy.size());
for (Sort.Order order : orderBy) {
if (order.isAscending()) {
orders.add(org.hibernate.query.Order.asc(entityClass, order.getProperty()));
} else {
orders.add(org.hibernate.query.Order.desc(entityClass, order.getProperty()));
}
}
setOrder(q, orders);
}
}

protected final <T> void collectPagedResults(CriteriaBuilder criteriaBuilder, S session, PagedQuery<T> pagedQuery, ResultCollector<T> resultCollector) {
Expand All @@ -625,7 +651,7 @@ protected final <T> void collectPagedResults(CriteriaBuilder criteriaBuilder, S
Root<T> root = query.from(entity);
bindCriteriaSort(query, root, criteriaBuilder, pageable);
P q = createQuery(session, query);
bindPageable(q, pageable);
bindPageable(q, pageable.withoutSort(), entity);
bindQueryHints(q, pagedQuery, session);
resultCollector.collect(q);
}
Expand All @@ -635,7 +661,7 @@ protected final <R> void collectCountOf(CriteriaBuilder criteriaBuilder, S sessi
countQuery.select(criteriaBuilder.count(countQuery.from(entity)));
P countQ = createQuery(session, countQuery);
if (pageable != null) {
bindPageable(countQ, pageable);
bindPageable(countQ, pageable.withoutSort(), entity);
}
resultCollector.collect(countQ);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
import org.hibernate.procedure.ProcedureCall;
import org.hibernate.query.CommonQueryContract;
import org.hibernate.query.MutationQuery;
import org.hibernate.query.Order;
import org.hibernate.query.Query;

import javax.sql.DataSource;
Expand All @@ -79,7 +80,6 @@
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand Down Expand Up @@ -231,14 +231,13 @@ protected void setOffset(Query<?> query, int offset) {
}

@Override
protected void setMaxResults(Query<?> query, int max) {
query.setMaxResults(max);
protected void setOrder(Query<?> query, List<Order<?>> orders) {
query.setOrder((List) orders);
}

@NonNull
@Override
public Map<String, Object> getQueryHints(@NonNull StoredQuery<?, ?> storedQuery) {
return super.getQueryHints(storedQuery);
protected void setMaxResults(Query<?> query, int max) {
query.setMaxResults(max);
}

@Nullable
Expand Down Expand Up @@ -549,7 +548,7 @@ public <T> Optional<Number> deleteAll(@NonNull DeleteBatchOperation<T> operation

private <T> int executeUpdate(Session session, StoredQuery<T, ?> storedQuery, InvocationContext<?, ?> invocationContext, T entity) {
MutationQuery query = session.createMutationQuery(storedQuery.getQuery());
bindParameters(query, storedQuery, invocationContext, true, entity);
bindParameters(query, storedQuery, invocationContext, entity);
return query.executeUpdate();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,14 @@
import jakarta.persistence.criteria.CriteriaQuery;
import jakarta.persistence.criteria.CriteriaUpdate;
import org.hibernate.SessionFactory;
import org.hibernate.query.Order;
import org.hibernate.reactive.stage.Stage;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.util.Collection;
import java.util.List;
import java.util.function.Function;

/**
Expand Down Expand Up @@ -150,6 +152,11 @@ protected void setOffset(Stage.SelectionQuery<?> query, int offset) {
query.setFirstResult(offset);
}

@Override
protected void setOrder(Stage.SelectionQuery<?> query, List<Order<?>> orders) {
query.setOrder((List) orders);
}

@Override
protected <T> EntityGraph<T> getEntityGraph(Stage.Session session, Class<T> entityType, String graphName) {
return session.getEntityGraph(entityType, graphName);
Expand Down Expand Up @@ -308,7 +315,7 @@ private <T> Mono<Integer> executeEntityUpdate(Stage.Session session,
InvocationContext<?, ?> invocationContext,
T entity) {
Stage.MutationQuery query = session.createMutationQuery(storedQuery.getQuery());
bindParameters(query, storedQuery, invocationContext, true, entity);
bindParameters(query, storedQuery, invocationContext, entity);
return helper.executeUpdate(query);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,10 @@ default Pageable orders(@NonNull List<Order> orders) {
*/
@NonNull
default Pageable withoutSort() {
return Pageable.from(getNumber(), getSize());
if (isSorted()) {
return Pageable.from(getNumber(), getSize());
}
return this;
}

/**
Expand Down

0 comments on commit 2a9d77e

Please sign in to comment.