@@ -403,7 +403,7 @@ def time_srs_bfill(self):
403
403
404
404
class GroupByMethods :
405
405
406
- param_names = ["dtype" , "method" , "application" ]
406
+ param_names = ["dtype" , "method" , "application" , "ncols" ]
407
407
params = [
408
408
["int" , "float" , "object" , "datetime" , "uint" ],
409
409
[
@@ -443,15 +443,23 @@ class GroupByMethods:
443
443
"var" ,
444
444
],
445
445
["direct" , "transformation" ],
446
+ [1 , 2 , 5 , 10 ],
446
447
]
447
448
448
- def setup (self , dtype , method , application ):
449
+ def setup (self , dtype , method , application , ncols ):
449
450
if method in method_blocklist .get (dtype , {}):
450
451
raise NotImplementedError # skip benchmark
452
+
453
+ if ncols != 1 and method in ["value_counts" , "unique" ]:
454
+ # DataFrameGroupBy doesn't have these methods
455
+ raise NotImplementedError
456
+
451
457
ngroups = 1000
452
458
size = ngroups * 2
453
- rng = np .arange (ngroups )
454
- values = rng .take (np .random .randint (0 , ngroups , size = size ))
459
+ rng = np .arange (ngroups ).reshape (- 1 , 1 )
460
+ rng = np .broadcast_to (rng , (len (rng ), ncols ))
461
+ taker = np .random .randint (0 , ngroups , size = size )
462
+ values = rng .take (taker , axis = 0 )
455
463
if dtype == "int" :
456
464
key = np .random .randint (0 , size , size = size )
457
465
elif dtype == "uint" :
@@ -465,22 +473,27 @@ def setup(self, dtype, method, application):
465
473
elif dtype == "datetime" :
466
474
key = date_range ("1/1/2011" , periods = size , freq = "s" )
467
475
468
- df = DataFrame ({"values" : values , "key" : key })
476
+ cols = [f"values{ n } " for n in range (ncols )]
477
+ df = DataFrame (values , columns = cols )
478
+ df ["key" ] = key
479
+
480
+ if len (cols ) == 1 :
481
+ cols = cols [0 ]
469
482
470
483
if application == "transform" :
471
484
if method == "describe" :
472
485
raise NotImplementedError
473
486
474
- self .as_group_method = lambda : df .groupby ("key" )["values" ].transform (method )
475
- self .as_field_method = lambda : df .groupby ("values" )["key" ].transform (method )
487
+ self .as_group_method = lambda : df .groupby ("key" )[cols ].transform (method )
488
+ self .as_field_method = lambda : df .groupby (cols )["key" ].transform (method )
476
489
else :
477
- self .as_group_method = getattr (df .groupby ("key" )["values" ], method )
478
- self .as_field_method = getattr (df .groupby ("values" )["key" ], method )
490
+ self .as_group_method = getattr (df .groupby ("key" )[cols ], method )
491
+ self .as_field_method = getattr (df .groupby (cols )["key" ], method )
479
492
480
- def time_dtype_as_group (self , dtype , method , application ):
493
+ def time_dtype_as_group (self , dtype , method , application , ncols ):
481
494
self .as_group_method ()
482
495
483
- def time_dtype_as_field (self , dtype , method , application ):
496
+ def time_dtype_as_field (self , dtype , method , application , ncols ):
484
497
self .as_field_method ()
485
498
486
499
0 commit comments