diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java index c09cc68dec2e..d9dc5c27fc34 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java @@ -102,6 +102,7 @@ * @author Rod Johnson * @author Juergen Hoeller * @author Thomas Risberg + * @author Yanming Zhou * @since May 3, 2001 * @see JdbcOperations * @see PreparedStatementCreator @@ -473,12 +474,12 @@ public String getSql() { @Override public void query(String sql, RowCallbackHandler rch) throws DataAccessException { - query(sql, new RowCallbackHandlerResultSetExtractor(rch)); + query(sql, new RowCallbackHandlerResultSetExtractor(rch, this.maxRows)); } @Override public List query(String sql, RowMapper rowMapper) throws DataAccessException { - return result(query(sql, new RowMapperResultSetExtractor<>(rowMapper))); + return result(query(sql, new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows))); } @Override @@ -488,7 +489,7 @@ class StreamStatementCallback implements StatementCallback>, SqlProvid public Stream doInStatement(Statement stmt) throws SQLException { ResultSet rs = stmt.executeQuery(sql); Connection con = stmt.getConnection(); - return new ResultSetSpliterator<>(rs, rowMapper).stream().onClose(() -> { + return new ResultSetSpliterator<>(rs, rowMapper, JdbcTemplate.this.maxRows).stream().onClose(() -> { JdbcUtils.closeResultSet(rs); JdbcUtils.closeStatement(stmt); DataSourceUtils.releaseConnection(con, getDataSource()); @@ -756,12 +757,12 @@ private String appendSql(@Nullable String sql, String statement) { @Override public void query(PreparedStatementCreator psc, RowCallbackHandler rch) throws DataAccessException { - query(psc, new RowCallbackHandlerResultSetExtractor(rch)); + query(psc, new RowCallbackHandlerResultSetExtractor(rch, this.maxRows)); } @Override public void query(String sql, @Nullable PreparedStatementSetter pss, RowCallbackHandler rch) throws DataAccessException { - query(sql, pss, new RowCallbackHandlerResultSetExtractor(rch)); + query(sql, pss, new RowCallbackHandlerResultSetExtractor(rch, this.maxRows)); } @Override @@ -782,28 +783,28 @@ public void query(String sql, RowCallbackHandler rch, @Nullable Object @Nullable @Override public List query(PreparedStatementCreator psc, RowMapper rowMapper) throws DataAccessException { - return result(query(psc, new RowMapperResultSetExtractor<>(rowMapper))); + return result(query(psc, new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows))); } @Override public List query(String sql, @Nullable PreparedStatementSetter pss, RowMapper rowMapper) throws DataAccessException { - return result(query(sql, pss, new RowMapperResultSetExtractor<>(rowMapper))); + return result(query(sql, pss, new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows))); } @Override public List query(String sql, @Nullable Object @Nullable [] args, int[] argTypes, RowMapper rowMapper) throws DataAccessException { - return result(query(sql, args, argTypes, new RowMapperResultSetExtractor<>(rowMapper))); + return result(query(sql, args, argTypes, new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows))); } @Deprecated @Override public List query(String sql, @Nullable Object @Nullable [] args, RowMapper rowMapper) throws DataAccessException { - return result(query(sql, newArgPreparedStatementSetter(args), new RowMapperResultSetExtractor<>(rowMapper))); + return result(query(sql, newArgPreparedStatementSetter(args), new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows))); } @Override public List query(String sql, RowMapper rowMapper, @Nullable Object @Nullable ... args) throws DataAccessException { - return result(query(sql, newArgPreparedStatementSetter(args), new RowMapperResultSetExtractor<>(rowMapper))); + return result(query(sql, newArgPreparedStatementSetter(args), new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows))); } /** @@ -828,7 +829,7 @@ public Stream queryForStream(PreparedStatementCreator psc, @Nullable Prep } ResultSet rs = ps.executeQuery(); Connection con = ps.getConnection(); - return new ResultSetSpliterator<>(rs, rowMapper).stream().onClose(() -> { + return new ResultSetSpliterator<>(rs, rowMapper, this.maxRows).stream().onClose(() -> { JdbcUtils.closeResultSet(rs); if (pss instanceof ParameterDisposer parameterDisposer) { parameterDisposer.cleanupParameters(); @@ -1347,7 +1348,7 @@ protected Map processResultSet( } else if (param.getRowCallbackHandler() != null) { RowCallbackHandler rch = param.getRowCallbackHandler(); - (new RowCallbackHandlerResultSetExtractor(rch)).extractData(rs); + (new RowCallbackHandlerResultSetExtractor(rch, -1)).extractData(rs); return Collections.singletonMap(param.getName(), "ResultSet returned from stored procedure was processed"); } @@ -1730,13 +1731,17 @@ private static class RowCallbackHandlerResultSetExtractor implements ResultSetEx private final RowCallbackHandler rch; - public RowCallbackHandlerResultSetExtractor(RowCallbackHandler rch) { + private final int maxRows; + + public RowCallbackHandlerResultSetExtractor(RowCallbackHandler rch, int maxRows) { this.rch = rch; + this.maxRows = maxRows; } @Override public @Nullable Object extractData(ResultSet rs) throws SQLException { - while (rs.next()) { + int processed = 0; + while (rs.next() && (this.maxRows == -1 || (processed++) < this.maxRows)) { this.rch.processRow(rs); } return null; @@ -1754,17 +1759,20 @@ private static class ResultSetSpliterator implements Spliterator { private final RowMapper rowMapper; + private final int maxRows; + private int rowNum = 0; - public ResultSetSpliterator(ResultSet rs, RowMapper rowMapper) { + public ResultSetSpliterator(ResultSet rs, RowMapper rowMapper, int maxRows) { this.rs = rs; this.rowMapper = rowMapper; + this.maxRows = maxRows; } @Override public boolean tryAdvance(Consumer action) { try { - if (this.rs.next()) { + if (this.rs.next() && (this.maxRows == -1 || this.rowNum < this.maxRows)) { action.accept(this.rowMapper.mapRow(this.rs, this.rowNum++)); return true; } diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/RowMapperResultSetExtractor.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/RowMapperResultSetExtractor.java index 70eb055bf5e7..6834ace60080 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/RowMapperResultSetExtractor.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/RowMapperResultSetExtractor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,6 +52,7 @@ * you can have executable query objects (containing row-mapping logic) there. * * @author Juergen Hoeller + * @author Yanming Zhou * @since 1.0.2 * @param the result element type * @see RowMapper @@ -64,6 +65,8 @@ public class RowMapperResultSetExtractor implements ResultSetExtractor rowMapper) { * (just used for optimized collection handling) */ public RowMapperResultSetExtractor(RowMapper rowMapper, int rowsExpected) { + this(rowMapper, rowsExpected, -1); + } + + /** + * Create a new RowMapperResultSetExtractor. + * @param rowMapper the RowMapper which creates an object for each row + * @param rowsExpected the number of expected rows + * (just used for optimized collection handling) + * @param maxRows the number of max rows + */ + public RowMapperResultSetExtractor(RowMapper rowMapper, int rowsExpected, int maxRows) { Assert.notNull(rowMapper, "RowMapper must not be null"); this.rowMapper = rowMapper; this.rowsExpected = rowsExpected; + this.maxRows = maxRows; } @@ -90,7 +105,7 @@ public RowMapperResultSetExtractor(RowMapper rowMapper, int rowsExpected) { public List extractData(ResultSet rs) throws SQLException { List results = (this.rowsExpected > 0 ? new ArrayList<>(this.rowsExpected) : new ArrayList<>()); int rowNum = 0; - while (rs.next()) { + while (rs.next() && (this.maxRows == -1 || rowNum < this.maxRows)) { results.add(this.rowMapper.mapRow(rs, rowNum++)); } return results; diff --git a/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java b/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java index 4848f4709feb..b2c78dd79572 100644 --- a/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java +++ b/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2024 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,7 +32,9 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.function.BiFunction; import java.util.function.Consumer; +import java.util.stream.Stream; import javax.sql.DataSource; @@ -77,6 +79,7 @@ * @author Thomas Risberg * @author Juergen Hoeller * @author Phillip Webb + * @author Yanming Zhou */ class JdbcTemplateTests { @@ -1236,6 +1239,50 @@ public int getBatchSize() { Collections.singletonMap("someId", 456)); } + @Test + void testSkipFurtherRowsOnceMaxRowsHasBeenReachedForRowMapper() throws Exception { + testDiscardFurtherRowsOnceMaxRowsHasBeenReached((template, sql) -> + template.query(sql, (rs, rowNum) -> rs.getString(1))); + } + + @Test + void testDiscardFurtherRowsOnceMaxRowsHasBeenReachedForRowCallbackHandler() throws Exception { + testDiscardFurtherRowsOnceMaxRowsHasBeenReached((template, sql) -> { + List list = new ArrayList<>(); + template.query(sql, (RowCallbackHandler) rs -> list.add(rs.getString(1))); + return list; + }); + } + + @Test + void testDiscardFurtherRowsOnceMaxRowsHasBeenReachedForStream() throws Exception { + testDiscardFurtherRowsOnceMaxRowsHasBeenReached((template, sql) -> { + try (Stream stream = template.queryForStream(sql, (rs, rowNum) -> rs.getString(1))) { + return stream.toList(); + } + }); + } + + private void testDiscardFurtherRowsOnceMaxRowsHasBeenReached(BiFunction> function) throws Exception { + String sql = "SELECT FORENAME FROM CUSTMR"; + String[] results = {"rod", "gary", " portia"}; + int maxRows = 2; + + given(this.resultSet.next()).willReturn(true, true, true, false); + given(this.resultSet.getString(1)).willReturn(results[0], results[1], results[2]); + given(this.connection.createStatement()).willReturn(this.preparedStatement); + + JdbcTemplate template = new JdbcTemplate(); + template.setDataSource(this.dataSource); + template.setMaxRows(maxRows); + + assertThat(function.apply(template, sql)).as("same length").hasSize(maxRows); + + verify(this.resultSet).close(); + verify(this.preparedStatement).close(); + verify(this.connection).close(); + } + private void mockDatabaseMetaData(boolean supportsBatchUpdates) throws SQLException { DatabaseMetaData databaseMetaData = mock(); given(databaseMetaData.getDatabaseProductName()).willReturn("MySQL");