diff --git a/torch_frame/data/stats.py b/torch_frame/data/stats.py index fbf6b6331..5ca42d04e 100644 --- a/torch_frame/data/stats.py +++ b/torch_frame/data/stats.py @@ -75,7 +75,7 @@ def stats_for_stype(stype: torch_frame.stype) -> list[StatType]: ], torch_frame.embedding: [ StatType.EMB_DIM, - ] + ], } return stats_type.get(stype, []) @@ -85,7 +85,7 @@ def compute( sep: str | None = None, ) -> Any: if self == StatType.MEAN: - flattened = np.hstack(np.hstack(ser.values)) + flattened = np.ravel(ser.values) finite_mask = np.isfinite(flattened) if not finite_mask.any(): # NOTE: We may just error out here if eveything is NaN @@ -93,14 +93,14 @@ def compute( return np.mean(flattened[finite_mask]).item() elif self == StatType.STD: - flattened = np.hstack(np.hstack(ser.values)) + flattened = np.ravel(ser.values) finite_mask = np.isfinite(flattened) if not finite_mask.any(): return np.nan return np.std(flattened[finite_mask]).item() elif self == StatType.QUANTILES: - flattened = np.hstack(np.hstack(ser.values)) + flattened = np.ravel(ser.values) finite_mask = np.isfinite(flattened) if not finite_mask.any(): return [np.nan, np.nan, np.nan, np.nan, np.nan]