1
1
import warnings
2
+ from typing import Any , Dict , List , Optional , Tuple
2
3
3
4
import numpy as np
4
5
import pandas as pd
5
6
import pytest
6
7
import xarray as xr
8
+ from pandas import DataFrame
9
+ from xarray import Dataset
7
10
8
11
from sgkit .stats .association import gwas_linear_regression
12
+ from sgkit .typing import ArrayLike
9
13
10
14
with warnings .catch_warnings ():
11
15
warnings .simplefilter ("ignore" , DeprecationWarning )
12
16
# Ignore: DeprecationWarning: Using or importing the ABCs from 'collections'
13
17
# instead of from 'collections.abc' is deprecated since Python 3.3,
14
18
# and in 3.9 it will stop working
15
19
import statsmodels .api as sm
20
+ from statsmodels .regression .linear_model import RegressionResultsWrapper
16
21
17
22
18
- def _generate_test_data (n = 100 , m = 10 , p = 3 , e_std = 0.001 , b_zero_slice = None , seed = 1 ):
23
+ def _generate_test_data (
24
+ n : int = 100 ,
25
+ m : int = 10 ,
26
+ p : int = 3 ,
27
+ e_std : float = 0.001 ,
28
+ b_zero_slice : Optional [slice ] = None ,
29
+ seed : Optional [int ] = 1 ,
30
+ ) -> Tuple [ArrayLike , ArrayLike , ArrayLike , ArrayLike ]:
19
31
"""Test data simulator for multiple variant associations to a continuous outcome
20
32
21
33
Outcomes for each variant are simulated separately based on linear combinations
@@ -40,20 +52,12 @@ def _generate_test_data(n=100, m=10, p=3, e_std=0.001, b_zero_slice=None, seed=1
40
52
41
53
Returns
42
54
-------
43
- n : int
44
- Number of samples
45
- m : int
46
- Number of variants
47
- p : int
48
- Number of covariates
49
55
g : (n, m) array-like
50
56
Simulated genotype dosage
51
57
x : (n, p) array-like
52
58
Simulated covariates
53
59
bg : (m,) array-like
54
60
Variant betas
55
- bx : (p,) array-like
56
- Covariate betas
57
61
ys : (m, n) array-like
58
62
Outcomes for each column in genotypes i.e. variant
59
63
"""
@@ -69,28 +73,29 @@ def _generate_test_data(n=100, m=10, p=3, e_std=0.001, b_zero_slice=None, seed=1
69
73
70
74
# Simulate y values using each variant independently
71
75
ys = np .array ([g [:, i ] * bg [i ] + x @ bx + e for i in range (m )])
72
- return n , m , p , g , x , bg , bx , ys
76
+ return g , x , bg , ys
73
77
74
78
75
- def _generate_test_dataset (** kwargs ) :
76
- n , m , p , g , x , bg , bx , ys = _generate_test_data (** kwargs )
79
+ def _generate_test_dataset (** kwargs : Any ) -> Dataset :
80
+ g , x , bg , ys = _generate_test_data (** kwargs )
77
81
data_vars = {}
78
- # TODO: use literals or constants for dimension names?
79
82
data_vars ["dosage" ] = (["variant" , "sample" ], g .T )
80
83
for i in range (x .shape [1 ]):
81
84
data_vars [f"covar_{ i } " ] = (["sample" ], x [:, i ])
82
85
for i in range (len (ys )):
83
86
data_vars [f"trait_{ i } " ] = (["sample" ], ys [i ])
84
87
attrs = dict (beta = bg )
85
- return xr .Dataset (data_vars , attrs = attrs )
88
+ return xr .Dataset (data_vars , attrs = attrs ) # type: ignore[arg-type]
86
89
87
90
88
- @pytest .fixture
89
- def ds ():
91
+ @pytest .fixture # type: ignore[misc]
92
+ def ds () -> Dataset :
90
93
return _generate_test_dataset ()
91
94
92
95
93
- def _sm_statistics (ds , i , add_intercept ):
96
+ def _sm_statistics (
97
+ ds : Dataset , i : int , add_intercept : bool
98
+ ) -> RegressionResultsWrapper :
94
99
X = []
95
100
# Make sure first independent variable is variant
96
101
X .append (ds ["dosage" ].values [i ])
@@ -104,8 +109,11 @@ def _sm_statistics(ds, i, add_intercept):
104
109
return sm .OLS (y , X , hasconst = True ).fit ()
105
110
106
111
107
- def _get_statistics (ds , add_intercept , ** kwargs ):
108
- df_pred , df_true = [], []
112
+ def _get_statistics (
113
+ ds : Dataset , add_intercept : bool , ** kwargs : Any
114
+ ) -> Tuple [DataFrame , DataFrame ]:
115
+ df_pred : List [Dict [str , Any ]] = []
116
+ df_true : List [Dict [str , Any ]] = []
109
117
for i in range (ds .dims ["variant" ]):
110
118
dsr = gwas_linear_regression (
111
119
ds ,
@@ -116,7 +124,7 @@ def _get_statistics(ds, add_intercept, **kwargs):
116
124
)
117
125
res = _sm_statistics (ds , i , add_intercept )
118
126
df_pred .append (
119
- dsr .to_dataframe ()
127
+ dsr .to_dataframe () # type: ignore[no-untyped-call]
120
128
.rename (columns = lambda c : c .replace ("variant/" , "" ))
121
129
.iloc [i ]
122
130
.to_dict ()
@@ -126,7 +134,7 @@ def _get_statistics(ds, add_intercept, **kwargs):
126
134
127
135
128
136
def test_linear_regression_statistics (ds ):
129
- def validate (dfp , dft ) :
137
+ def validate (dfp : DataFrame , dft : DataFrame ) -> None :
130
138
print (dfp )
131
139
print (dft )
132
140
0 commit comments