dukethrash
dukethrash

Reputation: 1529

How to only allow specific fields to sort by in a Spring Data JPA Repository Pageable?

Using a Pageable parameter in a Spring Data JPA Repository allows for specifying fields to sort by like: PageRequest.of(0, 50, Sort.by("field1", "field2")), which would sort by field1 and field2 ascending.

It works by appending an ORDER BY clause directly by doing SQL injection which would result in a JPA query like SELECT a FROM SomeEntity a ORDER BY field1, field2. However, if a non-existing field name is passed in it would result in a org.springframework.dao.InvalidDataAccessApiUsageException as seen below.

How do you whitelist, only allow specific fields, or validate the sorting without adding boilerplate code in a service that wraps the repository? Same goes for in a @RestController ensuring that a 400 level HttpStatus.BAD_REQUEST is returned to the API?

org.springframework.dao.InvalidDataAccessApiUsageException: org.hibernate.QueryException: could not resolve property: field1 of: com.example.SomeEntity [SELECT a FROM com.example.SomeEntity a order by a.field1 asc, a.field2 asc]
        at org.springframework.orm.jpa.EntityManagerFactoryUtils.convertJpaAccessExceptionIfPossible(EntityManagerFactoryUtils.java:374)
        at org.springframework.orm.jpa.vendor.HibernateJpaDialect.translateExceptionIfPossible(HibernateJpaDialect.java:257)
        at org.springframework.orm.jpa.AbstractEntityManagerFactoryBean.translateExceptionIfPossible(AbstractEntityManagerFactoryBean.java:531)
        at org.springframework.dao.support.ChainedPersistenceExceptionTranslator.translateExceptionIfPossible(ChainedPersistenceExceptionTranslator.java:61)
        at org.springframework.dao.support.DataAccessUtils.translateIfNecessary(DataAccessUtils.java:242)
        at org.springframework.dao.support.PersistenceExceptionTranslationInterceptor.invoke(PersistenceExceptionTranslationInterceptor.java:154)
        at org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:186)
        at org.springframework.data.jpa.repository.support.CrudMethodMetadataPostProcessor$CrudMethodMetadataPopulatingMethodInterceptor.invoke(CrudMethodMetadataPostProcessor.java:149)
        at org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:186)
        at org.springframework.aop.interceptor.ExposeInvocationInterceptor.invoke(ExposeInvocationInterceptor.java:95)
        at org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:186)
        at org.springframework.aop.framework.JdkDynamicAopProxy.invoke(JdkDynamicAopProxy.java:212)
        at com.sun.proxy.$Proxy340.searchPaged(Unknown Source)
        at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
        at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
        at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.base/java.lang.reflect.Method.invoke(Method.java:566)
        at org.zeroturnaround.jrebel.integration.springdata.RepositoryReloadingProxyFactoryBuilder$ReloadingMethodHandler.invoke(RepositoryReloadingProxyFactoryBuilder.java:80)
...
Caused by: java.lang.IllegalArgumentException: org.hibernate.QueryException: could not resolve property: field1 of: com.example.SomeEntity [SELECT a FROM com.example.SomeEntity a order by a.field1 asc, a.field2 asc]
        at org.hibernate.internal.ExceptionConverterImpl.convert(ExceptionConverterImpl.java:138)
        at org.hibernate.internal.ExceptionConverterImpl.convert(ExceptionConverterImpl.java:181)
        at org.hibernate.internal.ExceptionConverterImpl.convert(ExceptionConverterImpl.java:188)
        at org.hibernate.internal.AbstractSharedSessionContract.createQuery(AbstractSharedSessionContract.java:725)
        at org.hibernate.internal.AbstractSharedSessionContract.createQuery(AbstractSharedSessionContract.java:113)
        at jdk.internal.reflect.GeneratedMethodAccessor749.invoke(Unknown Source)
        at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.base/java.lang.reflect.Method.invoke(Method.java:566)
        at org.springframework.orm.jpa.ExtendedEntityManagerCreator$ExtendedEntityManagerInvocationHandler.invoke(ExtendedEntityManagerCreator.java:366)
        at com.sun.proxy.$Proxy265.createQuery(Unknown Source)
        at jdk.internal.reflect.GeneratedMethodAccessor749.invoke(Unknown Source)
        at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.base/java.lang.reflect.Method.invoke(Method.java:566)
        at org.springframework.orm.jpa.SharedEntityManagerCreator$SharedEntityManagerInvocationHandler.invoke(SharedEntityManagerCreator.java:314)
        at com.sun.proxy.$Proxy265.createQuery(Unknown Source)
        at org.springframework.data.jpa.repository.query.AbstractStringBasedJpaQuery.createJpaQuery(AbstractStringBasedJpaQuery.java:150)
        at org.springframework.data.jpa.repository.query.AbstractStringBasedJpaQuery.doCreateQuery(AbstractStringBasedJpaQuery.java:86)
        at org.springframework.data.jpa.repository.query.AbstractJpaQuery.createQuery(AbstractJpaQuery.java:226)
        at org.springframework.data.jpa.repository.query.JpaQueryExecution$PagedExecution.doExecute(JpaQueryExecution.java:175)
        at org.springframework.data.jpa.repository.query.JpaQueryExecution.execute(JpaQueryExecution.java:88)
        at org.springframework.data.jpa.repository.query.AbstractJpaQuery.doExecute(AbstractJpaQuery.java:154)
        at org.springframework.data.jpa.repository.query.AbstractJpaQuery.execute(AbstractJpaQuery.java:142)
        at org.springframework.data.repository.core.support.RepositoryFactorySupport$QueryExecutorMethodInterceptor.doInvoke(RepositoryFactorySupport.java:618)
        at org.springframework.data.repository.core.support.RepositoryFactorySupport$QueryExecutorMethodInterceptor.invoke(RepositoryFactorySupport.java:605)
        at org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:186)
        at org.springframework.data.projection.DefaultMethodInvokingMethodInterceptor.invoke(DefaultMethodInvokingMethodInterceptor.java:80)
        at org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:186)
        at org.springframework.transaction.interceptor.TransactionAspectSupport.invokeWithinTransaction(TransactionAspectSupport.java:367)
        at org.springframework.transaction.interceptor.TransactionInterceptor.invoke(TransactionInterceptor.java:118)
        at org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:186)
        at org.springframework.dao.support.PersistenceExceptionTranslationInterceptor.invoke(PersistenceExceptionTranslationInterceptor.java:139)
        ... 127 more
Caused by: org.hibernate.QueryException: could not resolve property: field1 of: com.example.SomeEntity [SELECT a FROM com.example.SomeEntity a order by a.field1 asc, a.field2 asc]
        at org.hibernate.QueryException.generateQueryException(QueryException.java:120)
        at org.hibernate.QueryException.wrapWithQueryString(QueryException.java:103)
        at org.hibernate.hql.internal.ast.QueryTranslatorImpl.doCompile(QueryTranslatorImpl.java:220)
        at org.hibernate.hql.internal.ast.QueryTranslatorImpl.compile(QueryTranslatorImpl.java:144)
        at org.hibernate.engine.query.spi.HQLQueryPlan.<init>(HQLQueryPlan.java:113)
        at org.hibernate.engine.query.spi.HQLQueryPlan.<init>(HQLQueryPlan.java:73)
        at org.hibernate.engine.query.spi.QueryPlanCache.getHQLQueryPlan(QueryPlanCache.java:162)
        at org.hibernate.internal.AbstractSharedSessionContract.getQueryPlan(AbstractSharedSessionContract.java:604)
        at org.hibernate.internal.AbstractSharedSessionContract.createQuery(AbstractSharedSessionContract.java:716)
        ... 154 more
Caused by: org.hibernate.QueryException: could not resolve property: field1 of: com.example.SomeEntity
        at org.hibernate.persister.entity.AbstractPropertyMapping.propertyException(AbstractPropertyMapping.java:77)
        at org.hibernate.persister.entity.AbstractPropertyMapping.toType(AbstractPropertyMapping.java:71)
        at org.hibernate.persister.entity.AbstractEntityPersister.toType(AbstractEntityPersister.java:2043)
        at org.hibernate.hql.internal.ast.tree.FromElementType.getPropertyType(FromElementType.java:412)
        at org.hibernate.hql.internal.ast.tree.FromElement.getPropertyType(FromElement.java:520)
        at org.hibernate.hql.internal.ast.tree.DotNode.getDataType(DotNode.java:694)
        at org.hibernate.hql.internal.ast.tree.DotNode.prepareLhs(DotNode.java:269)
        at org.hibernate.hql.internal.ast.tree.DotNode.resolve(DotNode.java:209)
        at org.hibernate.hql.internal.ast.HqlSqlWalker.resolve(HqlSqlWalker.java:1053)
        at org.hibernate.hql.internal.antlr.HqlSqlBaseWalker.expr(HqlSqlBaseWalker.java:1303)
        at org.hibernate.hql.internal.antlr.HqlSqlBaseWalker.orderExpr(HqlSqlBaseWalker.java:1887)
        at org.hibernate.hql.internal.antlr.HqlSqlBaseWalker.orderExprs(HqlSqlBaseWalker.java:1681)
        at org.hibernate.hql.internal.antlr.HqlSqlBaseWalker.orderClause(HqlSqlBaseWalker.java:1654)
        at org.hibernate.hql.internal.antlr.HqlSqlBaseWalker.query(HqlSqlBaseWalker.java:666)
        at org.hibernate.hql.internal.antlr.HqlSqlBaseWalker.selectStatement(HqlSqlBaseWalker.java:325)
        at org.hibernate.hql.internal.antlr.HqlSqlBaseWalker.statement(HqlSqlBaseWalker.java:273)
        at org.hibernate.hql.internal.ast.QueryTranslatorImpl.analyze(QueryTranslatorImpl.java:276)
        at org.hibernate.hql.internal.ast.QueryTranslatorImpl.doCompile(QueryTranslatorImpl.java:192)
        ... 160 more


Upvotes: 3

Views: 4239

Answers (1)

dukethrash
dukethrash

Reputation: 1529

I ended up using JSR-303 validations on the repository methods to whitelist the sort fields.

Enable method validation post processor to run JSR-303 validation annotations at the method level.

ValidationConfig.java

@Configuration
public class ValidationConfig {

    @Bean
    public MethodValidationPostProcessor methodValidationPostProcessor() {
        return new MethodValidationPostProcessor();
    }
}

Create a validation that takes in a list of sort fields to validate against.

AllowSortFields.java

@Documented
@Constraint(validatedBy = {AllowSortFieldsValidator.class})
@Target({ANNOTATION_TYPE, TYPE, FIELD, PARAMETER})
@Retention(RUNTIME)
public @interface AllowSortFields {

    String message() default "Sort field values provided are not within the allowed fields that are sortable.";

    Class<?>[] groups() default {};

    Class<? extends Payload>[] payload() default {};

    /**
     * Specify an array of fields that are allowed.
     *
     * @return the allowed sort fields
     */
    String[] value() default {};

}

AllowSortFieldsValidator.java

/**
 * Validates a list of sort fields within a Pageable against an allowed list.
 */
public class AllowSortFieldsValidator implements ConstraintValidator<AllowSortFields, Pageable> {

    private List<String> allowedSortFields;

    static final String PROPERTY_NOT_FOUND_MESSAGE = "The following sort fields [%s] are not within the allowed fields. "
            + "Allowed sort fields are: [%s]";

    @Override
    public void initialize(AllowSortFields constraintAnnotation) {
        allowedSortFields = Arrays.asList(constraintAnnotation.value());
    }

    @Override
    public boolean isValid(Pageable value, ConstraintValidatorContext context) {
        if (value == null) {
            return true;
        }

        if (CollectionUtils.isEmpty(allowedSortFields)) {
            return true;
        }

        // ignore unsorted
        Sort sort = value.getSort();
        if (sort.isUnsorted()) {
            return true;
        }

        String fieldsNotFound = fieldsNotFoundAsCommaDelimited(sort);

        // all found fields are allowed
        if (StringUtils.isEmpty(fieldsNotFound)) {
            return true;
        }

        context.disableDefaultConstraintViolation();
        context.buildConstraintViolationWithTemplate(String.format(PROPERTY_NOT_FOUND_MESSAGE, fieldsNotFound, String.join(",", allowedSortFields)))
                .addConstraintViolation();
        return false;

    }

    private String fieldsNotFoundAsCommaDelimited(Sort sort) {
        String fieldsNotFound = sort.stream()
                .map(order -> order.getProperty())
                .filter(property -> !allowedSortFields.contains(property))
                .collect(joining(","));
        return fieldsNotFound;
    }
}

AllowSortFieldsValidatorSmallTest.java

public class AllowSortFieldsValidatorSmallTest {

    private static final String[] ALLOWED_SORT_FIELDS = new String[]{"allowed1", "allowed2"};

    private static final String ALLOWED_SORT_FIELDS_DELIMITED = String.join(",", Arrays.asList(ALLOWED_SORT_FIELDS));

    private static Validator validator;

    @BeforeClass
    public static void setupValidator() throws Exception {
        ValidatorFactory factory = Validation.buildDefaultValidatorFactory();
        validator = factory.getValidator();
    }

    @Test
    public void isValid_TwoOfFourFieldsAllowed_FalseWithExpectedMessageExplainingDisallowedFields() {

        List<String> sortFields = List.of("allowed1", "allowed2|desc", "notfound1", "not.found2");

        AllowedSortFields toValidate = newAllowedSortFields(sortFields);

        Set<ConstraintViolation<AllowedSortFields>> constraintViolations = validator.validate(toValidate, Default.class);

        String expected = String.format(AllowSortFieldsValidator.PROPERTY_NOT_FOUND_MESSAGE, "notfound1,not.found2", ALLOWED_SORT_FIELDS_DELIMITED);
        String actual = getConstraintMessages(constraintViolations);

        assertEquals(expected, actual);

    }

    @Test
    public void isValid_NoSortFields_True() {

        List<String> sortFields = null;

        AllowedSortFields toValidate = newAllowedSortFields(sortFields);

        Set<ConstraintViolation<AllowedSortFields>> constraintViolations = validator.validate(toValidate, Default.class);

        assertTrue(constraintViolations.isEmpty());

    }

    @Test
    public void isValid_EmptyAllowedSortFields_True() {

        List<String> sortFields = List.of("allowed1", "allowed2|desc", "notfound1", "not.found2");

        EmptyAllowedSortFields toValidate = newEmptyAllowedSortFields(sortFields);

        Set<ConstraintViolation<EmptyAllowedSortFields>> constraintViolations = validator.validate(toValidate, Default.class);

        assertTrue(constraintViolations.isEmpty());

    }

    @Test
    public void isValid_AllSortFieldsFoundAsAllowed_True() {

        List<String> sortFields = Arrays.asList(ALLOWED_SORT_FIELDS);

        AllowedSortFields toValidate = newAllowedSortFields(sortFields);

        Set<ConstraintViolation<AllowedSortFields>> constraintViolations = validator.validate(toValidate, Default.class);

        assertTrue(constraintViolations.isEmpty());

    }

    @Test
    public void isValid_NullValue_True() {

        AllowedSortFields toValidate = new AllowedSortFields();
        toValidate.pageable = null;

        Set<ConstraintViolation<AllowedSortFields>> constraintViolations = validator.validate(toValidate, Default.class);

        assertTrue(constraintViolations.isEmpty());

    }

    private String getConstraintMessages(Set<ConstraintViolation<AllowedSortFields>> constraintViolations) {
        String actual = constraintViolations.stream()
                .map(c -> c.getMessage())
                .collect(joining(","));
        return actual;
    }

    private AllowedSortFields newAllowedSortFields(List<String> sortFields) {
        AllowedSortFields toValidate = new AllowedSortFields();
        toValidate.pageable = new CustomPageable().sort(sortFields);
        return toValidate;
    }

    private EmptyAllowedSortFields newEmptyAllowedSortFields(List<String> sortFields) {
        EmptyAllowedSortFields toValidate = new EmptyAllowedSortFields();
        toValidate.pageable = new CustomPageable().sort(sortFields);
        return toValidate;
    }

    public class AllowedSortFields {

        @AllowSortFields({"allowed1", "allowed2"})
        public Pageable pageable;

    }

    public class EmptyAllowedSortFields {

        @AllowSortFields
        public Pageable pageable;

    }
}

Finally the usage within the repository. Be sure to put @Validated at the top of the class.

ExampleSearchRepository.java

public interface ExampleSearchRepository extends JpaRepository<ExampleSearch, Integer>,
    JpaSpecificationExecutor<ExampleSearch>, PagingAndSortingRepository<ExampleSearch, Integer> {

    public Page<ExampleSearch> search(
        @Param("searchCriteria") ExampleSearchCriteria searchCriteria, 
        @AllowSortFields({"field1","subfield.name"}) Pageable pageable);

Upvotes: 5

Related Questions