Skip to content

Commit 4f77514

Browse files
committed
Merge branch 'main' of github.com:data-simply/pyretailscience into feature/rfm-segmentation
2 parents 314ea0a + 5ed5feb commit 4f77514

File tree

4 files changed

+150
-99
lines changed

4 files changed

+150
-99
lines changed

pyproject.toml

Lines changed: 13 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,129 +1,51 @@
11
[project]
22
name = "pyretailscience"
3-
version = "0.9.0"
3+
version = "0.10.0"
44
description = "Retail Data Science Tools"
55
requires-python = ">=3.10,<3.13"
66
readme = "README.md"
77
license = "Elastic-2.0"
8-
dependencies = [
9-
"pandas>=2.1.4,<3",
10-
"pyarrow>=14.0.2,<15",
11-
"matplotlib>=3.9.1,<4",
12-
"numpy>=1.26.3,<2",
13-
"loguru>=0.7.2,<0.8",
14-
"tqdm>=4.66.1,<5",
15-
"scipy>=1.13.0,<2",
16-
"scikit-learn>=1.4.2,<2",
17-
"matplotlib-set-diagrams~=0.0.2",
18-
"toml>=0.10.2,<0.11",
19-
"duckdb>=1.0.0,<2",
20-
"graphviz>=0.20.3,<0.21",
21-
"ibis-framework[duckdb]>=9.5.0,<10",
22-
]
8+
dependencies = [ "pandas>=2.1.4,<3", "pyarrow>=14.0.2,<15", "matplotlib>=3.9.1,<4", "numpy>=1.26.3,<2", "loguru>=0.7.2,<0.8", "tqdm>=4.66.1,<5", "scipy>=1.13.0,<2", "scikit-learn>=1.4.2,<2", "matplotlib-set-diagrams~=0.0.2", "toml>=0.10.2,<0.11", "duckdb>=1.0.0,<2", "graphviz>=0.20.3,<0.21", "ibis-framework[duckdb]>=9.5.0,<10",]
239
[[project.authors]]
2410
name = "Murray Vanwyk"
2511
2612

2713
[dependency-groups]
28-
dev = [
29-
"pytest>=8.0.0,<9",
30-
"pytest-cov>=4.1.0,<5",
31-
"nbstripout>=0.7.1,<0.8",
32-
"ruff>=0.9,<0.10",
33-
"pre-commit>=3.6.2,<4",
34-
"pytest-mock>=3.14.0,<4",
35-
]
36-
examples = ["jupyterlab>=4.2.5,<5", "tqdm>=4.66.1,<5"]
37-
docs = [
38-
"mkdocs-material>=9.5.4,<10",
39-
"mkdocstrings[python]>=0.24.0,<0.25",
40-
"mkdocs>=1.5.3,<2",
41-
"mkdocs-jupyter>=0.24.6,<0.25",
42-
]
14+
dev = [ "pytest>=8.0.0,<9", "pytest-cov>=4.1.0,<5", "nbstripout>=0.7.1,<0.8", "ruff>=0.9,<0.10", "pre-commit>=3.6.2,<4", "pytest-mock>=3.14.0,<4",]
15+
examples = [ "jupyterlab>=4.2.5,<5", "tqdm>=4.66.1,<5",]
16+
docs = [ "mkdocs-material>=9.5.4,<10", "mkdocstrings[python]>=0.24.0,<0.25", "mkdocs>=1.5.3,<2", "mkdocs-jupyter>=0.24.6,<0.25",]
4317

4418
[build-system]
45-
requires = ["hatchling"]
19+
requires = [ "hatchling",]
4620
build-backend = "hatchling.build"
4721

4822
[tool.uv]
49-
default-groups = ["dev", "examples", "docs"]
23+
default-groups = [ "dev", "examples", "docs",]
5024

5125
[tool.ruff]
5226
target-version = "py310"
5327
line-length = 120
5428
show-fixes = true
5529

5630
[tool.ruff.lint]
57-
ignore = ["ANN101", "ANN102", "EM101", "TRY003", "PT011", "PTH123", "SLF001"]
58-
select = [
59-
"A",
60-
"ANN",
61-
"ARG",
62-
"B",
63-
"BLE",
64-
"C4",
65-
"C90",
66-
"COM",
67-
"D",
68-
"D1",
69-
"D2",
70-
"D3",
71-
"D4",
72-
"DTZ",
73-
"EM",
74-
"ERA",
75-
"EXE",
76-
"F",
77-
"FA",
78-
"FLY",
79-
"G",
80-
"I",
81-
"ICN",
82-
"INP",
83-
"INT",
84-
"ISC",
85-
"N",
86-
"NPY",
87-
"PERF",
88-
"PGH",
89-
"PIE",
90-
"PL",
91-
"PT",
92-
"PTH",
93-
"PYI",
94-
"Q",
95-
"RET",
96-
"RUF",
97-
"RSE",
98-
"S",
99-
"SIM",
100-
"SLF",
101-
"SLOT",
102-
"T10",
103-
"T20",
104-
"TCH",
105-
"TID",
106-
"TRY",
107-
"UP",
108-
"W",
109-
"YTT",
110-
]
31+
ignore = [ "ANN101", "ANN102", "EM101", "TRY003", "PT011", "PTH123", "SLF001",]
32+
select = [ "A", "ANN", "ARG", "B", "BLE", "C4", "C90", "COM", "D", "D1", "D2", "D3", "D4", "DTZ", "EM", "ERA", "EXE", "F", "FA", "FLY", "G", "I", "ICN", "INP", "INT", "ISC", "N", "NPY", "PERF", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "Q", "RET", "RUF", "RSE", "S", "SIM", "SLF", "SLOT", "T10", "T20", "TCH", "TID", "TRY", "UP", "W", "YTT",]
11133

11234
[tool.pytest.ini_options]
11335
addopts = "--cov=pyretailscience --cov-report=term-missing --cov-branch"
11436

11537
[tool.coverage.run]
11638
branch = true
117-
source = ["pyretailscience"]
39+
source = [ "pyretailscience",]
11840

11941
[tool.coverage.report]
12042
show_missing = true
12143
skip_covered = true
12244

12345
[tool.ruff.lint.per-file-ignores]
124-
"__init__.py" = ["F401", "F403", "F405", "D104"]
125-
"tests/*" = ["ANN", "ARG", "INP001", "S101", "SLF001"]
126-
"*.ipynb" = ["T201"]
46+
"__init__.py" = [ "F401", "F403", "F405", "D104",]
47+
"tests/*" = [ "ANN", "ARG", "INP001", "S101", "SLF001",]
48+
"*.ipynb" = [ "T201",]
12749

12850
[tool.ruff.lint.pylint]
12951
max-args = 15

pyretailscience/analysis/segmentation.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,12 @@ class SegTransactionStats:
190190

191191
_df: pd.DataFrame | None = None
192192

193-
def __init__(self, data: pd.DataFrame | ibis.Table, segment_col: str = "segment_name") -> None:
193+
def __init__(
194+
self,
195+
data: pd.DataFrame | ibis.Table,
196+
segment_col: str = "segment_name",
197+
extra_aggs: dict[str, tuple[str, str]] | None = None,
198+
) -> None:
194199
"""Calculates transaction statistics by segment.
195200
196201
Args:
@@ -199,6 +204,12 @@ def __init__(self, data: pd.DataFrame | ibis.Table, segment_col: str = "segment_
199204
the columns unit_spend and unit_quantity are used to calculate the price_per_unit and
200205
units_per_transaction.
201206
segment_col (str, optional): The column to use for the segmentation. Defaults to "segment_name".
207+
extra_aggs (dict[str, tuple[str, str]], optional): Additional aggregations to perform.
208+
The keys in the dictionary will be the column names for the aggregation results.
209+
The values are tuples with (column_name, aggregation_function), where:
210+
- column_name is the name of the column to aggregate
211+
- aggregation_function is a string name of an Ibis aggregation function (e.g., "nunique", "sum")
212+
Example: {"stores": ("store_id", "nunique")} would count unique store_ids.
202213
"""
203214
cols = ColumnHelper()
204215
required_cols = [
@@ -215,9 +226,21 @@ def __init__(self, data: pd.DataFrame | ibis.Table, segment_col: str = "segment_
215226
msg = f"The following columns are required but missing: {missing_cols}"
216227
raise ValueError(msg)
217228

229+
# Validate extra_aggs if provided
230+
if extra_aggs:
231+
for col_tuple in extra_aggs.values():
232+
col, func = col_tuple
233+
if col not in data.columns:
234+
msg = f"Column '{col}' specified in extra_aggs does not exist in the data"
235+
raise ValueError(msg)
236+
if not hasattr(data[col], func):
237+
msg = f"Aggregation function '{func}' not available for column '{col}'"
238+
raise ValueError(msg)
239+
218240
self.segment_col = segment_col
241+
self.extra_aggs = {} if extra_aggs is None else extra_aggs
219242

220-
self.table = self._calc_seg_stats(data, segment_col)
243+
self.table = self._calc_seg_stats(data, segment_col, self.extra_aggs)
221244

222245
@staticmethod
223246
def _get_col_order(include_quantity: bool) -> list[str]:
@@ -249,12 +272,19 @@ def _get_col_order(include_quantity: bool) -> list[str]:
249272
return col_order
250273

251274
@staticmethod
252-
def _calc_seg_stats(data: pd.DataFrame | ibis.Table, segment_col: str) -> ibis.Table:
275+
def _calc_seg_stats(
276+
data: pd.DataFrame | ibis.Table,
277+
segment_col: str,
278+
extra_aggs: dict[str, tuple[str, str]] | None = None,
279+
) -> ibis.Table:
253280
"""Calculates the transaction statistics by segment.
254281
255282
Args:
256283
data (pd.DataFrame | ibis.Table): The transaction data.
257284
segment_col (str): The column to use for the segmentation.
285+
extra_aggs (dict[str, tuple[str, str]], optional): Additional aggregations to perform.
286+
The keys in the dictionary will be the column names for the aggregation results.
287+
The values are tuples with (column_name, aggregation_function).
258288
259289
Returns:
260290
pd.DataFrame: The transaction statistics by segment.
@@ -277,6 +307,12 @@ def _calc_seg_stats(data: pd.DataFrame | ibis.Table, segment_col: str) -> ibis.T
277307
if cols.unit_qty in data.columns:
278308
aggs[cols.agg_unit_qty] = data[cols.unit_qty].sum()
279309

310+
# Add extra aggregations if provided
311+
if extra_aggs:
312+
for agg_name, col_tuple in extra_aggs.items():
313+
col, func = col_tuple
314+
aggs[agg_name] = getattr(data[col], func)()
315+
280316
# Calculate metrics for segments and total
281317
segment_metrics = data.group_by(segment_col).aggregate(**aggs)
282318
total_metrics = data.aggregate(**aggs).mutate(segment_name=ibis.literal("Total"))
@@ -311,6 +347,11 @@ def df(self) -> pd.DataFrame:
311347
self.segment_col,
312348
*SegTransactionStats._get_col_order(include_quantity=cols.agg_unit_qty in self.table.columns),
313349
]
350+
351+
# Add any extra aggregation columns to the column order
352+
if hasattr(self, "extra_aggs") and self.extra_aggs:
353+
col_order.extend(self.extra_aggs.keys())
354+
314355
self._df = self.table.execute()[col_order]
315356
return self._df
316357

@@ -484,10 +525,10 @@ def _compute_rfm(self, df: ibis.Table, current_date: datetime.date) -> ibis.Tabl
484525
order_by=[ibis.asc(customer_metrics.recency_days), ibis.asc(customer_metrics.customer_id)],
485526
)
486527
window_frequency = ibis.window(
487-
order_by=[ibis.desc(customer_metrics.frequency), ibis.asc(customer_metrics.customer_id)],
528+
order_by=[ibis.asc(customer_metrics.frequency), ibis.asc(customer_metrics.customer_id)],
488529
)
489530
window_monetary = ibis.window(
490-
order_by=[ibis.desc(customer_metrics.monetary), ibis.asc(customer_metrics.customer_id)],
531+
order_by=[ibis.asc(customer_metrics.monetary), ibis.asc(customer_metrics.customer_id)],
491532
)
492533

493534
rfm_scores = customer_metrics.mutate(

tests/analysis/test_segmentation.py

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,93 @@ def test_handles_empty_dataframe_with_errors(self):
319319
with pytest.raises(ValueError):
320320
SegTransactionStats(df, "segment_name")
321321

322+
def test_extra_aggs_functionality(self):
323+
"""Test that the extra_aggs parameter works correctly."""
324+
# Constants for expected values
325+
segment_a_store_count = 3 # Segment A has stores 1, 2, 4
326+
segment_b_store_count = 2 # Segment B has stores 1, 3
327+
total_store_count = 4 # Total has stores 1, 2, 3, 4
328+
329+
segment_a_product_count = 3 # Segment A has products 10, 20, 40
330+
segment_b_product_count = 2 # Segment B has products 10, 30
331+
total_product_count = 4 # Total has products 10, 20, 30, 40
332+
df = pd.DataFrame(
333+
{
334+
cols.customer_id: [1, 1, 2, 2, 3, 3],
335+
cols.unit_spend: [100.0, 150.0, 200.0, 250.0, 300.0, 350.0],
336+
cols.transaction_id: [101, 102, 103, 104, 105, 106],
337+
"segment_name": ["A", "A", "B", "B", "A", "A"],
338+
"store_id": [1, 2, 1, 3, 2, 4],
339+
"product_id": [10, 20, 10, 30, 20, 40],
340+
},
341+
)
342+
343+
# Test with a single extra aggregation
344+
seg_stats = SegTransactionStats(df, "segment_name", extra_aggs={"distinct_stores": ("store_id", "nunique")})
345+
346+
# Verify the extra column exists and has correct values
347+
assert "distinct_stores" in seg_stats.df.columns
348+
349+
# Sort by segment_name to ensure consistent order
350+
result_df = seg_stats.df.sort_values("segment_name").reset_index(drop=True)
351+
352+
assert result_df.loc[0, "distinct_stores"] == segment_a_store_count # Segment A
353+
assert result_df.loc[1, "distinct_stores"] == segment_b_store_count # Segment B
354+
assert result_df.loc[2, "distinct_stores"] == total_store_count # Total
355+
356+
# Test with multiple extra aggregations
357+
seg_stats_multi = SegTransactionStats(
358+
df,
359+
"segment_name",
360+
extra_aggs={
361+
"distinct_stores": ("store_id", "nunique"),
362+
"distinct_products": ("product_id", "nunique"),
363+
},
364+
)
365+
366+
# Verify both extra columns exist
367+
assert "distinct_stores" in seg_stats_multi.df.columns
368+
assert "distinct_products" in seg_stats_multi.df.columns
369+
370+
# Sort by segment_name to ensure consistent order
371+
result_df_multi = seg_stats_multi.df.sort_values("segment_name").reset_index(drop=True)
372+
373+
assert result_df_multi.loc[0, "distinct_products"] == segment_a_product_count # Segment A
374+
assert result_df_multi.loc[1, "distinct_products"] == segment_b_product_count # Segment B
375+
assert result_df_multi.loc[2, "distinct_products"] == total_product_count # Total
376+
377+
def test_extra_aggs_with_invalid_column(self):
378+
"""Test that an error is raised when an invalid column is specified in extra_aggs."""
379+
df = pd.DataFrame(
380+
{
381+
cols.customer_id: [1, 2, 3],
382+
cols.unit_spend: [100.0, 200.0, 300.0],
383+
cols.transaction_id: [101, 102, 103],
384+
"segment_name": ["A", "B", "A"],
385+
},
386+
)
387+
388+
with pytest.raises(ValueError) as excinfo:
389+
SegTransactionStats(df, "segment_name", extra_aggs={"invalid_agg": ("nonexistent_column", "nunique")})
390+
391+
assert "does not exist in the data" in str(excinfo.value)
392+
393+
def test_extra_aggs_with_invalid_function(self):
394+
"""Test that an error is raised when an invalid function is specified in extra_aggs."""
395+
df = pd.DataFrame(
396+
{
397+
cols.customer_id: [1, 2, 3],
398+
cols.unit_spend: [100.0, 200.0, 300.0],
399+
cols.transaction_id: [101, 102, 103],
400+
"segment_name": ["A", "B", "A"],
401+
},
402+
)
403+
404+
with pytest.raises(ValueError) as excinfo:
405+
SegTransactionStats(df, "segment_name", extra_aggs={"invalid_agg": (cols.customer_id, "invalid_function")})
406+
407+
assert "not available for column" in str(excinfo.value)
408+
322409

323410
class TestHMLSegmentation:
324411
"""Tests for the HMLSegmentation class."""
@@ -428,7 +515,7 @@ def test_correct_rfm_segmentation(self, base_df):
428515
expected_df = pd.DataFrame(
429516
{
430517
"customer_id": [1, 2, 3, 4, 5],
431-
"rfm_segment": [104, 312, 423, 30, 241],
518+
"rfm_segment": [100, 312, 421, 34, 243],
432519
},
433520
).set_index("customer_id")
434521

@@ -505,7 +592,7 @@ def test_rfm_segmentation_with_no_date(self, base_df):
505592
expected_df = pd.DataFrame(
506593
{
507594
"customer_id": [1, 2, 3, 4, 5],
508-
"rfm_segment": [104, 312, 423, 30, 241],
595+
"rfm_segment": [100, 312, 421, 34, 243],
509596
},
510597
).set_index("customer_id")
511598

uv.lock

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)