Skip to content

Commit f2c0352

Browse files
committed
Fixing mypy errors
1 parent d82ce72 commit f2c0352

File tree

5 files changed

+48
-26
lines changed

5 files changed

+48
-26
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,5 @@ repos:
2929
rev: v0.782
3030
hooks:
3131
- id: mypy
32-
args: ["--strict"]
32+
args: ["--strict", "--show-error-codes"]
3333
additional_dependencies: ["numpy", "xarray"]

setup.cfg

+6
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,14 @@ line_length = 88
6060

6161
[mypy-numpy.*]
6262
ignore_missing_imports = True
63+
[mypy-pandas.*]
64+
ignore_missing_imports = True
65+
[mypy-dask.*]
66+
ignore_missing_imports = True
6367
[mypy-pytest.*]
6468
ignore_missing_imports = True
69+
[mypy-statsmodels.*]
70+
ignore_missing_imports = True
6571
[mypy-setuptools]
6672
ignore_missing_imports = True
6773
[mypy-sgkit.tests.*]

sgkit/stats/association.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
import collections
2-
from typing import Sequence
2+
from typing import Optional, Sequence
33

44
import dask.array as da
55
import numpy as np
66
import xarray as xr
77
from dask.array import Array, stats
88
from xarray import Dataset
99

10+
from ..typing import ArrayLike
11+
1012
LinearRegressionResult = collections.namedtuple(
1113
"LinearRegressionResult", ["beta", "t_value", "p_value"]
1214
)
1315

1416

15-
def _gwas_linear_regression(G, X, y) -> LinearRegressionResult:
17+
def _gwas_linear_regression(
18+
G: ArrayLike, X: ArrayLike, y: ArrayLike
19+
) -> LinearRegressionResult:
1620
"""Efficient linear regression estimation for multiple covariate sets
1721
1822
Parameters
@@ -59,7 +63,7 @@ def _gwas_linear_regression(G, X, y) -> LinearRegressionResult:
5963
return LinearRegressionResult(beta=b, t_value=t_val, p_value=p_val)
6064

6165

62-
def _get_loop_covariates(ds: Dataset, dosage: str = None) -> Array:
66+
def _get_loop_covariates(ds: Dataset, dosage: Optional[str] = None) -> Array:
6367
if dosage is None:
6468
# TODO: This should be (probably gwas-specific) allele
6569
# count with sex chromosome considerations

sgkit/tests/test_association.py

+29-21
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,33 @@
11
import warnings
2+
from typing import Any, Dict, List, Optional, Tuple
23

34
import numpy as np
45
import pandas as pd
56
import pytest
67
import xarray as xr
8+
from pandas import DataFrame
9+
from xarray import Dataset
710

811
from sgkit.stats.association import gwas_linear_regression
12+
from sgkit.typing import ArrayLike
913

1014
with warnings.catch_warnings():
1115
warnings.simplefilter("ignore", DeprecationWarning)
1216
# Ignore: DeprecationWarning: Using or importing the ABCs from 'collections'
1317
# instead of from 'collections.abc' is deprecated since Python 3.3,
1418
# and in 3.9 it will stop working
1519
import statsmodels.api as sm
20+
from statsmodels.regression.linear_model import RegressionResultsWrapper
1621

1722

18-
def _generate_test_data(n=100, m=10, p=3, e_std=0.001, b_zero_slice=None, seed=1):
23+
def _generate_test_data(
24+
n: int = 100,
25+
m: int = 10,
26+
p: int = 3,
27+
e_std: float = 0.001,
28+
b_zero_slice: Optional[slice] = None,
29+
seed: Optional[int] = 1,
30+
) -> Tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike]:
1931
"""Test data simulator for multiple variant associations to a continuous outcome
2032
2133
Outcomes for each variant are simulated separately based on linear combinations
@@ -40,20 +52,12 @@ def _generate_test_data(n=100, m=10, p=3, e_std=0.001, b_zero_slice=None, seed=1
4052
4153
Returns
4254
-------
43-
n : int
44-
Number of samples
45-
m : int
46-
Number of variants
47-
p : int
48-
Number of covariates
4955
g : (n, m) array-like
5056
Simulated genotype dosage
5157
x : (n, p) array-like
5258
Simulated covariates
5359
bg : (m,) array-like
5460
Variant betas
55-
bx : (p,) array-like
56-
Covariate betas
5761
ys : (m, n) array-like
5862
Outcomes for each column in genotypes i.e. variant
5963
"""
@@ -69,28 +73,29 @@ def _generate_test_data(n=100, m=10, p=3, e_std=0.001, b_zero_slice=None, seed=1
6973

7074
# Simulate y values using each variant independently
7175
ys = np.array([g[:, i] * bg[i] + x @ bx + e for i in range(m)])
72-
return n, m, p, g, x, bg, bx, ys
76+
return g, x, bg, ys
7377

7478

75-
def _generate_test_dataset(**kwargs):
76-
n, m, p, g, x, bg, bx, ys = _generate_test_data(**kwargs)
79+
def _generate_test_dataset(**kwargs: Any) -> Dataset:
80+
g, x, bg, ys = _generate_test_data(**kwargs)
7781
data_vars = {}
78-
# TODO: use literals or constants for dimension names?
7982
data_vars["dosage"] = (["variant", "sample"], g.T)
8083
for i in range(x.shape[1]):
8184
data_vars[f"covar_{i}"] = (["sample"], x[:, i])
8285
for i in range(len(ys)):
8386
data_vars[f"trait_{i}"] = (["sample"], ys[i])
8487
attrs = dict(beta=bg)
85-
return xr.Dataset(data_vars, attrs=attrs)
88+
return xr.Dataset(data_vars, attrs=attrs) # type: ignore[arg-type]
8689

8790

88-
@pytest.fixture
89-
def ds():
91+
@pytest.fixture # type: ignore[misc]
92+
def ds() -> Dataset:
9093
return _generate_test_dataset()
9194

9295

93-
def _sm_statistics(ds, i, add_intercept):
96+
def _sm_statistics(
97+
ds: Dataset, i: int, add_intercept: bool
98+
) -> RegressionResultsWrapper:
9499
X = []
95100
# Make sure first independent variable is variant
96101
X.append(ds["dosage"].values[i])
@@ -104,8 +109,11 @@ def _sm_statistics(ds, i, add_intercept):
104109
return sm.OLS(y, X, hasconst=True).fit()
105110

106111

107-
def _get_statistics(ds, add_intercept, **kwargs):
108-
df_pred, df_true = [], []
112+
def _get_statistics(
113+
ds: Dataset, add_intercept: bool, **kwargs: Any
114+
) -> Tuple[DataFrame, DataFrame]:
115+
df_pred: List[Dict[str, Any]] = []
116+
df_true: List[Dict[str, Any]] = []
109117
for i in range(ds.dims["variant"]):
110118
dsr = gwas_linear_regression(
111119
ds,
@@ -116,7 +124,7 @@ def _get_statistics(ds, add_intercept, **kwargs):
116124
)
117125
res = _sm_statistics(ds, i, add_intercept)
118126
df_pred.append(
119-
dsr.to_dataframe()
127+
dsr.to_dataframe() # type: ignore[no-untyped-call]
120128
.rename(columns=lambda c: c.replace("variant/", ""))
121129
.iloc[i]
122130
.to_dict()
@@ -126,7 +134,7 @@ def _get_statistics(ds, add_intercept, **kwargs):
126134

127135

128136
def test_linear_regression_statistics(ds):
129-
def validate(dfp, dft):
137+
def validate(dfp: DataFrame, dft: DataFrame) -> None:
130138
print(dfp)
131139
print(dft)
132140

sgkit/typing.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1-
from typing import Any
1+
from typing import Any, Union
2+
3+
import dask.array as da
4+
import numpy as np
25

36
DType = Any
7+
ArrayLike = Union[np.ndarray, da.Array]

0 commit comments

Comments
 (0)