1
1
package org .junit .runners ;
2
2
3
+ import java .lang .annotation .Annotation ;
3
4
import java .lang .annotation .ElementType ;
4
5
import java .lang .annotation .Inherited ;
5
6
import java .lang .annotation .Retention ;
8
9
import java .text .MessageFormat ;
9
10
import java .util .ArrayList ;
10
11
import java .util .Arrays ;
12
+ import java .util .Collection ;
11
13
import java .util .Collections ;
12
14
import java .util .List ;
13
15
14
16
import org .junit .runner .Runner ;
15
17
import org .junit .runners .model .FrameworkMethod ;
18
+ import org .junit .runners .model .InvalidTestClassError ;
16
19
import org .junit .runners .model .TestClass ;
17
20
import org .junit .runners .parameterized .BlockJUnit4ClassRunnerWithParametersFactory ;
18
21
import org .junit .runners .parameterized .ParametersRunnerFactory ;
134
137
* }
135
138
* </pre>
136
139
*
140
+ * <h3>Executing code before/after executing tests for specific parameters</h3>
141
+ * <p>
142
+ * If your test needs to perform some preparation or cleanup based on the
143
+ * parameters, this can be done by adding public static methods annotated with
144
+ * {@code @BeforeParam}/{@code @AfterParam}. Such methods should either have no
145
+ * parameters or the same parameters as the test.
146
+ * <pre>
147
+ * @BeforeParam
148
+ * public static void beforeTestsForParameter(String onlyParameter) {
149
+ * System.out.println("Testing " + onlyParameter);
150
+ * }
151
+ * </pre>
152
+ *
137
153
* <h3>Create different runners</h3>
138
154
* <p>
139
155
* By default the {@code Parameterized} runner creates a slightly modified
@@ -234,32 +250,91 @@ public class Parameterized extends Suite {
234
250
Class <? extends ParametersRunnerFactory > value () default BlockJUnit4ClassRunnerWithParametersFactory .class ;
235
251
}
236
252
253
+ /**
254
+ * Annotation for {@code public static void} methods which should be executed before
255
+ * evaluating tests with particular parameters.
256
+ *
257
+ * @see org.junit.BeforeClass
258
+ * @see org.junit.Before
259
+ * @since 4.13
260
+ */
261
+ @ Retention (RetentionPolicy .RUNTIME )
262
+ @ Target (ElementType .METHOD )
263
+ public @interface BeforeParam {
264
+ }
265
+
266
+ /**
267
+ * Annotation for {@code public static void} methods which should be executed after
268
+ * evaluating tests with particular parameters.
269
+ *
270
+ * @see org.junit.AfterClass
271
+ * @see org.junit.After
272
+ * @since 4.13
273
+ */
274
+ @ Retention (RetentionPolicy .RUNTIME )
275
+ @ Target (ElementType .METHOD )
276
+ public @interface AfterParam {
277
+ }
278
+
237
279
/**
238
280
* Only called reflectively. Do not use programmatically.
239
281
*/
240
282
public Parameterized (Class <?> klass ) throws Throwable {
241
- super (klass , RunnersFactory .createRunnersForClass (klass ));
283
+ this (klass , new RunnersFactory (klass ));
284
+ }
285
+
286
+ private Parameterized (Class <?> klass , RunnersFactory runnersFactory ) throws Exception {
287
+ super (klass , runnersFactory .createRunners ());
288
+ validateBeforeParamAndAfterParamMethods (runnersFactory .parameterCount );
289
+ }
290
+
291
+ private void validateBeforeParamAndAfterParamMethods (Integer parameterCount )
292
+ throws InvalidTestClassError {
293
+ List <Throwable > errors = new ArrayList <Throwable >();
294
+ validatePublicStaticVoidMethods (Parameterized .BeforeParam .class , parameterCount , errors );
295
+ validatePublicStaticVoidMethods (Parameterized .AfterParam .class , parameterCount , errors );
296
+ if (!errors .isEmpty ()) {
297
+ throw new InvalidTestClassError (getTestClass ().getJavaClass (), errors );
298
+ }
299
+ }
300
+
301
+ private void validatePublicStaticVoidMethods (
302
+ Class <? extends Annotation > annotation , Integer parameterCount ,
303
+ List <Throwable > errors ) {
304
+ List <FrameworkMethod > methods = getTestClass ().getAnnotatedMethods (annotation );
305
+ for (FrameworkMethod fm : methods ) {
306
+ fm .validatePublicVoid (true , errors );
307
+ if (parameterCount != null ) {
308
+ int methodParameterCount = fm .getMethod ().getParameterTypes ().length ;
309
+ if (methodParameterCount != 0 && methodParameterCount != parameterCount ) {
310
+ errors .add (new Exception ("Method " + fm .getName ()
311
+ + "() should have 0 or " + parameterCount + " parameter(s)" ));
312
+ }
313
+ }
314
+ }
242
315
}
243
316
244
317
private static class RunnersFactory {
245
318
private static final ParametersRunnerFactory DEFAULT_FACTORY = new BlockJUnit4ClassRunnerWithParametersFactory ();
246
319
247
320
private final TestClass testClass ;
321
+ private final FrameworkMethod parametersMethod ;
322
+ private final List <Object > allParameters ;
323
+ private final int parameterCount ;
248
324
249
- static List <Runner > createRunnersForClass (Class <?> klass )
250
- throws Throwable {
251
- return new RunnersFactory (klass ).createRunners ();
252
- }
253
325
254
- private RunnersFactory (Class <?> klass ) {
326
+ private RunnersFactory (Class <?> klass ) throws Throwable {
255
327
testClass = new TestClass (klass );
328
+ parametersMethod = getParametersMethod (testClass );
329
+ allParameters = allParameters (testClass , parametersMethod );
330
+ parameterCount =
331
+ allParameters .isEmpty () ? 0 : normalizeParameters (allParameters .get (0 )).length ;
256
332
}
257
333
258
- private List <Runner > createRunners () throws Throwable {
259
- Parameters parameters = getParametersMethod ().getAnnotation (
260
- Parameters .class );
334
+ private List <Runner > createRunners () throws Exception {
335
+ Parameters parameters = parametersMethod .getAnnotation (Parameters .class );
261
336
return Collections .unmodifiableList (createRunnersForParameters (
262
- allParameters () , parameters .name (),
337
+ allParameters , parameters .name (),
263
338
getParametersRunnerFactory ()));
264
339
}
265
340
@@ -278,25 +353,37 @@ private ParametersRunnerFactory getParametersRunnerFactory()
278
353
279
354
private TestWithParameters createTestWithNotNormalizedParameters (
280
355
String pattern , int index , Object parametersOrSingleParameter ) {
281
- Object [] parameters = (parametersOrSingleParameter instanceof Object []) ? (Object []) parametersOrSingleParameter
356
+ Object [] parameters = normalizeParameters (parametersOrSingleParameter );
357
+ return createTestWithParameters (testClass , pattern , index , parameters );
358
+ }
359
+
360
+ private static Object [] normalizeParameters (Object parametersOrSingleParameter ) {
361
+ return (parametersOrSingleParameter instanceof Object []) ? (Object []) parametersOrSingleParameter
282
362
: new Object [] { parametersOrSingleParameter };
283
- return createTestWithParameters (testClass , pattern , index ,
284
- parameters );
285
363
}
286
364
287
365
@ SuppressWarnings ("unchecked" )
288
- private Iterable <Object > allParameters () throws Throwable {
289
- Object parameters = getParametersMethod ().invokeExplosively (null );
290
- if (parameters instanceof Iterable ) {
291
- return (Iterable <Object >) parameters ;
366
+ private static List <Object > allParameters (
367
+ TestClass testClass , FrameworkMethod parametersMethod ) throws Throwable {
368
+ Object parameters = parametersMethod .invokeExplosively (null );
369
+ if (parameters instanceof List ) {
370
+ return (List <Object >) parameters ;
371
+ } else if (parameters instanceof Collection ) {
372
+ return new ArrayList <Object >((Collection <Object >) parameters );
373
+ } else if (parameters instanceof Iterable ) {
374
+ List <Object > result = new ArrayList <Object >();
375
+ for (Object entry : ((Iterable <Object >) parameters )) {
376
+ result .add (entry );
377
+ }
378
+ return result ;
292
379
} else if (parameters instanceof Object []) {
293
380
return Arrays .asList ((Object []) parameters );
294
381
} else {
295
- throw parametersMethodReturnedWrongType ();
382
+ throw parametersMethodReturnedWrongType (testClass , parametersMethod );
296
383
}
297
384
}
298
385
299
- private FrameworkMethod getParametersMethod () throws Exception {
386
+ private static FrameworkMethod getParametersMethod (TestClass testClass ) throws Exception {
300
387
List <FrameworkMethod > methods = testClass
301
388
.getAnnotatedMethods (Parameters .class );
302
389
for (FrameworkMethod each : methods ) {
@@ -322,7 +409,7 @@ private List<Runner> createRunnersForParameters(
322
409
}
323
410
return runners ;
324
411
} catch (ClassCastException e ) {
325
- throw parametersMethodReturnedWrongType ();
412
+ throw parametersMethodReturnedWrongType (testClass , parametersMethod );
326
413
}
327
414
}
328
415
@@ -338,9 +425,10 @@ private List<TestWithParameters> createTestsForParameters(
338
425
return children ;
339
426
}
340
427
341
- private Exception parametersMethodReturnedWrongType () throws Exception {
428
+ private static Exception parametersMethodReturnedWrongType (
429
+ TestClass testClass , FrameworkMethod parametersMethod ) throws Exception {
342
430
String className = testClass .getName ();
343
- String methodName = getParametersMethod () .getName ();
431
+ String methodName = parametersMethod .getName ();
344
432
String message = MessageFormat .format (
345
433
"{0}.{1}() must return an Iterable of arrays." , className ,
346
434
methodName );
0 commit comments