Skip to content

Commit 1bf8438

Browse files
panchenkokcooney
authored andcommitted
@BeforeParam/@AfterParam for Parameterized runner (#1435)
Closes #45
1 parent 3ce01b1 commit 1bf8438

File tree

5 files changed

+442
-28
lines changed

5 files changed

+442
-28
lines changed

src/main/java/org/junit/internal/runners/statements/RunAfters.java

+8-1
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,19 @@ public void evaluate() throws Throwable {
3030
} finally {
3131
for (FrameworkMethod each : afters) {
3232
try {
33-
each.invokeExplosively(target);
33+
invokeMethod(each);
3434
} catch (Throwable e) {
3535
errors.add(e);
3636
}
3737
}
3838
}
3939
MultipleFailureException.assertEmpty(errors);
4040
}
41+
42+
/**
43+
* @since 4.13
44+
*/
45+
protected void invokeMethod(FrameworkMethod method) throws Throwable {
46+
method.invokeExplosively(target);
47+
}
4148
}

src/main/java/org/junit/internal/runners/statements/RunBefores.java

+8-1
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,15 @@ public RunBefores(Statement next, List<FrameworkMethod> befores, Object target)
2121
@Override
2222
public void evaluate() throws Throwable {
2323
for (FrameworkMethod before : befores) {
24-
before.invokeExplosively(target);
24+
invokeMethod(before);
2525
}
2626
next.evaluate();
2727
}
28+
29+
/**
30+
* @since 4.13
31+
*/
32+
protected void invokeMethod(FrameworkMethod method) throws Throwable {
33+
method.invokeExplosively(target);
34+
}
2835
}

src/main/java/org/junit/runners/Parameterized.java

+110-22
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package org.junit.runners;
22

3+
import java.lang.annotation.Annotation;
34
import java.lang.annotation.ElementType;
45
import java.lang.annotation.Inherited;
56
import java.lang.annotation.Retention;
@@ -8,11 +9,13 @@
89
import java.text.MessageFormat;
910
import java.util.ArrayList;
1011
import java.util.Arrays;
12+
import java.util.Collection;
1113
import java.util.Collections;
1214
import java.util.List;
1315

1416
import org.junit.runner.Runner;
1517
import org.junit.runners.model.FrameworkMethod;
18+
import org.junit.runners.model.InvalidTestClassError;
1619
import org.junit.runners.model.TestClass;
1720
import org.junit.runners.parameterized.BlockJUnit4ClassRunnerWithParametersFactory;
1821
import org.junit.runners.parameterized.ParametersRunnerFactory;
@@ -134,6 +137,19 @@
134137
* }
135138
* </pre>
136139
*
140+
* <h3>Executing code before/after executing tests for specific parameters</h3>
141+
* <p>
142+
* If your test needs to perform some preparation or cleanup based on the
143+
* parameters, this can be done by adding public static methods annotated with
144+
* {@code @BeforeParam}/{@code @AfterParam}. Such methods should either have no
145+
* parameters or the same parameters as the test.
146+
* <pre>
147+
* &#064;BeforeParam
148+
* public static void beforeTestsForParameter(String onlyParameter) {
149+
* System.out.println("Testing " + onlyParameter);
150+
* }
151+
* </pre>
152+
*
137153
* <h3>Create different runners</h3>
138154
* <p>
139155
* By default the {@code Parameterized} runner creates a slightly modified
@@ -234,32 +250,91 @@ public class Parameterized extends Suite {
234250
Class<? extends ParametersRunnerFactory> value() default BlockJUnit4ClassRunnerWithParametersFactory.class;
235251
}
236252

253+
/**
254+
* Annotation for {@code public static void} methods which should be executed before
255+
* evaluating tests with particular parameters.
256+
*
257+
* @see org.junit.BeforeClass
258+
* @see org.junit.Before
259+
* @since 4.13
260+
*/
261+
@Retention(RetentionPolicy.RUNTIME)
262+
@Target(ElementType.METHOD)
263+
public @interface BeforeParam {
264+
}
265+
266+
/**
267+
* Annotation for {@code public static void} methods which should be executed after
268+
* evaluating tests with particular parameters.
269+
*
270+
* @see org.junit.AfterClass
271+
* @see org.junit.After
272+
* @since 4.13
273+
*/
274+
@Retention(RetentionPolicy.RUNTIME)
275+
@Target(ElementType.METHOD)
276+
public @interface AfterParam {
277+
}
278+
237279
/**
238280
* Only called reflectively. Do not use programmatically.
239281
*/
240282
public Parameterized(Class<?> klass) throws Throwable {
241-
super(klass, RunnersFactory.createRunnersForClass(klass));
283+
this(klass, new RunnersFactory(klass));
284+
}
285+
286+
private Parameterized(Class<?> klass, RunnersFactory runnersFactory) throws Exception {
287+
super(klass, runnersFactory.createRunners());
288+
validateBeforeParamAndAfterParamMethods(runnersFactory.parameterCount);
289+
}
290+
291+
private void validateBeforeParamAndAfterParamMethods(Integer parameterCount)
292+
throws InvalidTestClassError {
293+
List<Throwable> errors = new ArrayList<Throwable>();
294+
validatePublicStaticVoidMethods(Parameterized.BeforeParam.class, parameterCount, errors);
295+
validatePublicStaticVoidMethods(Parameterized.AfterParam.class, parameterCount, errors);
296+
if (!errors.isEmpty()) {
297+
throw new InvalidTestClassError(getTestClass().getJavaClass(), errors);
298+
}
299+
}
300+
301+
private void validatePublicStaticVoidMethods(
302+
Class<? extends Annotation> annotation, Integer parameterCount,
303+
List<Throwable> errors) {
304+
List<FrameworkMethod> methods = getTestClass().getAnnotatedMethods(annotation);
305+
for (FrameworkMethod fm : methods) {
306+
fm.validatePublicVoid(true, errors);
307+
if (parameterCount != null) {
308+
int methodParameterCount = fm.getMethod().getParameterTypes().length;
309+
if (methodParameterCount != 0 && methodParameterCount != parameterCount) {
310+
errors.add(new Exception("Method " + fm.getName()
311+
+ "() should have 0 or " + parameterCount + " parameter(s)"));
312+
}
313+
}
314+
}
242315
}
243316

244317
private static class RunnersFactory {
245318
private static final ParametersRunnerFactory DEFAULT_FACTORY = new BlockJUnit4ClassRunnerWithParametersFactory();
246319

247320
private final TestClass testClass;
321+
private final FrameworkMethod parametersMethod;
322+
private final List<Object> allParameters;
323+
private final int parameterCount;
248324

249-
static List<Runner> createRunnersForClass(Class<?> klass)
250-
throws Throwable {
251-
return new RunnersFactory(klass).createRunners();
252-
}
253325

254-
private RunnersFactory(Class<?> klass) {
326+
private RunnersFactory(Class<?> klass) throws Throwable {
255327
testClass = new TestClass(klass);
328+
parametersMethod = getParametersMethod(testClass);
329+
allParameters = allParameters(testClass, parametersMethod);
330+
parameterCount =
331+
allParameters.isEmpty() ? 0 : normalizeParameters(allParameters.get(0)).length;
256332
}
257333

258-
private List<Runner> createRunners() throws Throwable {
259-
Parameters parameters = getParametersMethod().getAnnotation(
260-
Parameters.class);
334+
private List<Runner> createRunners() throws Exception {
335+
Parameters parameters = parametersMethod.getAnnotation(Parameters.class);
261336
return Collections.unmodifiableList(createRunnersForParameters(
262-
allParameters(), parameters.name(),
337+
allParameters, parameters.name(),
263338
getParametersRunnerFactory()));
264339
}
265340

@@ -278,25 +353,37 @@ private ParametersRunnerFactory getParametersRunnerFactory()
278353

279354
private TestWithParameters createTestWithNotNormalizedParameters(
280355
String pattern, int index, Object parametersOrSingleParameter) {
281-
Object[] parameters = (parametersOrSingleParameter instanceof Object[]) ? (Object[]) parametersOrSingleParameter
356+
Object[] parameters = normalizeParameters(parametersOrSingleParameter);
357+
return createTestWithParameters(testClass, pattern, index, parameters);
358+
}
359+
360+
private static Object[] normalizeParameters(Object parametersOrSingleParameter) {
361+
return (parametersOrSingleParameter instanceof Object[]) ? (Object[]) parametersOrSingleParameter
282362
: new Object[] { parametersOrSingleParameter };
283-
return createTestWithParameters(testClass, pattern, index,
284-
parameters);
285363
}
286364

287365
@SuppressWarnings("unchecked")
288-
private Iterable<Object> allParameters() throws Throwable {
289-
Object parameters = getParametersMethod().invokeExplosively(null);
290-
if (parameters instanceof Iterable) {
291-
return (Iterable<Object>) parameters;
366+
private static List<Object> allParameters(
367+
TestClass testClass, FrameworkMethod parametersMethod) throws Throwable {
368+
Object parameters = parametersMethod.invokeExplosively(null);
369+
if (parameters instanceof List) {
370+
return (List<Object>) parameters;
371+
} else if (parameters instanceof Collection) {
372+
return new ArrayList<Object>((Collection<Object>) parameters);
373+
} else if (parameters instanceof Iterable) {
374+
List<Object> result = new ArrayList<Object>();
375+
for (Object entry : ((Iterable<Object>) parameters)) {
376+
result.add(entry);
377+
}
378+
return result;
292379
} else if (parameters instanceof Object[]) {
293380
return Arrays.asList((Object[]) parameters);
294381
} else {
295-
throw parametersMethodReturnedWrongType();
382+
throw parametersMethodReturnedWrongType(testClass, parametersMethod);
296383
}
297384
}
298385

299-
private FrameworkMethod getParametersMethod() throws Exception {
386+
private static FrameworkMethod getParametersMethod(TestClass testClass) throws Exception {
300387
List<FrameworkMethod> methods = testClass
301388
.getAnnotatedMethods(Parameters.class);
302389
for (FrameworkMethod each : methods) {
@@ -322,7 +409,7 @@ private List<Runner> createRunnersForParameters(
322409
}
323410
return runners;
324411
} catch (ClassCastException e) {
325-
throw parametersMethodReturnedWrongType();
412+
throw parametersMethodReturnedWrongType(testClass, parametersMethod);
326413
}
327414
}
328415

@@ -338,9 +425,10 @@ private List<TestWithParameters> createTestsForParameters(
338425
return children;
339426
}
340427

341-
private Exception parametersMethodReturnedWrongType() throws Exception {
428+
private static Exception parametersMethodReturnedWrongType(
429+
TestClass testClass, FrameworkMethod parametersMethod) throws Exception {
342430
String className = testClass.getName();
343-
String methodName = getParametersMethod().getName();
431+
String methodName = parametersMethod.getName();
344432
String message = MessageFormat.format(
345433
"{0}.{1}() must return an Iterable of arrays.", className,
346434
methodName);

src/main/java/org/junit/runners/parameterized/BlockJUnit4ClassRunnerWithParameters.java

+43-1
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
import java.lang.reflect.Field;
55
import java.util.List;
66

7+
import org.junit.internal.runners.statements.RunAfters;
8+
import org.junit.internal.runners.statements.RunBefores;
79
import org.junit.runner.RunWith;
810
import org.junit.runner.notification.RunNotifier;
911
import org.junit.runners.BlockJUnit4ClassRunner;
12+
import org.junit.runners.Parameterized;
1013
import org.junit.runners.Parameterized.Parameter;
1114
import org.junit.runners.model.FrameworkField;
1215
import org.junit.runners.model.FrameworkMethod;
@@ -135,7 +138,46 @@ protected void validateFields(List<Throwable> errors) {
135138

136139
@Override
137140
protected Statement classBlock(RunNotifier notifier) {
138-
return childrenInvoker(notifier);
141+
Statement statement = childrenInvoker(notifier);
142+
statement = withBeforeParams(statement);
143+
statement = withAfterParams(statement);
144+
return statement;
145+
}
146+
147+
private Statement withBeforeParams(Statement statement) {
148+
List<FrameworkMethod> befores = getTestClass()
149+
.getAnnotatedMethods(Parameterized.BeforeParam.class);
150+
return befores.isEmpty() ? statement : new RunBeforeParams(statement, befores);
151+
}
152+
153+
private class RunBeforeParams extends RunBefores {
154+
RunBeforeParams(Statement next, List<FrameworkMethod> befores) {
155+
super(next, befores, null);
156+
}
157+
158+
@Override
159+
protected void invokeMethod(FrameworkMethod method) throws Throwable {
160+
int paramCount = method.getMethod().getParameterTypes().length;
161+
method.invokeExplosively(null, paramCount == 0 ? (Object[]) null : parameters);
162+
}
163+
}
164+
165+
private Statement withAfterParams(Statement statement) {
166+
List<FrameworkMethod> afters = getTestClass()
167+
.getAnnotatedMethods(Parameterized.AfterParam.class);
168+
return afters.isEmpty() ? statement : new RunAfterParams(statement, afters);
169+
}
170+
171+
private class RunAfterParams extends RunAfters {
172+
RunAfterParams(Statement next, List<FrameworkMethod> afters) {
173+
super(next, afters, null);
174+
}
175+
176+
@Override
177+
protected void invokeMethod(FrameworkMethod method) throws Throwable {
178+
int paramCount = method.getMethod().getParameterTypes().length;
179+
method.invokeExplosively(null, paramCount == 0 ? (Object[]) null : parameters);
180+
}
139181
}
140182

141183
@Override

0 commit comments

Comments
 (0)