@@ -663,30 +663,33 @@ def test_groupby_dataset_reduce() -> None:
663
663
assert_allclose (expected , actual )
664
664
665
665
666
- def test_groupby_dataset_math () -> None :
666
+ @pytest .mark .parametrize ("squeeze" , [True , False ])
667
+ def test_groupby_dataset_math (squeeze ) -> None :
667
668
def reorder_dims (x ):
668
669
return x .transpose ("dim1" , "dim2" , "dim3" , "time" )
669
670
670
671
ds = create_test_data ()
671
672
ds ["dim1" ] = ds ["dim1" ]
672
- for squeeze in [True , False ]:
673
- grouped = ds .groupby ("dim1" , squeeze = squeeze )
673
+ grouped = ds .groupby ("dim1" , squeeze = squeeze )
674
674
675
- expected = reorder_dims (ds + ds .coords ["dim1" ])
676
- actual = grouped + ds .coords ["dim1" ]
677
- assert_identical (expected , reorder_dims (actual ))
675
+ expected = reorder_dims (ds + ds .coords ["dim1" ])
676
+ actual = grouped + ds .coords ["dim1" ]
677
+ assert_identical (expected , reorder_dims (actual ))
678
678
679
- actual = ds .coords ["dim1" ] + grouped
680
- assert_identical (expected , reorder_dims (actual ))
679
+ actual = ds .coords ["dim1" ] + grouped
680
+ assert_identical (expected , reorder_dims (actual ))
681
681
682
- ds2 = 2 * ds
683
- expected = reorder_dims (ds + ds2 )
684
- actual = grouped + ds2
685
- assert_identical (expected , reorder_dims (actual ))
682
+ ds2 = 2 * ds
683
+ expected = reorder_dims (ds + ds2 )
684
+ actual = grouped + ds2
685
+ assert_identical (expected , reorder_dims (actual ))
686
686
687
- actual = ds2 + grouped
688
- assert_identical (expected , reorder_dims (actual ))
687
+ actual = ds2 + grouped
688
+ assert_identical (expected , reorder_dims (actual ))
689
689
690
+
691
+ def test_groupby_math_more () -> None :
692
+ ds = create_test_data ()
690
693
grouped = ds .groupby ("numbers" )
691
694
zeros = DataArray ([0 , 0 , 0 , 0 ], [("numbers" , range (4 ))])
692
695
expected = (ds + Variable ("dim3" , np .zeros (10 ))).transpose (
@@ -719,6 +722,58 @@ def reorder_dims(x):
719
722
ds + ds .groupby ("time.month" )
720
723
721
724
725
+ @pytest .mark .parametrize ("indexed_coord" , [True , False ])
726
+ def test_groupby_bins_math (indexed_coord ) -> None :
727
+ N = 7
728
+ da = DataArray (np .random .random ((N , N )), dims = ("x" , "y" ))
729
+ if indexed_coord :
730
+ da ["x" ] = np .arange (N )
731
+ da ["y" ] = np .arange (N )
732
+ g = da .groupby_bins ("x" , np .arange (0 , N + 1 , 3 ))
733
+ mean = g .mean ()
734
+ expected = da .isel (x = slice (1 , None )) - mean .isel (x_bins = ("x" , [0 , 0 , 0 , 1 , 1 , 1 ]))
735
+ actual = g - mean
736
+ assert_identical (expected , actual )
737
+
738
+
739
+ def test_groupby_math_nD_group () -> None :
740
+ N = 40
741
+ da = DataArray (
742
+ np .random .random ((N , N )),
743
+ dims = ("x" , "y" ),
744
+ coords = {
745
+ "labels" : (
746
+ "x" ,
747
+ np .repeat (["a" , "b" , "c" , "d" , "e" , "f" , "g" , "h" ], repeats = N // 8 ),
748
+ ),
749
+ },
750
+ )
751
+ da ["labels2d" ] = xr .broadcast (da .labels , da )[0 ]
752
+
753
+ g = da .groupby ("labels2d" )
754
+ mean = g .mean ()
755
+ expected = da - mean .sel (labels2d = da .labels2d )
756
+ expected ["labels" ] = expected .labels .broadcast_like (expected .labels2d )
757
+ actual = g - mean
758
+ assert_identical (expected , actual )
759
+
760
+ da ["num" ] = (
761
+ "x" ,
762
+ np .repeat ([1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ], repeats = N // 8 ),
763
+ )
764
+ da ["num2d" ] = xr .broadcast (da .num , da )[0 ]
765
+ g = da .groupby_bins ("num2d" , bins = [0 , 4 , 6 ])
766
+ mean = g .mean ()
767
+ idxr = np .digitize (da .num2d , bins = (0 , 4 , 6 ), right = True )[:30 , :] - 1
768
+ expanded_mean = mean .drop ("num2d_bins" ).isel (num2d_bins = (("x" , "y" ), idxr ))
769
+ expected = da .isel (x = slice (30 )) - expanded_mean
770
+ expected ["labels" ] = expected .labels .broadcast_like (expected .labels2d )
771
+ expected ["num" ] = expected .num .broadcast_like (expected .num2d )
772
+ expected ["num2d_bins" ] = (("x" , "y" ), mean .num2d_bins .data [idxr ])
773
+ actual = g - mean
774
+ assert_identical (expected , actual )
775
+
776
+
722
777
def test_groupby_dataset_math_virtual () -> None :
723
778
ds = Dataset ({"x" : ("t" , [1 , 2 , 3 ])}, {"t" : pd .date_range ("20100101" , periods = 3 )})
724
779
grouped = ds .groupby ("t.day" )
0 commit comments