Skip to content

Commit 860b406

Browse files
authored
fix: Enhance predicate validation and cast safety in join_where (#22112)
1 parent 0ebf0f3 commit 860b406

File tree

3 files changed

+281
-163
lines changed

3 files changed

+281
-163
lines changed

crates/polars-plan/src/plans/conversion/join.rs

+136-144
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use arrow::legacy::error::PolarsResult;
22
use either::Either;
33
use polars_core::chunked_array::cast::CastOptions;
44
use polars_core::error::feature_gated;
5-
use polars_core::utils::get_numeric_upcast_supertype_lossless;
5+
use polars_core::utils::{get_numeric_upcast_supertype_lossless, try_get_supertype};
66
use polars_utils::format_pl_smallstr;
77
use polars_utils::itertools::Itertools;
88

@@ -424,29 +424,33 @@ fn resolve_join_where(
424424
ctxt,
425425
)?;
426426

427-
let mut ae_nodes_stack = Vec::new();
428-
429427
let schema_merged = ctxt
430428
.lp_arena
431429
.get(last_node)
432430
.schema(ctxt.lp_arena)
433431
.into_owned();
434-
let schema_merged = schema_merged.as_ref();
435432

433+
// Perform predicate validation.
434+
let mut upcast_exprs = Vec::<(Node, DataType)>::new();
436435
for e in predicates {
437-
let predicate = to_expr_ir_ignore_alias(e, ctxt.expr_arena)?;
438-
439-
debug_assert!(ae_nodes_stack.is_empty());
440-
ae_nodes_stack.clear();
441-
ae_nodes_stack.push(predicate.node());
442-
443-
process_join_where_predicate(
444-
&mut ae_nodes_stack,
445-
0,
446-
schema_left.as_ref(),
447-
schema_merged,
448-
ctxt.expr_arena,
449-
&mut ExprOrigin::None,
436+
let arena = &mut ctxt.expr_arena;
437+
let predicate = to_expr_ir_ignore_alias(e, arena)?;
438+
let node = predicate.node();
439+
440+
// Ensure the predicate dtype output of the root node is Boolean
441+
let ae = arena.get(node);
442+
let dt_out = ae.to_dtype(&schema_merged, Context::Default, arena)?;
443+
polars_ensure!(
444+
dt_out == DataType::Boolean,
445+
ComputeError: "'join_where' predicates must resolve to boolean"
446+
);
447+
448+
ensure_lossless_binary_comparisons(
449+
&node,
450+
&schema_left,
451+
&schema_merged,
452+
arena,
453+
&mut upcast_exprs,
450454
)?;
451455

452456
ctxt.conversion_optimizer
@@ -467,137 +471,125 @@ fn resolve_join_where(
467471
Ok((last_node, join_node))
468472
}
469473

470-
/// Performs validation and type-coercion on join_where predicates.
471-
///
472-
/// Validates for all comparison expressions / subexpressions, that:
473-
/// 1. They reference columns from both sides.
474-
/// 2. The dtypes of the LHS and RHS are match, or can be casted to a lossless
475-
/// supertype (and inserts the necessary casting).
476-
///
477-
/// We perform (1) by recursing whenever we encounter a comparison expression.
478-
fn process_join_where_predicate(
479-
stack: &mut Vec<Node>,
480-
prev_comparison_expr_stack_offset: usize,
474+
/// Locate nodes that are operands in a binary comparison involving both tables, and ensure that
475+
/// these nodes are losslessly upcast to a safe dtype.
476+
fn ensure_lossless_binary_comparisons(
477+
node: &Node,
481478
schema_left: &Schema,
482479
schema_merged: &Schema,
483480
expr_arena: &mut Arena<AExpr>,
484-
column_origins: &mut ExprOrigin,
481+
upcast_exprs: &mut Vec<(Node, DataType)>,
485482
) -> PolarsResult<()> {
486-
while stack.len() > prev_comparison_expr_stack_offset {
487-
let ae_node = stack.pop().unwrap();
488-
let ae = expr_arena.get(ae_node).clone();
489-
490-
match ae {
491-
AExpr::Column(ref name) => {
492-
let origin = if schema_left.contains(name) {
493-
ExprOrigin::Left
494-
} else if schema_merged.contains(name) {
495-
ExprOrigin::Right
496-
} else {
497-
polars_bail!(ColumnNotFound: "{}", name);
498-
};
499-
500-
*column_origins |= origin;
501-
},
502-
// This is not actually Origin::Both, but we set this because the test suite expects
503-
// this predicate to pass:
504-
// * `pl.col("flag_right") == 1`
505-
// Observe that it only has a column from one side because it is comparing to a literal.
506-
AExpr::Literal(_) => *column_origins = ExprOrigin::Both,
507-
AExpr::BinaryExpr {
508-
left: left_node,
509-
op,
510-
right: right_node,
511-
} if op.is_comparison_or_bitwise() => {
512-
{
513-
let new_stack_offset = stack.len();
514-
stack.extend([right_node, left_node]);
515-
516-
// Reset `column_origins` to a `None` state. We will only have 2 possible return states from
517-
// this point:
518-
// * Ok(()), with column_origins @ ExprOrigin::Both
519-
// * Err(_), in which case the value of column_origins doesn't matter.
520-
*column_origins = ExprOrigin::None;
521-
522-
process_join_where_predicate(
523-
stack,
524-
new_stack_offset,
525-
schema_left,
526-
schema_merged,
527-
expr_arena,
528-
column_origins,
529-
)?;
530-
531-
if *column_origins != ExprOrigin::Both {
532-
polars_bail!(
533-
InvalidOperation:
534-
"'join_where' predicate only refers to columns from a single table: {}",
535-
node_to_expr(ae_node, expr_arena),
536-
)
537-
}
538-
}
539-
540-
// Fetch them again in case they were rewritten.
541-
let left = expr_arena.get(left_node).clone();
542-
let right = expr_arena.get(right_node).clone();
543-
544-
let resolve_dtype = |ae: &AExpr, node: Node| -> PolarsResult<DataType> {
545-
ae.to_dtype(schema_merged, Context::Default, expr_arena)
546-
.map_err(|e| {
547-
e.context(
548-
format!(
549-
"could not resolve dtype of join_where predicate (expr: {})",
550-
node_to_expr(node, expr_arena),
551-
)
552-
.into(),
553-
)
554-
})
555-
};
556-
557-
let dtype_left = resolve_dtype(&left, left_node)?;
558-
let dtype_right = resolve_dtype(&right, right_node)?;
559-
560-
// Note: We only upcast the sides if the expr output dtype is Boolean (i.e. `op` is
561-
// a comparison), otherwise the output may change.
562-
563-
if let Some(dtype) =
564-
get_numeric_upcast_supertype_lossless(&dtype_left, &dtype_right)
565-
.filter(|_| op.is_comparison())
566-
{
567-
// We have unique references to these nodes (they are created by this function),
568-
// so we can mutate in-place without causing side effects somewhere else.
569-
let expr = expr_arena.add(expr_arena.get(left_node).clone());
570-
expr_arena.replace(
571-
left_node,
572-
AExpr::Cast {
573-
expr,
574-
dtype: dtype.clone(),
575-
options: CastOptions::Overflowing,
576-
},
577-
);
578-
579-
let expr = expr_arena.add(expr_arena.get(right_node).clone());
580-
expr_arena.replace(
581-
right_node,
582-
AExpr::Cast {
583-
expr,
584-
dtype,
585-
options: CastOptions::Overflowing,
586-
},
587-
);
588-
} else {
589-
polars_ensure!(
590-
dtype_left == dtype_right,
591-
SchemaMismatch:
592-
"datatypes of join_where comparison don't match - {} on left does not match {} on right \
593-
(expr: {})",
594-
dtype_left, dtype_right, node_to_expr(ae_node, expr_arena),
595-
)
596-
}
597-
},
598-
ae => ae.inputs_rev(stack),
599-
}
483+
// let mut upcast_exprs = Vec::<(Node, DataType)>::new();
484+
// Ensure that all binary comparisons that use both tables are lossless.
485+
build_upcast_node_list(node, schema_left, schema_merged, expr_arena, upcast_exprs)?;
486+
// Replace each node with its casted counterpart
487+
for (expr, dtype) in upcast_exprs.drain(..) {
488+
let old_expr = expr_arena.duplicate(expr);
489+
let new_aexpr = AExpr::Cast {
490+
expr: old_expr,
491+
dtype,
492+
options: CastOptions::Overflowing,
493+
};
494+
expr_arena.replace(expr, new_aexpr);
600495
}
601-
602496
Ok(())
603497
}
498+
499+
/// If we are dealing with a binary comparison involving columns from exclusively the left table
500+
/// on the LHS and the right table on the RHS side, ensure that the cast is lossless.
501+
/// Expressions involving binaries using either table alone we leave up to the user to verify
502+
/// that they are valid, as they could theoretically be pushed outside of the join.
503+
#[recursive]
504+
fn build_upcast_node_list(
505+
node: &Node,
506+
schema_left: &Schema,
507+
schema_merged: &Schema,
508+
expr_arena: &Arena<AExpr>,
509+
to_replace: &mut Vec<(Node, DataType)>,
510+
) -> PolarsResult<ExprOrigin> {
511+
let expr_origin = match expr_arena.get(*node) {
512+
AExpr::Column(name) => {
513+
if schema_left.contains(name) {
514+
ExprOrigin::Left
515+
} else if schema_merged.contains(name) {
516+
ExprOrigin::Right
517+
} else {
518+
ExprOrigin::None
519+
}
520+
},
521+
AExpr::Literal(..) => ExprOrigin::None,
522+
AExpr::Cast { expr: node, .. } => {
523+
build_upcast_node_list(node, schema_left, schema_merged, expr_arena, to_replace)?
524+
},
525+
AExpr::BinaryExpr {
526+
left: left_node,
527+
op,
528+
right: right_node,
529+
} => {
530+
// If left and right node has both, ensure the dtypes are valid.
531+
let left_origin = build_upcast_node_list(
532+
left_node,
533+
schema_left,
534+
schema_merged,
535+
expr_arena,
536+
to_replace,
537+
)?;
538+
let right_origin = build_upcast_node_list(
539+
right_node,
540+
schema_left,
541+
schema_merged,
542+
expr_arena,
543+
to_replace,
544+
)?;
545+
// We only update casts during comparisons if the operands are from different tables.
546+
if op.is_comparison() {
547+
match (left_origin, right_origin) {
548+
(ExprOrigin::Left, ExprOrigin::Right)
549+
| (ExprOrigin::Right, ExprOrigin::Left) => {
550+
// Ensure our dtype casts are lossless
551+
let left = expr_arena.get(*left_node);
552+
let right = expr_arena.get(*right_node);
553+
let dtype_left =
554+
left.to_dtype(schema_merged, Context::Default, expr_arena)?;
555+
let dtype_right =
556+
right.to_dtype(schema_merged, Context::Default, expr_arena)?;
557+
if dtype_left != dtype_right {
558+
// Ensure that we have a lossless cast between the two types.
559+
let dt = if dtype_left.is_primitive_numeric()
560+
|| dtype_right.is_primitive_numeric()
561+
{
562+
get_numeric_upcast_supertype_lossless(&dtype_left, &dtype_right)
563+
.ok_or(PolarsError::SchemaMismatch(
564+
format!(
565+
"'join_where' cannot compare {:?} with {:?}",
566+
dtype_left, dtype_right
567+
)
568+
.into(),
569+
))
570+
} else {
571+
try_get_supertype(&dtype_left, &dtype_right)
572+
}?;
573+
574+
// Store the nodes and their replacements if a cast is required.
575+
let replace_left = dt != dtype_left;
576+
let replace_right = dt != dtype_right;
577+
if replace_left && replace_right {
578+
to_replace.push((*left_node, dt.clone()));
579+
to_replace.push((*right_node, dt));
580+
} else if replace_left {
581+
to_replace.push((*left_node, dt));
582+
} else if replace_right {
583+
to_replace.push((*right_node, dt));
584+
}
585+
}
586+
},
587+
_ => (),
588+
}
589+
}
590+
left_origin | right_origin
591+
},
592+
_ => ExprOrigin::None,
593+
};
594+
Ok(expr_origin)
595+
}

py-polars/tests/unit/operations/test_inequality_join.py

+16-17
Original file line numberDiff line numberDiff line change
@@ -452,15 +452,6 @@ def test_ie_join_with_floats(
452452
assert_frame_equal(actual, expected, check_row_order=False, check_exact=True)
453453

454454

455-
def test_raise_on_ambiguous_name() -> None:
456-
df = pl.DataFrame({"id": [1, 2]})
457-
with pytest.raises(
458-
pl.exceptions.InvalidOperationError,
459-
match="'join_where' predicate only refers to columns from a single table",
460-
):
461-
df.join_where(df, pl.col("id") >= pl.col("id"))
462-
463-
464455
def test_raise_invalid_input_join_where() -> None:
465456
df = pl.DataFrame({"id": [1, 2]})
466457
with pytest.raises(
@@ -570,15 +561,23 @@ def test_ie_join_projection_pd_19005() -> None:
570561
assert out.shape == (0, 2)
571562

572563

573-
def test_raise_invalid_predicate() -> None:
574-
left = pl.LazyFrame({"a": [1, 2]}).with_row_index()
575-
right = pl.LazyFrame({"b": [1, 2]}).with_row_index()
564+
def test_single_sided_predicate() -> None:
565+
left = pl.LazyFrame({"a": [1, -1, 2]}).with_row_index()
566+
right = pl.LazyFrame({"b": [1, 2]})
576567

577-
with pytest.raises(
578-
pl.exceptions.InvalidOperationError,
579-
match="'join_where' predicate only refers to columns from a single table",
580-
):
581-
left.join_where(right, pl.col.index >= pl.col.a).collect()
568+
result = (
569+
left.join_where(right, pl.col.index >= pl.col.a)
570+
.collect()
571+
.sort("index", "a", "b")
572+
)
573+
expected = pl.DataFrame(
574+
{
575+
"index": pl.Series([1, 1, 2, 2], dtype=pl.get_index_type()),
576+
"a": [-1, -1, 2, 2],
577+
"b": [1, 2, 1, 2],
578+
}
579+
)
580+
assert_frame_equal(result, expected)
582581

583582

584583
def test_join_on_strings() -> None:

0 commit comments

Comments
 (0)