@@ -193,7 +193,7 @@ class SegTransactionStats:
193
193
def __init__ (
194
194
self ,
195
195
data : pd .DataFrame | ibis .Table ,
196
- segment_col : str = "segment_name" ,
196
+ segment_col : str | list [ str ] = "segment_name" ,
197
197
extra_aggs : dict [str , tuple [str , str ]] | None = None ,
198
198
) -> None :
199
199
"""Calculates transaction statistics by segment.
@@ -203,7 +203,8 @@ def __init__(
203
203
customer_id, unit_spend and transaction_id. If the dataframe contains the column unit_quantity, then
204
204
the columns unit_spend and unit_quantity are used to calculate the price_per_unit and
205
205
units_per_transaction.
206
- segment_col (str, optional): The column to use for the segmentation. Defaults to "segment_name".
206
+ segment_col (str | list[str], optional): The column or list of columns to use for the segmentation.
207
+ Defaults to "segment_name".
207
208
extra_aggs (dict[str, tuple[str, str]], optional): Additional aggregations to perform.
208
209
The keys in the dictionary will be the column names for the aggregation results.
209
210
The values are tuples with (column_name, aggregation_function), where:
@@ -212,11 +213,14 @@ def __init__(
212
213
Example: {"stores": ("store_id", "nunique")} would count unique store_ids.
213
214
"""
214
215
cols = ColumnHelper ()
216
+
217
+ if isinstance (segment_col , str ):
218
+ segment_col = [segment_col ]
215
219
required_cols = [
216
220
cols .customer_id ,
217
221
cols .unit_spend ,
218
222
cols .transaction_id ,
219
- segment_col ,
223
+ * segment_col ,
220
224
]
221
225
if cols .unit_qty in data .columns :
222
226
required_cols .append (cols .unit_qty )
@@ -274,14 +278,14 @@ def _get_col_order(include_quantity: bool) -> list[str]:
274
278
@staticmethod
275
279
def _calc_seg_stats (
276
280
data : pd .DataFrame | ibis .Table ,
277
- segment_col : str ,
281
+ segment_col : list [ str ] ,
278
282
extra_aggs : dict [str , tuple [str , str ]] | None = None ,
279
283
) -> ibis .Table :
280
284
"""Calculates the transaction statistics by segment.
281
285
282
286
Args:
283
287
data (pd.DataFrame | ibis.Table): The transaction data.
284
- segment_col (str): The column to use for the segmentation.
288
+ segment_col (list[ str] ): The columns to use for the segmentation.
285
289
extra_aggs (dict[str, tuple[str, str]], optional): Additional aggregations to perform.
286
290
The keys in the dictionary will be the column names for the aggregation results.
287
291
The values are tuples with (column_name, aggregation_function).
@@ -315,7 +319,7 @@ def _calc_seg_stats(
315
319
316
320
# Calculate metrics for segments and total
317
321
segment_metrics = data .group_by (segment_col ).aggregate (** aggs )
318
- total_metrics = data .aggregate (** aggs ).mutate (segment_name = ibis .literal ("Total" ))
322
+ total_metrics = data .aggregate (** aggs ).mutate ({ col : ibis .literal ("Total" ) for col in segment_col } )
319
323
total_customers = data [cols .customer_id ].nunique ()
320
324
321
325
# Cross join with total_customers to make it available for percentage calculation
@@ -344,7 +348,7 @@ def df(self) -> pd.DataFrame:
344
348
if self ._df is None :
345
349
cols = ColumnHelper ()
346
350
col_order = [
347
- self .segment_col ,
351
+ * self .segment_col ,
348
352
* SegTransactionStats ._get_col_order (include_quantity = cols .agg_unit_qty in self .table .columns ),
349
353
]
350
354
@@ -393,18 +397,23 @@ def plot(
393
397
Raises:
394
398
ValueError: If the sort_order is not "ascending", "descending" or None.
395
399
ValueError: If the orientation is not "vertical" or "horizontal".
400
+ ValueError: If multiple segment columns are used, as plotting is only supported for a single segment column.
396
401
"""
397
402
if sort_order not in ["ascending" , "descending" , None ]:
398
403
raise ValueError ("sort_order must be either 'ascending' or 'descending' or None" )
399
404
if orientation not in ["vertical" , "horizontal" ]:
400
405
raise ValueError ("orientation must be either 'vertical' or 'horizontal'" )
406
+ if len (self .segment_col ) > 1 :
407
+ raise ValueError ("Plotting is only supported for a single segment column" )
401
408
402
409
default_title = f"{ value_col .title ()} by Segment"
403
410
kind = "bar"
404
411
if orientation == "horizontal" :
405
412
kind = "barh"
406
413
407
- val_s = self .df .set_index (self .segment_col )[value_col ]
414
+ # Use the first segment column for plotting
415
+ plot_segment_col = self .segment_col [0 ]
416
+ val_s = self .df .set_index (plot_segment_col )[value_col ]
408
417
if hide_total :
409
418
val_s = val_s [val_s .index != "Total" ]
410
419
@@ -462,7 +471,7 @@ class RFMSegmentation:
462
471
463
472
_df : pd .DataFrame | None = None
464
473
465
- def __init__ (self , df : pd .DataFrame | ibis .Table , current_date : str | None = None ) -> None :
474
+ def __init__ (self , df : pd .DataFrame | ibis .Table , current_date : str | datetime . date | None = None ) -> None :
466
475
"""Initializes the RFM segmentation process.
467
476
468
477
Args:
@@ -472,8 +481,8 @@ def __init__(self, df: pd.DataFrame | ibis.Table, current_date: str | None = Non
472
481
- transaction_date
473
482
- unit_spend
474
483
- transaction_id
475
- current_date (Optional[str] ): The reference date for calculating recency (format: "YYYY-MM-DD") .
476
- If not provided, the current system date will be used .
484
+ current_date (Optional[Union[ str, datetime.date]] ): The reference date for calculating recency.
485
+ Can be a string (format: "YYYY-MM-DD"), a date object, or None (defaults to the current system date) .
477
486
478
487
Raises:
479
488
ValueError: If the dataframe is missing required columns.
@@ -486,14 +495,22 @@ def __init__(self, df: pd.DataFrame | ibis.Table, current_date: str | None = Non
486
495
cols .unit_spend ,
487
496
cols .transaction_id ,
488
497
]
498
+ if isinstance (df , pd .DataFrame ):
499
+ df = ibis .memtable (df )
500
+ elif not isinstance (df , ibis .Table ):
501
+ raise TypeError ("df must be either a pandas DataFrame or an Ibis Table" )
489
502
490
503
missing_cols = set (required_cols ) - set (df .columns )
491
504
if missing_cols :
492
505
error_message = f"Missing required columns: { missing_cols } "
493
506
raise ValueError (error_message )
494
- current_date = (
495
- datetime .date .fromisoformat (current_date ) if current_date else datetime .datetime .now (datetime .UTC ).date ()
496
- )
507
+
508
+ if isinstance (current_date , str ):
509
+ current_date = datetime .date .fromisoformat (current_date )
510
+ elif current_date is None :
511
+ current_date = datetime .datetime .now (datetime .UTC ).date ()
512
+ elif not isinstance (current_date , datetime .date ):
513
+ raise TypeError ("current_date must be a string in 'YYYY-MM-DD' format, a datetime.date object, or None" )
497
514
498
515
self .table = self ._compute_rfm (df , current_date )
499
516
@@ -507,11 +524,6 @@ def _compute_rfm(self, df: ibis.Table, current_date: datetime.date) -> ibis.Tabl
507
524
Returns:
508
525
ibis.Table: A table with RFM scores and segment values.
509
526
"""
510
- if isinstance (df , pd .DataFrame ):
511
- df = ibis .memtable (df )
512
- elif not isinstance (df , ibis .Table ):
513
- raise TypeError ("df must be either a pandas DataFrame or an Ibis Table" )
514
-
515
527
cols = ColumnHelper ()
516
528
current_date_expr = ibis .literal (current_date )
517
529
@@ -537,13 +549,19 @@ def _compute_rfm(self, df: ibis.Table, current_date: datetime.date) -> ibis.Tabl
537
549
m_score = (ibis .ntile (10 ).over (window_monetary )),
538
550
)
539
551
540
- rfm_segment = (rfm_scores .r_score * 100 + rfm_scores .f_score * 10 + rfm_scores .m_score ).name ("rfm_segment" )
541
-
542
- return rfm_scores .mutate (rfm_segment = rfm_segment )
552
+ return rfm_scores .mutate (
553
+ rfm_segment = (rfm_scores .r_score * 100 + rfm_scores .f_score * 10 + rfm_scores .m_score ),
554
+ fm_segment = (rfm_scores .f_score * 10 + rfm_scores .m_score ),
555
+ )
543
556
544
557
@property
545
558
def df (self ) -> pd .DataFrame :
546
559
"""Returns the dataframe with the segment names."""
547
560
if self ._df is None :
548
561
self ._df = self .table .execute ().set_index (get_option ("column.customer_id" ))
549
562
return self ._df
563
+
564
+ @property
565
+ def ibis_table (self ) -> ibis .Table :
566
+ """Returns the computed Ibis table with RFM segmentation."""
567
+ return self .table
0 commit comments