@@ -183,11 +183,14 @@ def count_variant_alleles(
183
183
variables .validate (ds , {call_genotype : variables .call_genotype_spec })
184
184
n_alleles = ds .sizes ["alleles" ]
185
185
n_variant = ds .sizes ["variants" ]
186
- G = da .asarray (ds [call_genotype ]).reshape ((n_variant , - 1 ))
186
+ G = da .asarray (ds [call_genotype ])
187
+ G = da .reshape (G , (n_variant , - 1 ))
187
188
shape = (G .chunks [0 ], n_alleles )
188
189
# use uint64 dummy array to return uin64 counts array
189
190
N = np .empty (n_alleles , dtype = np .uint64 )
190
- AC = da .map_blocks (count_alleles , G , N , chunks = shape , drop_axis = 1 , new_axis = 1 )
191
+ AC = da .map_blocks (
192
+ count_alleles , G , N , chunks = shape , dtype = np .uint64 , drop_axis = 1 , new_axis = 1
193
+ )
191
194
AC = xr .DataArray (AC , dims = ["variants" , "alleles" ])
192
195
else :
193
196
options = {variables .call_genotype , variables .call_allele_count }
@@ -692,22 +695,23 @@ def variant_stats(
692
695
using = variables .call_genotype , # improved performance
693
696
merge = False ,
694
697
)[variant_allele_count ]
695
- G = da .array (ds [call_genotype ].data )
698
+ G = da .asarray (ds [call_genotype ].data )
696
699
H = xr .DataArray (
697
700
da .map_blocks (
698
- count_hom ,
701
+ lambda * args : count_hom ( * args )[:, np . newaxis , :] ,
699
702
G ,
700
703
np .zeros (3 , np .uint64 ),
701
- drop_axis = ( 1 , 2 ) ,
702
- new_axis = 1 ,
704
+ drop_axis = 2 ,
705
+ new_axis = 2 ,
703
706
dtype = np .int64 ,
704
- chunks = (G .chunks [0 ], 3 ),
707
+ chunks = (G .chunks [0 ], 1 , 3 ),
705
708
),
706
- dims = ["variants" , "categories" ],
709
+ dims = ["variants" , "samples" , " categories" ],
707
710
)
711
+ H = H .sum (axis = 1 )
708
712
_ , n_sample , _ = G .shape
709
713
n_called = H .sum (axis = - 1 )
710
- call_rate = n_called / n_sample
714
+ call_rate = n_called . astype ( float ) / float ( n_sample )
711
715
n_hom_ref = H [:, 0 ]
712
716
n_hom_alt = H [:, 1 ]
713
717
n_het = H [:, 2 ]
@@ -723,7 +727,8 @@ def variant_stats(
723
727
variables .variant_n_non_ref : n_non_ref ,
724
728
variables .variant_allele_count : AC ,
725
729
variables .variant_allele_total : allele_total ,
726
- variables .variant_allele_frequency : AC / allele_total ,
730
+ variables .variant_allele_frequency : AC .astype (float )
731
+ / allele_total .astype (float ),
727
732
}
728
733
)
729
734
# for backwards compatible behavior
@@ -798,7 +803,7 @@ def sample_stats(
798
803
mixed_ploidy = ds [call_genotype ].attrs .get ("mixed_ploidy" , False )
799
804
if mixed_ploidy :
800
805
raise ValueError ("Mixed-ploidy dataset" )
801
- G = da .array (ds [call_genotype ].data )
806
+ G = da .asarray (ds [call_genotype ].data )
802
807
H = xr .DataArray (
803
808
da .map_blocks (
804
809
count_hom ,
0 commit comments