Skip to content

Commit d77b715

Browse files
asymprosbrannen
authored andcommitted
Merge class-level and method-level @SQL declarations
See gh-1835
1 parent b0939a8 commit d77b715

File tree

5 files changed

+137
-20
lines changed

5 files changed

+137
-20
lines changed

spring-test/src/main/java/org/springframework/test/context/jdbc/Sql.java

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
* SQL {@link #scripts} and {@link #statements} to be executed against a given
3232
* database during integration tests.
3333
*
34-
* <p>Method-level declarations override class-level declarations.
34+
* <p>Method-level declarations override class-level declarations by default.
35+
* This behaviour can be adjusted via {@link MergeMode}
3536
*
3637
* <p>Script execution is performed by the {@link SqlScriptsTestExecutionListener},
3738
* which is enabled by default.
@@ -146,6 +147,13 @@
146147
*/
147148
SqlConfig config() default @SqlConfig;
148149

150+
/**
151+
* Indicates whether this annotation should be merged with upper-level annotations
152+
* or override them.
153+
* <p>Defaults to {@link MergeMode#OVERRIDE}.
154+
*/
155+
MergeMode mergeMode() default MergeMode.OVERRIDE;
156+
149157

150158
/**
151159
* Enumeration of <em>phases</em> that dictate when SQL scripts are executed.
@@ -165,4 +173,23 @@ enum ExecutionPhase {
165173
AFTER_TEST_METHOD
166174
}
167175

176+
/**
177+
* Enumeration of <em>modes</em> that dictate whether or not
178+
* declared SQL {@link #scripts} and {@link #statements} are merged
179+
* with the upper-level annotations.
180+
*/
181+
enum MergeMode {
182+
183+
/**
184+
* Indicates that locally declared SQL {@link #scripts} and {@link #statements}
185+
* should override the upper-level (e.g. Class-level) annotations.
186+
*/
187+
OVERRIDE,
188+
189+
/**
190+
* Indicates that locally declared SQL {@link #scripts} and {@link #statements}
191+
* should be merged the upper-level (e.g. Class-level) annotations.
192+
*/
193+
MERGE
194+
}
168195
}

spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,17 @@
1616

1717
package org.springframework.test.context.jdbc;
1818

19+
import java.lang.reflect.AnnotatedElement;
1920
import java.lang.reflect.Method;
2021
import java.util.List;
2122
import java.util.Set;
23+
import java.util.stream.Collectors;
2224
import javax.sql.DataSource;
2325

2426
import org.apache.commons.logging.Log;
2527
import org.apache.commons.logging.LogFactory;
2628

29+
import org.jetbrains.annotations.NotNull;
2730
import org.springframework.context.ApplicationContext;
2831
import org.springframework.core.annotation.AnnotatedElementUtils;
2932
import org.springframework.core.io.ByteArrayResource;
@@ -126,19 +129,35 @@ public void afterTestMethod(TestContext testContext) throws Exception {
126129
* {@link TestContext} and {@link ExecutionPhase}.
127130
*/
128131
private void executeSqlScripts(TestContext testContext, ExecutionPhase executionPhase) throws Exception {
129-
boolean classLevel = false;
130-
131-
Set<Sql> sqlAnnotations = AnnotatedElementUtils.getMergedRepeatableAnnotations(
132-
testContext.getTestMethod(), Sql.class, SqlGroup.class);
133-
if (sqlAnnotations.isEmpty()) {
134-
sqlAnnotations = AnnotatedElementUtils.getMergedRepeatableAnnotations(
135-
testContext.getTestClass(), Sql.class, SqlGroup.class);
136-
if (!sqlAnnotations.isEmpty()) {
137-
classLevel = true;
138-
}
132+
Set<Sql> methodLevelSqls = getScriptsFromElement(testContext.getTestMethod());
133+
List<Sql> methodLevelOverrides = methodLevelSqls.stream()
134+
.filter(s -> s.executionPhase() == executionPhase)
135+
.filter(s -> s.mergeMode() == Sql.MergeMode.OVERRIDE)
136+
.collect(Collectors.toList());
137+
if (methodLevelOverrides.isEmpty()) {
138+
executeScripts(getScriptsFromElement(testContext.getTestClass()), testContext, executionPhase, true);
139+
executeScripts(methodLevelSqls, testContext, executionPhase, false);
140+
} else {
141+
executeScripts(methodLevelOverrides, testContext, executionPhase, false);
139142
}
143+
}
144+
145+
/**
146+
* Get SQL scripts configured via {@link Sql @Sql} for the supplied
147+
* {@link AnnotatedElement}.
148+
*/
149+
private Set<Sql> getScriptsFromElement(AnnotatedElement annotatedElement) throws Exception {
150+
return AnnotatedElementUtils.getMergedRepeatableAnnotations(annotatedElement, Sql.class, SqlGroup.class);
151+
}
140152

141-
for (Sql sql : sqlAnnotations) {
153+
/**
154+
* Execute given {@link Sql @Sql} scripts.
155+
* {@link AnnotatedElement}.
156+
*/
157+
private void executeScripts(Iterable<Sql> scripts, TestContext testContext, ExecutionPhase executionPhase,
158+
boolean classLevel)
159+
throws Exception {
160+
for (Sql sql : scripts) {
142161
executeSqlScripts(sql, executionPhase, testContext, classLevel);
143162
}
144163
}
@@ -166,14 +185,7 @@ private void executeSqlScripts(Sql sql, ExecutionPhase executionPhase, TestConte
166185
mergedSqlConfig, executionPhase, testContext));
167186
}
168187

169-
final ResourceDatabasePopulator populator = new ResourceDatabasePopulator();
170-
populator.setSqlScriptEncoding(mergedSqlConfig.getEncoding());
171-
populator.setSeparator(mergedSqlConfig.getSeparator());
172-
populator.setCommentPrefix(mergedSqlConfig.getCommentPrefix());
173-
populator.setBlockCommentStartDelimiter(mergedSqlConfig.getBlockCommentStartDelimiter());
174-
populator.setBlockCommentEndDelimiter(mergedSqlConfig.getBlockCommentEndDelimiter());
175-
populator.setContinueOnError(mergedSqlConfig.getErrorMode() == ErrorMode.CONTINUE_ON_ERROR);
176-
populator.setIgnoreFailedDrops(mergedSqlConfig.getErrorMode() == ErrorMode.IGNORE_FAILED_DROPS);
188+
final ResourceDatabasePopulator populator = configurePopulator(mergedSqlConfig);
177189

178190
String[] scripts = getScripts(sql, testContext, classLevel);
179191
scripts = TestContextResourceUtils.convertToClasspathResourcePaths(testContext.getTestClass(), scripts);
@@ -232,6 +244,19 @@ private void executeSqlScripts(Sql sql, ExecutionPhase executionPhase, TestConte
232244
}
233245
}
234246

247+
@NotNull
248+
private ResourceDatabasePopulator configurePopulator(MergedSqlConfig mergedSqlConfig) {
249+
final ResourceDatabasePopulator populator = new ResourceDatabasePopulator();
250+
populator.setSqlScriptEncoding(mergedSqlConfig.getEncoding());
251+
populator.setSeparator(mergedSqlConfig.getSeparator());
252+
populator.setCommentPrefix(mergedSqlConfig.getCommentPrefix());
253+
populator.setBlockCommentStartDelimiter(mergedSqlConfig.getBlockCommentStartDelimiter());
254+
populator.setBlockCommentEndDelimiter(mergedSqlConfig.getBlockCommentEndDelimiter());
255+
populator.setContinueOnError(mergedSqlConfig.getErrorMode() == ErrorMode.CONTINUE_ON_ERROR);
256+
populator.setIgnoreFailedDrops(mergedSqlConfig.getErrorMode() == ErrorMode.IGNORE_FAILED_DROPS);
257+
return populator;
258+
}
259+
235260
@Nullable
236261
private DataSource getDataSourceFromTransactionManager(PlatformTransactionManager transactionManager) {
237262
try {

spring-test/src/test/java/org/springframework/test/context/jdbc/RepeatableSqlAnnotationSqlScriptsTests.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.springframework.test.annotation.DirtiesContext;
2626
import org.springframework.test.context.ContextConfiguration;
2727
import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests;
28+
import org.springframework.test.jdbc.JdbcTestUtils;
2829

2930
import static org.assertj.core.api.Assertions.assertThat;
3031

@@ -58,6 +59,10 @@ public void test02_methodLevelScripts() {
5859
assertNumUsers(2);
5960
}
6061

62+
protected int countRowsInTable(String tableName) {
63+
return JdbcTestUtils.countRowsInTable(this.jdbcTemplate, tableName);
64+
}
65+
6166
protected void assertNumUsers(int expected) {
6267
assertThat(countRowsInTable("user")).as("Number of rows in the 'user' table.").isEqualTo(expected);
6368
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package org.springframework.test.context.jdbc;
2+
3+
import org.junit.Test;
4+
import org.springframework.test.annotation.DirtiesContext;
5+
import org.springframework.test.context.ContextConfiguration;
6+
import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests;
7+
8+
import static org.junit.Assert.assertEquals;
9+
10+
/**
11+
* Test to verify method level merge of @Sql annotations.
12+
*
13+
* @author Dmitry Semukhin
14+
*/
15+
@ContextConfiguration(classes = EmptyDatabaseConfig.class)
16+
@Sql(value = {"schema.sql", "data-add-catbert.sql"})
17+
@DirtiesContext
18+
public class SqlMethodMergeTest extends AbstractTransactionalJUnit4SpringContextTests {
19+
20+
@Test
21+
@Sql(value = "data-add-dogbert.sql", mergeMode = Sql.MergeMode.MERGE)
22+
public void testMerge() {
23+
assertNumUsers(2);
24+
}
25+
26+
protected void assertNumUsers(int expected) {
27+
assertEquals("Number of rows in the 'user' table.", expected, countRowsInTable("user"));
28+
}
29+
30+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package org.springframework.test.context.jdbc;
2+
3+
import org.junit.Test;
4+
import org.springframework.test.annotation.DirtiesContext;
5+
import org.springframework.test.context.ContextConfiguration;
6+
import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests;
7+
8+
import static org.junit.Assert.assertEquals;
9+
10+
/**
11+
* Test to verify method level override of @Sql annotations.
12+
*
13+
* @author Dmitry Semukhin
14+
*/
15+
@ContextConfiguration(classes = EmptyDatabaseConfig.class)
16+
@Sql(value = {"schema.sql", "data-add-catbert.sql"})
17+
@DirtiesContext
18+
public class SqlMethodOverrideTest extends AbstractTransactionalJUnit4SpringContextTests {
19+
20+
@Test
21+
@Sql(value = {"schema.sql", "data.sql", "data-add-dogbert.sql", "data-add-catbert.sql"}, mergeMode = Sql.MergeMode.OVERRIDE)
22+
public void testMerge() {
23+
assertNumUsers(3);
24+
}
25+
26+
protected void assertNumUsers(int expected) {
27+
assertEquals("Number of rows in the 'user' table.", expected, countRowsInTable("user"));
28+
}
29+
30+
}

0 commit comments

Comments
 (0)