From cd4c093f09b7148095742c2314c072ddd81b3f1f Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Tue, 25 Mar 2025 12:27:09 +0800 Subject: [PATCH] disable ScalarSubqueryToJoin --- wren-core/core/src/mdl/context.rs | 4 ++-- wren-core/core/src/mdl/mod.rs | 27 +++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/wren-core/core/src/mdl/context.rs b/wren-core/core/src/mdl/context.rs index 81b0f4be3..4bd37be08 100644 --- a/wren-core/core/src/mdl/context.rs +++ b/wren-core/core/src/mdl/context.rs @@ -34,7 +34,6 @@ use datafusion::optimizer::extract_equijoin_predicate::ExtractEquijoinPredicate; use datafusion::optimizer::filter_null_join_keys::FilterNullJoinKeys; use datafusion::optimizer::propagate_empty_relation::PropagateEmptyRelation; use datafusion::optimizer::replace_distinct_aggregate::ReplaceDistinctWithAggregate; -use datafusion::optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin; use datafusion::optimizer::unwrap_cast_in_comparison::UnwrapCastInComparison; use datafusion::optimizer::{AnalyzerRule, OptimizerRule}; use datafusion::physical_plan::ExecutionPlan; @@ -158,7 +157,8 @@ fn optimize_rule_for_unparsing() -> Vec> { Arc::new(ReplaceDistinctWithAggregate::new()), Arc::new(EliminateJoin::new()), Arc::new(DecorrelatePredicateSubquery::new()), - Arc::new(ScalarSubqueryToJoin::new()), + // Disable ScalarSubqueryToJoin to avoid generate invalid sql (join without condition) + // Arc::new(ScalarSubqueryToJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), // Disable SimplifyExpressions to avoid apply some function locally // Arc::new(SimplifyExpressions::new()), diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index f669a2fd6..a1aa5832f 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -1367,6 +1367,33 @@ mod test { Ok(()) } + #[tokio::test] + async fn test_disable_scalar_subquery() -> Result<()> { + let ctx = SessionContext::new(); + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("customer") + .table_reference("customer") + .column(ColumnBuilder::new("c_custkey", "int").build()) + .column(ColumnBuilder::new("c_name", "string").build()) + .build(), + ) + .build(); + let sql = r#"SELECT c_custkey, (SELECT c_name FROM customer WHERE c_custkey = 1) FROM customer"#; + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let result = + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + assert_eq!( + result, + "SELECT customer.c_custkey, (SELECT customer.c_name FROM (SELECT customer.c_custkey, customer.c_name \ + FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name FROM customer AS __source) AS customer) AS customer \ + WHERE customer.c_custkey = 1) FROM (SELECT customer.c_custkey FROM (SELECT __source.c_custkey AS c_custkey FROM customer AS __source) AS customer) AS customer" + ); + Ok(()) + } + /// Return a RecordBatch with made up data about customer fn customer() -> RecordBatch { let custkey: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3]));