@@ -2,7 +2,7 @@ use arrow::legacy::error::PolarsResult;
2
2
use either:: Either ;
3
3
use polars_core:: chunked_array:: cast:: CastOptions ;
4
4
use polars_core:: error:: feature_gated;
5
- use polars_core:: utils:: get_numeric_upcast_supertype_lossless;
5
+ use polars_core:: utils:: { get_numeric_upcast_supertype_lossless, try_get_supertype } ;
6
6
use polars_utils:: format_pl_smallstr;
7
7
use polars_utils:: itertools:: Itertools ;
8
8
@@ -424,29 +424,33 @@ fn resolve_join_where(
424
424
ctxt,
425
425
) ?;
426
426
427
- let mut ae_nodes_stack = Vec :: new ( ) ;
428
-
429
427
let schema_merged = ctxt
430
428
. lp_arena
431
429
. get ( last_node)
432
430
. schema ( ctxt. lp_arena )
433
431
. into_owned ( ) ;
434
- let schema_merged = schema_merged. as_ref ( ) ;
435
432
433
+ // Perform predicate validation.
434
+ let mut upcast_exprs = Vec :: < ( Node , DataType ) > :: new ( ) ;
436
435
for e in predicates {
437
- let predicate = to_expr_ir_ignore_alias ( e, ctxt. expr_arena ) ?;
438
-
439
- debug_assert ! ( ae_nodes_stack. is_empty( ) ) ;
440
- ae_nodes_stack. clear ( ) ;
441
- ae_nodes_stack. push ( predicate. node ( ) ) ;
442
-
443
- process_join_where_predicate (
444
- & mut ae_nodes_stack,
445
- 0 ,
446
- schema_left. as_ref ( ) ,
447
- schema_merged,
448
- ctxt. expr_arena ,
449
- & mut ExprOrigin :: None ,
436
+ let arena = & mut ctxt. expr_arena ;
437
+ let predicate = to_expr_ir_ignore_alias ( e, arena) ?;
438
+ let node = predicate. node ( ) ;
439
+
440
+ // Ensure the predicate dtype output of the root node is Boolean
441
+ let ae = arena. get ( node) ;
442
+ let dt_out = ae. to_dtype ( & schema_merged, Context :: Default , arena) ?;
443
+ polars_ensure ! (
444
+ dt_out == DataType :: Boolean ,
445
+ ComputeError : "'join_where' predicates must resolve to boolean"
446
+ ) ;
447
+
448
+ ensure_lossless_binary_comparisons (
449
+ & node,
450
+ & schema_left,
451
+ & schema_merged,
452
+ arena,
453
+ & mut upcast_exprs,
450
454
) ?;
451
455
452
456
ctxt. conversion_optimizer
@@ -467,137 +471,125 @@ fn resolve_join_where(
467
471
Ok ( ( last_node, join_node) )
468
472
}
469
473
470
- /// Performs validation and type-coercion on join_where predicates.
471
- ///
472
- /// Validates for all comparison expressions / subexpressions, that:
473
- /// 1. They reference columns from both sides.
474
- /// 2. The dtypes of the LHS and RHS are match, or can be casted to a lossless
475
- /// supertype (and inserts the necessary casting).
476
- ///
477
- /// We perform (1) by recursing whenever we encounter a comparison expression.
478
- fn process_join_where_predicate (
479
- stack : & mut Vec < Node > ,
480
- prev_comparison_expr_stack_offset : usize ,
474
+ /// Locate nodes that are operands in a binary comparison involving both tables, and ensure that
475
+ /// these nodes are losslessly upcast to a safe dtype.
476
+ fn ensure_lossless_binary_comparisons (
477
+ node : & Node ,
481
478
schema_left : & Schema ,
482
479
schema_merged : & Schema ,
483
480
expr_arena : & mut Arena < AExpr > ,
484
- column_origins : & mut ExprOrigin ,
481
+ upcast_exprs : & mut Vec < ( Node , DataType ) > ,
485
482
) -> PolarsResult < ( ) > {
486
- while stack. len ( ) > prev_comparison_expr_stack_offset {
487
- let ae_node = stack. pop ( ) . unwrap ( ) ;
488
- let ae = expr_arena. get ( ae_node) . clone ( ) ;
489
-
490
- match ae {
491
- AExpr :: Column ( ref name) => {
492
- let origin = if schema_left. contains ( name) {
493
- ExprOrigin :: Left
494
- } else if schema_merged. contains ( name) {
495
- ExprOrigin :: Right
496
- } else {
497
- polars_bail ! ( ColumnNotFound : "{}" , name) ;
498
- } ;
499
-
500
- * column_origins |= origin;
501
- } ,
502
- // This is not actually Origin::Both, but we set this because the test suite expects
503
- // this predicate to pass:
504
- // * `pl.col("flag_right") == 1`
505
- // Observe that it only has a column from one side because it is comparing to a literal.
506
- AExpr :: Literal ( _) => * column_origins = ExprOrigin :: Both ,
507
- AExpr :: BinaryExpr {
508
- left : left_node,
509
- op,
510
- right : right_node,
511
- } if op. is_comparison_or_bitwise ( ) => {
512
- {
513
- let new_stack_offset = stack. len ( ) ;
514
- stack. extend ( [ right_node, left_node] ) ;
515
-
516
- // Reset `column_origins` to a `None` state. We will only have 2 possible return states from
517
- // this point:
518
- // * Ok(()), with column_origins @ ExprOrigin::Both
519
- // * Err(_), in which case the value of column_origins doesn't matter.
520
- * column_origins = ExprOrigin :: None ;
521
-
522
- process_join_where_predicate (
523
- stack,
524
- new_stack_offset,
525
- schema_left,
526
- schema_merged,
527
- expr_arena,
528
- column_origins,
529
- ) ?;
530
-
531
- if * column_origins != ExprOrigin :: Both {
532
- polars_bail ! (
533
- InvalidOperation :
534
- "'join_where' predicate only refers to columns from a single table: {}" ,
535
- node_to_expr( ae_node, expr_arena) ,
536
- )
537
- }
538
- }
539
-
540
- // Fetch them again in case they were rewritten.
541
- let left = expr_arena. get ( left_node) . clone ( ) ;
542
- let right = expr_arena. get ( right_node) . clone ( ) ;
543
-
544
- let resolve_dtype = |ae : & AExpr , node : Node | -> PolarsResult < DataType > {
545
- ae. to_dtype ( schema_merged, Context :: Default , expr_arena)
546
- . map_err ( |e| {
547
- e. context (
548
- format ! (
549
- "could not resolve dtype of join_where predicate (expr: {})" ,
550
- node_to_expr( node, expr_arena) ,
551
- )
552
- . into ( ) ,
553
- )
554
- } )
555
- } ;
556
-
557
- let dtype_left = resolve_dtype ( & left, left_node) ?;
558
- let dtype_right = resolve_dtype ( & right, right_node) ?;
559
-
560
- // Note: We only upcast the sides if the expr output dtype is Boolean (i.e. `op` is
561
- // a comparison), otherwise the output may change.
562
-
563
- if let Some ( dtype) =
564
- get_numeric_upcast_supertype_lossless ( & dtype_left, & dtype_right)
565
- . filter ( |_| op. is_comparison ( ) )
566
- {
567
- // We have unique references to these nodes (they are created by this function),
568
- // so we can mutate in-place without causing side effects somewhere else.
569
- let expr = expr_arena. add ( expr_arena. get ( left_node) . clone ( ) ) ;
570
- expr_arena. replace (
571
- left_node,
572
- AExpr :: Cast {
573
- expr,
574
- dtype : dtype. clone ( ) ,
575
- options : CastOptions :: Overflowing ,
576
- } ,
577
- ) ;
578
-
579
- let expr = expr_arena. add ( expr_arena. get ( right_node) . clone ( ) ) ;
580
- expr_arena. replace (
581
- right_node,
582
- AExpr :: Cast {
583
- expr,
584
- dtype,
585
- options : CastOptions :: Overflowing ,
586
- } ,
587
- ) ;
588
- } else {
589
- polars_ensure ! (
590
- dtype_left == dtype_right,
591
- SchemaMismatch :
592
- "datatypes of join_where comparison don't match - {} on left does not match {} on right \
593
- (expr: {})",
594
- dtype_left, dtype_right, node_to_expr( ae_node, expr_arena) ,
595
- )
596
- }
597
- } ,
598
- ae => ae. inputs_rev ( stack) ,
599
- }
483
+ // let mut upcast_exprs = Vec::<(Node, DataType)>::new();
484
+ // Ensure that all binary comparisons that use both tables are lossless.
485
+ build_upcast_node_list ( node, schema_left, schema_merged, expr_arena, upcast_exprs) ?;
486
+ // Replace each node with its casted counterpart
487
+ for ( expr, dtype) in upcast_exprs. drain ( ..) {
488
+ let old_expr = expr_arena. duplicate ( expr) ;
489
+ let new_aexpr = AExpr :: Cast {
490
+ expr : old_expr,
491
+ dtype,
492
+ options : CastOptions :: Overflowing ,
493
+ } ;
494
+ expr_arena. replace ( expr, new_aexpr) ;
600
495
}
601
-
602
496
Ok ( ( ) )
603
497
}
498
+
499
+ /// If we are dealing with a binary comparison involving columns from exclusively the left table
500
+ /// on the LHS and the right table on the RHS side, ensure that the cast is lossless.
501
+ /// Expressions involving binaries using either table alone we leave up to the user to verify
502
+ /// that they are valid, as they could theoretically be pushed outside of the join.
503
+ #[ recursive]
504
+ fn build_upcast_node_list (
505
+ node : & Node ,
506
+ schema_left : & Schema ,
507
+ schema_merged : & Schema ,
508
+ expr_arena : & Arena < AExpr > ,
509
+ to_replace : & mut Vec < ( Node , DataType ) > ,
510
+ ) -> PolarsResult < ExprOrigin > {
511
+ let expr_origin = match expr_arena. get ( * node) {
512
+ AExpr :: Column ( name) => {
513
+ if schema_left. contains ( name) {
514
+ ExprOrigin :: Left
515
+ } else if schema_merged. contains ( name) {
516
+ ExprOrigin :: Right
517
+ } else {
518
+ ExprOrigin :: None
519
+ }
520
+ } ,
521
+ AExpr :: Literal ( ..) => ExprOrigin :: None ,
522
+ AExpr :: Cast { expr : node, .. } => {
523
+ build_upcast_node_list ( node, schema_left, schema_merged, expr_arena, to_replace) ?
524
+ } ,
525
+ AExpr :: BinaryExpr {
526
+ left : left_node,
527
+ op,
528
+ right : right_node,
529
+ } => {
530
+ // If left and right node has both, ensure the dtypes are valid.
531
+ let left_origin = build_upcast_node_list (
532
+ left_node,
533
+ schema_left,
534
+ schema_merged,
535
+ expr_arena,
536
+ to_replace,
537
+ ) ?;
538
+ let right_origin = build_upcast_node_list (
539
+ right_node,
540
+ schema_left,
541
+ schema_merged,
542
+ expr_arena,
543
+ to_replace,
544
+ ) ?;
545
+ // We only update casts during comparisons if the operands are from different tables.
546
+ if op. is_comparison ( ) {
547
+ match ( left_origin, right_origin) {
548
+ ( ExprOrigin :: Left , ExprOrigin :: Right )
549
+ | ( ExprOrigin :: Right , ExprOrigin :: Left ) => {
550
+ // Ensure our dtype casts are lossless
551
+ let left = expr_arena. get ( * left_node) ;
552
+ let right = expr_arena. get ( * right_node) ;
553
+ let dtype_left =
554
+ left. to_dtype ( schema_merged, Context :: Default , expr_arena) ?;
555
+ let dtype_right =
556
+ right. to_dtype ( schema_merged, Context :: Default , expr_arena) ?;
557
+ if dtype_left != dtype_right {
558
+ // Ensure that we have a lossless cast between the two types.
559
+ let dt = if dtype_left. is_primitive_numeric ( )
560
+ || dtype_right. is_primitive_numeric ( )
561
+ {
562
+ get_numeric_upcast_supertype_lossless ( & dtype_left, & dtype_right)
563
+ . ok_or ( PolarsError :: SchemaMismatch (
564
+ format ! (
565
+ "'join_where' cannot compare {:?} with {:?}" ,
566
+ dtype_left, dtype_right
567
+ )
568
+ . into ( ) ,
569
+ ) )
570
+ } else {
571
+ try_get_supertype ( & dtype_left, & dtype_right)
572
+ } ?;
573
+
574
+ // Store the nodes and their replacements if a cast is required.
575
+ let replace_left = dt != dtype_left;
576
+ let replace_right = dt != dtype_right;
577
+ if replace_left && replace_right {
578
+ to_replace. push ( ( * left_node, dt. clone ( ) ) ) ;
579
+ to_replace. push ( ( * right_node, dt) ) ;
580
+ } else if replace_left {
581
+ to_replace. push ( ( * left_node, dt) ) ;
582
+ } else if replace_right {
583
+ to_replace. push ( ( * right_node, dt) ) ;
584
+ }
585
+ }
586
+ } ,
587
+ _ => ( ) ,
588
+ }
589
+ }
590
+ left_origin | right_origin
591
+ } ,
592
+ _ => ExprOrigin :: None ,
593
+ } ;
594
+ Ok ( expr_origin)
595
+ }
0 commit comments