|
9 | 9 | from pandas import DataFrame
|
10 | 10 | from xarray import Dataset
|
11 | 11 |
|
12 |
| -from sgkit import variables |
13 | 12 | from sgkit.stats.association import gwas_linear_regression, linear_regression
|
14 | 13 | from sgkit.typing import ArrayLike
|
15 | 14 |
|
@@ -175,16 +174,24 @@ def validate(dfp: DataFrame, dft: DataFrame) -> None:
|
175 | 174 |
|
176 | 175 | def test_gwas_linear_regression__lazy_results(ds):
|
177 | 176 | res = gwas_linear_regression(
|
178 |
| - ds, dosage="dosage", covariates="covar_0", traits="trait_0" |
| 177 | + ds, dosage="dosage", covariates="covar_0", traits="trait_0", merge=False |
179 | 178 | )
|
180 |
| - for v in [ |
181 |
| - variables.variant_beta, |
182 |
| - variables.variant_t_value, |
183 |
| - variables.variant_p_value, |
184 |
| - ]: |
| 179 | + for v in res: |
185 | 180 | assert isinstance(res[v].data, da.Array)
|
186 | 181 |
|
187 | 182 |
|
| 183 | +@pytest.mark.parametrize("chunks", [5, -1, "auto"]) |
| 184 | +def test_gwas_linear_regression__variable_shapes(ds, chunks): |
| 185 | + ds = ds.chunk(chunks=chunks) |
| 186 | + res = gwas_linear_regression( |
| 187 | + ds, dosage="dosage", covariates="covar_0", traits="trait_0", merge=False |
| 188 | + ) |
| 189 | + shape = (ds.dims["variants"], 1) |
| 190 | + for v in res: |
| 191 | + assert res[v].data.shape == shape |
| 192 | + assert res[v].data.compute().shape == shape |
| 193 | + |
| 194 | + |
188 | 195 | def test_gwas_linear_regression__multi_trait(ds):
|
189 | 196 | def run(traits: Sequence[str]) -> Dataset:
|
190 | 197 | return gwas_linear_regression(
|
|
0 commit comments