Skip to content

Commit d005d3d

Browse files
committed
fix(eda.distribution): delay scipy computations
1 parent 4bba52e commit d005d3d

File tree

3 files changed

+37
-23
lines changed

3 files changed

+37
-23
lines changed

dataprep/eda/correlation/compute/common.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ def kendalltau( # pylint: disable=invalid-name
4343
return np.float64(corr) # Sometimes corr is a float, causes dask error
4444

4545

46-
@dask.delayed
46+
@dask.delayed( # pylint: disable=no-value-for-parameter
47+
name="kendalltau-scipy", pure=True
48+
)
4749
def corrcoef(arr: np.ndarray) -> np.ndarray:
4850
"""delayed version of np.corrcoef."""
4951
_, (corr, _) = np.corrcoef(arr, rowvar=False)

dataprep/eda/distribution/compute/common.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import pandas as pd
88
import dask.dataframe as dd
9+
from scipy.stats import normaltest as normaltest_, ks_2samp as ks_2samp_
910

1011
from ...dtypes import drop_null, is_dtype, detect_dtype, Continuous, DTypeDef
1112

@@ -271,9 +272,8 @@ def _calc_line_dt(
271272
def _calc_groups(
272273
df: dd.DataFrame, x: str, ngroups: int, largest: bool = True
273274
) -> Tuple[dd.DataFrame, Dict[str, int], List[str]]:
274-
"""
275-
Auxillary function to parse the dataframe to consist of only the
276-
groups with the largest counts
275+
"""Auxillary function to parse the dataframe to consist of only the
276+
groups with the largest counts.
277277
"""
278278

279279
# group count statistics to inform the user of the sampled output
@@ -292,3 +292,20 @@ def _calc_groups(
292292
grp_cnt_stats[f"{x}_shw"] = len(largest_grps)
293293

294294
return df, grp_cnt_stats, largest_grps
295+
296+
297+
@dask.delayed( # pylint: disable=no-value-for-parameter
298+
name="scipy-normaltest", pure=True, nout=2
299+
)
300+
def normaltest(arr: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
301+
"""Delayed version of scipy normaltest. Due to the dask version will
302+
trigger a compute."""
303+
return normaltest_(arr)
304+
305+
306+
@dask.delayed( # pylint: disable=no-value-for-parameter
307+
name="scipy-ks_2samp", pure=True, nout=2
308+
)
309+
def ks_2samp(data1: np.ndarray, data2: np.ndarray) -> Tuple[float, float]:
310+
"""Delayed version of scipy ks_2samp."""
311+
return ks_2samp_(data1, data2)

dataprep/eda/distribution/compute/overview.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
"""Computations for plot(df) function."""
22

3-
from typing import Any, Dict, List, Optional, Tuple
43
from itertools import combinations
4+
from typing import Any, Dict, List, Optional, Tuple
55

66
import dask
77
import dask.array as da
88
import dask.dataframe as dd
99
import numpy as np
1010
import pandas as pd
11-
from dask.array.stats import chisquare, normaltest, skew
12-
from scipy.stats import ks_2samp
11+
from dask.array.stats import chisquare, skew
1312

1413
from ....errors import UnreachableError
1514
from ...dtypes import (
@@ -24,7 +23,7 @@
2423
is_dtype,
2524
)
2625
from ...intermediate import Intermediate
27-
from .common import _calc_line_dt
26+
from .common import _calc_line_dt, ks_2samp, normaltest
2827

2928

3029
def compute_overview(
@@ -81,9 +80,7 @@ def compute_overview(
8180
first_rows[col].apply(hash)
8281
except TypeError:
8382
srs = df[col] = srs.astype(str)
84-
datas.append(
85-
calc_nom_col(drop_null(srs), first_rows[col], ngroups, largest)
86-
)
83+
datas.append(calc_nom_col(drop_null(srs), ngroups, largest))
8784
col_names_dtypes.append((col, Nominal()))
8885
elif is_dtype(col_dtype, Continuous()):
8986
## if cfg.hist_enable or cfg.any_insights("hist"):
@@ -179,9 +176,7 @@ def calc_cont_col(srs: dd.Series, bins: int) -> Dict[str, Any]:
179176

180177

181178
## def calc_nom_col(srs: dd.Series, first_rows: pd.Series, cfg: Config)
182-
def calc_nom_col(
183-
srs: dd.Series, first_rows: pd.Series, ngroups: int, largest: bool
184-
) -> Dict[str, Any]:
179+
def calc_nom_col(srs: dd.Series, ngroups: int, largest: bool) -> Dict[str, Any]:
185180
"""
186181
Computations for a categorical column in plot(df)
187182
@@ -227,9 +222,7 @@ def calc_nom_col(
227222
## data["npresent"] = srs.shape[0]
228223

229224
## if cfg.insight.constant_length_enable:
230-
if not first_rows.apply(lambda x: isinstance(x, str)).all():
231-
srs = srs.astype(str) # srs must be a string to compute the value lengths
232-
length = srs.str.len()
225+
length = srs.apply(lambda v: len(str(v)), meta=(srs.name, np.int64))
233226
data["min_len"], data["max_len"] = length.min(), length.max()
234227

235228
return data
@@ -269,12 +262,13 @@ def calc_stats(
269262
# compute distribution similarity on a data sample
270263
# TODO .map_partitions() fails for create_report since it calls calc_stats() with a pd dataframe
271264
# df_smp = df.map_partitions(lambda x: x.sample(min(1000, x.shape[0])), meta=df)
272-
# NOTE ks_2samp triggers a .compute(), could use .delayed()
265+
273266
if num_cols: # remove this if statement when create_report is refactored
274267
stats["ks_tests"] = []
275268
for col1, col2 in list(combinations(num_cols, 2)):
276-
if ks_2samp(df[col1], df[col2])[1] > 0.05:
277-
stats["ks_tests"].append((col1, col2))
269+
stats["ks_tests"].append(
270+
(col1, col2, ks_2samp(df[col1], df[col2])[1] > 0.05)
271+
)
278272

279273
return stats
280274

@@ -299,9 +293,10 @@ def format_overview(data: Dict[str, Any]) -> List[Dict[str, str]]:
299293
ins.append({"Duplicates": f"Dataset has {ndup} ({pdup}%) duplicate rows"})
300294

301295
## if cfg.insight.similar_distribution_enable
302-
for cols in data.get("ks_tests", []):
303-
msg = f"{cols[0]} and {cols[1]} have similar distributions"
304-
ins.append({"Similar Distribution": msg})
296+
for (*cols, test_result) in data.get("ks_tests", []):
297+
if test_result:
298+
msg = f"{cols[0]} and {cols[1]} have similar distributions"
299+
ins.append({"Similar Distribution": msg})
305300

306301
data.pop("ks_tests", None)
307302

0 commit comments

Comments
 (0)