Skip to content

Commit

Permalink
[Blazebit#564] Add support for specifications
Browse files Browse the repository at this point in the history
  • Loading branch information
Giovanni Lovato committed Aug 1, 2018
1 parent 4334304 commit ae0ea2e
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
import com.blazebit.persistence.PagedList;
import com.blazebit.persistence.PaginatedCriteriaBuilder;
import com.blazebit.persistence.criteria.BlazeCriteria;
import com.blazebit.persistence.criteria.BlazeCriteriaBuilder;
import com.blazebit.persistence.criteria.BlazeCriteriaQuery;
import com.blazebit.persistence.spring.data.base.query.JpaParameters.JpaParameter;
import com.blazebit.persistence.view.EntityViewManager;
import com.blazebit.persistence.view.EntityViewSetting;

import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.data.jpa.domain.Specification;
import org.springframework.data.jpa.provider.PersistenceProvider;
import org.springframework.data.jpa.repository.query.AbstractJpaQuery;
import org.springframework.data.jpa.repository.query.FixedJpaCountQueryCreator;
Expand All @@ -44,8 +46,12 @@
import javax.persistence.TypedQuery;
import javax.persistence.criteria.CriteriaBuilder;
import javax.persistence.criteria.CriteriaQuery;
import javax.persistence.criteria.Predicate;
import javax.persistence.criteria.Root;

import java.util.Collections;
import java.util.List;
import java.util.regex.Pattern;

/**
* Implementation is similar to {@link PartTreeJpaQuery} but was modified to work with entity views.
Expand All @@ -56,6 +62,8 @@
*/
public abstract class AbstractPartTreeBlazePersistenceQuery extends AbstractJpaQuery {

private static final Pattern QUERY_PATTERN = Pattern.compile("^(find|read|get|query|stream)All$");

private final Class<?> domainClass;
private final Class<?> entityViewClass;
private final PartTree tree;
Expand All @@ -73,12 +81,20 @@ public AbstractPartTreeBlazePersistenceQuery(EntityViewAwareJpaQueryMethod metho
this.evm = evm;

this.entityViewClass = method.getEntityViewClass();

this.domainClass = method.getEntityInformation().getJavaType();
this.tree = new PartTree(method.getName(), domainClass);
this.parameters = method.getJpaParameters();

boolean recreateQueries = parameters.potentiallySortsDynamically() || entityViewClass != null;
this.parameters = method.getJpaParameters();
String methodName = method.getName();
boolean matchesSpecificationSignature = parameters.hasSpecificationParameter()
&& QUERY_PATTERN.matcher(methodName).matches();
String source = matchesSpecificationSignature ? "" : methodName;
this.tree = new PartTree(source, domainClass);

/*-
boolean recreateQueries = parameters.potentiallySortsDynamically() || entityViewClass != null
|| matchesSpecificationSignature;
-*/
boolean recreateQueries = true;
this.query = isCountProjection(tree) ? new AbstractPartTreeBlazePersistenceQuery.CountQueryPreparer(persistenceProvider,
recreateQueries) : new AbstractPartTreeBlazePersistenceQuery.QueryPreparer(persistenceProvider, recreateQueries);
}
Expand Down Expand Up @@ -238,7 +254,10 @@ private TypedQuery<?> createQuery(CriteriaQuery<?> criteriaQuery, Object[] value
}

protected TypedQuery<?> createQuery0(CriteriaQuery<?> criteriaQuery, Object[] values) {
processSpecification(criteriaQuery, values);

com.blazebit.persistence.CriteriaBuilder<?> cb = ((BlazeCriteriaQuery<?>) criteriaQuery).createCriteriaBuilder();

if (entityViewClass == null) {
return cb.getQuery();
} else {
Expand All @@ -259,6 +278,8 @@ Query createPaginatedQuery(Object[] values, boolean withCount) {
expressions = creator.getParameterExpressions();
}

processSpecification(criteriaQuery, values);

com.blazebit.persistence.CriteriaBuilder<?> cb = ((BlazeCriteriaQuery<?>) criteriaQuery).createCriteriaBuilder();
TypedQuery<Object> jpaQuery;
ParameterBinder binder = getBinder(values, expressions);
Expand Down Expand Up @@ -294,6 +315,19 @@ protected void processSetting(EntityViewSetting<?, ?> setting, Object[] values)
}
}

@SuppressWarnings({ "rawtypes", "unchecked" })
protected void processSpecification(CriteriaQuery<?> criteriaQuery, Object[] values) {
BlazeCriteriaQuery<?> blazeCriteriaQuery = (BlazeCriteriaQuery<?>) criteriaQuery;
int specificationIndex = parameters.getSpecificationIndex();
if (specificationIndex >= 0) {
Specification<?> specification = (Specification<?>) values[specificationIndex];
Root root = criteriaQuery.getRoots().iterator().next();
BlazeCriteriaBuilder criteriaBuilder = blazeCriteriaQuery.getCriteriaBuilder();
Predicate predicate = specification.toPredicate(root, criteriaQuery, criteriaBuilder);
criteriaQuery.where(predicate);
}
}

protected FixedJpaQueryCreator createCreator(ParametersParameterAccessor accessor,
PersistenceProvider persistenceProvider) {
BlazeCriteriaQuery<Long> cq = BlazeCriteria.get(getEntityManager(), cbf, Long.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.blazebit.persistence.spring.data.base.query;

import org.springframework.core.MethodParameter;
import org.springframework.data.jpa.domain.Specification;
import org.springframework.data.jpa.repository.Temporal;

import com.blazebit.persistence.spring.data.annotation.OptionalParam;
Expand Down Expand Up @@ -71,6 +72,34 @@ public JpaParameters getOptionalParameters() {
return createFrom(parameters);
}

/**
* Returns the index of the {@link Specification} {@link Method} parameter if available. Will return {@literal -1} if there
* is no {@link Specification} parameter in the {@link Method}'s parameter list.
*
* @return the index of the specification parameter, or -1 if not present
*/
public int getSpecificationIndex() {
int index = 0;

for (JpaParameter candidate : this) {
if (candidate.isSpecificationParameter()) {
return index;
}
++index;
}

return -1;
}

/**
* Returns whether the method the {@link Parameters} was created for contains a {@link Specification} parameter.
*
* @return true if the methods has a specification parameter
*/
public boolean hasSpecificationParameter() {
return getSpecificationIndex() >= 0;
}

/*
* (non-Javadoc)
* @see org.springframework.data.repository.query.Parameters#createParameter(org.springframework.core.MethodParameter)
Expand Down Expand Up @@ -141,13 +170,17 @@ public boolean isBindable() {

@Override
public boolean isSpecialParameter() {
return super.isSpecialParameter() || isOptionalParameter();
return super.isSpecialParameter() || isOptionalParameter() || isSpecificationParameter();
}

boolean isOptionalParameter() {
return optional != null;
}

boolean isSpecificationParameter() {
return Specification.class.isAssignableFrom(parameter.getParameterType());
}

/**
* @return {@literal true} if this parameter is of type {@link Date} and has an {@link Temporal} annotation.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,29 @@ public void testFindWithOptionalParameterAndPageable() {
assertEquals(param, optionalParameter);
}

@Test
public void testFindAllBySpecWithOptionalParameter() {
// Given
final Document d3 = createDocument("d3", null, 3L, null);
final Document d2 = createDocument("d2", null, 2L, null);
final Document d1 = createDocument("d1", null, 1L, null);

final String param = "Foo";

// When
List<DocumentView> actual = documentRepository.findAll(new Specification<Document>() {

@Override
public Predicate toPredicate(Root<Document> root, CriteriaQuery<?> criteriaQuery, CriteriaBuilder criteriaBuilder) {
return criteriaBuilder.ge(root.<Long>get("age"), 2L);
}
}, param);

// Then
assertEquals(2, actual.size());
assertEquals(actual.get(0).getOptionalParameter(), param);
}

private List<Long> getIdsFromViews(Iterable<DocumentAccessor> views) {
List<Long> ids = new ArrayList<>();
for (DocumentAccessor view : views) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Slice;
import org.springframework.data.jpa.domain.Specification;
import org.springframework.data.jpa.repository.EntityGraph;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.NoRepositoryBean;
Expand Down Expand Up @@ -75,4 +76,6 @@ public interface DocumentRepository<T> extends EntityViewRepository<T, Long>, En
List<DocumentView> findByName(String name, @OptionalParam("optionalParameter") String optionalParameter);

Page<DocumentView> findByNameOrderById(String name, Pageable pageable, @OptionalParam("optionalParameter") String optionalParameter);

List<DocumentView> findAll(Specification<Document> specification, @OptionalParam("optionalParameter") String optionalParameter);
}

0 comments on commit ae0ea2e

Please sign in to comment.