Skip to content

Commit cd4c093

Browse files
committed
disable ScalarSubqueryToJoin
1 parent a7d985c commit cd4c093

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

wren-core/core/src/mdl/context.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ use datafusion::optimizer::extract_equijoin_predicate::ExtractEquijoinPredicate;
3434
use datafusion::optimizer::filter_null_join_keys::FilterNullJoinKeys;
3535
use datafusion::optimizer::propagate_empty_relation::PropagateEmptyRelation;
3636
use datafusion::optimizer::replace_distinct_aggregate::ReplaceDistinctWithAggregate;
37-
use datafusion::optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin;
3837
use datafusion::optimizer::unwrap_cast_in_comparison::UnwrapCastInComparison;
3938
use datafusion::optimizer::{AnalyzerRule, OptimizerRule};
4039
use datafusion::physical_plan::ExecutionPlan;
@@ -158,7 +157,8 @@ fn optimize_rule_for_unparsing() -> Vec<Arc<dyn OptimizerRule + Send + Sync>> {
158157
Arc::new(ReplaceDistinctWithAggregate::new()),
159158
Arc::new(EliminateJoin::new()),
160159
Arc::new(DecorrelatePredicateSubquery::new()),
161-
Arc::new(ScalarSubqueryToJoin::new()),
160+
// Disable ScalarSubqueryToJoin to avoid generate invalid sql (join without condition)
161+
// Arc::new(ScalarSubqueryToJoin::new()),
162162
Arc::new(ExtractEquijoinPredicate::new()),
163163
// Disable SimplifyExpressions to avoid apply some function locally
164164
// Arc::new(SimplifyExpressions::new()),

wren-core/core/src/mdl/mod.rs

+27
Original file line numberDiff line numberDiff line change
@@ -1367,6 +1367,33 @@ mod test {
13671367
Ok(())
13681368
}
13691369

1370+
#[tokio::test]
1371+
async fn test_disable_scalar_subquery() -> Result<()> {
1372+
let ctx = SessionContext::new();
1373+
let manifest = ManifestBuilder::new()
1374+
.catalog("wren")
1375+
.schema("test")
1376+
.model(
1377+
ModelBuilder::new("customer")
1378+
.table_reference("customer")
1379+
.column(ColumnBuilder::new("c_custkey", "int").build())
1380+
.column(ColumnBuilder::new("c_name", "string").build())
1381+
.build(),
1382+
)
1383+
.build();
1384+
let sql = r#"SELECT c_custkey, (SELECT c_name FROM customer WHERE c_custkey = 1) FROM customer"#;
1385+
let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?);
1386+
let result =
1387+
transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?;
1388+
assert_eq!(
1389+
result,
1390+
"SELECT customer.c_custkey, (SELECT customer.c_name FROM (SELECT customer.c_custkey, customer.c_name \
1391+
FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name FROM customer AS __source) AS customer) AS customer \
1392+
WHERE customer.c_custkey = 1) FROM (SELECT customer.c_custkey FROM (SELECT __source.c_custkey AS c_custkey FROM customer AS __source) AS customer) AS customer"
1393+
);
1394+
Ok(())
1395+
}
1396+
13701397
/// Return a RecordBatch with made up data about customer
13711398
fn customer() -> RecordBatch {
13721399
let custkey: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3]));

0 commit comments

Comments
 (0)