diff --git a/dask_sql/context.py b/dask_sql/context.py index 77ae10b6f..220790136 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -86,6 +86,7 @@ def __init__(self): RelConverter.add_plugin_class(logical.LogicalSortPlugin, replace=False) RelConverter.add_plugin_class(logical.LogicalTableScanPlugin, replace=False) RelConverter.add_plugin_class(logical.LogicalUnionPlugin, replace=False) + RelConverter.add_plugin_class(logical.LogicalMinusPlugin, replace=False) RelConverter.add_plugin_class(logical.LogicalValuesPlugin, replace=False) RelConverter.add_plugin_class(logical.SamplePlugin, replace=False) RelConverter.add_plugin_class(custom.AnalyzeTablePlugin, replace=False) @@ -515,6 +516,10 @@ def _get_ral(self, sql): nonOptimizedRelNode = generator.getRelationalAlgebra(validatedSqlNode) rel = generator.getOptimizedRelationalAlgebra(nonOptimizedRelNode) rel_string = str(generator.getRelationalAlgebraString(rel)) + logger.debug( + f"Non optimised query plan: \n " + f"{str(generator.getRelationalAlgebraString(nonOptimizedRelNode))}" + ) except (ValidationException, SqlParseException) as e: logger.debug(f"Original exception raised by Java:\n {e}") # We do not want to re-raise an exception here diff --git a/dask_sql/physical/rel/logical/__init__.py b/dask_sql/physical/rel/logical/__init__.py index 99698157c..9d429e6ac 100644 --- a/dask_sql/physical/rel/logical/__init__.py +++ b/dask_sql/physical/rel/logical/__init__.py @@ -6,6 +6,7 @@ from .sort import LogicalSortPlugin from .table_scan import LogicalTableScanPlugin from .union import LogicalUnionPlugin +from .minus import LogicalMinusPlugin from .values import LogicalValuesPlugin __all__ = [ @@ -16,6 +17,7 @@ LogicalSortPlugin, LogicalTableScanPlugin, LogicalUnionPlugin, + LogicalMinusPlugin, LogicalValuesPlugin, SamplePlugin, ] diff --git a/dask_sql/physical/rel/logical/minus.py b/dask_sql/physical/rel/logical/minus.py new file mode 100644 index 000000000..c611ce050 --- /dev/null +++ b/dask_sql/physical/rel/logical/minus.py @@ -0,0 +1,69 @@ +import dask.dataframe as dd + +from dask_sql.physical.rex import RexConverter +from dask_sql.physical.rel.base import BaseRelPlugin +from dask_sql.datacontainer import DataContainer, ColumnContainer + + +class LogicalMinusPlugin(BaseRelPlugin): + """ + LogicalUnion is used on EXCEPT clauses. + It just concatonates the two data frames. + """ + + class_name = "org.apache.calcite.rel.logical.LogicalMinus" + + def convert( + self, rel: "org.apache.calcite.rel.RelNode", context: "dask_sql.Context" + ) -> DataContainer: + first_dc, second_dc = self.assert_inputs(rel, 2, context) + + first_df = first_dc.df + first_cc = first_dc.column_container + + second_df = second_dc.df + second_cc = second_dc.column_container + + # For concatenating, they should have exactly the same fields + output_field_names = [str(x) for x in rel.getRowType().getFieldNames()] + assert len(first_cc.columns) == len(output_field_names) + first_cc = first_cc.rename( + columns={ + col: output_col + for col, output_col in zip(first_cc.columns, output_field_names) + } + ) + first_dc = DataContainer(first_df, first_cc) + + assert len(second_cc.columns) == len(output_field_names) + second_cc = second_cc.rename( + columns={ + col: output_col + for col, output_col in zip(second_cc.columns, output_field_names) + } + ) + second_dc = DataContainer(second_df, second_cc) + + # To concat the to dataframes, we need to make sure the + # columns actually have the specified names in the + # column containers + # Otherwise the concat won't work + first_df = first_dc.assign() + second_df = second_dc.assign() + + self.check_columns_from_row_type(first_df, rel.getExpectedInputRowType(0)) + self.check_columns_from_row_type(second_df, rel.getExpectedInputRowType(1)) + + df = first_df.merge( + second_df, + how="left", + indicator=True, + ) + + df = df[df.iloc[:, -1] == "left_only"].iloc[:, :-1] + + cc = ColumnContainer(df.columns) + cc = self.fix_column_to_row_type(cc, rel.getRowType()) + dc = DataContainer(df, cc) + dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) + return dc diff --git a/tests/integration/test_except.py b/tests/integration/test_except.py new file mode 100644 index 000000000..9fe98131a --- /dev/null +++ b/tests/integration/test_except.py @@ -0,0 +1,29 @@ +def test_except_empty(c, df): + result_df = c.sql( + """ + SELECT * FROM df + EXCEPT + SELECT * FROM df + """ + ) + result_df = result_df.compute() + assert len(result_df) == 0 + + +def test_except_non_empty(c, df): + result_df = c.sql( + """ + ( + SELECT 1 as "a" + UNION + SELECT 2 as "a" + UNION + SELECT 3 as "a" + ) + EXCEPT + SELECT 2 as "a" + """ + ) + result_df = result_df.compute() + assert result_df.columns == "a" + assert set(result_df["a"]) == set([1, 3])