|
20 | 20 | import java.lang.reflect.Method;
|
21 | 21 | import java.util.List;
|
22 | 22 | import java.util.Set;
|
23 |
| -import java.util.stream.Collectors; |
24 | 23 | import javax.sql.DataSource;
|
25 | 24 |
|
26 | 25 | import org.apache.commons.logging.Log;
|
|
38 | 37 | import org.springframework.test.context.jdbc.Sql.ExecutionPhase;
|
39 | 38 | import org.springframework.test.context.jdbc.SqlConfig.ErrorMode;
|
40 | 39 | import org.springframework.test.context.jdbc.SqlConfig.TransactionMode;
|
| 40 | +import org.springframework.test.context.jdbc.SqlMergeMode.MergeMode; |
41 | 41 | import org.springframework.test.context.support.AbstractTestExecutionListener;
|
42 | 42 | import org.springframework.test.context.transaction.TestContextTransactionUtils;
|
43 | 43 | import org.springframework.test.context.util.TestContextResourceUtils;
|
@@ -130,36 +130,57 @@ public void afterTestMethod(TestContext testContext) {
|
130 | 130 | * {@link TestContext} and {@link ExecutionPhase}.
|
131 | 131 | */
|
132 | 132 | private void executeSqlScripts(TestContext testContext, ExecutionPhase executionPhase) {
|
133 |
| - Set<Sql> methodLevelSqls = getSqlAnnotationsFor(testContext.getTestMethod()); |
134 |
| - List<Sql> methodLevelOverrides = methodLevelSqls.stream() |
135 |
| - .filter(s -> s.executionPhase() == executionPhase) |
136 |
| - .filter(s -> s.mergeMode() == Sql.MergeMode.OVERRIDE) |
137 |
| - .collect(Collectors.toList()); |
138 |
| - if (methodLevelOverrides.isEmpty()) { |
139 |
| - executeScripts(getSqlAnnotationsFor(testContext.getTestClass()), testContext, executionPhase, true); |
140 |
| - executeScripts(methodLevelSqls, testContext, executionPhase, false); |
141 |
| - } else { |
142 |
| - executeScripts(methodLevelOverrides, testContext, executionPhase, false); |
| 133 | + Method testMethod = testContext.getTestMethod(); |
| 134 | + Class<?> testClass = testContext.getTestClass(); |
| 135 | + |
| 136 | + if (mergeSqlAnnotations(testContext)) { |
| 137 | + executeSqlScripts(getSqlAnnotationsFor(testClass), testContext, executionPhase, true); |
| 138 | + executeSqlScripts(getSqlAnnotationsFor(testMethod), testContext, executionPhase, false); |
| 139 | + } |
| 140 | + else { |
| 141 | + Set<Sql> methodLevelSqlAnnotations = getSqlAnnotationsFor(testMethod); |
| 142 | + if (!methodLevelSqlAnnotations.isEmpty()) { |
| 143 | + executeSqlScripts(methodLevelSqlAnnotations, testContext, executionPhase, false); |
| 144 | + } |
| 145 | + else { |
| 146 | + executeSqlScripts(getSqlAnnotationsFor(testClass), testContext, executionPhase, true); |
| 147 | + } |
| 148 | + } |
| 149 | + } |
| 150 | + |
| 151 | + /** |
| 152 | + * Determine if method-level {@code @Sql} annotations should be merged with |
| 153 | + * class-level {@code @Sql} annotations. |
| 154 | + */ |
| 155 | + private boolean mergeSqlAnnotations(TestContext testContext) { |
| 156 | + SqlMergeMode sqlMergeMode = getSqlMergeModeFor(testContext.getTestMethod()); |
| 157 | + if (sqlMergeMode == null) { |
| 158 | + sqlMergeMode = getSqlMergeModeFor(testContext.getTestClass()); |
143 | 159 | }
|
| 160 | + return (sqlMergeMode != null && sqlMergeMode.value() == MergeMode.MERGE); |
| 161 | + } |
| 162 | + |
| 163 | + /** |
| 164 | + * Get the {@code @SqlMergeMode} annotation declared on the supplied {@code element}. |
| 165 | + */ |
| 166 | + private SqlMergeMode getSqlMergeModeFor(AnnotatedElement element) { |
| 167 | + return AnnotatedElementUtils.findMergedAnnotation(element, SqlMergeMode.class); |
144 | 168 | }
|
145 | 169 |
|
146 | 170 | /**
|
147 |
| - * Get the {@link Sql @Sql} annotations declared on the supplied |
148 |
| - * {@link AnnotatedElement}. |
| 171 | + * Get the {@code @Sql} annotations declared on the supplied {@code element}. |
149 | 172 | */
|
150 |
| - private Set<Sql> getSqlAnnotationsFor(AnnotatedElement annotatedElement) { |
151 |
| - return AnnotatedElementUtils.getMergedRepeatableAnnotations(annotatedElement, Sql.class, SqlGroup.class); |
| 173 | + private Set<Sql> getSqlAnnotationsFor(AnnotatedElement element) { |
| 174 | + return AnnotatedElementUtils.getMergedRepeatableAnnotations(element, Sql.class, SqlGroup.class); |
152 | 175 | }
|
153 | 176 |
|
154 | 177 | /**
|
155 | 178 | * Execute SQL scripts for the supplied {@link Sql @Sql} annotations.
|
156 | 179 | */
|
157 |
| - private void executeScripts( |
158 |
| - Iterable<Sql> scripts, TestContext testContext, ExecutionPhase executionPhase, boolean classLevel) { |
| 180 | + private void executeSqlScripts( |
| 181 | + Set<Sql> sqlAnnotations, TestContext testContext, ExecutionPhase executionPhase, boolean classLevel) { |
159 | 182 |
|
160 |
| - for (Sql sql : scripts) { |
161 |
| - executeSqlScripts(sql, executionPhase, testContext, classLevel); |
162 |
| - } |
| 183 | + sqlAnnotations.forEach(sql -> executeSqlScripts(sql, executionPhase, testContext, classLevel)); |
163 | 184 | }
|
164 | 185 |
|
165 | 186 | /**
|
@@ -196,7 +217,7 @@ private void executeSqlScripts(
|
196 | 217 | }
|
197 | 218 | }
|
198 | 219 |
|
199 |
| - ResourceDatabasePopulator populator = configurePopulator(mergedSqlConfig); |
| 220 | + ResourceDatabasePopulator populator = createDatabasePopulator(mergedSqlConfig); |
200 | 221 | populator.setScripts(scriptResources.toArray(new Resource[0]));
|
201 | 222 | if (logger.isDebugEnabled()) {
|
202 | 223 | logger.debug("Executing SQL scripts: " + ObjectUtils.nullSafeToString(scriptResources));
|
@@ -242,7 +263,7 @@ private void executeSqlScripts(
|
242 | 263 | }
|
243 | 264 |
|
244 | 265 | @NonNull
|
245 |
| - private ResourceDatabasePopulator configurePopulator(MergedSqlConfig mergedSqlConfig) { |
| 266 | + private ResourceDatabasePopulator createDatabasePopulator(MergedSqlConfig mergedSqlConfig) { |
246 | 267 | ResourceDatabasePopulator populator = new ResourceDatabasePopulator();
|
247 | 268 | populator.setSqlScriptEncoding(mergedSqlConfig.getEncoding());
|
248 | 269 | populator.setSeparator(mergedSqlConfig.getSeparator());
|
|
0 commit comments